In [30]:
import os
import json

import session_processing_helper as helper
import utils_c4 as utils

import pandas as pd
import statistics

In [75]:
cohort = 'cohort_4_v2'
to_analyze = 'full_clean'
data_dir = '/Users/rebekahzhang/Documents/shuler_lab/behavior_data'
data_folder = os.path.join(data_dir, cohort, to_analyze)
print(data_folder)

/Users/rebekahzhang/Documents/shuler_lab/behavior_data/cohort_4_v2/full_clean


In [76]:
mouse_dict = {'s': ['RZ026', 'RZ027', 'RZ030', 'RZ031'],
                  'l': ['RZ028', 'RZ029', 'RZ032', 'RZ033']}
mouse_list = [mouse for group in list(mouse_dict.values()) for mouse in group]

# Generate all session logs

generate session log using meta data from each session and add columns of basic info to each session

In [77]:
def generate_sessions_all(data_folder):
    sessions_all = pd.DataFrame()
    for root, _, files in os.walk(data_folder):
        for file in files:
            if file.startswith("meta_") and file.endswith(".json"):
                path = os.path.join(root, file)
                with open(path) as f:
                    session_data = json.load(f)
                session_meta = pd.DataFrame([session_data])  # Wrap in a list to create a single-row DataFrame
                sessions_all = pd.concat([sessions_all, session_meta], ignore_index=True)
    
    if not sessions_all.empty:
        sessions_all['date'] = pd.to_datetime(sessions_all['date'])
        sessions_all['dir'] = sessions_all['date'].dt.strftime('%Y-%m-%d') + '_' + sessions_all['time'] + '_' + sessions_all['mouse']
        sessions_all['exp'] = sessions_all['exp'].replace({'exp1_short': 's', 'exp1_long': 'l'})
    
    return sessions_all

In [92]:
sessions_all = generate_sessions_all(data_folder)

In [84]:
sessions_all.head()

Unnamed: 0,mouse,date,time,exp,training,rig,pump_ul_per_turn,total_trial,total_reward,avg_tw,dir
0,RZ033,2023-12-18,13-54-44,l,regular,rig3,0.059,171,369.5,2.44,2023-12-18_13-54-44_RZ033
1,RZ027,2023-12-15,11-27-18,s,regular,rig3,0.059,266,868.3,4.03,2023-12-15_11-27-18_RZ027
2,RZ030,2023-12-17,17-09-56,s,regular,rig2,0.066,23,37.8,4.0,2023-12-17_17-09-56_RZ030
3,RZ033,2023-12-19,12-53-14,l,regular,rig3,0.059,220,315.0,1.79,2023-12-19_12-53-14_RZ033
4,RZ031,2023-12-19,12-10-22,s,regular,rig3,0.059,451,272.1,1.31,2023-12-19_12-10-22_RZ031


In [93]:
session_basics_list = []
for _, session_info in sessions_all.iterrows():
    try:
        session = pd.read_csv(utils.generate_events_path(data_folder, session_info))
        session_basics = {'dir': session_info.dir} | helper.get_session_basics(session)
        session_basics_list.append(session_basics)
    except:
        print(session_info.dir)
session_basics_df = pd.DataFrame(session_basics_list)
sessions_all = pd.merge(sessions_all, session_basics_df, on='dir')

focus only on regular training sessions

In [94]:
sessions_training = sessions_all.loc[sessions_all.training == 'regular'].sort_values('dir').reset_index()

## Examine quality of sessions
doesn't need to run when data folder is cleaned

### Check for short sessions

In [95]:
short_session = sessions_training.loc[sessions_training['num_trials'] < 50] 
if len(short_session)>0:
    display(short_session)
else: 
    print('no short sessions!')

Unnamed: 0,index,mouse,date,time,exp,training,rig,pump_ul_per_turn,total_trial,total_reward,avg_tw,dir,num_blocks,num_trials,rewards,session_time,proper_end
28,2,RZ030,2023-12-17,17-09-56,s,regular,rig2,0.066,23,37.8,4.0,2023-12-17_17-09-56_RZ030,1.0,24.0,37.8,239.64,True
30,41,RZ032,2023-12-17,17-23-19,l,regular,rig3,0.059,35,68.7,3.88,2023-12-17_17-23-19_RZ032,1.0,36.0,68.7,535.16,True
31,16,RZ033,2023-12-17,17-37-24,l,regular,rig3,0.059,28,39.7,3.04,2023-12-17_17-37-24_RZ033,1.0,29.0,39.7,421.21,True
36,24,RZ030,2023-12-18,13-21-02,s,regular,rig2,0.066,28,79.1,4.11,2023-12-18_13-21-02_RZ030,1.0,29.0,79.1,269.42,True


### Check if mice ran correct experiments for each session

In [96]:
sessions_training['correct_exp'] = sessions_training.apply(lambda row: row['mouse'] in mouse_dict.get(row['exp'], []), axis=1)
incorrect_exp_df = sessions_training[~sessions_training['correct_exp']]
if len(incorrect_exp_df) > 0:
    display(incorrect_exp_df)
else:
    print('all correct!')

all correct!


### Check for session number mismatch for each day

In [97]:
log_by_date = sessions_training.groupby('date')

In [98]:
# generate a dict of number or rounds for each date
rounds_dict = {}
for date, data in log_by_date:
    num_session_dict = {}
    for mouse in mouse_list:
        mouse_by_date = data.loc[data['mouse'] == mouse]
        num_session_dict[mouse] = len(mouse_by_date)
    num_session_list = list(num_session_dict.values())
    mode_session = statistics.mode(num_session_list)
    rounds_dict[date]=mode_session

check for missing sessions, copy and paste from previous day to pad

In [91]:
for date, data in log_by_date:
    rounds = rounds_dict[date]
    for mouse in mouse_list:
        mouse_by_date = data.loc[data['mouse'] == mouse]
        if len(mouse_by_date) < rounds:
            print(f"on {date}, {mouse} has missing sessions")

on 2023-12-14 00:00:00, RZ029 has missing sessions


check for sessions that needs patching, combine two sessions

In [100]:
for date, data in log_by_date:
    rounds = rounds_dict[date]
    for mouse in mouse_list:
        mouse_by_date = data.loc[data['mouse'] == mouse]
        if len(mouse_by_date) > rounds:
            print(f"on {date}, {mouse} has too many sessions")

In [102]:
# day = log_by_date.get_group('2023-11-15')
# display(day)

In [104]:
# sessions_to_stitch = day[day['mouse'] == 'RZ033']

In [105]:
# sessions_to_stitch

In [None]:
# session_1 = utils.load_data(utils.generate_events_path(data_folder, sessions_to_stitch.iloc[0]))
# session_2 = utils.load_data(utils.utils.generate_events_path(data_folder, sessions_to_stitch.iloc[1]))
# stitched_session = helper.stitch_sessions(session_1, session_2)

In [None]:
# session_2

In [None]:
# utils.stitched_session.to_csv(utils.generate_events_path(data_folder, sessions_to_stitch.iloc[0]))

make a copy of cleaned data before preceeding!

## Add training session number to training log

In [106]:
sessions_training = sessions_training.groupby('mouse', group_keys=False).apply(helper.assign_session_numbers)

## Saves all sessions log and training session log

In [107]:
utils.save_as_csv(df=sessions_all, folder=data_folder, filename='sessions_all.csv')
utils.save_as_csv(df=sessions_training, folder=data_folder, filename='sessions_training.csv')

# Process raw session and generate all_trials

load session log

In [108]:
sessions_training = utils.load_data(os.path.join(data_folder, 'sessions_training.csv'))

### Generate all trials, align trial number, trial state, and trial time for raw session

In [118]:
for _, session_info in sessions_training.iterrows():
    try: 
        events_processed_path = utils.generate_events_processed_path(data_folder, session_info)
        trials_path = utils.generate_trials_path(data_folder, session_info)
        if os.path.isfile(events_processed_path) and os.path.isfile(trials_path):
            continue
        
        events = pd.read_csv(utils.generate_events_path(data_folder, session_info))
        trials = helper.generate_trials(session_info, events)

        # align trial number
        events = helper.align_trial_number(events, trials)
        events = utils.trim_session(session_info, events)

        # align trial state
        events = events.groupby('session_trial_num', group_keys=False).apply(helper.align_trial_states)
        # add trial_time
        events = events.groupby('session_trial_num', group_keys=False).apply(helper.add_trial_time)

        events.to_csv(events_processed_path)
        trials.to_csv(trials_path)
    except:
        display(session_info)

### Adding analyzed trial data to all trials df

In [119]:
for _, session_info in sessions_training.iterrows():
    trials_analyzed_path = utils.generate_trials_analyzed_path(data_folder, session_info)
    if os.path.isfile(trials_analyzed_path):
        continue
    
    session_by_trial = utils.load_data(utils.generate_events_processed_path(data_folder, session_info)).groupby('session_trial_num')
    trials = utils.load_data(utils.generate_trials_path(data_folder, session_info))
    trials_data = helper.get_trial_data_df(session_by_trial)
    trials_analyzed = pd.merge(trials, trials_data, on='session_trial_num')
    trials_analyzed.to_csv(trials_analyzed_path)

# Stitching of padded sessions

In [14]:
stitched_folder = os.path.join(data_dir, cohort, 'padded_stitched')

load session log

In [15]:
sessions_training = utils.load_data(os.path.join(data_folder, 'sessions_training.csv'))

generage stitched all mice log with columns of date, mouse, dir, filename

In [17]:
sessions_training

Unnamed: 0,index,mouse,date,time,exp,training,rig,dir,num_blocks,num_trials,rewards,session_time,proper_end,correct_exp,session
0,74,RZ026,2023-10-31,11-39-46,s,regular,rig2,2023-10-31_11-39-46_RZ026,1.0,487.0,187.7,2403.29,True,True,0
1,145,RZ028,2023-10-31,12-22-14,l,regular,rig2,2023-10-31_12-22-14_RZ028,1.0,353.0,89.9,2405.82,True,True,0
2,147,RZ029,2023-10-31,12-25-13,l,regular,rig3,2023-10-31_12-25-13_RZ029,1.0,354.0,73.7,2402.37,True,True,0
3,247,RZ030,2023-10-31,13-12-44,s,regular,rig2,2023-10-31_13-12-44_RZ030,1.0,438.0,533.1,2405.46,True,True,0
4,198,RZ031,2023-10-31,13-14-02,s,regular,rig3,2023-10-31_13-14-02_RZ031,1.0,468.0,259.5,2405.59,True,True,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
227,81,RZ029,2023-12-07,11-15-38,l,regular,rig2,2023-12-07_11-15-38_RZ029,1.0,298.0,571.7,2404.07,True,True,28
228,200,RZ032,2023-12-07,14-23-24,l,regular,rig3,2023-12-07_14-23-24_RZ032,1.0,306.0,432.2,2401.58,True,True,28
229,211,RZ033,2023-12-07,14-24-21,l,regular,rig2,2023-12-07_14-24-21_RZ033,1.0,304.0,493.5,2404.83,True,True,28
230,189,RZ030,2023-12-07,15-12-34,s,regular,rig3,2023-12-07_15-12-34_RZ030,1.0,367.0,861.4,2404.54,True,True,28


In [22]:
def generate_sessions_training_stitched(sessions_training):
    sessions_training_stitched = sessions_training[['date', 'training', 'session']].copy()
    sessions_training_stitched = sessions_training_stitched.drop_duplicates(subset=['session'], keep='first')
    return sessions_training_stitched

In [23]:
sessions_training_stitched = generate_sessions_training_stitched(sessions_training)

creates empty directories in the stitched folder

In [25]:
sessions_training_stitched

Unnamed: 0,date,training,session
0,2023-10-31,regular,0
8,2023-11-01,regular,1
16,2023-11-02,regular,2
24,2023-11-03,regular,3
32,2023-11-03,regular,4
40,2023-11-04,regular,5
48,2023-11-05,regular,6
56,2023-11-06,regular,7
64,2023-11-07,regular,8
72,2023-11-08,regular,9


In [28]:
for _, session_info in sessions_training_stitched.iterrows():
    new_dir = os.path.join(stitched_folder, f"{session_info.date}_{session_info.session}")
    if not os.path.exists(new_dir):
        os.makedirs(new_dir)

In [41]:
def generate_events_processed_stitched_path(data_folder, session_info):
    filename = f'events_processed_stitched_{session_info.date}_{session_info.session}.csv'
    return os.path.join(data_folder, f"{session_info.date}_{session_info.session}", filename)

In [42]:
def generate_trials_analyzed_stitched_path(data_folder, session_info):
    filename = f'trials_analyzed_stitched_{session_info.date}_{session_info.session}.csv'
    return os.path.join(data_folder, f"{session_info.date}_{session_info.session}", filename)

In [34]:
log_by_session = sessions_training.groupby('session')

In [51]:
for s, session_num in log_by_session:
    num_mice = len(session_num)
    session_info = session_num.iloc[0]
    events_stitched = utils.load_data(utils.generate_events_processed_path(data_folder, session_info))
    trials_stitched = utils.load_data(utils.generate_trials_analyzed_path(data_folder, session_info))

    for i in range(1, num_mice):
        session_2 = utils.load_data(utils.generate_events_processed_path(data_folder, session_num.iloc[i]))
        all_trials_2 = utils.load_data(utils.generate_trials_analyzed_path(data_folder, session_num.iloc[i]))
        events_stitched = helper.stitch_sessions(events_stitched, session_2)
        trials_stitched = helper.stitch_all_trials(trials_stitched, all_trials_2)
    
    events_stitched.to_csv(generate_events_processed_stitched_path(stitched_folder, session_info))
    trials_stitched.to_csv(generate_trials_analyzed_stitched_path(stitched_folder, session_info))

In [52]:
stitched_session_basics_list = []
for _, session_info in sessions_training_stitched.iterrows():
    events = utils.load_data(generate_events_processed_stitched_path(stitched_folder, session_info))
    session_basics = {'session': session_info['session']} | helper.get_session_basics(events)
    stitched_session_basics_list.append(session_basics)
stitched_session_basics = pd.DataFrame(stitched_session_basics_list)
sessions_training_stitched = pd.merge(sessions_training_stitched, stitched_session_basics, on='session')

  df = pd.read_csv(path, index_col=0)
  df = pd.read_csv(path, index_col=0)
  df = pd.read_csv(path, index_col=0)
  df = pd.read_csv(path, index_col=0)


In [53]:
utils.save_as_csv(sessions_training_stitched, stitched_folder, 'sessions_training_stitched.csv')

In [54]:
sessions_training_stitched

Unnamed: 0,date,training,session,num_blocks,num_trials,rewards,session_time,proper_end
0,2023-10-31,regular,0,8.0,3225.0,1911.6,62958.68,True
1,2023-11-01,regular,1,8.0,2942.0,4150.2,92528.09,True
2,2023-11-02,regular,2,9.0,3033.0,4887.6,93254.17,True
3,2023-11-03,regular,3,8.0,2619.0,4202.3,48806.09,True
4,2023-11-03,regular,4,8.0,1813.0,4676.3,44786.86,True
5,2023-11-04,regular,5,8.0,2508.0,4368.8,426901.11,True
6,2023-11-05,regular,6,8.0,2772.0,4073.6,54142.84,True
7,2023-11-06,regular,7,8.0,2868.0,3891.9,51936.3,True
8,2023-11-07,regular,8,8.0,2701.0,3843.5,67753.79,True
9,2023-11-08,regular,9,8.0,2677.0,3392.3,49173.7,True


# Combine sessions

In [None]:
stitched_folder = os.path.join(data_dir, cohort, 'full_clean_stitched')
stitched_all_mice_folder = os.path.join(data_dir, cohort, 'full_clean_stitched_all_mice')

## Stitch sessions from the same mouse on the same day

load session log and generate lists for looping

In [None]:
training_session_log = utils.load_session_log(data_folder, 'training_sessions.csv')

makes the stitched session log with columns of date, mouse, dir, filename

In [None]:
stitched_session_log = helper.generate_stitched_session_log(training_session_log)

### Stitch!

creates empty directories in the stitched folder

In [None]:
for d in stitched_session_log.dir:
    new_dir = os.path.join(stitched_folder, d)
    if not os.path.exists(new_dir):
        os.makedirs(new_dir)

loop through each day and each mouse to stitch processed_session and all_trials_analyzed

In [None]:
training_session_log = training_session_log.sort_values(by = ['dir'])
log_by_date = training_session_log.groupby('date')

In [None]:
for d, date in log_by_date:
    date_log_by_mouse = date.groupby('mouse')
    for m, mouse in date_log_by_mouse:
        num_sessions = len(mouse)
        session_info = mouse.iloc[0]
        stitched_session = utils.load_processed_session(data_folder, session_info)
        stitched_all_trials = utils.load_all_trials_analyzed(data_folder, session_info)

        for i in range(1, num_sessions):
            session_2 = utils.load_processed_session(data_folder, mouse.iloc[i])
            all_trials_2 = utils.load_all_trials_analyzed(data_folder, mouse.iloc[i])

            stitched_session = helper.stitch_sessions(stitched_session, session_2)
            stitched_all_trials = helper.stitch_all_trials(stitched_all_trials, all_trials_2)
        
        stitched_session.to_csv(utils.generate_stitched_processed_session_path(stitched_folder, session_info))
        stitched_all_trials.to_csv(utils.generate_stitched_all_trials_path(stitched_folder, session_info))

### Add info to stitched sesscion log and save it

In [None]:
stitched_session_basics_list = []
for _, session_info in stitched_session_log.iterrows():
    session = utils.load_stitched_processed_session(stitched_folder, session_info)
    session_basics = {'dir': session_info.dir} | helper.get_session_basics(session)
    stitched_session_basics_list.append(session_basics)
stitched_session_basics = pd.DataFrame(stitched_session_basics_list)
stitched_session_log = pd.merge(stitched_session_log, stitched_session_basics, on='dir')

In [None]:
def assign_session_numbers(group):
    group.sort_values(by=['mouse', 'date'], inplace=True)
    group['session'] = list(range(len(group)))
    return group

In [None]:
stitched_session_log = stitched_session_log.groupby('mouse', group_keys=False).apply(assign_session_numbers)
utils.save_as_csv(stitched_session_log, stitched_folder, 'stitched_training_session_log.csv')

## Stitch all Sessions from the same day

load stitched session log

In [None]:
stitched_session_log = utils.load_data(os.path.join(data_folder, 'sessions_training.csv'))

generage stitched all mice log with columns of date, mouse, dir, filename

In [None]:
stitched_all_mice_session_log = helper.generate_stitched_all_mice_session_log(stitched_session_log)

### Stitch!

creates empty directories in the stitched folder

In [None]:
for date in stitched_all_mice_session_log.date:
    new_dir = os.path.join(stitched_all_mice_folder, date)
    if not os.path.exists(new_dir):
        os.makedirs(new_dir)

loop through each day to stitch processed_session and all_trials_analyzed from all mice

In [None]:
def generate_trials_analyzed_stitched_path(data_folder, session_info):
    filename = f'trials_analyzed_stitched_{session_info.date}.csv'
    return os.path.join(data_folder, session_info.date, filename)

In [None]:
log_by_date = stitched_session_log.groupby('date')

In [None]:
for d, date in log_by_date:
    num_mice = len(date)
    session_info = date.iloc[0]
    stitched_session = utils.load_data(utils.generate_events_processed_path(data_folder, session_info))
    stitched_all_trials = utils.load_data(utils.generate_trials_analyzed_path(data_folder, session_info))

    for i in range(1, num_mice):
        session_2 = utils.load_data(utils.generate_events_processed_path(data_folder, date.iloc[i]))
        all_trials_2 = utils.load_data(utils.generate_trials_analyzed_path(data_folder, date.iloc[i]))
        stitched_session = helper.stitch_sessions(stitched_session, session_2)
        stitched_all_trials = helper.stitch_all_trials(stitched_all_trials, all_trials_2)
    
    stitched_session.to_csv(generate_events_processed_stitched_path(stitched_all_mice_folder, session_info))
    stitched_all_trials.to_csv(generate_trials_analyzed_stitched_path(stitched_all_mice_folder, session_info))

In [None]:
stitched_all_trials

### Add info to stitched sesscion log and save it

In [None]:
stitched_session_basics_list = []
for _, session_info in stitched_all_mice_session_log.iterrows():
    session = utils.load_data(generate_events_processed_stitched_path(stitched_all_mice_folder, session_info))
    session_basics = {'date': session_info.date} | helper.get_session_basics(session)
    stitched_session_basics_list.append(session_basics)
stitched_session_basics = pd.DataFrame(stitched_session_basics_list)
stitched_session_log = pd.merge(stitched_all_mice_session_log, stitched_session_basics, on='date')

In [None]:
total_days = len(stitched_all_mice_session_log)
stitched_all_mice_session_log['days'] = list(range(total_days))

In [None]:
utils.save_as_csv(stitched_session_log, stitched_all_mice_folder, 'sessions_training_stitched.csv')