In [1]:
import matplotlib.pyplot as plt
import matplotlib as mpl
import seaborn as sns
from aind_dynamic_foraging_basic_analysis.plot import plot_fip as pf
from aind_dynamic_foraging_basic_analysis.plot import plot_foraging_session as pb
import numpy as np
import pandas as pd
from matplotlib.gridspec import GridSpec, GridSpecFromSubplotSpec
import seaborn as sns
from scipy import stats

In [2]:
from importlib import reload


# get the data, get the analysis specifications

In [3]:

import json
from pathlib import Path
import os

DATA_PATH: Path = Path("/data")  # TODO: don't hardcode
ANALYSIS_BUCKET = os.getenv("ANALYSIS_BUCKET")


input_model_paths = tuple(DATA_PATH.glob('job_dict/*'))
print(f"Found {len(input_model_paths)} input job models to run analysis on.")
analysis_specs = None

analysis_spec_path = tuple(DATA_PATH.glob("analysis_parameters.json"))
if analysis_spec_path:
    with open(analysis_spec_path[0], "r") as f:
        analysis_specs = json.load(f)

from analysis_pipeline_utils.analysis_dispatch_model import AnalysisDispatchModel
import utils as utils
from analysis_model import (
    SummaryPlotsAnalysisSpecification, SummaryPlotsAnalysisSpecificationCLI
)
for model_path in input_model_paths:
    with open(model_path, "r") as f:
        analysis_dispatch_inputs = AnalysisDispatchModel.model_validate(json.load(f))
    
    analysis_specification = SummaryPlotsAnalysisSpecification.model_validate(analysis_specs).model_dump()

Found 1 input job models to run analysis on.


# get the data and set them up right

In [4]:
import rachel_analysis_framework_utils as r_utils
import analysis_util
from plots import summary_plots
from aind_dynamic_foraging_basic_analysis.metrics import trial_metrics



In [5]:


parameters = analysis_specification

(df_sess, df_trials, df_events, df_fip) = r_utils.get_nwb_processed(analysis_dispatch_inputs.file_location, **parameters)


Saving channels: ['R_0', 'G_0', 'G_1']
CURRENTLY RUNNING 1/2: 726649_2024-09-12
--------------------------------------------------
Timestamps are adjusted such that `_in_session` timestamps start at the first go cue
Timestamps are adjusted such that `_in_session` timestamps start at the first go cue
Timestamps are adjusted such that `_in_session` timestamps start at the first go cue
CURRENTLY RUNNING 2/2: 726649_2024-09-16
--------------------------------------------------




Timestamps are adjusted such that `_in_session` timestamps start at the first go cue
Timestamps are adjusted such that `_in_session` timestamps start at the first go cue
Timestamps are adjusted such that `_in_session` timestamps start at the first go cue
Retrieving foraging model QLearning_L1F1_CK1_softmax
Query: {'analysis_spec.analysis_name': 'MLE fitting', 'analysis_spec.analysis_ver': 'first version @ 0.10.0', 'subject_id': '726649', 'session_date': '2024-09-12', 'analysis_results.fit_settings.agent_alias': 'QLearning_L1F1_CK1_softmax'}
Found 1 MLE fitting records!
Found 1 successful MLE fitting!


Get latent variables from s3: 100%|██████████| 1/1 [00:00<00:00,  3.53it/s]


Query: {'analysis_spec.analysis_name': 'MLE fitting', 'analysis_spec.analysis_ver': 'first version @ 0.10.0', 'subject_id': '726649', 'session_date': '2024-09-16', 'analysis_results.fit_settings.agent_alias': 'QLearning_L1F1_CK1_softmax'}
Found 1 MLE fitting records!
Found 1 successful MLE fitting!


Get latent variables from s3: 100%|██████████| 1/1 [00:00<00:00,  6.84it/s]
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_ses.loc[:, "Q_chosen"] = chosen_values
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_ses.loc[:, "Q_unchosen"] = unchosen_values
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_ses.loc[:, "Q_sum"] = df_ses["Q_left"]

In [6]:
# TODO: refactor this out to r_utils. 


df_trials['reward_all'] = df_trials['earned_reward'] + df_trials['extra_reward']
# Compute num_reward_past and num_no_reward_past
df_trials['reward_shifted'] = df_trials.groupby('ses_idx')['reward_all'].shift(1)  # Shift to look at past values

df_trials['num_reward_past'] = df_trials.groupby(
                        (df_trials['reward_shifted'] != df_trials['reward_all']).cumsum()).cumcount() + 1

# Set 'NA' for mismatched reward types
df_trials.loc[df_trials['reward_all'] == 0, 'num_reward_past'] = df_trials.loc[df_trials['reward_all'] == 0, 'num_reward_past']* -1 

# Drop the temporary column
df_trials.drop(columns=['reward_shifted'], inplace=True)


RPE_binned3_label_names = [str(np.round(i,2)) for i in np.arange(-1,0.99,1/3)]

bins = np.arange(-1,1.01,1/3)
bins[-1] = 1.001

df_trials['RPE-binned3'] = pd.cut(df_trials['RPE_earned'],# all versus earned not a huge difference
                    bins = bins, right = True, labels=RPE_binned3_label_names)

(df_sess, nwbs_by_week) = analysis_util.get_dummy_nwbs_by_week(df_sess, df_trials, df_events, df_fip) 


# TODO: will need to refactor code so there's flexibility on the plots that come out
#       consult alex? or figure it out on my own. 
# get average activity 
data_column = 'data_z_norm'
alignment_event='choice_time_in_session'
rpe_slope_dict = {}
for channel in list(analysis_specification["channels"].keys()):
    avg_signal_col = summary_plots.output_col_name(channel, data_column, alignment_event)
    for nwb_week in nwbs_by_week:
    
        nwb_week = trial_metrics.get_average_signal_window_multi(
                        nwb_week,
                        alignment_event='choice_time_in_session',
                        offsets=[0.33, 1],
                        channel=channel,
                        data_column=data_column,
                        output_col = avg_signal_col
                    )

In [7]:

pb.plot_foraging_session_nwb(nwbs_by_week[0][0])

KeyError: 'animal_response'

In [None]:
nwbs_by_week[0][0].df_trials

Unnamed: 0,trial,choice,rewarded_historyL,rewarded_historyR,side_bias,side_bias_confidence_interval,bait_left,bait_right,base_reward_probability_sum,reward_probabilityL,...,data_z_G_1_baseline,data_R_0_baseline,data_z_R_0_baseline,reward_all,num_reward_past,RPE-binned3,week_interval,avg_data_z_norm_R_0_choice_time,avg_data_z_norm_G_0_choice_time,avg_data_z_norm_G_1_choice_time
0,0,0.0,True,False,,"[nan, nan]",True,True,0.8,0.7,...,-0.253120,2485.121382,0.714040,True,1,0.67,1,0.078969,1.050893,1.122472
1,1,0.0,True,False,,"[nan, nan]",True,True,0.8,0.7,...,-0.834826,2428.392225,0.596711,True,2,0.33,1,1.150323,3.940241,4.558299
2,2,0.0,True,False,,"[nan, nan]",True,True,0.8,0.7,...,1.712887,3389.024622,2.583516,True,3,-0.0,1,0.350892,1.052745,1.570548
3,3,0.0,True,False,,"[nan, nan]",True,True,0.8,0.7,...,1.205047,3328.141685,2.457596,True,4,-0.0,1,0.756732,2.356876,1.727903
4,4,0.0,True,False,,"[nan, nan]",True,True,0.8,0.7,...,0.635668,3169.831533,2.130175,True,5,-0.0,1,0.570183,2.047199,2.226870
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
192,276,1.0,False,True,,"[nan, nan]",False,True,0.8,0.1,...,-1.325052,1585.814255,-1.145931,True,1,0.67,1,0.014622,0.155704,0.609415
193,285,1.0,False,True,,"[nan, nan]",False,True,0.8,0.1,...,-1.360709,1543.471274,-1.233506,True,1,0.33,1,0.053791,0.063956,0.284147
194,287,1.0,False,False,,"[nan, nan]",True,False,0.8,0.7,...,-1.879338,1636.039741,-1.042053,False,-2,-1.0,1,-0.077024,-0.618143,0.101645
195,296,1.0,False,False,,"[nan, nan]",True,False,0.8,0.7,...,-1.634625,1652.575378,-1.007854,False,-11,-0.67,1,-0.056925,-0.727662,0.025982
