In [1]:
import os
from pathlib import Path

import numpy as np
import pickle
import pandas as pd

import constants as k

In [2]:
raw_pickle_name = "session_keys_0120"
data_dir = '/Users/rebekahzhang/data/neural_data'
pickle_dir = Path(os.path.join(data_dir, 'session_pickles'))
figure_folder = os.path.join(data_dir, 'figures')

# Unify sorted data and recording log

**load sorted session data pickle**
- pickle is a list of dictionaries
- each dict representing one session, with subject, session date time, a event df, and a list of units
- each unit is an array of spike times

In [3]:
with open(os.path.join(data_dir, f'{raw_pickle_name}.pkl'), 'rb') as f:
    sorted_sessions_list = pickle.load(f)

**make a log converting the data list into a df**
- temporarily keeping events and units in the df. 
- will drop later

In [4]:
def generate_sessions_sorted(sorted_sessions_list):
    session_info_list = []
    for session in sorted_sessions_list:
        session_info = {
            'mouse': session['subject'],
            'datetime': session['session_datetime'],
            'date': session['session_datetime'].strftime("%Y-%m-%d"),
            'insertion_number': session['insertion_number'],
            'paramset_idx': session['paramset_idx'],
            'num_units': len(session['spikes']),
            'events': session['events'],
            'units': session['spikes']
        }
        session_info_list.append(session_info)
    session_info_df = pd.DataFrame(session_info_list)
    return session_info_df

In [5]:
sorted_sessions = generate_sessions_sorted(sorted_sessions_list)

In [6]:
sorted_sessions_with_units = sorted_sessions.loc[sorted_sessions['num_units'] > 0]

print(len(sorted_sessions_with_units), "sessions with units")
print("total cells:", sum(sorted_sessions_with_units.num_units))

54 sessions with units
total cells: 1385


**load recording log**
- download from google sheet and save it to data_dir

In [7]:
recording_log = pd.read_csv(os.path.join(data_dir, 'recording_log.csv'))
recording_log = recording_log.drop(columns=['NIDAQ', 'simultaneous', 'probe', 'probe treatment', 'insertion speed',
       'resting time', 'surface', 'extraction speed', 'notes', 'rewards',
       'num trials', 'tw', 'potential problems', 'sorting notes'])

**merge the two df to add region info**

In [8]:
sorted_session_all = pd.merge(
    recording_log, sorted_sessions,
    on=['mouse', 'date', 'insertion_number'],
    how='inner'
)
sorted_session_all['id'] = sorted_session_all[['mouse', 'date', 'region']].agg('_'.join, axis=1)

# Process sorted session data

### events processing

In [9]:
def process_raw_events(events_df):
    # Get trial start times (assuming one 'trial' event per trial_id)
    trial_starts = events_df.loc[events_df['event_type'] == 'trial']
    trial_starts = trial_starts.set_index('trial_id')['event_start_time']
    
    # Map trial start times to all events
    events_df['trial_start_time'] = events_df['trial_id'].map(trial_starts)
    
    # Calculate relative times
    events_df['event_start_trial_time'] = events_df['event_start_time'] - events_df['trial_start_time']
    events_df['event_end_trial_time'] = events_df['event_end_time'] - events_df['trial_start_time']
    
    # Drop the column cuz all values are 0
    events_df = events_df.drop(columns=['trial_start_time'])
    
    return events_df

### trials processing

In [10]:
def get_trial_data(trial):
    # Extract visual cue times
    visual_events = trial[trial['event_type'] == 'visual']
    cue_on_time = visual_events['event_start_trial_time'].iloc[0]
    cue_off_time = visual_events['event_end_trial_time'].iloc[0]
    
    # Initialize default values
    trial_data = {
        'missed': True,  # default assumption
        'rewarded': False,
        'cue_on_time': cue_on_time,
        'cue_off_time': cue_off_time,
        'consumption_time': np.nan,
        'background_length': cue_off_time - cue_on_time,
        'wait_length': 60,  # default for missed trials
    }
    
    # Check for reward events
    if 'reward' in trial['event_type'].values:
        reward_time = trial.loc[trial['event_type'] == 'reward', 'event_start_trial_time'].iloc[0]
        trial_data.update({
            'missed': False,
            'wait_length': reward_time - cue_off_time
        })
        
        # Check for consumption events
        for cons_type in ['cons_reward', 'cons_no_reward']:
            if cons_type in trial['event_type'].values:
                trial_data.update({
                    'consumption_time': trial.loc[trial['event_type'] == cons_type, 'event_start_trial_time'].iloc[0],
                    'rewarded': (cons_type == 'cons_reward')
                })
    
    return trial_data

def generate_trials(events):
    trials = events.loc[events['event_type'] == 'trial'].copy()
    trial_data_list = []
    for t, trial in events.groupby("trial_id"):
        trial_data = {'trial_id': t} | get_trial_data(trial)
        trial_data_list.append(trial_data)
    trial_data_df = pd.DataFrame(trial_data_list)
    trials = pd.merge(trials, trial_data_df, on='trial_id')
    trials['consumption_length'] = trials['event_end_trial_time'] - trials['consumption_time']
    trials = trials.rename(columns={"event_end_trial_time": "trial_length"})

    trials = trials.drop(columns=['event_start_trial_time', 'event_type'])
    return trials

### spikes processing

In [11]:
def add_trial_time_to_spikes(spikes, trials):
    for _, trial_basics in trials.iterrows():
        trial_start_time = trial_basics['event_start_time']
        trial_end_time = trial_basics['event_end_time']
        spikes.loc[spikes['spike_time'].between(trial_start_time, trial_end_time), 
                'trial_id'] = trial_basics['trial_id']
        spikes.loc[spikes['spike_time'].between(trial_start_time, trial_end_time), 
                'trial_time'] = spikes['spike_time'] - trial_start_time
        time_columns = ["cue_on_time", "cue_off_time", "consumption_time"] 
    trials_to_merge = trials[['trial_id']+ time_columns].copy()
    spikes = trials_to_merge.merge(spikes, on='trial_id', how='inner')
    return spikes

def align_spike_time_to_anchors(spikes):
    spikes[k.TO_CUE_ON] = spikes['trial_time'] - spikes["cue_on_time"]
    spikes[k.TO_CUE_OFF] = spikes['trial_time'] - spikes["cue_off_time"]
    spikes[k.TO_CONSUMPTION] = spikes['trial_time'] - spikes["consumption_time"]    
    return spikes

def add_period_to_spikes(row):
    if row['cue_on_time'] <= row['trial_time'] < row['cue_off_time']:
        return k.BACKGROUND
    elif row['cue_off_time'] <= row['trial_time'] < row['consumption_time']:
        return k.WAIT
    elif row['consumption_time'] <= row['trial_time']:
        return k.CONSUMPTION

def process_spikes(spikes, trials):
    spikes = pd.DataFrame(spikes, columns=['spike_time'])
    spikes = add_trial_time_to_spikes(spikes, trials)
    spikes = align_spike_time_to_anchors(spikes)
    spikes['period'] = spikes.apply(add_period_to_spikes, axis=1)
    return spikes

### align events to anchor times

In [12]:
def align_events(events, trials):
    time_columns = ["cue_on_time", "cue_off_time", "consumption_time"]
    trials_to_merge = trials[['trial_id']+ time_columns].copy()

    events = trials_to_merge.merge(events, on='trial_id', how='inner')
    events[k.TO_CUE_ON] = events['event_start_trial_time'] - events["cue_on_time"]
    events[k.TO_CUE_OFF] = events['event_start_trial_time'] - events["cue_off_time"]
    events[k.TO_CONSUMPTION] = events['event_start_trial_time'] - events["consumption_time"]
    return events

### process a session through all steps

In [13]:
def process_session(session):
    """session is a row in the df"""
    events = process_raw_events(session['events'])
    trials = generate_trials(events)
    events_aligned = align_events(events, trials)
    units_aligned = [process_spikes(unit, trials) for unit in session['units']]
    return events_aligned, trials, units_aligned

In [14]:
# for debugging
# test_log = sorted_session_all.head(2)
# test_session = sorted_session_all.iloc[6]
# test_events, test_trials, test_units = process_session(test_session)

### loop through all sessions

In [15]:
def process_and_save_all_sessions(sorted_session_all, pickle_dir, regenerate):
    Path(pickle_dir).mkdir(parents=True, exist_ok=True)
    for _, session in sorted_session_all.iterrows():
        output_path = os.path.join(pickle_dir, f"{session['id']}.pkl")
                # Skip if file exists and we're not regenerating
        if os.path.exists(output_path) and not regenerate:
            print(f"Session {session['id']} already exists at - skipping")
            continue

        events, trials, units = process_session(session)

        session_data = {
            'id' : session['id'],
            'mouse': session['mouse'],
            'date': session['date'],
            'region': session['region'],
            'events': events,
            'trials': trials,
            'units': units,
        }

        with open(output_path, 'wb') as f:
            pickle.dump(session_data, f, protocol=pickle.HIGHEST_PROTOCOL)
        
        print(f"Saved session {session['id']} to {output_path}")

In [16]:
regenerate = False
process_and_save_all_sessions(sorted_session_all, pickle_dir, regenerate)

Session RZ034_2024-07-13_v1 already exists at - skipping
Session RZ034_2024-07-13_str already exists at - skipping
Session RZ034_2024-07-14_str already exists at - skipping
Session RZ036_2024-07-13_v1 already exists at - skipping
Session RZ036_2024-07-13_str already exists at - skipping
Session RZ036_2024-07-14_str already exists at - skipping
Session RZ037_2024-07-16_v1 already exists at - skipping
Session RZ037_2024-07-16_str already exists at - skipping
Session RZ037_2024-07-17_v1 already exists at - skipping
Session RZ037_2024-07-17_str already exists at - skipping
Session RZ038_2024-07-16_v1 already exists at - skipping
Session RZ038_2024-07-16_str already exists at - skipping
Session RZ038_2024-07-18_str already exists at - skipping
Session RZ038_2024-07-19_str already exists at - skipping
Session RZ039_2024-07-17_str already exists at - skipping
Session RZ053_2024-10-22_v1 already exists at - skipping
Session RZ036_2024-07-12_v1 already exists at - skipping
Session RZ036_2024-07

# Finalize and save session log

In [17]:
sorted_session_all = sorted_session_all.drop(columns=['events', 'units'])
sorted_session_all.to_csv(os.path.join(data_dir, 'sorted_session_all.csv'))
sorted_sessions_w_units = sorted_session_all.loc[sorted_session_all['num_units']>0]
sorted_sessions_w_units.to_csv(os.path.join(data_dir, 'sorted_session_with_units.csv'))