In [1]:
import os

import numpy as np
import pandas as pd
import pickle
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
from tqdm.notebook import tqdm

In [2]:
data_dir = '/Users/rebekahzhang/data/neural_data'
figure_folder = os.path.join(data_dir, 'figures')

### get data pickle and filter out sessions with no units

In [3]:
# load the session_keys pickle. its a list of dictionaries, each dict representing one session, 
# with a event df, and a list of units, each unit is an array of spike times
# with open(os.path.join(data_dir, 'session_keys_0120.pkl'), 'rb') as f:
#     session_data_list = pickle.load(f)
with open(os.path.join(data_dir, 'session_keys.pkl'), 'rb') as f:
    session_data_list = pickle.load(f)

In [4]:
len(session_data_list)

21

In [5]:
# filter out the sessions with no units
def get_seeions_with_units(session_data_list):
    sessions_with_units = []
    for session in session_data_list:
        num_neurons = len(session['spikes'])
        if num_neurons > 0:
            sessions_with_units.append(session)
    return sessions_with_units

session_with_units = get_seeions_with_units(session_data_list)
len(session_with_units)

12

In [6]:
total_cell_count = 0
for session in session_with_units:
    total_cell_count+=len(session['spikes'])
    print(session['subject'], session['session_datetime'], len(session['spikes']))
print("total cells:", total_cell_count)

RZ034 2024-07-14 12:52:46 31
RZ036 2024-07-12 12:50:31 15
RZ036 2024-07-12 12:50:31 45
RZ036 2024-07-13 14:29:03 30
RZ037 2024-07-16 11:33:07 61
RZ037 2024-07-17 17:09:40 60
RZ037 2024-07-18 12:39:17 13
RZ037 2024-07-18 12:39:17 18
RZ038 2024-07-17 12:01:45 2
RZ038 2024-07-18 15:07:36 32
RZ038 2024-07-19 12:38:14 19
RZ039 2024-07-17 14:45:27 16
total cells: 342


In [7]:
def generate_session_identity(session_dict):
    mouse = session_dict['subject']
    date = session_dict['session_datetime'].strftime("%Y-%m-%d")
    insertion = session_dict['insertion_number']+1
    session_identity = f'{mouse}_{date}_probe-{insertion}'
    return session_identity

### add trial time to events
each row is an event

In [8]:
def add_trial_time(trial):
    trial_start_time = trial.loc[trial['event_type'] == 'trial', 'event_start_time'].iloc[0]
    trial['event_start_trial_time'] = trial['event_start_time'] - trial_start_time
    trial['event_end_trial_time'] = trial['event_end_time'] - trial_start_time
    return trial

### generate trials and add trial data based on events
each row is a trial

In [9]:
def get_trial_data(trial):
    missed = False
    rewarded = False
    wait_length = np.nan
    
    cue_on_time = trial.loc[trial['event_type'] == 'visual', 'event_start_trial_time'].iloc[0]
    cue_off_time = trial.loc[trial['event_type'] == 'visual', 'event_end_trial_time'].iloc[0]
    cons_time = np.nan

    if 'reward' in trial.event_type.unique():
        missed = False
        wait_length = trial.loc[trial['event_type'] == 'reward', 'event_start_trial_time'].iloc[0] - cue_off_time

        if 'cons_reward' in trial.event_type.unique():
            rewarded = True
            cons_time = trial.loc[trial['event_type'] == 'cons_reward', 'event_start_trial_time'].iloc[0]
        elif 'cons_no_reward' in trial.event_type.unique():
            rewarded = False
            cons_time = trial.loc[trial['event_type'] == 'cons_no_reward', 'event_start_trial_time'].iloc[0]

    elif 'reward' not in trial.event_type.unique():
        missed = True
        wait_length = 60

    trial_data = {
        'missed': missed,
        'rewarded': rewarded,
        'cue_on_time': cue_on_time,
        'cue_off_time': cue_off_time,
        'cons_time': cons_time,
        'bg_length': cue_off_time-cue_on_time,
        'wait_length': wait_length,
    }

    return trial_data

In [10]:
def generate_trials(events):
    trials = events.loc[events['event_type'] == 'trial'].copy().reset_index()
    trial_data_list = []
    for t, trial in events.groupby("trial_id"):
        trial_data = {'trial_id': t} | get_trial_data(trial)
        trial_data_list.append(trial_data)
    trial_data_df = pd.DataFrame(trial_data_list)
    trials = pd.merge(trials, trial_data_df, on='trial_id')
    trial['cons_length'] = trials['event_end_time'] - trials['cons_time']
    return trials

### generate spikes df for each unit
get an array of spike times, making it into a df, then adding trial times to it

In [11]:
def generate_spikes(spikes, trials):
    spikes = pd.DataFrame(spikes, columns=['spike_time'])
    for _, trial_basics in trials.iterrows():
        trial_start_time = trial_basics['event_start_time']
        trial_end_time = trial_basics['event_end_time']
        spikes.loc[spikes['spike_time'].between(trial_start_time, trial_end_time), 
                'trial_id'] = trial_basics['trial_id']
        spikes.loc[spikes['spike_time'].between(trial_start_time, trial_end_time), 
                'trial_time'] = spikes['spike_time'] - trial_start_time
    return spikes

In [12]:
def generate_sorted_trial_id_dict(trials):
    trial_id_by_time_waited = trials.sort_values(['wait_length']).trial_id.tolist()
    trial_id_by_reward_and_time_watied = trials.sort_values(by=['missed','rewarded','wait_length']).trial_id.tolist()
    trial_id_by_bg_length = trials.sort_values(['bg_length']).trial_id.tolist()
    trial_id_by_trial_num = trials.trial_id.tolist()
    sorted_trial_id_dict = {
        "time_waited": trial_id_by_time_waited, 
        "reward_time_waited": trial_id_by_reward_and_time_watied, 
        "bg_length": trial_id_by_bg_length,
        "trial_num": trial_id_by_trial_num
    }
    return sorted_trial_id_dict

## Plotting raster only

In [13]:
def create_figure():
    fig = Figure(figsize=(12, 8))
    ax = fig.add_subplot(111)
    return fig, ax

def align_times(trial_data, spikes, events, anchor):
    anchor_time = trial_data[anchor].iloc[0]
    aligned_spike_times = spikes['trial_time'] - anchor_time
    
    aligned_events = {}
    for event_type in ['visual', 'wait', 'cons_reward', 'cons_no_reward']:
        event_times = events.loc[events['event_type'] == event_type, 'event_start_trial_time']
        aligned_events[event_type] = event_times - anchor_time
    
    return aligned_spike_times, aligned_events

def plot_trial(ax, aligned_spike_times, aligned_events, trial_offset):
    ax.eventplot(aligned_spike_times, lineoffsets=trial_offset, color='k', linelengths=1.0, linewidths=0.3)
    
    event_colors = {'visual': 'orange', 'wait': 'g', 'cons_reward': 'b', 'cons_no_reward': 'r'}
    for event_type, times in aligned_events.items():
        ax.eventplot(times, lineoffsets=trial_offset, color=event_colors[event_type], linelengths=1.0, linewidths=1)

def add_legend(ax):
    legend_handles = [
        plt.Line2D([0], [0], color=c, lw=2) 
        for c in ['orange', 'g', 'b', 'r']
    ]
    legend_labels = ['Visual', 'Wait Start', 'Reward', 'No Reward']
    ax.legend(legend_handles, legend_labels, loc='center left', bbox_to_anchor=(1, 0.95))

def save_and_close_figure(fig, title, figure_folder):
    fig.tight_layout()
    fig.savefig(f'{figure_folder}/{title}.png', bbox_inches='tight', dpi=300, format='png')
    fig.clf()
    plt.close(fig)

def plot_raster_per_neuron(unit_identity, sorted_trial_id_dict, trials_by_trial, events_by_trial, spikes_by_trial, sorter, anchor, save_fig=True):
    fig, ax = create_figure()
    
    trial_id_sorted = sorted_trial_id_dict[sorter]
    plotted_trials = 0
    
    for t in trial_id_sorted:
        if t in trials_by_trial.groups and t in spikes_by_trial.groups:
            trial_data = trials_by_trial.get_group(t)
            spikes = spikes_by_trial.get_group(t)
            events = events_by_trial.get_group(t)
            
            aligned_spike_times, aligned_events = align_times(trial_data, spikes, events, anchor)
            plot_trial(ax, aligned_spike_times, aligned_events, plotted_trials)
            
            if plotted_trials == 0:
                add_legend(ax)
            
            plotted_trials += 1
    
    ax.set_xlabel('Time (s)')
    ax.set_ylabel('Trials')
    title = f'{unit_identity} sorted by {sorter} aligned to {anchor}'
    ax.set_title(title)
    if save_fig:
        save_and_close_figure(fig, title, figure_folder)

In [14]:
def generate_all_plots(trials_by_trial, sorted_trial_id_dict, events_by_trial, spikes_by_trial, unit_identity):
    # plot_raster_per_neuron(unit_identity, sorted_trial_id_dict, trials_by_trial, events_by_trial, spikes_by_trial, 'time_waited', 'cue_on_time')
    # plot_raster_per_neuron(unit_identity, sorted_trial_id_dict, trials_by_trial, events_by_trial, spikes_by_trial, 'bg_length', 'cue_on_time')
    # plot_raster_per_neuron(unit_identity, sorted_trial_id_dict, trials_by_trial, events_by_trial, spikes_by_trial, 'trial_num', 'cue_on_time')
    # plot_raster_per_neuron(unit_identity, sorted_trial_id_dict, trials_by_trial, events_by_trial, spikes_by_trial, 'reward_time_waited', 'cue_on_time')
    
    # plot_raster_per_neuron(unit_identity, sorted_trial_id_dict, trials_by_trial, events_by_trial, spikes_by_trial, 'time_waited', 'cue_off_time')
    # plot_raster_per_neuron(unit_identity, sorted_trial_id_dict, trials_by_trial, events_by_trial, spikes_by_trial, 'bg_length', 'cue_off_time')
    # plot_raster_per_neuron(unit_identity, sorted_trial_id_dict, trials_by_trial, events_by_trial, spikes_by_trial, 'trial_num', 'cue_off_time')
    plot_raster_per_neuron(unit_identity, sorted_trial_id_dict, trials_by_trial, events_by_trial, spikes_by_trial, 'reward_time_waited', 'cue_off_time')

    # plot_raster_per_neuron(unit_identity, sorted_trial_id_dict, trials_by_trial, events_by_trial, spikes_by_trial, 'time_waited', 'cons_time')
    # plot_raster_per_neuron(unit_identity, sorted_trial_id_dict, trials_by_trial, events_by_trial, spikes_by_trial, 'bg_length', 'cons_time')
    # plot_raster_per_neuron(unit_identity, sorted_trial_id_dict, trials_by_trial, events_by_trial, spikes_by_trial, 'trial_num', 'cons_time')
    # plot_raster_per_neuron(unit_identity, sorted_trial_id_dict, trials_by_trial, events_by_trial, spikes_by_trial, 'reward_time_waited', 'cons_time')


loop through all sessions all units

In [15]:
figure_folder = os.path.join(figure_folder, "raster_to_cue_off")

In [16]:
# claude code implementing tqdm
total_units = sum(len(session['spikes']) for session in session_with_units)
failed_units = []

with tqdm(total=total_units, desc="Processing Units") as pbar:
    for session in session_with_units:
        session_identity = generate_session_identity(session)
        events = session['events']
        events = events.groupby('trial_id', group_keys=False).apply(add_trial_time)
        events_by_trial = events.groupby('trial_id')
        
        trials = generate_trials(events)
        sorted_trial_id_dict = generate_sorted_trial_id_dict(trials)
        trials_by_trial = trials.groupby('trial_id')
        
        units = session['spikes']
        for i, unit_spikes in enumerate(units):
            unit_identity = session_identity + '_unit-' + str(i)
            spikes = generate_spikes(unit_spikes, trials)
            spikes_by_trial = spikes.groupby('trial_id')
            try:
                generate_all_plots(trials_by_trial, sorted_trial_id_dict, events_by_trial, spikes_by_trial, unit_identity)
            except:
                print(f"Failed for: {unit_identity}")
                failed_units.append(unit_identity)
            pbar.update(1)

Processing Units:   0%|          | 0/342 [00:00<?, ?it/s]

  events = events.groupby('trial_id', group_keys=False).apply(add_trial_time)
  events = events.groupby('trial_id', group_keys=False).apply(add_trial_time)
  events = events.groupby('trial_id', group_keys=False).apply(add_trial_time)
  events = events.groupby('trial_id', group_keys=False).apply(add_trial_time)
  events = events.groupby('trial_id', group_keys=False).apply(add_trial_time)
  events = events.groupby('trial_id', group_keys=False).apply(add_trial_time)
  events = events.groupby('trial_id', group_keys=False).apply(add_trial_time)
  events = events.groupby('trial_id', group_keys=False).apply(add_trial_time)
  events = events.groupby('trial_id', group_keys=False).apply(add_trial_time)
  events = events.groupby('trial_id', group_keys=False).apply(add_trial_time)
  events = events.groupby('trial_id', group_keys=False).apply(add_trial_time)
  events = events.groupby('trial_id', group_keys=False).apply(add_trial_time)


## Plot raster with histogram

In [None]:
def align_to_anchor_times(df):
    """Add columns for different alignment times."""
    df['to_cue_on'] = df['trial_time'] - df['cue_on_time']
    df['to_cue_off'] = df['trial_time'] - df['cue_off_time']
    df['to_cons'] = df['trial_time'] - df['cons_time']
    return df

def prepare_aligned_data(events, trials, spikes):
    """Prepare aligned events and spikes data."""
    # Prepare events data
    events_needed = ['visual', 'wait', 'cons_reward', 'cons_no_reward']
    events_to_plot = events.loc[events['event_type'].isin(events_needed),
                              ['trial_id', 'event_type', 'event_start_trial_time']]
    
    # Prepare trials data
    trial_columns = ['trial_id', 'missed', 'rewarded', 'wait_length', 
                    'cue_on_time', 'cue_off_time', 'cons_time', 'bg_length']
    trials_to_merge = trials[trial_columns].copy()
    
    # Align events
    events_to_align = trials_to_merge.merge(events_to_plot, on='trial_id', how='inner')
    events_to_align = events_to_align.rename(columns={'event_start_trial_time': 'trial_time'})
    events_aligned = align_to_anchor_times(events_to_align)
    
    # Align spikes
    spikes_to_align = trials_to_merge.merge(spikes, on='trial_id', how='inner')
    spikes_aligned = align_to_anchor_times(spikes_to_align)
    
    return events_aligned, spikes_aligned

In [None]:
def add_raster_legend(ax, event_colors):
    """Add legend to the raster plot."""
    legend_handles = [
        plt.Line2D([0], [0], color=event_colors['visual'], lw=2),
        plt.Line2D([0], [0], color=event_colors['wait'], lw=2),
        plt.Line2D([0], [0], color=event_colors['cons_reward'], lw=2),
        plt.Line2D([0], [0], color=event_colors['cons_no_reward'], lw=2)
    ]
    
    legend_labels = ['Visual', 'Wait Start', 'Reward', 'No Reward']
    
    ax.legend(legend_handles, legend_labels, loc='center left', bbox_to_anchor=(1, 0.9))

def plot_raster(ax, spikes_aligned, events_aligned, sorted_trial_id_dict, 
                anchor, sorter, event_colors, histo_colors):
    """
    Plot raster plot for spike times and events.
    """
    spikes_aligned_by_trial = spikes_aligned.groupby('trial_id')
    events_aligned_by_trial = events_aligned.groupby('trial_id')
    plotted_trials = 0
    
    for t in sorted_trial_id_dict[sorter]:
        if t in spikes_aligned_by_trial.groups:
            # Plot spike times
            trial_spike_times = spikes_aligned_by_trial.get_group(t)[anchor]
            ax.eventplot(trial_spike_times, lineoffsets=plotted_trials, 
                        color='k', linelengths=1.0, linewidths=0.3)
            
            # Plot event times
            trial_event_times = events_aligned_by_trial.get_group(t)[['event_type', anchor]]
            trial_event_times_dict = trial_event_times.set_index('event_type')[anchor].to_dict()
            
            for event_type, event_time in trial_event_times_dict.items():
                ax.eventplot([event_time], lineoffsets=plotted_trials,
                           color=event_colors[event_type], linelengths=1.0, linewidths=1)
            
            if plotted_trials == 0:
                add_raster_legend(ax, event_colors)
            plotted_trials += 1
    
    # Add vertical line
    ax.axvline(0, color=histo_colors[anchor], linestyle='--', alpha=0.5)
    ax.set_ylabel('Trial number')

def plot_kde(ax, spikes_aligned, anchor, histo_colors, event_colors):
    """
    Plot KDE for spike times with common normalization between rewarded and non-rewarded.
    """
    # Shared KDE parameters
    kde_params = {
        'bw_adjust': 0.5,
        'common_norm': True,
        'ax': ax
    }
    
    # Plot KDEs
    sns.kdeplot(spikes_aligned[anchor], color='black', label='All Spikes', 
                common_norm=False, bw_adjust=0.5, ax=ax)
    sns.kdeplot(spikes_aligned.loc[spikes_aligned['rewarded']==True, anchor], 
                color=event_colors['cons_reward'], label='Rewarded', **kde_params)
    sns.kdeplot(spikes_aligned.loc[spikes_aligned['rewarded']==False, anchor], 
                color=event_colors['cons_no_reward'], label='Non-Rewarded', **kde_params)
    
    # Add vertical line
    ax.axvline(0, color=histo_colors[anchor], linestyle='--', alpha=0.5)
    
    # Add legend and labels
    ax.legend(bbox_to_anchor=(1, 1), loc='upper left')
    ax.set_xlabel('Time (s)')
    ax.set_ylabel('Density')

def save_and_close_figure(fig, title, figure_folder):
    fig.tight_layout()
    fig.savefig(f'{figure_folder}/{title}.png', bbox_inches='tight', dpi=300, format='png')
    fig.clf()
    plt.close(fig)

In [None]:
def plot_raster_and_kde(events, trials, spikes, sorted_trial_id_dict, 
                        unit_identity, anchor, sorter, figure_folder):
    events_aligned, spikes_aligned = prepare_aligned_data(events, trials, spikes)

    # Color schemes
    histo_colors = {
    'to_cue_on': 'lightcoral',
    'to_cue_off': 'g',
    'to_cons': 'mediumorchid'
    }
    event_colors = {
    'visual': 'lightcoral',
    'wait': 'g',
    'cons_reward': 'tab:blue',
    'cons_no_reward': 'tab:orange'
    }

    # Create figure
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 8), height_ratios=[3, 1], sharex=True)
    plt.subplots_adjust(hspace=0.1)
    plot_raster(ax1, spikes_aligned, events_aligned, sorted_trial_id_dict, anchor, sorter, event_colors, histo_colors)
    plot_kde(ax2, spikes_aligned, anchor, histo_colors, event_colors)

        # Set title
    title = f'{unit_identity} sorted by {sorter} aligned {anchor}'
    ax1.set_title(title)

    save_and_close_figure(fig, title, figure_folder)

In [None]:
def generate_all_plots(events, trials, spikes, sorted_trial_id_dict, unit_identity, figure_folder):
    plot_raster_and_kde(events, trials, spikes, sorted_trial_id_dict, unit_identity, 
                   'to_cue_on', 'reward_time_waited', figure_folder)
    plot_raster_and_kde(events, trials, spikes, sorted_trial_id_dict, unit_identity, 
                   'to_cue_off', 'reward_time_waited', figure_folder)
    plot_raster_and_kde(events, trials, spikes, sorted_trial_id_dict, unit_identity, 
                   'to_cons', 'reward_time_waited', figure_folder)

single session test

In [None]:
session = session_with_units[0]
session_identity = generate_session_identity(session)
events = session['events']
events = events.groupby('trial_id', group_keys=False).apply(add_trial_time)

trials = generate_trials(events)
sorted_trial_id_dict = generate_sorted_trial_id_dict(trials)

units = session['spikes']
i=4
unit_spikes = units[i]
unit_identity = session_identity + '_unit-' + str(i)
spikes = generate_spikes(unit_spikes, trials)
events_aligned, spikes_aligned = prepare_aligned_data(events, trials, spikes)

In [None]:
# histo_colors = {
#     'to_cue_on': 'lightcoral',
#     'to_cue_off': 'g',
#     'to_cons': 'mediumorchid'
# }
# event_colors = {
#     'visual': 'lightcoral',
#     'wait': 'g',
#     'cons_reward': 'tab:blue',
#     'cons_no_reward': 'tab:orange'
# }
# anchor = 'to_cue_on'
# sorter = 'reward_time_waited'

# fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 8), height_ratios=[3, 1], sharex=True)
# plt.subplots_adjust(hspace=0.1)

# plot_raster(ax1, spikes_aligned, events_aligned, sorted_trial_id_dict, anchor, sorter, event_colors, histo_colors)
# plot_kde(ax2, spikes_aligned, anchor, histo_colors, event_colors)

loop through all sessions

In [None]:
total_units = sum(len(session['spikes']) for session in session_with_units)
failed_units = []

with tqdm(total=total_units, desc="Processing Units") as pbar:
    for session in session_with_units:
        session_identity = generate_session_identity(session)
        events = session['events']
        events = events.groupby('trial_id', group_keys=False).apply(add_trial_time)

        trials = generate_trials(events)
        sorted_trial_id_dict = generate_sorted_trial_id_dict(trials)
        
        units = session['spikes']
        for i, unit_spikes in enumerate(units):
            unit_identity = session_identity + '_unit-' + str(i)
            spikes = generate_spikes(unit_spikes, trials)
            try:
                generate_all_plots(events, trials, spikes, sorted_trial_id_dict, unit_identity, figure_folder)
            except:
                print(f"Failed for: {unit_identity}")
                failed_units.append(unit_identity)
            pbar.update(1)

# single session test

In [None]:
import neo
import quantities as pq
from elephant.statistics import instantaneous_rate
from elephant.kernels import GaussianKernel

from scipy.ndimage import gaussian_filter1d

In [None]:
def align_to_anchor_times(df):
    """Add columns for different alignment times."""
    df['to_cue_on'] = df['trial_time'] - df['cue_on_time']
    df['to_cue_off'] = df['trial_time'] - df['cue_off_time']
    df['to_cons'] = df['trial_time'] - df['cons_time']
    return df

def assign_period(row):
    if row['cue_on_time'] <= row['trial_time'] < row['cue_off_time']:
        return 'background'
    elif row['cue_off_time'] <= row['trial_time'] < row['cons_time']:
        return 'wait'
    elif row['cons_time'] <= row['trial_time']:
        return 'cons'

def prepare_aligned_data(events, trials, spikes):
    """Prepare aligned events and spikes data."""
    # Prepare events data
    events_needed = ['visual', 'wait', 'cons_reward', 'cons_no_reward']
    events_to_plot = events.loc[events['event_type'].isin(events_needed),
                              ['trial_id', 'event_type', 'event_start_trial_time']]
    
    # Prepare trials data
    trial_columns = ['trial_id', 'trial_length', 'missed', 'rewarded',
                     'bg_length', 'wait_length', 
                    'cue_on_time', 'cue_off_time', 'cons_time']
    trials_to_merge = trials[trial_columns].copy()
    
    # Align events
    events_to_align = trials_to_merge.merge(events_to_plot, on='trial_id', how='inner')
    events_to_align = events_to_align.rename(columns={'event_start_trial_time': 'trial_time'})
    events_aligned = align_to_anchor_times(events_to_align)
    
    # Align spikes
    spikes_to_align = trials_to_merge.merge(spikes, on='trial_id', how='inner')
    spikes_aligned = align_to_anchor_times(spikes_to_align)
    spikes_aligned['period'] = spikes_aligned.apply(assign_period, axis=1)
    
    return events_aligned, spikes_aligned

In [None]:
session = session_with_units[0]
session_identity = generate_session_identity(session)
events = session['events']
events = events.groupby('trial_id', group_keys=False).apply(add_trial_time)

# trials = generate_trials(events)
# sorted_trial_id_dict = generate_sorted_trial_id_dict(trials)

# units = session['spikes']
# i=2
# unit_spikes = units[i]
# unit_identity = session_identity + '_unit-' + str(i)
# spikes = generate_spikes(unit_spikes, trials)

# events_aligned, spikes_aligned = prepare_aligned_data(events, trials, spikes)

In [None]:
events

## plot sessiong FR

In [None]:
# Your raw spike times (in seconds)
spike_times = spikes.spike_time.tolist()  # e.g., [0.1, 0.3, 0.7, 0.9]

# Define the trial duration (t_stop) for the SpikeTrain
t_stop = max(spike_times) + 1.0  # Add buffer (adjust as needed)

# Create a neo.SpikeTrain
spike_train = neo.SpikeTrain(
    spike_times,
    units=pq.s,        # Units of spike times (seconds)
    t_start=0.0 * pq.s,  # Optional: session start time
    t_stop=t_stop * pq.s  # session end time
)

# Define kernel (e.g., Gaussian with σ=100ms)
kernel = GaussianKernel(sigma=2 * pq.s)

# Calculate instantaneous rate
rate = instantaneous_rate(
    spike_train,
    sampling_period=10 * pq.ms,  # 10ms bins
    kernel=kernel
)

# Get the correct time axis from the rate object itself
time_axis = rate.times.rescale('s').magnitude  # Convert to seconds as numpy array
rate_array = rate.magnitude.squeeze()

# Now plot (dimensions will match)
plt.figure(figsize=(20, 4))
plt.plot(time_axis, rate_array)
plt.xlabel("Time (s)")
plt.xlim(0, time_axis.max())
plt.ylabel("Firing rate (Hz)")
plt.title("Instantaneous Firing Rate")
plt.show()

In [None]:
fr_session_mean=rate_array.mean()
fr_session_max=rate_array.max()

## Plotting histo with corrected bins

histo without error bars

In [None]:
def calculate_instantaneous_rate(spike_train, sampling_period=0.01, sigma=0.3, bounds=None):
    """
    Calculate instantaneous firing rate using Gaussian kernel smoothing.
    
    Parameters:
        spike_train (array): Array of spike times
        sampling_period (float): Time resolution for rate calculation (in seconds)
        sigma (float): Standard deviation of Gaussian kernel (in seconds)
        bounds (tuple): (t_start, t_stop) time window to analyze
        test_plot (bool): Whether to plot a test snippet
        
    Returns:
        rate_array (array): Instantaneous firing rate (Hz)
        time_axis (array): Corresponding time points
    """
    # Set time bounds
    t_start, t_stop = bounds if bounds is not None else (0, spike_train[-1])
    
    # Create time axis (centered on bins)
    time_axis = np.arange(t_start, t_stop, sampling_period) + sampling_period / 2
    
    # Filter spikes within bounds and create SpikeTrain object
    spike_train = spike_train[(t_start < spike_train) & (spike_train < t_stop)]
    spike_train = neo.SpikeTrain(spike_train, t_stop=t_stop, units=pq.s)
    
    # Calculate instantaneous rate
    kernel = GaussianKernel(sigma=sigma * pq.s)
    rate = instantaneous_rate(
        spike_train,
        sampling_period=sampling_period * pq.s,
        kernel=kernel,
        t_start=t_start * pq.s,
        t_stop=t_stop * pq.s,
        border_correction=True
    )
    
    return time_axis, rate.as_array().squeeze()

def count_active_trials(trial_durations, bin_edges):
    """
    Count how many trials are active (ongoing) at each time bin.
    
    Parameters:
        trial_durations (array): Duration of each trial
        bin_edges (array): Edges of time bins
        
    Returns:
        active_counts (array): Number of active trials per bin
    """
    return [np.sum(trial_durations > edge) for edge in bin_edges[:-1]]

In [None]:
anchor_event = 'to_cue_on'
max_trial_length = round(spikes_aligned.trial_length.max(), 1)
time_step = 0.1

# Define analysis window
analysis_bounds = (0, max_trial_length)
time_bins = np.arange(analysis_bounds[0], analysis_bounds[1] + time_step, time_step)

# Get trial durations and active trial counts
trial_durations = spikes_aligned.groupby('trial_id')['trial_length'].first()
active_trials = count_active_trials(trial_durations, time_bins)

# Calculate and normalize firing rate
time_axis, firing_rate = calculate_instantaneous_rate(
    spikes_aligned[anchor_event],
    bounds=analysis_bounds,
    sampling_period=time_step,
    sigma=0.3
)

normalized_rate = firing_rate / active_trials

# Plot results
plt.figure(figsize=(12, 4))
plt.plot(time_axis, normalized_rate)
plt.xlabel('Time from cue onset (s)')
plt.ylabel('Normalized firing rate (Hz)')
plt.title('Firing rate aligned to cue onset')
plt.tight_layout()
plt.show()

In [None]:
# rate not corrected
bin_org = np.full(len(time_bins)-1, spikes_aligned.trial_id.max()+1)

plt.figure(figsize=(12, 4))
plt.plot(time_axis, firing_rate / bin_org)
plt.xlabel('Time from cue onset (s)')
plt.ylabel('Normalized firing rate (Hz)')
plt.title('Firing rate aligned to cue onset')
plt.tight_layout()
plt.show()

histo with error bars

In [None]:
spikes_by_trial = spikes_aligned.groupby('trial_id')
bin_centers = time_bins[:-1] + time_step/2
active_trials = count_active_trials(trial_durations, time_bins)
active_trials = np.array(active_trials)

all_trial_spike_times = []
for t, trial in spikes_by_trial:
    all_trial_spike_times.append(trial[anchor_event].tolist())

# Initialize empty array (trials × bins)
counts_per_trial = np.zeros((len(all_trial_spike_times), len(time_bins)-1))
rates_per_trial = np.zeros_like(counts_per_trial) #[trial, time_bin]

for i, trial_spikes in enumerate(all_trial_spike_times):
    counts_per_trial[i], _ = np.histogram(trial_spikes, bins=time_bins)
    rates_per_trial[i] = counts_per_trial[i] / time_step

In [None]:
# Your existing code up to SEM calculation
rates_all_trials = np.nansum(rates_per_trial, 0)/active_trials 
rates_all_trials_sem = np.std(rates_per_trial, axis=0)/np.sqrt(active_trials)

# Create mask for bins with at least 5 active trials
trial_duration_mask = 5

valid_bins_mask = active_trials >= trial_duration_mask  # Boolean array where True means ≥5 trials

# Apply smoothing only to valid data to avoid edge artifacts
smoothed_rates = np.full_like(rates_all_trials, np.nan)  # Initialize with NaNs
smoothed_sem = np.full_like(rates_all_trials_sem, np.nan)

# Smooth only the valid portions (preserves original data length)
sigma = 0.5
smoothed_rates[valid_bins_mask] = gaussian_filter1d(rates_all_trials[valid_bins_mask], sigma=sigma)
smoothed_sem[valid_bins_mask] = gaussian_filter1d(rates_all_trials_sem[valid_bins_mask], sigma=sigma)

# Get bin centers only for valid bins
valid_bin_centers = bin_centers[valid_bins_mask]
valid_smoothed_rates = smoothed_rates[valid_bins_mask]
valid_smoothed_sem = smoothed_sem[valid_bins_mask]

# Plot only valid bins
plt.figure(figsize=(12, 4))
plt.plot(valid_bin_centers, valid_smoothed_rates, 'b-', lw=2, label='Mean rate (≥5 trials)')
plt.fill_between(
    valid_bin_centers,
    valid_smoothed_rates - valid_smoothed_sem,
    valid_smoothed_rates + valid_smoothed_sem,
    color='blue',
    alpha=0.3,
    label='± SEM'
)
plt.xlabel('Time (s)')
plt.ylabel('Firing rate (Hz)')
plt.legend()
plt.show()

fix it to sweep both forward and backward

In [None]:
CUE_ON = 'cue_on_time'
CUE_OFF = 'cue_off_time'
CONS = 'cons_time'

In [None]:
session = session_with_units[0]
session_identity = generate_session_identity(session)

events = session['events']
events = events.groupby('trial_id', group_keys=False).apply(add_trial_time)

trials = generate_trials(events)

units = session['spikes']
i=2
unit_spikes = units[i]
unit_identity = session_identity + '_unit-' + str(i)
spikes = generate_spikes(unit_spikes, trials)

In [None]:
anchor = CUE_ON
time_step = 0.1

In [None]:
trials['aligned_start_time']=trials['event_start_trial_time']-trials[anchor]
trials['aligned_end_time']=trials['event_end_trial_time']-trials[anchor]
bounds = (
    round(trials.aligned_start_time.min(), 1), 
    round(trials.aligned_end_time.max(), 1)
    )
bin_edges = np.arange(bounds[0] - time_step, bounds[1] + time_step, time_step) 

In [None]:
trials

In [None]:
active_trials_per_bin = np.zeros(len(bin_edges) - 1, dtype=int)

for _, trial in trials.iterrows():
    start = trial['aligned_start_time']
    end = trial['aligned_end_time']
    
    # Find bins that overlap with this trial's duration
    overlaps = (bin_edges[:-1] < end) & (bin_edges[1:] > start)
    active_trials_per_bin[overlaps] += 1

In [None]:
bin_centers = bin_edges[:-1] + time_step/2
bin_centers.shape

In [None]:
trial_columns = ['trial_id', anchor]
trials_to_merge = trials[trial_columns].copy()

# Align spikes
spikes = trials_to_merge.merge(spikes, on='trial_id', how='inner')
spikes['aligned_time'] = spikes['trial_time'] - spikes[anchor]

In [None]:
spikes_by_trial = spikes.groupby('trial_id')

all_trial_spike_times = []
for t, trial in spikes_by_trial:
    all_trial_spike_times.append(trial.aligned_time.tolist())

# Initialize empty array (trials × bins)
counts_per_trial = np.zeros((len(all_trial_spike_times), len(bin_edges)-1))
rates_per_trial = np.zeros_like(counts_per_trial) #[trial, time_bin]

for i, trial_spikes in enumerate(all_trial_spike_times):
    counts_per_trial[i], _ = np.histogram(trial_spikes, bins=bin_edges)
    rates_per_trial[i] = counts_per_trial[i] / time_step

In [None]:
# Your existing code up to SEM calculation
rates_all_trials = np.nansum(rates_per_trial, 0)/active_trials_per_bin 
rates_all_trials_sem = np.std(rates_per_trial, axis=0)/np.sqrt(active_trials_per_bin)

# Create mask for bins with at least 5 active trials
trial_duration_mask = 5

valid_bins_mask = active_trials_per_bin >= trial_duration_mask  # Boolean array where True means ≥5 trials

# Apply smoothing only to valid data to avoid edge artifacts
smoothed_rates = np.full_like(rates_all_trials, np.nan)  # Initialize with NaNs
smoothed_sem = np.full_like(rates_all_trials_sem, np.nan)

# Smooth only the valid portions (preserves original data length)
sigma = 0.5
smoothed_rates[valid_bins_mask] = gaussian_filter1d(rates_all_trials[valid_bins_mask], sigma=sigma)
smoothed_sem[valid_bins_mask] = gaussian_filter1d(rates_all_trials_sem[valid_bins_mask], sigma=sigma)

# Get bin centers only for valid bins
valid_bin_centers = bin_centers[valid_bins_mask]
valid_smoothed_rates = smoothed_rates[valid_bins_mask]
valid_smoothed_sem = smoothed_sem[valid_bins_mask]

# Plot only valid bins
plt.figure(figsize=(12, 4))
plt.plot(valid_bin_centers, valid_smoothed_rates, 'b-', lw=2, label='Mean rate (≥5 trials)')
plt.fill_between(
    valid_bin_centers,
    valid_smoothed_rates - valid_smoothed_sem,
    valid_smoothed_rates + valid_smoothed_sem,
    color='blue',
    alpha=0.3,
    label='± SEM'
)
plt.xlabel('Time (s)')
plt.ylabel('Firing rate (Hz)')
plt.legend()
plt.show()

organization attempt

In [None]:
# ==============================================
# 1. Configuration
# ==============================================
CUE_ON = 'cue_on_time'
CUE_OFF = 'cue_off_time'
CONS = 'cons_time'

# ==============================================
# 2. Data Preparation
# ==============================================
def prepare_data(session, unit_idx=2):
    """Organizes spikes and trials into aligned time bins."""
    # Get session data
    events = session['events'].groupby('trial_id', group_keys=False).apply(add_trial_time)
    trials = generate_trials(events)
    spikes = generate_spikes(session['spikes'][unit_idx], trials)
    
    # Align times to anchor event
    trials['aligned_start_time'] = trials['event_start_trial_time'] - trials[anchor]
    trials['aligned_end_time'] = trials['event_end_trial_time'] - trials[anchor]

    trial_columns = ['trial_id', anchor]
    trials_to_merge = trials[trial_columns].copy()

    # Align spikes
    spikes = trials_to_merge.merge(spikes, on='trial_id', how='inner')
    spikes['aligned_time'] = spikes['trial_time'] - spikes[anchor]
    
    # Calculate bounds rounded to time_step precision
    bounds = (
        np.round(trials.aligned_start_time.min(), decimals=1),
        np.round(trials.aligned_end_time.max(), decimals=1)
    )
    
    # Create bins with buffer for edge cases
    bin_edges = np.arange(
        bounds[0] - time_step,
        bounds[1] + 2*time_step,  # Extra buffer for safety
        time_step
    )
    
    return trials, spikes, bin_edges

# ==============================================
# 3. Active Trial Calculation
# ==============================================
def calculate_active_trials(trials, bin_edges):
    """Counts active trials for each time bin."""
    active_trials = np.zeros(len(bin_edges) - 1, dtype=int)
    
    for _, trial in trials.iterrows():
        overlaps = (bin_edges[:-1] < trial['aligned_end_time']) & \
                   (bin_edges[1:] > trial['aligned_start_time'])
        active_trials[overlaps] += 1
    
    return active_trials

# ==============================================
# 4. Rate Calculation
# ==============================================
def calculate_rates(spikes, bin_edges, active_trials):
    """Calculates firing rates with safe division."""
    # Bin spikes across trials
    counts = np.array([
        np.histogram(trial['aligned_time'], bins=bin_edges)[0]
        for _, trial in spikes.groupby('trial_id')
    ])
    
    # Convert to rates (Hz)
    rates = counts / time_step
    
    # Safe division with error handling
    with np.errstate(divide='ignore', invalid='ignore'):
        mean_rates = np.nansum(rates, axis=0) / active_trials
        sem_rates = np.std(rates, axis=0) / np.sqrt(active_trials)
        
        mean_rates = np.nan_to_num(mean_rates, nan=0.0)
        sem_rates = np.nan_to_num(sem_rates, nan=0.0)
    
    return mean_rates, sem_rates

# ==============================================
# 5. Main Analysis Pipeline
# ==============================================
def analyze_unit(session, unit_idx=2):
    # Prepare data
    trials, spikes, bin_edges = prepare_data(session, unit_idx)
    
    # Calculate active trials
    active_trials = calculate_active_trials(trials, bin_edges)
    
    # Calculate rates
    mean_rates, sem_rates = calculate_rates(spikes, bin_edges, active_trials)
    
    # Apply smoothing to valid bins only
    valid_mask = active_trials >= min_trials
    bin_centers = bin_edges[:-1] + time_step/2
    
    smoothed_rates = np.full_like(mean_rates, np.nan)
    smoothed_sem = np.full_like(sem_rates, np.nan)
    
    smoothed_rates[valid_mask] = gaussian_filter1d(mean_rates[valid_mask], sigma=sigma)
    smoothed_sem[valid_mask] = gaussian_filter1d(sem_rates[valid_mask], sigma=sigma)
    
    # Plot results
    plt.figure(figsize=(12, 4))
    plt.plot(bin_centers[valid_mask], smoothed_rates[valid_mask], 'b-', lw=2, 
             label=f'Mean rate (≥{min_trials} trials)')
    plt.fill_between(
        bin_centers[valid_mask],
        smoothed_rates[valid_mask] - smoothed_sem[valid_mask],
        smoothed_rates[valid_mask] + smoothed_sem[valid_mask],
        color='blue', alpha=0.3, label='± SEM'
    )
    plt.axvline(0, color='k', linestyle='--', label=anchor)
    plt.xlabel('Time relative to event (s)')
    plt.ylabel('Firing rate (Hz)')
    plt.legend()
    plt.grid(alpha=0.3)
    plt.show()

# # ==============================================
# # Execute Analysis
# # ==============================================
# analyze_unit(session_with_units[0], unit_idx=2)

In [None]:
anchor = CUE_ON
time_step = 0.1
trials, spikes, bin_edges = prepare_data(session, unit_idx=2)

In [None]:
anchor = CUE_ON
time_step = 0.1
min_trials = 5
sigma = 3
analyze_unit(session_with_units[0], unit_idx=2)

In [None]:
anchor = CUE_OFF
time_step = 0.1
min_trials = 5
sigma = 3
analyze_unit(session_with_units[0], unit_idx=2)

In [None]:
anchor = CONS
time_step = 0.1
min_trials = 5
sigma = 3
analyze_unit(session_with_units[0], unit_idx=2)