In [1]:
import os

import helpers_for_processing as helper
import utils_new as utils

import pandas as pd

In [2]:
cohort = 'cohort_3_v23'
to_analyze = 'full_clean'
data_dir = '/Users/rebekahzhang/Documents/shuler_lab/behavior_data'
data_folder = os.path.join(data_dir, cohort, to_analyze)
print(data_folder)

/Users/rebekahzhang/Documents/shuler_lab/behavior_data/cohort_3_v23/full_clean


# Generate all session logs

generate session log using meta data from each session

In [None]:
session_log = helper.generate_all_session_log(data_folder)

all columns of basic info for each session

In [None]:
session_basics_list = []
for _, session_info in session_log.iterrows():
    try:
        session = utils.load_session(data_folder, session_info)
        session_basics = {'dir': session_info.dir} | helper.get_session_basics(session)
        session_basics_list.append(session_basics)
    except:
        print(session_info.dir)
session_basics_df = pd.DataFrame(session_basics_list)
session_log = pd.merge(session_log, session_basics_df, on='dir')

focus only on regular training sessions

In [None]:
training_session_log = session_log.loc[session_log.training == 'regular'].sort_values('dir').reset_index()

### Examine quality of sessions
doesn't need to run when data folder is cleaned

check for short sessions

In [None]:
short_session = training_session_log.loc[(training_session_log['training'] == 'regular') & 
                                         (training_session_log['num_trials'] < 50)] 
display(short_session)

check for missing sessions by the number of sessions in each training day

In [None]:
num_mice = 8
date_list = training_session_log.date.unique().tolist()
for date in date_list:
    data = training_session_log.loc[training_session_log['date'] == date]
    if len(data) < num_mice:
        print(date)

check if mice ran correct experiments for each session

In [None]:
mouse_dict_pre = {'s': ['RZ018', 'RZ019', 'RZ022', 'RZ023'],
                  'l': ['RZ020', 'RZ021', 'RZ024', 'RZ025']}
mouse_dict_post = {'l': ['RZ018', 'RZ019', 'RZ022', 'RZ023'],
                  's': ['RZ020', 'RZ021', 'RZ024', 'RZ025']}

In [None]:
pre_switch_log = training_session_log[training_session_log.date < '2023-10-05'].copy()
post_switch_log = training_session_log[training_session_log.date >= '2023-10-05'].copy()

In [None]:
pre_switch_log['correct_exp'] = pre_switch_log.apply(lambda row: row['mouse'] in mouse_dict_pre.get(row['exp'], []), axis=1)
incorrect_exp_df = pre_switch_log[~pre_switch_log['correct_exp']]
if len(incorrect_exp_df) > 0:
    display(incorrect_exp_df)
else:
    print('all correct!')

In [None]:
post_switch_log['correct_exp'] = post_switch_log.apply(lambda row: row['mouse'] in mouse_dict_post.get(row['exp'], []), axis=1)
incorrect_exp_df = post_switch_log[~post_switch_log['correct_exp']]
if len(incorrect_exp_df) > 0:
    display(incorrect_exp_df)
else:
    print('all correct!')

make a copy of cleaned data before preceeding!

### Add training session number to training log

In [None]:
training_session_log = training_session_log.groupby('mouse', group_keys=False).apply(helper.assign_session_numbers)

### Saves all sessions log and training session log

In [None]:
utils.save_as_csv(df=session_log, folder=data_folder, filename='all_sessions.csv')
utils.save_as_csv(df=training_session_log, folder=data_folder, filename='training_sessions.csv')

# Process raw session and generate all_trials

load session log

In [None]:
training_session_log = utils.load_session_log(data_folder, 'training_sessions.csv')

### Generate all trials, align trial number, trial state, and trial time for raw session

In [None]:
for _, session_info in training_session_log.iterrows():
    try: 
        processed_session_path = utils.generate_processed_session_path(data_folder, session_info)
        all_trials_path = utils.generate_all_trials_path(data_folder, session_info)
        if os.path.isfile(all_trials_path) and os.path.isfile(processed_session_path):
            continue
        
        session = utils.load_session(data_folder, session_info)
        # make all_trials
        all_trials = helper.generate_all_trials(session_info, session)

        # align trial number
        session = helper.align_trial_number(session, all_trials)
        session = utils.trim_session(session_info, session)

        # align trial state
        session = session.groupby('session_trial_num', group_keys=False).apply(helper.align_trial_states)
        # add trial_time
        session = session.groupby('session_trial_num', group_keys=False).apply(helper.add_trial_time)

        session.to_csv(processed_session_path)
        all_trials.to_csv(all_trials_path)
    except:
        display(session_info)

### Adding analyzed trial data to all trials df

In [None]:
for _, session_info in training_session_log.iterrows():
    all_trials_analyzed_path = utils.generate_all_trials_analyzed_path(data_folder, session_info)
    if os.path.isfile(all_trials_path) and os.path.isfile(processed_session_path):
        continue
    
    session_by_trial = utils.load_processed_session(data_folder, session_info).groupby('session_trial_num')
    all_trials = utils.load_all_trials(data_folder, session_info)
    all_trials_data = helper.get_trial_data_df(session_by_trial)
    all_trials_analyzed = pd.merge(all_trials, all_trials_data, on='session_trial_num')
    all_trials_analyzed.to_csv(all_trials_analyzed_path)

# Combine sessions

In [3]:
stitched_folder = os.path.join(data_dir, cohort, 'full_clean_stitched')
stitched_all_mice_folder = os.path.join(data_dir, cohort, 'full_clean_stitched_all_mice')

## Stitch sessions from the same mouse on the same day

load session log and generate lists for looping

In [46]:
training_session_log = utils.load_session_log(data_folder, 'training_sessions.csv')

makes the stitched session log with columns of date, mouse, dir, filename

In [58]:
stitched_session_log = helper.generate_stitched_session_log(training_session_log)

### Stitch!

creates empty directories in the stitched folder

In [48]:
for d in stitched_session_log.dir:
    new_dir = os.path.join(stitched_folder, d)
    if not os.path.exists(new_dir):
        os.makedirs(new_dir)

loop through each day and each mouse to stitch processed_session and all_trials_analyzed

In [49]:
training_session_log = training_session_log.sort_values(by = ['dir'])
log_by_date = training_session_log.groupby('date')

In [50]:
for d, date in log_by_date:
    date_log_by_mouse = date.groupby('mouse')
    for m, mouse in date_log_by_mouse:
        num_sessions = len(mouse)
        session_info = mouse.iloc[0]
        stitched_session = utils.load_processed_session(data_folder, session_info)
        stitched_all_trials = utils.load_all_trials_analyzed(data_folder, session_info)

        for i in range(1, num_sessions):
            session_2 = utils.load_processed_session(data_folder, mouse.iloc[i])
            all_trials_2 = utils.load_all_trials_analyzed(data_folder, mouse.iloc[i])

            stitched_session = helper.stitch_sessions(stitched_session, session_2)
            stitched_all_trials = helper.stitch_all_trials(stitched_all_trials, all_trials_2)
        
        stitched_session.to_csv(utils.generate_stitched_processed_session_path(stitched_folder, session_info))
        stitched_all_trials.to_csv(utils.generate_stitched_all_trials_path(stitched_folder, session_info))

### Add info to stitched sesscion log and save it

In [59]:
stitched_session_basics_list = []
for _, session_info in stitched_session_log.iterrows():
    session = utils.load_stitched_processed_session(stitched_folder, session_info)
    session_basics = {'dir': session_info.dir} | helper.get_session_basics(session)
    stitched_session_basics_list.append(session_basics)
stitched_session_basics = pd.DataFrame(stitched_session_basics_list)
stitched_session_log = pd.merge(stitched_session_log, stitched_session_basics, on='dir')

In [61]:
def assign_session_numbers(group):
    group.sort_values(by=['mouse', 'date'], inplace=True)
    group['session'] = list(range(len(group)))
    return group

In [64]:
stitched_session_log = stitched_session_log.groupby('mouse', group_keys=False).apply(assign_session_numbers)
utils.save_as_csv(stitched_session_log, stitched_folder, 'stitched_training_session_log.csv')

## Stitch all Sessions from the same day

load stitched session log

In [11]:
stitched_session_log = utils.load_session_log(stitched_folder, 'stitched_training_session_log.csv')

generage stitched all mice log with columns of date, mouse, dir, filename

In [20]:
stitched_all_mice_session_log = helper.generate_stitched_all_mice_session_log(stitched_session_log)

### Stitch!

creates empty directories in the stitched folder

In [23]:
for date in stitched_all_mice_session_log.date:
    new_dir = os.path.join(stitched_all_mice_folder, date)
    if not os.path.exists(new_dir):
        os.makedirs(new_dir)

loop through each day to stitch processed_session and all_trials_analyzed from all mice

In [25]:
log_by_date = stitched_session_log.groupby('date')

In [38]:
for d, date in log_by_date:
    num_mice = len(date)
    session_info = date.iloc[0]
    stitched_session = utils.load_stitched_processed_session(stitched_folder, session_info)
    stitched_all_trials = utils.load_stitched_all_trials_analyzed(stitched_folder, session_info)

    for i in range(1, num_mice):
        session_2 = utils.load_stitched_processed_session(stitched_folder, date.iloc[i])
        all_trials_2 = utils.load_stitched_all_trials_analyzed(stitched_folder, date.iloc[i])
        stitched_session = helper.stitch_sessions(stitched_session, session_2)
        stitched_all_trials = helper.stitch_all_trials(stitched_all_trials, all_trials_2)
    
    stitched_session.to_csv(utils.generate_stitched_all_mice_processed_session_path(stitched_all_mice_folder, session_info))
    stitched_all_trials.to_csv(utils.generate_stitched_all_mice_all_trials_analyzed_path(stitched_all_mice_folder, session_info))

### Add info to stitched sesscion log and save it

In [41]:
stitched_session_basics_list = []
for _, session_info in stitched_all_mice_session_log.iterrows():
    session = utils.load_stitched_all_mice_processed_session(stitched_all_mice_folder, session_info)
    session_basics = {'date': session_info.date} | helper.get_session_basics(session)
    stitched_session_basics_list.append(session_basics)
stitched_session_basics = pd.DataFrame(stitched_session_basics_list)
stitched_session_log = pd.merge(stitched_all_mice_session_log, stitched_session_basics, on='date')

In [43]:
total_days = len(stitched_all_mice_session_log)
stitched_all_mice_session_log['days'] = list(range(total_days))

In [45]:
utils.save_as_csv(stitched_all_mice_session_log, stitched_all_mice_folder, 'stitched_all_mice_training_session_log.csv')