In [27]:
import os
import json
import shutil

import session_processing_helper_c5 as helper
import utils_c5 as utils

import pandas as pd
import statistics

enter cohort name and folder name

In [28]:
cohort = 'cohort_5'
to_analyze = 'full_clean'

In [29]:
data_dir = '/Users/rebekahzhang/data/behavior_data'
data_folder = os.path.join(data_dir, cohort, to_analyze)
print(data_folder)

/Users/rebekahzhang/data/behavior_data/cohort_5/full_clean


# The code cannot handel stitching if the first session didn't end properly


ideally, session processing can just happen as part of the pi code, so process raw sessions to align trial num and trial state doesnt need to happen. pi code should also add in exit code. so when max missed trial happens, then those trials can be cut out. and if improper end happens, those trials can also be cut out.
<br>
the ideal work flow is to have the following:
<br>
 - write file folder, and initite path generator class, and load all helper functions
 - examine the quality of sessions: delete test and short sessions
 - trim beginning and end of events
 - pad or stitch if there is a mis match of session numbers
 - generate trials based on event info
 - analyze trials

# Generate sessions log

generate session log using meta data from each session and add columns of basic info to each session

In [30]:
def generate_sessions_all(data_folder):
    sessions_all = pd.DataFrame()
    for root, _, files in os.walk(data_folder):
        for file in files:
            if file.startswith("meta_") and file.endswith(".json"):
                path = os.path.join(root, file)
                with open(path) as f:
                    session_data = json.load(f)
                session_meta = pd.DataFrame([session_data])  # Wrap in a list to create a single-row DataFrame
                sessions_all = pd.concat([sessions_all, session_meta], ignore_index=True)
                
    if not sessions_all.empty:
        sessions_all['dir'] = sessions_all['date']+ '_' + sessions_all['time'] + '_' + sessions_all['mouse']
        sessions_all[['exp', 'group']] = sessions_all['exp'].str.extract(r'exp(\d)_(short|long)')
        sessions_all['group'] = sessions_all['group'].map({'short': 's', 'long': 'l'})
    sessions_all=sessions_all.sort_values('dir')
    return sessions_all

In [31]:
sessions_all = generate_sessions_all(data_folder)

## Examine session quality
doesn't need to run when data folder is cleaned 
<br>
sessions_all needs to be regenerated after every cleaning step

### Remove test sessions

In [32]:
def remove_sessions(sessions_to_remove, data_folder):
    for _, session_info in sessions_to_remove.iterrows():
        shutil.rmtree(os.path.join(data_folder, session_info.dir))

In [33]:
sessions_test = sessions_all.loc[sessions_all.mouse=='test']
if len(sessions_test) > 0:
    remove_sessions(sessions_test, data_folder)
else:
    print("no test sessions")

no test sessions


### Check for short sessions

In [34]:
sessions_short = sessions_all.loc[sessions_all['total_trial'] < 50] 
if len(sessions_short)>0:
    display(sessions_short)
else: 
    print('no short sessions!')

no short sessions!


remove short sessions

In [35]:
remove_sessions(sessions_short, data_folder)

### Regenerate sessions_all

In [36]:
sessions_all = generate_sessions_all(data_folder)

### Check for session number mismatch for each day

In [37]:
# generate a dict of number or rounds for each date
mouse_list = utils.generate_mouse_list(sessions_all)
sessions_by_date = sessions_all.groupby('date')
rounds_dict = {}
for date, data in sessions_by_date:
    num_session_list = []
    for mouse in mouse_list:
        mouse_by_date = data.loc[data['mouse'] == mouse]
        num_session_list.append(len(mouse_by_date))
    rounds_dict[date] = statistics.mode(num_session_list)

check for missing sessions, copy and paste from previous day to pad

In [38]:
no_missing_sessions = True
for date, data in sessions_by_date:
    rounds = rounds_dict[date]
    for mouse in mouse_list:
        mouse_by_date = data.loc[data['mouse'] == mouse]
        if len(mouse_by_date) < rounds:
            no_missing_sessions = False
            print(f"on {date}, {mouse} has missing sessions")
if no_missing_sessions:
    print("no missing sessions!")

on 2024-01-24, RZ045 has missing sessions
on 2024-03-31, RZ034 has missing sessions
on 2024-03-31, RZ036 has missing sessions
on 2024-03-31, RZ037 has missing sessions
on 2024-03-31, RZ038 has missing sessions
on 2024-03-31, RZ039 has missing sessions


pad if missing session exists

In [39]:
# left here to examine the sessions and mice that need to be handled
# day = sessions_by_date.get_group('2024-04-02').sort_values('mouse')
# display(day)

check for multiple sessions on the same day. decide if stitching is needed. stitch

In [40]:
days_to_stitch = []
mice_to_stitch = []
for date, data in sessions_by_date:
    rounds = rounds_dict[date]
    for mouse in mouse_list:
        mouse_by_date = data.loc[data['mouse'] == mouse]
        if (len(mouse_by_date) > rounds) & (len(mouse_by_date) > 1):
            days_to_stitch.append(date)
            mice_to_stitch.append(mouse)
            print(f"on {date}, {mouse} has {len(mouse_by_date)} sessions")
if not days_to_stitch:
    print("no sessions to stitch!")

on 2024-04-15, RZ038 has 2 sessions


In [41]:
# run it if session stitching is needed, nothing would happen otherwise
if not days_to_stitch:
    print("no sessions to stitch!")
else:
    for d, m in zip(days_to_stitch, mice_to_stitch):
        day = sessions_by_date.get_group(d)
        sessions_to_stitch = day[day['mouse'] == m]

        session_1_dir = utils.generate_events_processed_path(data_folder, sessions_to_stitch.iloc[0])
        session_2_dir = utils.generate_events_processed_path(data_folder, sessions_to_stitch.iloc[1])

        if os.path.exists(session_1_dir) and os.path.exists(session_2_dir):
            session_1 = pd.read_csv(session_1_dir)
            session_2 = pd.read_csv(session_2_dir)
            stitched_session = helper.stitch_sessions(session_1, session_2) 
            #should change to stitch events. stitch sessions should be deleted. to follow nomanclature, all session should be renamed to events!!

            stitched_session.to_csv(session_1_dir, index=False)
            shutil.rmtree(os.path.join(data_folder, sessions_to_stitch.iloc[1].dir))
            print(f"{d} {m} session 2 deleted")
        else:
            print("one of the sessions do not exist")

2024-04-15 RZ038 session 2 deleted


make a copy of cleaned data before preceeding! and need to re generate session log!

## Add additional info to session logs and save them

In [42]:
sessions_all = generate_sessions_all(data_folder)

In [43]:
session_basics_list = []
for _, session_info in sessions_all.iterrows():
    try:
        events = pd.read_csv(utils.generate_events_path(data_folder, session_info))
        session_basics = {'dir': session_info.dir} | helper.get_session_basics(events)
        session_basics_list.append(session_basics)
    except:
        print(session_info.dir)
session_basics_df = pd.DataFrame(session_basics_list)
sessions_all = pd.merge(sessions_all, session_basics_df, on='dir')
sessions_all = sessions_all.drop(['pump_ul_per_turn', 'total_trial', 'total_reward'], axis=1)

add number of days trained for training sessions

In [44]:
sessions_training = sessions_all.loc[sessions_all.training == 'regular'].reset_index()

In [45]:
sessions_training = sessions_training.groupby('mouse', group_keys=False).apply(helper.assign_session_numbers)

### Saves all sessions log and training session log

In [46]:
utils.save_as_csv(df=sessions_all, folder=data_folder, filename='sessions_all.csv')
utils.save_as_csv(df=sessions_training, folder=data_folder, filename='sessions_training.csv')

# Process raw session and generate all_trials

load session log

In [47]:
sessions_training = utils.load_data(os.path.join(data_folder, 'sessions_training.csv'))

## Generate all trials, align trial number, trial state, and trial time for raw session

In [None]:
for _, session_info in sessions_training.iterrows():
    try: 
        events_processed_path = utils.generate_events_processed_path(data_folder, session_info)
        trials_path = utils.generate_trials_path(data_folder, session_info)
        # if os.path.isfile(events_processed_path) and os.path.isfile(trials_path):
        #     continue
        
        events = pd.read_csv(utils.generate_events_path(data_folder, session_info))
        trials = helper.generate_trials(session_info, events)

        # align trial number
        events = helper.align_trial_number(events, trials)
        
        events = utils.trim_session(session_info, events)

        # align trial state
        events = events.groupby('session_trial_num', group_keys=False).apply(helper.align_trial_states)
        # add trial_time
        events = events.groupby('session_trial_num', group_keys=False).apply(helper.add_trial_time)

        events.to_csv(events_processed_path)
        trials.to_csv(trials_path)
    except:
        display(session_info)

## Analyze trials

In [50]:
for _, session_info in sessions_training.iterrows():
    trials_analyzed_path = utils.generate_trials_analyzed_path(data_folder, session_info)
    # if os.path.isfile(trials_analyzed_path):
    #     continue
    
    session_by_trial = utils.load_data(utils.generate_events_processed_path(data_folder, session_info)).groupby('session_trial_num')
    trials = utils.load_data(utils.generate_trials_path(data_folder, session_info))
    trials_data = helper.get_trial_data_df(session_by_trial)
    trials_analyzed = pd.merge(trials, trials_data, on='session_trial_num')
    trials_analyzed['group'] = session_info.group #assigning trial type manually
    trials_analyzed.to_csv(trials_analyzed_path)

# Sort sessions by experiment
move exp2 session folders to exp2 folder, and same for exp3

In [60]:
exp2_folder = os.path.join(data_dir, cohort, 'exp2')
exp3_folder = os.path.join(data_dir, cohort, 'exp3')

load session log

In [61]:
sessions_all = utils.load_data(os.path.join(data_folder, 'sessions_all.csv'))
sessions_training = utils.load_data(os.path.join(data_folder, 'sessions_training.csv'))

In [62]:
def move_dir_to_new_folder(session_log, data_folder, target_folder):
    for _, session_info in session_log.iterrows():
        session_dir = os.path.join(data_folder, session_info.dir)
        dest_dir = os.path.join(target_folder, session_info.dir)

        if os.path.exists(dest_dir):
            continue
        elif os.path.exists(session_dir):
            shutil.copytree(session_dir, dest_dir)
        else:
            print(f"didn't work for: {session_info.dir}")

In [63]:
# slice out logs for exp2
sessions_all_exp2 = sessions_all.loc[sessions_all.exp == 2]
sessions_training_exp2 = sessions_training.loc[sessions_training.exp == 2]
# specify exp2 folder
data_folder_exp2 = os.path.join(exp2_folder, 'full_clean')
# move files from full_clean to exp folder
move_dir_to_new_folder(sessions_all_exp2, data_folder, data_folder_exp2)
# save sliced logs to exp_folde
utils.save_as_csv(df=sessions_all_exp2, folder=data_folder_exp2, filename='sessions_all_exp2.csv')
utils.save_as_csv(df=sessions_training_exp2, folder=data_folder_exp2, filename='sessions_training_exp2.csv')

In [64]:
# slice out logs for exp3
sessions_all_exp3 = sessions_all.loc[sessions_all.exp == 3]
sessions_training_exp3 = sessions_training.loc[sessions_training.exp == 3]
# specify exp3 folder
data_folder_exp3 = os.path.join(exp3_folder, 'full_clean')
# move files from full_clean to exp folder
move_dir_to_new_folder(sessions_all_exp3, data_folder, data_folder_exp3)
# save sliced logs to exp_folder
utils.save_as_csv(df=sessions_all_exp3, folder=data_folder_exp3, filename='sessions_all_exp3.csv')
utils.save_as_csv(df=sessions_training_exp3, folder=data_folder_exp3, filename='sessions_training_exp3.csv')

# Stitch sessions from the same day

need to run this separately for exp2 and exp3, adjust based on next exp

In [65]:
exp2_folder = os.path.join(data_dir, cohort, 'exp2')
data_folder_exp2 = os.path.join(exp2_folder, 'full_clean')
exp3_folder = os.path.join(data_dir, cohort, 'exp3')
data_folder_exp3 = os.path.join(exp3_folder, 'full_clean')

In [66]:
stitched_folder_exp2 = os.path.join(exp2_folder, "stitched")
stitched_folder_exp3 = os.path.join(exp3_folder, "stitched")

load session log

In [67]:
sessions_training_exp2 = utils.load_data(os.path.join(data_folder_exp2, 'sessions_training_exp2.csv'))
sessions_training_exp3 = utils.load_data(os.path.join(data_folder_exp3, 'sessions_training_exp3.csv'))

generate session with each training day as an entry

In [68]:
def generate_stitched_all_mice_session_log(session_log):
    stitched_session_log = session_log[['date', 'training', 'exp']].copy()
    stitched_session_log = stitched_session_log.drop_duplicates(subset=['date'], keep='first')
    return stitched_session_log

In [69]:
sessions_training_stitched_exp2 = generate_stitched_all_mice_session_log(sessions_training_exp2)
sessions_training_stitched_exp3 = generate_stitched_all_mice_session_log(sessions_training_exp3)

creates empty directoreis in the stitched folder

In [70]:
for date in sessions_training_stitched_exp2.date:
    new_dir = os.path.join(stitched_folder_exp2, date)
    if not os.path.exists(new_dir):
        os.makedirs(new_dir)

In [71]:
for date in sessions_training_stitched_exp3.date:
    new_dir = os.path.join(stitched_folder_exp3, date)
    if not os.path.exists(new_dir):
        os.makedirs(new_dir)

loop through each day to stitch events_processed and trials_analyzed from all mice

In [72]:
sessions_by_date_exp2 = sessions_training_exp2.groupby('date')
sessions_by_date_exp3 = sessions_training_exp3.groupby('date')

In [73]:
def generate_trials_analyzed_stitched_path(data_folder, session_info):
    filename = f'trials_analyzed_stitched_{session_info.date}.csv'
    return os.path.join(data_folder, session_info.date, filename)

In [74]:
def generate_events_processed_stitched_path(data_folder, session_info):
    filename = f'events_processed_stitched_{session_info.date}.csv'
    return os.path.join(data_folder, f"{session_info.date}", filename)

In [75]:
def stitch_events(events_1, events_2):
    session_1_basics = helper.get_session_basics(events_1)
    time_offset = session_1_basics['session_time']
    block_offset = session_1_basics['num_blocks']
    trial_offset = session_1_basics['num_trials']
    
    events_2.session_time = events_2.session_time + time_offset
    events_2.block_num = events_2.block_num + block_offset
    events_2.session_trial_num= events_2.session_trial_num + trial_offset

    stitched_session = pd.concat([events_1, events_2])
    return stitched_session

In [76]:
def stitch_trials(trials_1, trials_2):
    trial_offset = trials_1.session_trial_num.max()+1
    block_offset = trials_1.block_num.max()+1
    time_offset = trials_1.end_time.max()
    
    trials_2.session_trial_num = trials_2.session_trial_num + trial_offset
    trials_2.block_num = trials_2.block_num + block_offset
    trials_2.start_time = trials_2.start_time + time_offset
    trials_2.end_time = trials_2.end_time + time_offset

    stitched_all_trials = pd.concat([trials_1, trials_2])
    return stitched_all_trials

In [77]:
# run this between reruns to prevent the df 
events_stitched = None
trials_stitched = None

BUG HERE!!! MOUSE NAME NOT PRESERVED!!!! FIX IT!!

In [None]:
for d, date in sessions_by_date_exp2:
    num_mice = len(date)
    session_info = date.iloc[0]
    events_stitched = utils.load_data(utils.generate_events_processed_path(data_folder, session_info))
    events_stitched['mouse'] = session_info['mouse']
    trials_stitched = utils.load_data(utils.generate_trials_analyzed_path(data_folder, session_info))
    trials_stitched['mouse'] = session_info['mouse']

    for i in range(1, num_mice):
        events_2 = utils.load_data(utils.generate_events_processed_path(data_folder, date.iloc[i]))
        events_2['mouse'] = session_info['mouse']
        trials_2 = utils.load_data(utils.generate_trials_analyzed_path(data_folder, date.iloc[i]))
        events_2['mouse'] = session_info['mouse']
        events_stitched = stitch_events(events_stitched, events_2)
        trials_stitched = stitch_trials(trials_stitched, trials_2)
    
    events_stitched.to_csv(generate_events_processed_stitched_path(stitched_folder_exp2, session_info), index=False)
    trials_stitched.to_csv(generate_trials_analyzed_stitched_path(stitched_folder_exp2, session_info), index=False)

In [None]:
for d, date in sessions_by_date_exp3:
    num_mice = len(date)
    session_info = date.iloc[0]
    events_stitched = utils.load_data(utils.generate_events_processed_path(data_folder, session_info))
    events_stitched['mouse'] = session_info['mouse']
    trials_stitched = utils.load_data(utils.generate_trials_analyzed_path(data_folder, session_info))
    trials_stitched['mouse'] = session_info['mouse']

    for i in range(1, num_mice):
        events_2 = utils.load_data(utils.generate_events_processed_path(data_folder, date.iloc[i]))
        events_2['mouse'] = session_info['mouse']
        trials_2 = utils.load_data(utils.generate_trials_analyzed_path(data_folder, date.iloc[i]))
        events_2['mouse'] = session_info['mouse']
        events_stitched = stitch_events(events_stitched, events_2)
        trials_stitched = stitch_trials(trials_stitched, trials_2)
    
    events_stitched.to_csv(generate_events_processed_stitched_path(stitched_folder_exp3, session_info), index=False)
    trials_stitched.to_csv(generate_trials_analyzed_stitched_path(stitched_folder_exp3, session_info), index=False)

add info to stitched session log and save it

In [None]:
stitched_session_basics_list = []
for _, session_info in sessions_training_stitched_exp2.iterrows():
    events = pd.read_csv(generate_events_processed_stitched_path(stitched_folder_exp2, session_info))
    session_basics = {'date': session_info.date} | helper.get_session_basics(events)
    stitched_session_basics_list.append(session_basics)
stitched_session_basics = pd.DataFrame(stitched_session_basics_list)
stitched_session_log_2 = pd.merge(sessions_training_stitched_exp2, stitched_session_basics, on='date')
utils.save_as_csv(df=stitched_session_log_2, folder=stitched_folder_exp2, filename='sessions_training_stitched_exp2.csv')

In [None]:
stitched_session_basics_list = []
for _, session_info in sessions_training_stitched_exp3.iterrows():
    events = pd.read_csv(generate_events_processed_stitched_path(stitched_folder_exp3, session_info))
    session_basics = {'date': session_info.date} | helper.get_session_basics(events)
    stitched_session_basics_list.append(session_basics)
stitched_session_basics = pd.DataFrame(stitched_session_basics_list)
stitched_session_log_3 = pd.merge(sessions_training_stitched_exp3, stitched_session_basics, on='date')
utils.save_as_csv(df=stitched_session_log_3, folder=stitched_folder_exp3, filename='sessions_training_stitched_exp3.csv')