In [43]:
import os
from pathlib import Path

import numpy as np
import pickle
import pandas as pd

import constants as k

In [44]:
raw_pickle_name = "neural_data_0519"
data_dir = '/Users/rebekahzhang/data/neural_data'
pickle_dir = Path(os.path.join(data_dir, 'session_pickles'))
figure_folder = os.path.join(data_dir, 'figures')

# Unify sorted data and recording log

**load sorted session data pickle**
- pickle is a list of dictionaries
- each dict representing one session, with subject, session date time, a event df, and a list of units
- each unit is an array of spike times

In [45]:
with open(os.path.join(data_dir, f'{raw_pickle_name}.pkl'), 'rb') as f:
    sorted_sessions_list = pickle.load(f)

**make a log converting the data list into a df**
- temporarily keeping events and units in the df. 
- will drop later

In [46]:
def generate_sessions_sorted(sorted_sessions_list):
    session_info_list = []
    for session in sorted_sessions_list:
        session_info = {
            'mouse': session['subject'],
            'datetime': session['session_datetime'],
            'date': session['session_datetime'].strftime("%Y-%m-%d"),
            'insertion_number': session['insertion_number'],
            'paramset_idx': session['paramset_idx'],
            'num_units': len(session['unit_spikes']),
            'events': session['events'],
            'unit_ids': session['unit_ids'],
            'units': session['unit_spikes']
        }
        session_info_list.append(session_info)
    session_info_df = pd.DataFrame(session_info_list)
    return session_info_df

In [47]:
sorted_sessions_all = generate_sessions_sorted(sorted_sessions_list)

In [48]:
sorted_sessions_all

Unnamed: 0,mouse,datetime,date,insertion_number,paramset_idx,num_units,events,unit_ids,units
0,RZ034,2024-07-13 12:58:26,2024-07-13,0,101,0,subject session_datetime event_type ...,[],[]
1,RZ034,2024-07-13 12:58:26,2024-07-13,1,101,47,subject session_datetime event_type ...,"[7, 8, 13, 14, 16, 19, 20, 26, 34, 43, 44, 50,...","[[0.147313932480591, 0.3241479850050406, 0.941..."
2,RZ034,2024-07-14 12:52:46,2024-07-14,0,101,0,subject session_datetime event_type ...,[],[]
3,RZ034,2024-07-14 12:52:46,2024-07-14,1,101,31,subject session_datetime event_type ...,"[2, 4, 10, 17, 23, 25, 35, 40, 44, 48, 67, 71,...","[[0.8333367225208166, 0.8930702987929775, 0.91..."
4,RZ036,2024-07-12 12:50:31,2024-07-12,0,101,15,subject session_datetime event_type ...,"[2, 6, 7, 9, 11, 13, 17, 19, 20, 22, 24, 28, 3...","[[0.04722898714344431, 0.21768939509224766, 0...."
...,...,...,...,...,...,...,...,...,...
114,RZ070,2025-02-12 14:02:10,2025-02-12,1,101,0,subject session_datetime event_type ...,[],[]
115,RZ070,2025-02-13 11:40:15,2025-02-13,0,101,0,subject session_datetime event_type ...,[],[]
116,RZ070,2025-02-13 11:40:15,2025-02-13,1,101,2,subject session_datetime event_type ...,"[0, 14]","[[0.030466766747558088, 0.03196677167495428, 0..."
117,RZ070,2025-02-14 11:03:49,2025-02-14,0,101,0,subject session_datetime event_type ...,[],[]


In [49]:
sorted_sessions_with_units = sorted_sessions_all.loc[sorted_sessions_all['num_units'] > 0]

print(len(sorted_sessions_with_units), "sessions with units")
print("total cells:", sum(sorted_sessions_with_units.num_units))

96 sessions with units
total cells: 3660


**load recording log**
- download from google sheet and save it to data_dir

In [50]:
recording_log = pd.read_csv(os.path.join(data_dir, 'processing_check', 'sessions_cross_checked.csv'), index_col=0)
# recording_log = recording_log.drop(columns=['NIDAQ', 'simultaneous', 'probe', 'probe treatment', 'insertion speed',
#        'resting time', 'surface', 'extraction speed', 'notes', 'rewards',
#        'num trials', 'tw', 'potential problems', 'sorting notes'])

In [51]:
recording_log

Unnamed: 0,date,mouse,insertion_number,region,potential problems,sorting notes,First_X_Column
0,2024-07-11,RZ034,0,str,,,SIClustering
1,2024-07-12,RZ034,0,str,D drive ran out of space,,not_uploaded
2,2024-07-13,RZ034,0,v1,,assertion error when lauching phy,ManualCuration
3,2024-07-13,RZ034,1,str,,should be all good now,Done
4,2024-07-14,RZ034,0,v1,,invalid sorting key,SIExport
...,...,...,...,...,...,...,...
166,2025-02-12,RZ070,1,str,pump wasnt properly grounded,,Done
167,2025-02-13,RZ070,0,v1,,"no bug, new phy",Done
168,2025-02-13,RZ070,1,str,,,Done
169,2025-02-14,RZ070,0,v1,,"no bug, new phy",Done


**merge the two df to add region info**

In [52]:
sorted_sessions = pd.merge(
    recording_log, sorted_sessions_with_units,
    on=['mouse', 'date', 'insertion_number'],
    how='inner'
)
sorted_sessions['id'] = sorted_sessions[['mouse', 'date', 'region']].agg('_'.join, axis=1)

In [53]:
sorted_sessions

Unnamed: 0,date,mouse,insertion_number,region,potential problems,sorting notes,First_X_Column,datetime,paramset_idx,num_units,events,unit_ids,units,id
0,2024-07-13,RZ034,1,str,,should be all good now,Done,2024-07-13 12:58:26,101,47,subject session_datetime event_type ...,"[7, 8, 13, 14, 16, 19, 20, 26, 34, 43, 44, 50,...","[[0.147313932480591, 0.3241479850050406, 0.941...",RZ034_2024-07-13_str
1,2024-07-14,RZ034,1,str,,,Done,2024-07-14 12:52:46,101,31,subject session_datetime event_type ...,"[2, 4, 10, 17, 23, 25, 35, 40, 44, 48, 67, 71,...","[[0.8333367225208166, 0.8930702987929775, 0.91...",RZ034_2024-07-14_str
2,2024-07-12,RZ036,0,v1,,,Done,2024-07-12 12:50:31,101,15,subject session_datetime event_type ...,"[2, 6, 7, 9, 11, 13, 17, 19, 20, 22, 24, 28, 3...","[[0.04722898714344431, 0.21768939509224766, 0....",RZ036_2024-07-12_v1
3,2024-07-12,RZ036,1,str,,,Done,2024-07-12 12:50:31,101,45,subject session_datetime event_type ...,"[2, 3, 5, 9, 10, 12, 16, 18, 19, 20, 26, 27, 2...","[[2.1390820330857916, 9.098710338170573, 12.45...",RZ036_2024-07-12_str
4,2024-07-13,RZ036,1,str,,,Done,2024-07-13 14:29:03,101,30,subject session_datetime event_type ...,"[3, 4, 17, 27, 36, 38, 43, 48, 51, 52, 72, 73,...","[[0.06722027337187958, 0.13642055481217313, 0....",RZ036_2024-07-13_str
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
90,2025-02-21,RZ065,0,v1,,"no bug, new phy",Done,2025-02-21 11:15:15,101,27,subject session_datetime event_type ...,"[6, 8, 12, 33, 37, 39, 54, 56, 67, 68, 80, 81,...","[[0.04172013704731256, 0.10455367678379723, 0....",RZ065_2025-02-21_v1
91,2025-02-21,RZ065,1,str,,,Done,2025-02-21 11:15:15,101,19,subject session_datetime event_type ...,"[8, 11, 12, 13, 14, 16, 17, 20, 61, 111, 122, ...","[[8.129773372321813, 8.38830188290062, 11.2871...",RZ065_2025-02-21_str
92,2025-02-22,RZ065,1,str,,,Done,2025-02-22 13:03:06,101,162,subject session_datetime event_type ...,"[0, 2, 6, 11, 15, 17, 18, 21, 22, 23, 38, 39, ...","[[0.34590113625756047, 0.4268014020084615, 1.3...",RZ065_2025-02-22_str
93,2025-02-13,RZ070,1,str,,,Done,2025-02-13 11:40:15,101,2,subject session_datetime event_type ...,"[0, 14]","[[0.030466766747558088, 0.03196677167495428, 0...",RZ070_2025-02-13_str


# Process sorted session data

### events processing

In [54]:
def process_raw_events(events_df):
    # Get trial start times (assuming one 'trial' event per trial_id)
    trial_starts = events_df.loc[events_df['event_type'] == 'trial']
    trial_starts = trial_starts.set_index('trial_id')['event_start_time']
    
    # Map trial start times to all events
    events_df['trial_start_time'] = events_df['trial_id'].map(trial_starts)
    
    # Calculate relative times
    events_df['event_start_trial_time'] = events_df['event_start_time'] - events_df['trial_start_time']
    events_df['event_end_trial_time'] = events_df['event_end_time'] - events_df['trial_start_time']
    
    # Drop the column cuz all values are 0
    events_df = events_df.drop(columns=['trial_start_time'])
    
    return events_df

### trials processing

In [55]:
def get_trial_data(trial):
    # Extract visual cue times
    visual_events = trial[trial['event_type'] == 'visual']
    cue_on_time = visual_events['event_start_trial_time'].iloc[0]
    cue_off_time = visual_events['event_end_trial_time'].iloc[0]
    
    # Initialize default values
    trial_data = {
        'missed': True,  # default assumption
        'rewarded': False,
        'cue_on_time': cue_on_time,
        'cue_off_time': cue_off_time,
        'decision_time': np.nan,
        'background_length': cue_off_time - cue_on_time,
        'wait_length': 60,  # default for missed trials
    }
    
    # Check for reward events
    if 'reward' in trial['event_type'].values:
        reward_time = trial.loc[trial['event_type'] == 'reward', 'event_start_trial_time'].iloc[0]
        trial_data.update({
            'missed': False,
            'wait_length': reward_time - cue_off_time
        })
        
        # Check for consumption events
        for cons_type in ['cons_reward', 'cons_no_reward']:
            if cons_type in trial['event_type'].values:
                trial_data.update({
                    'decision_time': trial.loc[trial['event_type'] == cons_type, 'event_start_trial_time'].iloc[0],
                    'rewarded': (cons_type == 'cons_reward')
                })
    
    return trial_data

def generate_trials(events):
    trials = events.loc[events['event_type'] == 'trial'].copy()
    trial_data_list = []
    for t, trial in events.groupby("trial_id"):
        trial_data = {'trial_id': t} | get_trial_data(trial)
        trial_data_list.append(trial_data)
    trial_data_df = pd.DataFrame(trial_data_list)
    trials = pd.merge(trials, trial_data_df, on='trial_id')
    trials['consumption_length'] = trials['event_end_trial_time'] - trials['decision_time']
    trials = trials.rename(columns={"event_end_trial_time": "trial_length"})

    trials = trials.drop(columns=['event_start_trial_time', 'event_type'])
    return trials

### spikes processing

In [56]:
def add_trial_time_to_spikes(spikes, trials):
    for _, trial_basics in trials.iterrows():
        trial_start_time = trial_basics['event_start_time']
        trial_end_time = trial_basics['event_end_time']
        spikes.loc[spikes['spike_time'].between(trial_start_time, trial_end_time), 
                'trial_id'] = trial_basics['trial_id']
        spikes.loc[spikes['spike_time'].between(trial_start_time, trial_end_time), 
                'trial_time'] = spikes['spike_time'] - trial_start_time
        time_columns = ["cue_on_time", "cue_off_time", "decision_time"] 
    trials_to_merge = trials[['trial_id']+ time_columns].copy()
    spikes = trials_to_merge.merge(spikes, on='trial_id', how='inner')
    return spikes

def align_spike_time_to_anchors(spikes):
    spikes[k.TO_CUE_ON] = spikes['trial_time'] - spikes["cue_on_time"]
    spikes[k.TO_CUE_OFF] = spikes['trial_time'] - spikes["cue_off_time"]
    spikes[k.TO_DECISION] = spikes['trial_time'] - spikes["decision_time"]    
    return spikes

def add_period_to_spikes(row):
    if row['cue_on_time'] <= row['trial_time'] < row['cue_off_time']:
        return k.BACKGROUND
    
    if pd.isna(row['decision_time']):
        # Wait period extends indefinitely if no consumption time
        if row['cue_off_time'] <= row['trial_time']:
            return k.WAIT
    else:
        if row['cue_off_time'] <= row['trial_time'] < row['decision_time']:
            return k.WAIT
        elif row['decision_time'] <= row['trial_time']:
            return k.CONSUMPTION

def process_spikes(spikes, trials):
    spikes = pd.DataFrame(spikes, columns=['spike_time'])
    spikes = add_trial_time_to_spikes(spikes, trials)
    spikes = align_spike_time_to_anchors(spikes)
    spikes['period'] = spikes.apply(add_period_to_spikes, axis=1)
    return spikes

### align events to anchor times

In [57]:
def align_events(events, trials):
    time_columns = ["cue_on_time", "cue_off_time", "decision_time"]
    trials_to_merge = trials[['trial_id']+ time_columns].copy()

    events = trials_to_merge.merge(events, on='trial_id', how='inner')
    events[k.TO_CUE_ON] = events['event_start_trial_time'] - events["cue_on_time"]
    events[k.TO_CUE_OFF] = events['event_start_trial_time'] - events["cue_off_time"]
    events[k.TO_DECISION] = events['event_start_trial_time'] - events["decision_time"]
    return events

### process a session through all steps

In [58]:
def process_session(session):
    """session is a row in the df"""
    events = process_raw_events(session['events'])
    trials = generate_trials(events)
    events_aligned = align_events(events, trials)
    units_aligned = [process_spikes(unit, trials) for unit in session['units']]
    return events_aligned, trials, units_aligned

In [59]:
# for debugging
# test_log = sorted_session_all.head(2)
# test_session = sorted_session_all.iloc[6]
# events = process_raw_events(test_session['events'])
# trials = generate_trials(events)
# events_aligned = align_events(events, trials)
# spikes = test_session['units'][1]
# test_events, test_trials, test_units = process_session(test_session)

### loop through all sessions

In [60]:
def process_and_save_all_sessions(sorted_sessions, pickle_dir, regenerate):
    Path(pickle_dir).mkdir(parents=True, exist_ok=True)
    for _, session in sorted_sessions.iterrows():
        output_path = os.path.join(pickle_dir, f"{session['id']}.pkl")
        # Skip if file exists and not regenerating
        if os.path.exists(output_path) and not regenerate:
            print(f"Session {session['id']} already exists at - skipping")
            continue

        events, trials, units = process_session(session)

        session_data = {
            'id' : session['id'],
            'mouse': session['mouse'],
            'date': session['date'],
            'region': session['region'],
            'events': events,
            'trials': trials,
            'unit_ids': session['unit_ids'],
            'units': units,
        }

        with open(output_path, 'wb') as f:
            pickle.dump(session_data, f, protocol=pickle.HIGHEST_PROTOCOL)
        
        print(f"Saved session {session['id']} to {output_path}")

In [61]:
# regenerate = True
regenerate = False
process_and_save_all_sessions(sorted_sessions, pickle_dir, regenerate)

Session RZ034_2024-07-13_str already exists at - skipping
Session RZ034_2024-07-14_str already exists at - skipping
Session RZ036_2024-07-12_v1 already exists at - skipping
Session RZ036_2024-07-12_str already exists at - skipping
Session RZ036_2024-07-13_str already exists at - skipping
Session RZ036_2024-07-14_str already exists at - skipping
Session RZ037_2024-07-16_str already exists at - skipping
Session RZ037_2024-07-17_str already exists at - skipping
Session RZ037_2024-07-18_v1 already exists at - skipping
Session RZ037_2024-07-18_str already exists at - skipping
Session RZ038_2024-07-17_v1 already exists at - skipping
Session RZ038_2024-07-18_str already exists at - skipping
Session RZ038_2024-07-19_str already exists at - skipping
Session RZ039_2024-07-17_str already exists at - skipping
Session RZ047_2024-11-19_str already exists at - skipping
Session RZ047_2024-11-20_v1 already exists at - skipping
Session RZ047_2024-11-21_v1 already exists at - skipping
Session RZ047_2024-

# Finalize and save session log

In [62]:
sorted_sessions = sorted_sessions.drop(columns=['events', 'units'])
sorted_sessions.to_csv(os.path.join(data_dir, 'sessions_official_raw.csv'))

In [63]:
sorted_sessions

Unnamed: 0,date,mouse,insertion_number,region,potential problems,sorting notes,First_X_Column,datetime,paramset_idx,num_units,unit_ids,id
0,2024-07-13,RZ034,1,str,,should be all good now,Done,2024-07-13 12:58:26,101,47,"[7, 8, 13, 14, 16, 19, 20, 26, 34, 43, 44, 50,...",RZ034_2024-07-13_str
1,2024-07-14,RZ034,1,str,,,Done,2024-07-14 12:52:46,101,31,"[2, 4, 10, 17, 23, 25, 35, 40, 44, 48, 67, 71,...",RZ034_2024-07-14_str
2,2024-07-12,RZ036,0,v1,,,Done,2024-07-12 12:50:31,101,15,"[2, 6, 7, 9, 11, 13, 17, 19, 20, 22, 24, 28, 3...",RZ036_2024-07-12_v1
3,2024-07-12,RZ036,1,str,,,Done,2024-07-12 12:50:31,101,45,"[2, 3, 5, 9, 10, 12, 16, 18, 19, 20, 26, 27, 2...",RZ036_2024-07-12_str
4,2024-07-13,RZ036,1,str,,,Done,2024-07-13 14:29:03,101,30,"[3, 4, 17, 27, 36, 38, 43, 48, 51, 52, 72, 73,...",RZ036_2024-07-13_str
...,...,...,...,...,...,...,...,...,...,...,...,...
90,2025-02-21,RZ065,0,v1,,"no bug, new phy",Done,2025-02-21 11:15:15,101,27,"[6, 8, 12, 33, 37, 39, 54, 56, 67, 68, 80, 81,...",RZ065_2025-02-21_v1
91,2025-02-21,RZ065,1,str,,,Done,2025-02-21 11:15:15,101,19,"[8, 11, 12, 13, 14, 16, 17, 20, 61, 111, 122, ...",RZ065_2025-02-21_str
92,2025-02-22,RZ065,1,str,,,Done,2025-02-22 13:03:06,101,162,"[0, 2, 6, 11, 15, 17, 18, 21, 22, 23, 38, 39, ...",RZ065_2025-02-22_str
93,2025-02-13,RZ070,1,str,,,Done,2025-02-13 11:40:15,101,2,"[0, 14]",RZ070_2025-02-13_str
