In [154]:
import os
import json
import shutil

import session_processing_helper_c5 as helper
import utils_c5 as utils

import pandas as pd

In [155]:
data_dir = '/Users/rebekahzhang/data/behavior_data'
period = '20240123-20240415'
exp = "exp2"
cohort = "cohort_5"
data_folder = os.path.join(data_dir, period, exp, cohort)
print(data_folder)

/Users/rebekahzhang/data/behavior_data/20240123-20240415/exp2/cohort_5


# Quality Control

## Check session folders have both meta and events

In [156]:
def check_session_files(data_folder):
    files_check = []
    for entry in os.scandir(data_folder):
        if entry.is_dir():
            dir = entry.name
            session_path = os.path.join(data_folder, dir)
            events_found = False
            meta_found = False
            events_empty = True
            meta_empty = True
            
            required_files = [f for f in os.scandir(session_path) if f.is_file() and not f.name.startswith('.')]
            
            for file in required_files:
                if file.name.startswith("events_"):
                    events_found = True
                    if file.stat().st_size > 0:
                        events_empty = False
                elif file.name.startswith("meta_"):
                    meta_found = True
                    if file.stat().st_size > 0:
                        meta_empty = False
            
            files_check.append({
                'dir': dir,
                'events': events_found,
                'meta': meta_found,
                'events_empty': events_empty if events_found else None,
                'meta_empty': meta_empty if meta_found else None
            })

    files_check_df = pd.DataFrame(files_check).sort_values("dir")
    missing_meta = files_check_df[files_check_df.meta==False]
    missing_events = files_check_df[files_check_df.events==False]
    empty_meta = files_check_df[(files_check_df.meta==True) & (files_check_df.meta_empty==True)]
    empty_events = files_check_df[(files_check_df.events==True) & (files_check_df.events_empty==True)]
    
    return missing_meta, missing_events, empty_meta, empty_events

In [157]:
missing_events, missing_meta, empty_meta, empty_events = check_session_files(data_folder)
if not (missing_meta.empty and missing_events.empty and empty_meta.empty and empty_events.empty):
    print("\nFile check results:")
    if not missing_meta.empty:
        print("\nSessions missing meta files:")
        display(missing_meta)
    if not missing_events.empty:
        print("\nSessions missing events files:")
        display(missing_events)
    if not empty_meta.empty:
        print("\nSessions with empty meta files:")
        display(empty_meta)
    if not empty_events.empty:
        print("\nSessions with empty events files:")
        display(empty_events)
else:
    print("\nAll sessions have non-empty meta and events files.")


All sessions have non-empty meta and events files.


In [158]:
def delete_folders(folder_list, data_folder):
  if folder_list:
    for folder in folder_list:
      full_path = os.path.join(data_folder, folder)
      if os.path.exists(full_path):
        shutil.rmtree(full_path)
        print(f"Deleted folder: {full_path}")
      else:
        print(f"Folder not found: {full_path}")
  else:
    print("no sessions to delete")

In [159]:
delete_folders(missing_meta.dir.tolist(), data_folder)

no sessions to delete


In [160]:
delete_folders(missing_events.dir.tolist(), data_folder)

no sessions to delete


In [161]:
delete_folders(empty_meta.dir.tolist(), data_folder)

no sessions to delete


In [162]:
delete_folders(empty_events.dir.tolist(), data_folder)

no sessions to delete


## Generate and save sessions log

generate session log using meta data from each session and add columns of basic info to each session

In [89]:
def modify_total_trial(row):
    ending_code = row['ending_code'].lower()
    if ending_code == 'pygame' or ending_code == 'manual':
        return row['total_trial'] - 1
    elif ending_code == 'miss':
        return row['total_trial'] - 5
    else:
        return row['total_trial']

In [90]:
def modify_sessions_all(sessions_all):
    sessions_all['dir'] = sessions_all['date']+ '_' + sessions_all['time'] + '_' + sessions_all['mouse']
    sessions_all = sessions_all.sort_values('dir')
    sessions_all[['exp', 'group']] = sessions_all['exp'].str.extract(r'exp(\d)_(short|long)')
    sessions_all['group'] = sessions_all['group'].map({'short': 's', 'long': 'l'})
    
    sessions_all['total_trial'] = sessions_all.apply(modify_total_trial, axis=1)
    sessions_all = sessions_all.drop(['forward_file'], axis=1)
    return sessions_all

In [91]:
def generate_sessions_all(data_folder):
    """Generates a DataFrame using session metadata from JSON files.
    Args:
        data_folder (str): Path to the directory containing JSON files.
    Returns:
        pd.DataFrame: DataFrame containing session metadata, sorted by 'dir' column.
    """

    data = []
    for root, _, files in os.walk(data_folder):
        for file in files:
            if file.startswith("meta_") and file.endswith(".json"):
                path = os.path.join(root, file)
                try:
                    with open(path) as f:
                        session_data = json.load(f)['session_config']
                        data.append(session_data)
                except Exception as e:
                    print(f"Error processing file {file}: {e}")

    sessions_all = pd.DataFrame(data)
    sessions_all = modify_sessions_all(sessions_all)
    return sessions_all

In [92]:
def generate_sessions_training(sessions_all):
    sessions_training = sessions_all.loc[sessions_all.training == 'regular'].reset_index()
    sessions_training = sessions_training.groupby('mouse', group_keys=False).apply(helper.assign_session_numbers)
    return sessions_training

In [93]:
def generate_session_logs(data_folder, save_logs=True):
    sessions_all = generate_sessions_all(data_folder)
    sessions_training = generate_sessions_training(sessions_all)
    if save_logs:
        utils.save_as_csv(df=sessions_all, folder=data_folder, filename='sessions_all.csv')
        utils.save_as_csv(df=sessions_training, folder=data_folder, filename='sessions_training.csv')

    return sessions_all, sessions_training

### re-run after every quality control steps

In [111]:
sessions_all, sessions_training = generate_session_logs(data_folder)

sessions_training.tail()

Unnamed: 0,index,date,time,mouse,exp,training,rig,trainer,record,total_reward,total_trial,avg_tw,ending_code,dir,group,session
113,71,2024-10-03,11-38-00,RZ053,2,regular,rig4,Rebekah,False,560,170,12.74,miss,2024-10-03_11-38-00_RZ053,l,19
114,47,2024-10-03,11-52-54,RZ054,2,regular,rig5,Rebekah,False,205,64,22.69,miss,2024-10-03_11-52-54_RZ054,l,14
115,18,2024-10-03,11-54-10,RZ047,2,regular,rig6,Rebekah,False,700,329,2.08,reward,2024-10-03_11-54-10_RZ047,s,12
116,87,2024-10-03,13-02-53,RZ050,2,regular,rig1,Rebekah,False,700,430,1.91,reward,2024-10-03_13-02-53_RZ050,s,7
117,30,2024-10-03,13-59-56,RZ051,2,regular,rig1,Rebekah,False,700,240,3.65,reward,2024-10-03_13-59-56_RZ051,s,7


## Remove unwanted sessions
doesn't need to run when data folder is cleaned 
<br>
sessions_all needs to be regenerated after every cleaning step

### Remove test sessions

In [95]:
def remove_sessions(sessions_to_remove, data_folder):
    for _, session_info in sessions_to_remove.iterrows():
        shutil.rmtree(os.path.join(data_folder, session_info.dir))

In [96]:
sessions_test = sessions_all.loc[sessions_all.mouse=='test']
if len(sessions_test) > 0:
    remove_sessions(sessions_test, data_folder)
    print("test sessions removed")
else:
    print("no test sessions to be deleted!")

no test sessions to be deleted!


### Check for short sessions

In [97]:
short_threshold = 20

In [98]:
sessions_short = sessions_all[(sessions_all['total_trial'] < short_threshold) | sessions_all['total_trial'].isna()]
if len(sessions_short)>0:
    display(sessions_short)
else: 
    print('no short sessions to be checked!')

Unnamed: 0,date,time,mouse,exp,training,rig,trainer,record,total_reward,total_trial,avg_tw,ending_code,dir,group
106,2024-09-18,11-53-54,RZ051,2,regular,rig4,Rebekah,False,30,18,34.23,miss,2024-09-18_11-53-54_RZ051,s
53,2024-09-18,13-01-40,RZ052,2,regular,rig6,Rebekah,False,10,1,48.76,miss,2024-09-18_13-01-40_RZ052,l
36,2024-09-18,13-02-33,RZ055,2,regular,rig4,Rebekah,False,25,11,36.21,miss,2024-09-18_13-02-33_RZ055,l
69,2024-09-18,14-41-44,RZ056,2,regular,rig2,Rebekah,False,5,1,56.62,miss,2024-09-18_14-41-44_RZ056,l
119,2024-09-21,15-06-44,RZ054,2,regular,rig4,Rebekah,False,25,6,33.5,miss,2024-09-21_15-06-44_RZ054,l
73,2024-09-24,12-10-36,RZ052,2,regular,rig2,Rebekah,False,0,-1,0.0,manual,2024-09-24_12-10-36_RZ052,l
40,2024-09-24,13-47-40,RZ055,2,regular,rig5,Rebekah,False,50,15,23.6,miss,2024-09-24_13-47-40_RZ055,l
19,2024-09-25,11-46-49,RZ052,2,regular,rig4,Rebekah,False,5,0,50.32,miss,2024-09-25_11-46-49_RZ052,l
74,2024-10-01,10-56-01,RZ049,2,regular,rig3,Rebekah,False,5,1,1.07,manual,2024-10-01_10-56-01_RZ049,s
44,2024-10-01,12-07-18,RZ056,2,regular,rig3,Rebekah,False,15,6,33.3,miss,2024-10-01_12-07-18_RZ056,l


remove short sessions if needed

In [99]:
remove_sessions(sessions_short, data_folder)

# Process Events

load session log

In [100]:
sessions_training = utils.load_data(os.path.join(data_folder, 'sessions_training.csv'))

In [101]:
# trim events and check if events properly ended
def process_events(session_info, events):
    events = events.loc[events['session_trial_num'].between(0, session_info['total_trial'])]
    return events

In [102]:
problematic_sessions = pd.DataFrame(columns=sessions_training.columns)

for _, session_info in sessions_training.iterrows():
    try:
        events_processed_path = utils.generate_events_processed_path(data_folder, session_info)
        if os.path.isfile(events_processed_path):
            continue
        events = pd.read_csv(utils.generate_events_path(data_folder, session_info), low_memory=False)
        events = process_events(session_info, events)
        events_processed = events.groupby('session_trial_num', group_keys=False).apply(helper.add_trial_time)
        events_processed.to_csv(events_processed_path)
    except:
        problematic_sessions = pd.concat([problematic_sessions, session_info.to_frame().T], ignore_index=True)

In [77]:
if len(problematic_sessions) > 0:
    display(problematic_sessions)
else:
    print("all sessions are perfect! woohoo!")

all sessions are perfect! woohoo!


# Data set curation 

In [130]:
# deprecated rounds dict cuz we no longer run more than 1 session per day
# mouse_list = utils.generate_mouse_list(sessions_all)
mouse_list = ['RZ047','RZ049','RZ050','RZ051','RZ052','RZ053','RZ054','RZ055','RZ056']
sessions_by_date = sessions_training.groupby('date')

### Deal with missing sessions
this is not the proper way to deal with this. should have it populated with mean and variation

In [131]:
no_missing_sessions = True
for date, data in sessions_by_date:
    for mouse in mouse_list:
        mouse_by_date = data.loc[data['mouse'] == mouse]
        if len(mouse_by_date) < 1:
            no_missing_sessions = False
            print(f"on {date}, {mouse} has missing sessions")
if no_missing_sessions:
    print("no missing sessions!")

on 2024-09-20, RZ050 has missing sessions
on 2024-09-20, RZ051 has missing sessions
on 2024-09-21, RZ050 has missing sessions
on 2024-09-21, RZ051 has missing sessions
on 2024-09-24, RZ050 has missing sessions
on 2024-09-24, RZ051 has missing sessions
on 2024-09-25, RZ054 has missing sessions
on 2024-09-30, RZ047 has missing sessions
on 2024-09-30, RZ049 has missing sessions
on 2024-09-30, RZ052 has missing sessions
on 2024-09-30, RZ053 has missing sessions
on 2024-09-30, RZ054 has missing sessions
on 2024-09-30, RZ055 has missing sessions
on 2024-09-30, RZ056 has missing sessions
on 2024-10-03, RZ049 has missing sessions


In [132]:
sessions_by_date.get_group('2024-09-24')

Unnamed: 0,date,time,mouse,exp,training,rig,trainer,record,avg_tw,ending_code,dir,group,session,num_blocks,num_trials,rewards,session_time,proper_end
23,2024-09-24,12-04-52,RZ049,2,regular,rig5,Rebekah,False,11.64,reward,2024-09-24_12-04-52_RZ049,s,3,1,225,700.0,3519.94,True
24,2024-09-24,12-07-31,RZ052,2,regular,rig6,Rebekah,False,2.36,miss,2024-09-24_12-07-31_RZ052,l,3,2,414,855.0,12090.37,False
25,2024-09-24,12-11-31,RZ053,2,regular,rig2,Rebekah,False,1.42,reward,2024-09-24_12-11-31_RZ053,l,3,2,844,1400.0,16350.81,True
26,2024-09-24,12-12-51,RZ054,2,regular,rig3,Rebekah,False,1.64,reward,2024-09-24_12-12-51_RZ054,l,3,2,593,1140.0,16964.39,False
27,2024-09-24,13-06-25,RZ055,2,regular,rig5,Rebekah,False,4.32,miss,2024-09-24_13-06-25_RZ055,l,3,1,144,230.0,1936.44,False
28,2024-09-24,13-11-53,RZ056,2,regular,rig4,Rebekah,False,1.36,reward,2024-09-24_13-11-53_RZ056,l,3,2,547,945.0,11873.72,False
29,2024-09-24,15-08-25,RZ047,2,regular,rig2,Rebekah,False,3.17,reward,2024-09-24_15-08-25_RZ047,s,3,1,303,700.0,2244.8,True


duplicate if you are sussed out of having to redo this

In [133]:
# put it into utils when you clean this up.
def backup(source_dir):
    """create a copy for source_dir in the same path parallel to source_dir"""
    data_folder = os.path.dirname(source_dir)
    source_name = os.path.basename(source_dir)
    destination_dir = os.path.join(data_folder, f"{source_name}_copy")
    if not os.path.isdir(destination_dir):
        shutil.copytree(source_dir, destination_dir)
        print(f"{os.path.basename(source_dir)} backed up")
    else:
        print(f"{os.path.basename(destination_dir)} already exist")

In [134]:
backup(data_folder)

cohort_6_copy already exist


### Deal with multiple sessions

In [135]:
days_to_stitch = []
mice_to_stitch = []
for date, data in sessions_by_date:
    for mouse in mouse_list:
        mouse_by_date = data.loc[data['mouse'] == mouse]
        if len(mouse_by_date) > 1:
            days_to_stitch.append(date)
            mice_to_stitch.append(mouse)
            print(f"on {date}, {mouse} has {len(mouse_by_date)} sessions")
if not days_to_stitch:
    print("no sessions to stitch!")

no sessions to stitch!


In [136]:
def get_session_basics(session_df):
    num_trials = session_df.session_trial_num.max() 
    last_trial = session_df.loc[session_df['session_trial_num'] == num_trials]

    num_blocks = last_trial.loc[(last_trial['key'] == 'trial') & (last_trial['value'] == 1), 'block_num'].iloc[0] + 1
    total_reward = round(session_df.reward_size.sum(), 2)
    total_time = round((session_df.session_time.max() - session_df.session_time.min()), 2)
    session_basics = {'num_blocks': num_blocks,
                      'num_trials': num_trials + 1,
                      'rewards': total_reward,
                      'session_time': total_time}
    return session_basics  

def assign_session_numbers(group):
    group.sort_values(by=['mouse', 'dir', 'date'], inplace=True)
    group['session'] = list(range(len(group)))
    return group

# Stitch sessions from the same mouse on the same day
def stitch_sessions(session_1, session_2):
    session_1_basics = get_session_basics(session_1)
    time_offset = session_1_basics['session_time']
    block_offset = session_1_basics['num_blocks']
    trial_offset = session_1_basics['num_trials']
    
    session_2.session_time = session_2.session_time + time_offset
    session_2.block_num = session_2.block_num + block_offset
    session_2.session_trial_num= session_2.session_trial_num + trial_offset

    stitched_session = pd.concat([session_1, session_2])
    return stitched_session

In [137]:
# run it if session stitching is needed, nothing would happen otherwise
# has to run more than once if there are more than 2 sessions. fix it for the next round pls
if not days_to_stitch:
    print("no sessions to stitch!")
else:
    for d, m in zip(days_to_stitch, mice_to_stitch):
        day = sessions_by_date.get_group(d)
        sessions_to_stitch = day[day['mouse'] == m]

        session_1_dir = utils.generate_events_processed_path(data_folder, sessions_to_stitch.iloc[0])
        session_2_dir = utils.generate_events_processed_path(data_folder, sessions_to_stitch.iloc[1])

        if os.path.exists(session_1_dir) and os.path.exists(session_2_dir):
            session_1 = pd.read_csv(session_1_dir)
            session_2 = pd.read_csv(session_2_dir)
            stitched_session = stitch_sessions(session_1, session_2) 
            #should change to stitch events. stitch sessions should be deleted. to follow nomanclature, all session should be renamed to events!!

            stitched_session.to_csv(session_1_dir, index=False)
            shutil.rmtree(os.path.join(data_folder, sessions_to_stitch.iloc[1].dir))
            print(f"{d} {m} session 2 deleted")
        else:
            print("one of the sessions do not exist")

no sessions to stitch!


In [138]:
sessions_all, sessions_training = generate_session_logs(data_folder)

sessions_training.tail()

Unnamed: 0,index,date,time,mouse,exp,training,rig,trainer,record,total_reward,total_trial,avg_tw,ending_code,dir,group,session
79,51,2024-10-03,10-44-14,RZ053,2,regular,rig4,Rebekah,False,700,215,4.84,reward,2024-10-03_10-44-14_RZ053,l,9
80,42,2024-10-03,10-46-02,RZ052,2,regular,rig3,Rebekah,False,570,355,3.94,miss,2024-10-03_10-46-02_RZ052,l,9
81,15,2024-10-03,11-54-10,RZ047,2,regular,rig6,Rebekah,False,700,329,2.08,reward,2024-10-03_11-54-10_RZ047,s,9
82,60,2024-10-03,13-02-53,RZ050,2,regular,rig1,Rebekah,False,700,430,1.91,reward,2024-10-03_13-02-53_RZ050,s,7
83,20,2024-10-03,13-59-56,RZ051,2,regular,rig1,Rebekah,False,700,240,3.65,reward,2024-10-03_13-59-56_RZ051,s,7


### Correct sessions log

In [139]:
def correct_sessions_training(data_folder, save_log=True):
    _, sessions_training = generate_session_logs(data_folder, save_logs=False)
    session_info_list = []
    for _, session_info in sessions_training.iterrows():
        events_processed = pd.read_csv(utils.generate_events_processed_path(data_folder, session_info), low_memory=False)
        session_basics = helper.get_session_basics(events_processed)
        session_basics['dir'] = session_info['dir']
        session_info_list.append(session_basics)
    sessions_info = pd.DataFrame(session_info_list)
    corrected_sessions_training = pd.merge(sessions_training, sessions_info, on="dir")
    corrected_sessions_training = corrected_sessions_training.drop(columns=['index', 'total_reward', 'total_trial', 'total_reward'])
    corrected_sessions_training = corrected_sessions_training.groupby('mouse', group_keys=False).apply(helper.assign_session_numbers)
    if save_log:
        utils.save_as_csv(df=corrected_sessions_training, folder=data_folder, filename='sessions_training.csv')
    return corrected_sessions_training

In [140]:
sessions_training = correct_sessions_training(data_folder)

# Analyze trials

## Generate Trials

In [141]:
def get_trial_basics(trial):
    """gets the df of a trial, extracts 5 things, and outputs as a dictionary"""
    trial_start = trial.loc[(trial['key'] == 'trial') & (trial['value'] == 1)].iloc[0]
    trial_end = trial.loc[(trial['key'] == 'trial') & (trial['value'] == 0)].iloc[0]

    trial_basics = {'session_trial_num': trial_start['session_trial_num'],
                    'block_trial_num': trial_start['block_trial_num'],
                    'block_num': trial_start['block_num'],
                    'start_time': trial_start['session_time'],
                    'end_time': trial_end['session_time']}
    return trial_basics

In [142]:
def generate_trials(session_info, events):
    trial_info_list = []
    for t in range(int(session_info.num_trials)):
        trial = events.loc[events['session_trial_num'] == t]
        trial_basics = get_trial_basics(trial)
        trial_info_list.append(trial_basics)
    trials = pd.DataFrame(trial_info_list)
    return trials

In [144]:
# generate all trials based on events processed
problematic_sessions = pd.DataFrame(columns=sessions_training.columns)
for _, session_info in sessions_training.iterrows():
    try: 
        trials_path = utils.generate_trials_path(data_folder, session_info)
        if os.path.isfile(trials_path):
            continue
        
        events_processed = pd.read_csv(utils.generate_events_processed_path(data_folder, session_info))
        trials = generate_trials(session_info, events_processed)

        trials.to_csv(trials_path)
    except:
        problematic_sessions = pd.concat([problematic_sessions, session_info.to_frame().T], ignore_index=True)

In [145]:
if len(problematic_sessions) > 0:
    display(problematic_sessions)
else:
    print("all sessions are perfect! woohoo!")

all sessions are perfect! woohoo!


## Analyze trials

In [146]:
for _, session_info in sessions_training.iterrows():
    try:
        trials_analyzed_path = utils.generate_trials_analyzed_path(data_folder, session_info)
        if os.path.isfile(trials_analyzed_path):
            continue
        
        session_by_trial = utils.load_data(utils.generate_events_processed_path(data_folder, session_info)).groupby('session_trial_num')
        trials = utils.load_data(utils.generate_trials_path(data_folder, session_info))
        trials_data = helper.get_trial_data_df(session_by_trial)
        trials_analyzed = pd.merge(trials, trials_data, on='session_trial_num')
        trials_analyzed['group'] = session_info.group #assigning trial type manually
        trials_analyzed.to_csv(trials_analyzed_path)
    except:
        display(session_info)