In [1]:
import os
import json
import shutil

import session_processing_helper_c5 as helper
import utils_c5 as utils

import pandas as pd
import statistics

enter cohort name and folder name

In [2]:
cohort = '20240531_to_20240703'
to_analyze = 'full_clean'

In [3]:
data_dir = '/Users/rebekahzhang/data/behavior_data'
data_folder = os.path.join(data_dir, cohort, to_analyze)
print(data_folder)

/Users/rebekahzhang/data/behavior_data/20240531_to_20240703/full_clean


# Quality control

## Check for missing files in data dir

In [7]:
missing_count = 0
subentry_list = []
for entry in os.scandir(data_folder):
  if entry.is_dir():
    session_folder = entry.name
    session_path = os.path.join(data_folder, session_folder)
    
    events_found = False
    meta_found = False

    # Check for files within the session folder (using scandir again)
    for subentry in os.scandir(session_path):
      if subentry.is_file():  # Check for files only
        filename = subentry.name
        if filename.startswith('.'):
          continue
        elif filename.startswith("events_"):
          events_found = True
        elif filename.startswith("meta_"):
          meta_found = True

    if not (events_found and meta_found):
      subentry_list.append(session_folder)
      missing_count += 1
      print(f"Session '{session_folder}' is missing one or both required files.")
    
if missing_count>0:
  print(f"{missing_count} sessions have missing files")
  subentry_list.sort()
  subentry_list
else:
  print("all sessions have both events and meta")

all sessions have both events and meta


## Generate sessions log

generate session log using meta data from each session and add columns of basic info to each session

In [10]:
def modify_total_trial(row):
    if row['ending_code'] == 'pygame':
        return row['total_trial'] - 1
    elif row['ending_code'] == 'miss':
        return row['total_trial'] - 5
    else:
        return row['total_trial']

# changes total number of trials based on ending code. pygame leads to minus
# currently does not account for system crashing
sessions_all['total_trial'] = sessions_all.apply(modify_total_trial, axis=1)

In [11]:
def modify_sessions_all(sessions_all):
    sessions_all[['exp', 'group']] = sessions_all['exp'].str.extract(r'exp(\d)_(short|long)')
    sessions_all['group'] = sessions_all['group'].map({'short': 's', 'long': 'l'})
    sessions_all['total_trial'] = sessions_all.apply(modify_total_trial, axis=1)
    return sessions_all

In [12]:
def generate_sessions_all(data_folder):
    """Generates a DataFrame using session metadata from JSON files.
    Args:
        data_folder (str): Path to the directory containing JSON files.
    Returns:
        pd.DataFrame: DataFrame containing session metadata, sorted by 'dir' column.
    """

    data = []
    for root, _, files in os.walk(data_folder):
        for file in files:
            if file.startswith("meta_") and file.endswith(".json"):
                path = os.path.join(root, file)
                try:
                    with open(path) as f:
                        session_data = json.load(f)['session_config']
                        data.append(session_data)
                except Exception as e:
                    print(f"Error processing file {file}: {e}")

    sessions_all = pd.DataFrame(data)
    sessions_all['dir'] = sessions_all['date']+ '_' + sessions_all['time'] + '_' + sessions_all['mouse']
    sessions_all = sessions_all.sort_values('dir')
    sessions_all = modify_sessions_all(sessions_all)
    return sessions_all

In [13]:
sessions_all = generate_sessions_all(data_folder)

In [14]:
sessions_all

Unnamed: 0,date,time,mouse,exp,training,rig,trainer,record,forward_file,total_reward,total_trial,avg_tw,ending_code,dir,group
15,2024-05-31,09-46-24,RZ034,2,regular,rig2,Lianne,False,True,700,417,1.60,reward,2024-05-31_09-46-24_RZ034,s
102,2024-05-31,09-47-49,RZ036,2,regular,rig3,Lianne,False,True,700,423,1.58,reward,2024-05-31_09-47-49_RZ036,s
133,2024-05-31,10-31-20,RZ037,2,regular,rig2,Lianne,False,True,700,472,1.60,reward,2024-05-31_10-31-20_RZ037,l
126,2024-05-31,10-34-49,RZ038,2,regular,rig3,Lianne,False,True,530,273,2.57,miss,2024-05-31_10-34-49_RZ038,l
148,2024-05-31,11-45-06,RZ039,2,regular,rig3,Lianne,False,True,700,386,1.92,reward,2024-05-31_11-45-06_RZ039,l
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
61,2024-07-03,13-16-12,RZ055,2,regular,rig3,Lianne,False,True,420,175,4.29,miss,2024-07-03_13-16-12_RZ055,l
225,2024-07-03,14-09-14,RZ056,2,regular,rig3,Lianne,False,True,700,356,5.15,reward,2024-07-03_14-09-14_RZ056,l
238,2024-07-03,15-16-45,RZ055,2,regular,rig2,Lianne,False,True,290,93,8.08,reward,2024-07-03_15-16-45_RZ055,l
2,2024-07-03,15-52-34,RZ054,2,regular,rig2,Lianne,False,True,160,84,2.62,reward,2024-07-03_15-52-34_RZ054,l


## Examine session quality
doesn't need to run when data folder is cleaned 
<br>
sessions_all needs to be regenerated after every cleaning step

### Remove test sessions

In [15]:
def remove_sessions(sessions_to_remove, data_folder):
    for _, session_info in sessions_to_remove.iterrows():
        shutil.rmtree(os.path.join(data_folder, session_info.dir))

In [16]:
sessions_test = sessions_all.loc[sessions_all.mouse=='test']
if len(sessions_test) > 0:
    remove_sessions(sessions_test, data_folder)
else:
    print("no test sessions")

no test sessions


### Check for short sessions

In [18]:
sessions_short = sessions_all.loc[sessions_all['total_trial'] < 30] 
if len(sessions_short)>0:
    display(sessions_short)
else: 
    print('no short sessions!')

no short sessions!


remove short sessions

In [19]:
remove_sessions(sessions_short, data_folder)

### Regenerate sessions_all

In [20]:
sessions_all = generate_sessions_all(data_folder)

### save logs

In [21]:
sessions_training = sessions_all.loc[sessions_all.training == 'regular'].reset_index()

In [22]:
utils.save_as_csv(df=sessions_all, folder=data_folder, filename='sessions_all.csv')
utils.save_as_csv(df=sessions_training, folder=data_folder, filename='sessions_training.csv')

# code to be ignored for now

### Check for session number mismatch for each 
______________maybe i dont need to do it here

In [None]:
# generate a dict of number or rounds for each date
mouse_list = utils.generate_mouse_list(sessions_all)
sessions_by_date = sessions_all.groupby('date')
rounds_dict = {}
for date, data in sessions_by_date:
    num_session_list = []
    for mouse in mouse_list:
        mouse_by_date = data.loc[data['mouse'] == mouse]
        num_session_list.append(len(mouse_by_date))
    rounds_dict[date] = statistics.mode(num_session_list)

check for missing sessions, copy and paste from previous day to pad

In [None]:
no_missing_sessions = True
for date, data in sessions_by_date:
    rounds = rounds_dict[date]
    for mouse in mouse_list:
        mouse_by_date = data.loc[data['mouse'] == mouse]
        if len(mouse_by_date) < rounds:
            no_missing_sessions = False
            print(f"on {date}, {mouse} has missing sessions")
if no_missing_sessions:
    print("no missing sessions!")

pad if missing session exists

In [None]:
# left here to examine the sessions and mice that need to be handled
# day = sessions_by_date.get_group('2024-04-02').sort_values('mouse')
# display(day)

check for multiple sessions on the same day. decide if stitching is needed. stitch

In [None]:
days_to_stitch = []
mice_to_stitch = []
for date, data in sessions_by_date:
    rounds = rounds_dict[date]
    for mouse in mouse_list:
        mouse_by_date = data.loc[data['mouse'] == mouse]
        if (len(mouse_by_date) > rounds) & (len(mouse_by_date) > 1):
            days_to_stitch.append(date)
            mice_to_stitch.append(mouse)
            print(f"on {date}, {mouse} has {len(mouse_by_date)} sessions")
if not days_to_stitch:
    print("no sessions to stitch!")

In [None]:
# run it if session stitching is needed, nothing would happen otherwise
if not days_to_stitch:
    print("no sessions to stitch!")
else:
    for d, m in zip(days_to_stitch, mice_to_stitch):
        day = sessions_by_date.get_group(d)
        sessions_to_stitch = day[day['mouse'] == m]

        session_1_dir = utils.generate_events_processed_path(data_folder, sessions_to_stitch.iloc[0])
        session_2_dir = utils.generate_events_processed_path(data_folder, sessions_to_stitch.iloc[1])

        if os.path.exists(session_1_dir) and os.path.exists(session_2_dir):
            session_1 = pd.read_csv(session_1_dir)
            session_2 = pd.read_csv(session_2_dir)
            stitched_session = helper.stitch_sessions(session_1, session_2) 
            #should change to stitch events. stitch sessions should be deleted. to follow nomanclature, all session should be renamed to events!!

            stitched_session.to_csv(session_1_dir, index=False)
            shutil.rmtree(os.path.join(data_folder, sessions_to_stitch.iloc[1].dir))
            print(f"{d} {m} session 2 deleted")
        else:
            print("one of the sessions do not exist")

make a copy of cleaned data before preceeding! and need to re generate session log!

## Add additional info to session logs and save them

In [None]:
sessions_all = generate_sessions_all(data_folder)

In [None]:
session_basics_list = []
for _, session_info in sessions_all.iterrows():
    try:
        events = pd.read_csv(utils.generate_events_path(data_folder, session_info))
        session_basics = {'dir': session_info.dir} | helper.get_session_basics(events)
        session_basics_list.append(session_basics)
    except:
        print(session_info.dir)
session_basics_df = pd.DataFrame(session_basics_list)
sessions_all = pd.merge(sessions_all, session_basics_df, on='dir')
sessions_all = sessions_all.drop(['pump_ul_per_turn', 'total_trial', 'total_reward'], axis=1)

add number of days trained for training sessions

In [None]:
sessions_training = sessions_all.loc[sessions_all.training == 'regular'].reset_index()

In [None]:
sessions_training = sessions_training.groupby('mouse', group_keys=False).apply(helper.assign_session_numbers)

### Saves all sessions log and training session log

In [None]:
utils.save_as_csv(df=sessions_all, folder=data_folder, filename='sessions_all.csv')
utils.save_as_csv(df=sessions_training, folder=data_folder, filename='sessions_training.csv')

# Process raw session

load session log

In [23]:
sessions_training = utils.load_data(os.path.join(data_folder, 'sessions_training.csv'))

In [24]:
def process_events(session_info, events):
    ending_to_adjust = ['pygame', 'miss']
    ending_smooth = ['time', 'reward', 'trial']
    if session_info['ending_code'] in ending_to_adjust:
        events = events.loc[events['session_trial_num'].between(0, session_info['total_trial'])]
    elif session_info['ending_code'] in ending_smooth:
        events = events.iloc[2:-1]
    else:
        print(session_info['dir'])
        raise "ending code unknown"
    return events

In [25]:
# process events by triming the beginning and the end and adding trial time
for _, session_info in sessions_training.iterrows():
    try:
        events_processed_path = utils.generate_events_processed_path(data_folder, session_info)
        if os.path.isfile(events_processed_path):
            continue

        events = pd.read_csv(utils.generate_events_path(data_folder, session_info), low_memory=False)
        events = process_events(session_info, events)
        events_processed = events.groupby('session_trial_num', group_keys=False).apply(helper.add_trial_time)
        events_processed.to_csv(events_processed_path)
    except:
        print(session_info['dir'])

# Generate all_trials

In [26]:
def get_trial_basics(trial):
    """gets the df of a trial, extracts 5 things, and outputs as a dictionary"""
    trial_start = trial.loc[(trial['key'] == 'trial') & (trial['value'] == 1)].iloc[0]
    trial_end = trial.loc[(trial['key'] == 'trial') & (trial['value'] == 0)].iloc[0]

    trial_basics = {'session_trial_num': trial_start['session_trial_num'],
                    'block_trial_num': trial_start['block_trial_num'],
                    'block_num': trial_start['block_num'],
                    'start_time': trial_start['session_time'],
                    'end_time': trial_end['session_time']}
    return trial_basics

In [27]:
def generate_trials(session_info, events):
    trial_info_list = []
    for t in range(int(session_info.total_trial)):
        trial = events.loc[events['session_trial_num'] == t]
        trial_basics = get_trial_basics(trial)
        trial_info_list.append(trial_basics)
    trials = pd.DataFrame(trial_info_list)
    return trials

In [28]:
# generate all trials based on events processed
for _, session_info in sessions_training.iterrows():
    try: 
        trials_path = utils.generate_trials_path(data_folder, session_info)
        if os.path.isfile(trials_path):
            continue
        
        events_processed = pd.read_csv(utils.generate_events_processed_path(data_folder, session_info))
        trials = generate_trials(session_info, events_processed)

        trials.to_csv(trials_path)
    except:
        display(session_info)

## Analyze trials

In [30]:
for _, session_info in sessions_training.iterrows():
    trials_analyzed_path = utils.generate_trials_analyzed_path(data_folder, session_info)
    # if os.path.isfile(trials_analyzed_path):
    #     continue
    
    session_by_trial = utils.load_data(utils.generate_events_processed_path(data_folder, session_info)).groupby('session_trial_num')
    trials = utils.load_data(utils.generate_trials_path(data_folder, session_info))
    trials_data = helper.get_trial_data_df(session_by_trial)
    trials_analyzed = pd.merge(trials, trials_data, on='session_trial_num')
    trials_analyzed['group'] = session_info.group #assigning trial type manually
    trials_analyzed.to_csv(trials_analyzed_path)

I think so far it works till here. code needs to be cleaned up tho
issues i see: 1. generate trials can use groupby? 2. code is not clean. 3. need more rules to catch data corruption. 

# Stitch sessions from the same day

need to run this separately for exp2 and exp3, adjust based on next exp

In [None]:
exp2_folder = os.path.join(data_dir, cohort, 'exp2')
data_folder_exp2 = os.path.join(exp2_folder, 'full_clean')
exp3_folder = os.path.join(data_dir, cohort, 'exp3')
data_folder_exp3 = os.path.join(exp3_folder, 'full_clean')

In [None]:
stitched_folder_exp2 = os.path.join(exp2_folder, "stitched")
stitched_folder_exp3 = os.path.join(exp3_folder, "stitched")

load session log

In [None]:
sessions_training_exp2 = utils.load_data(os.path.join(data_folder_exp2, 'sessions_training_exp2.csv'))
sessions_training_exp3 = utils.load_data(os.path.join(data_folder_exp3, 'sessions_training_exp3.csv'))

generate session with each training day as an entry

In [None]:
def generate_stitched_all_mice_session_log(session_log):
    stitched_session_log = session_log[['date', 'training', 'exp']].copy()
    stitched_session_log = stitched_session_log.drop_duplicates(subset=['date'], keep='first')
    return stitched_session_log

In [None]:
sessions_training_stitched_exp2 = generate_stitched_all_mice_session_log(sessions_training_exp2)
sessions_training_stitched_exp3 = generate_stitched_all_mice_session_log(sessions_training_exp3)

creates empty directoreis in the stitched folder

In [None]:
for date in sessions_training_stitched_exp2.date:
    new_dir = os.path.join(stitched_folder_exp2, date)
    if not os.path.exists(new_dir):
        os.makedirs(new_dir)

In [None]:
for date in sessions_training_stitched_exp3.date:
    new_dir = os.path.join(stitched_folder_exp3, date)
    if not os.path.exists(new_dir):
        os.makedirs(new_dir)

loop through each day to stitch events_processed and trials_analyzed from all mice

In [None]:
sessions_by_date_exp2 = sessions_training_exp2.groupby('date')
sessions_by_date_exp3 = sessions_training_exp3.groupby('date')

In [None]:
def generate_trials_analyzed_stitched_path(data_folder, session_info):
    filename = f'trials_analyzed_stitched_{session_info.date}.csv'
    return os.path.join(data_folder, session_info.date, filename)

In [None]:
def generate_events_processed_stitched_path(data_folder, session_info):
    filename = f'events_processed_stitched_{session_info.date}.csv'
    return os.path.join(data_folder, f"{session_info.date}", filename)

In [None]:
def stitch_events(events_1, events_2):
    session_1_basics = helper.get_session_basics(events_1)
    time_offset = session_1_basics['session_time']
    block_offset = session_1_basics['num_blocks']
    trial_offset = session_1_basics['num_trials']
    
    events_2.session_time = events_2.session_time + time_offset
    events_2.block_num = events_2.block_num + block_offset
    events_2.session_trial_num= events_2.session_trial_num + trial_offset

    stitched_session = pd.concat([events_1, events_2])
    return stitched_session

In [None]:
def stitch_trials(trials_1, trials_2):
    trial_offset = trials_1.session_trial_num.max()+1
    block_offset = trials_1.block_num.max()+1
    time_offset = trials_1.end_time.max()
    
    trials_2.session_trial_num = trials_2.session_trial_num + trial_offset
    trials_2.block_num = trials_2.block_num + block_offset
    trials_2.start_time = trials_2.start_time + time_offset
    trials_2.end_time = trials_2.end_time + time_offset

    stitched_all_trials = pd.concat([trials_1, trials_2])
    return stitched_all_trials

In [None]:
# run this between reruns to prevent the df 
events_stitched = None
trials_stitched = None

BUG HERE!!! MOUSE NAME NOT PRESERVED!!!! FIX IT!!

In [None]:
for d, date in sessions_by_date_exp2:
    num_mice = len(date)
    session_info = date.iloc[0]
    events_stitched = utils.load_data(utils.generate_events_processed_path(data_folder, session_info))
    events_stitched['mouse'] = session_info['mouse']
    trials_stitched = utils.load_data(utils.generate_trials_analyzed_path(data_folder, session_info))
    trials_stitched['mouse'] = session_info['mouse']

    for i in range(1, num_mice):
        events_2 = utils.load_data(utils.generate_events_processed_path(data_folder, date.iloc[i]))
        events_2['mouse'] = session_info['mouse']
        trials_2 = utils.load_data(utils.generate_trials_analyzed_path(data_folder, date.iloc[i]))
        events_2['mouse'] = session_info['mouse']
        events_stitched = stitch_events(events_stitched, events_2)
        trials_stitched = stitch_trials(trials_stitched, trials_2)
    
    events_stitched.to_csv(generate_events_processed_stitched_path(stitched_folder_exp2, session_info), index=False)
    trials_stitched.to_csv(generate_trials_analyzed_stitched_path(stitched_folder_exp2, session_info), index=False)

In [None]:
for d, date in sessions_by_date_exp3:
    num_mice = len(date)
    session_info = date.iloc[0]
    events_stitched = utils.load_data(utils.generate_events_processed_path(data_folder, session_info))
    events_stitched['mouse'] = session_info['mouse']
    trials_stitched = utils.load_data(utils.generate_trials_analyzed_path(data_folder, session_info))
    trials_stitched['mouse'] = session_info['mouse']

    for i in range(1, num_mice):
        events_2 = utils.load_data(utils.generate_events_processed_path(data_folder, date.iloc[i]))
        events_2['mouse'] = session_info['mouse']
        trials_2 = utils.load_data(utils.generate_trials_analyzed_path(data_folder, date.iloc[i]))
        events_2['mouse'] = session_info['mouse']
        events_stitched = stitch_events(events_stitched, events_2)
        trials_stitched = stitch_trials(trials_stitched, trials_2)
    
    events_stitched.to_csv(generate_events_processed_stitched_path(stitched_folder_exp3, session_info), index=False)
    trials_stitched.to_csv(generate_trials_analyzed_stitched_path(stitched_folder_exp3, session_info), index=False)

add info to stitched session log and save it

In [None]:
stitched_session_basics_list = []
for _, session_info in sessions_training_stitched_exp2.iterrows():
    events = pd.read_csv(generate_events_processed_stitched_path(stitched_folder_exp2, session_info))
    session_basics = {'date': session_info.date} | helper.get_session_basics(events)
    stitched_session_basics_list.append(session_basics)
stitched_session_basics = pd.DataFrame(stitched_session_basics_list)
stitched_session_log_2 = pd.merge(sessions_training_stitched_exp2, stitched_session_basics, on='date')
utils.save_as_csv(df=stitched_session_log_2, folder=stitched_folder_exp2, filename='sessions_training_stitched_exp2.csv')

In [None]:
stitched_session_basics_list = []
for _, session_info in sessions_training_stitched_exp3.iterrows():
    events = pd.read_csv(generate_events_processed_stitched_path(stitched_folder_exp3, session_info))
    session_basics = {'date': session_info.date} | helper.get_session_basics(events)
    stitched_session_basics_list.append(session_basics)
stitched_session_basics = pd.DataFrame(stitched_session_basics_list)
stitched_session_log_3 = pd.merge(sessions_training_stitched_exp3, stitched_session_basics, on='date')
utils.save_as_csv(df=stitched_session_log_3, folder=stitched_folder_exp3, filename='sessions_training_stitched_exp3.csv')