In [3]:
import os
from pathlib import Path
import random

import pickle
import numpy as np
import matplotlib.pyplot as plt
from scipy.ndimage import gaussian_filter1d
from tqdm.notebook import tqdm
import pandas as pd

import constants as k
import utils

In [2]:
data_dir = '/Users/rebekahzhang/data/neural_data'
pickle_dir = Path(os.path.join(data_dir, 'session_pickles'))
figure_folder = os.path.join(data_dir, 'figures')

In [None]:
session_log = pd.read_csv(os.path.join(data_dir, 'session'))

In [None]:
pickle_files = [f.name for f in pickle_dir.glob("*.pkl")]
filename = random.choice(pickle_files)

In [None]:
events, trials, units, idx = utils.get_session_data(filename, pickle_dir)
spikes = units[0]

### raster plot

In [None]:
def prepare_data_for_raster(events, trials, spikes):
    sorted_trial_id = trials.sort_values(by=['missed','rewarded','wait_length']).trial_id.tolist()
    events_raster = events.groupby('trial_id')
    spikes_raster = spikes.groupby('trial_id')
    return events_raster, spikes_raster, sorted_trial_id

def plot_raster(ax, sorted_trial_id, events_raster, spikes_raster, anchor, show_legend=True):
    ax.axvline(0, color='tab:gray', linestyle='--', alpha=0.5, label=anchor)

    event_colors = {
        'visual': 'orange',
        'wait': 'g', 
        'cons_reward': 'b', 
        'cons_no_reward': 'r'
    }

    for trial_offset, t in enumerate(sorted_trial_id):
        if t in events_raster.groups:
            # Plot trial events
            trial_events = events_raster.get_group(t)
            for event_type, color in event_colors.items():
                event_time = trial_events.loc[trial_events['event_type'] == event_type, anchor]
                ax.eventplot(
                    event_time, lineoffsets=trial_offset, color=color, 
                    linelengths=1.0, linewidths=2, alpha=1
                )

        if t in spikes_raster.groups:
            # Plot spikes (black)
            ax.eventplot(
                spikes_raster.get_group(t)[anchor], 
                lineoffsets=trial_offset, 
                color='k', 
                linelengths=0.8, 
                linewidths=0.2
            )
        else:
            continue

    # Create legend only for the last raster plot
    if show_legend:
        handles = [plt.Line2D([0], [0], color=c, lw=2, label=label) 
                   for label, c in event_colors.items()]
        ax.legend(handles=handles, loc='center left', bbox_to_anchor=(1, 0.92))

In [None]:
anchor = k.TO_CUE_OFF
events_raster, spikes_raster, sorted_trial_id = prepare_data_for_raster(events, trials, spikes)
_, ax = plt.subplots(figsize=(10, 5))
plot_raster(ax, sorted_trial_id, events_raster, spikes_raster, anchor)

### histo plot

In [None]:
def prepare_data_for_histogram(trials, spikes, anchor):
    # Calculate the aligned start and end time based on the anchor
    if anchor == k.TO_CUE_ON:
        trials['aligned_start_time'] = 0 - trials["cue_on_time"]
        trials['aligned_end_time'] = trials['trial_length'] - trials["cue_on_time"]
    elif anchor == k.TO_CUE_OFF:
        trials['aligned_start_time'] = 0 - trials["cue_off_time"]
        trials['aligned_end_time'] = trials['trial_length'] - trials["cue_off_time"]
    elif anchor == k.TO_CONSUMPTION:
        trials['aligned_start_time'] = 0 - trials["consumption_time"]
        trials['aligned_end_time'] = trials['trial_length'] - trials["consumption_time"]

    # Filter trials with spikes
    trials_with_spikes = spikes['trial_id'].unique()
    trials_histo = trials[trials['trial_id'].isin(trials_with_spikes)].copy()
    return trials_histo

def generate_time_frame(trials, time_step, trial_count_mask=1):
    """
    Generate time bins and active trial counts.
    """
    # Calculate bounds (with protection against empty data)
    if len(trials) == 0:
        return np.array([]), np.array([]), np.array([])
        
    bounds = (
        np.round(trials.aligned_start_time.min(), decimals=1),
        np.round(trials.aligned_end_time.max(), decimals=1)
    )
    
    # Create bins
    bin_edges = np.arange(
        bounds[0] - time_step,
        bounds[1] + 2*time_step,
        time_step
    )
    bin_centers = bin_edges[:-1] + time_step/2
    
    # Calculate active trials
    active_trials = np.zeros(len(bin_edges) - 1, dtype=int)
    for _, trial in trials.iterrows():
        occupied = (bin_edges[:-1] < trial['aligned_end_time']) & \
                  (bin_edges[1:] > trial['aligned_start_time'])
        active_trials[occupied] += 1

    # Apply mask with safety checks
    if trial_count_mask > 0:
        valid_mask = active_trials >= trial_count_mask
        if not np.any(valid_mask):  # No valid bins
            return np.array([]), np.array([]), np.array([])
            
        bin_edges = np.append(
            bin_edges[:-1][valid_mask],
            bin_edges[:-1][valid_mask][-1] + time_step
        )
        bin_centers = bin_centers[valid_mask]
        active_trials = active_trials[valid_mask]
        
    return bin_edges, bin_centers, active_trials

def calculate_firing_rates(trials, spikes, anchor, time_step, trial_count_mask=5, sigma=0.01):
    bin_edges, bin_centers, active_trials = generate_time_frame(trials, time_step, trial_count_mask)
    
    # Handle case where no bins meet criteria
    if len(bin_edges) == 0:
        return np.array([]), np.array([]), np.array([])

    # Bin spikes for each trial, shape is [trial, bin]
    counts = np.array([
        np.histogram(trial[anchor], bins=bin_edges)[0]
        for _, trial in spikes.groupby('trial_id')
    ])
    rates = counts / time_step

    mean_fr = np.nansum(rates, axis=0) / active_trials
    sem_fr = np.std(rates, axis=0) / np.sqrt(active_trials)

    if sigma > 0:
        mean_fr = gaussian_filter1d(mean_fr, sigma=sigma)
        sem_fr = gaussian_filter1d(sem_fr, sigma=sigma)

    return bin_centers, mean_fr, sem_fr

def plot_firing_rates(ax, trials, spikes, anchor, time_step, sigma, trial_count_mask, show_legend=True):
    """Plots firing rates on given axes with legend on the right."""
    ax.axvline(0, color='tab:gray', linestyle='--', alpha=0.5, label=anchor)
    
    # Calculate rates for all trials
    bin_centers, mean_fr, sem_fr = calculate_firing_rates(
        trials, spikes, anchor, time_step, trial_count_mask, sigma
    )
    if len(bin_centers) > 0:
        ax.plot(bin_centers, mean_fr, 'k-', lw=1.5, alpha=0.8, label='All trials')
        ax.fill_between(bin_centers, mean_fr-sem_fr, mean_fr+sem_fr, color='gray', alpha=0.4)

    # Calculate and plot rates for three categories
    # Missed trials (missed=True)
    if anchor != k.TO_CONSUMPTION:
        trials_missed = trials.loc[trials.missed == True]
        spikes_missed = spikes.loc[spikes.trial_id.isin(trials_missed.trial_id)]
        bin_centers_m, mean_fr_m, sem_fr_m = calculate_firing_rates(
            trials_missed, spikes_missed, anchor, time_step, trial_count_mask, sigma
        )
        if len(bin_centers_m) > 0:
            ax.plot(bin_centers_m, mean_fr_m, color='#FFA500', linestyle='-', lw=1, alpha=0.8, label='Missed')
            ax.fill_between(bin_centers_m, mean_fr_m-sem_fr_m, mean_fr_m+sem_fr_m, color='#FFA500', alpha=0.3)
    
    # Non-missed unrewarded trials (missed=False and rewarded=False)
    trials_unrewarded = trials.loc[(trials.missed == False) & (trials.rewarded == False)]
    spikes_unrewarded = spikes.loc[spikes.trial_id.isin(trials_unrewarded.trial_id)]
    bin_centers_u, mean_fr_u, sem_fr_u = calculate_firing_rates(
        trials_unrewarded, spikes_unrewarded, anchor, time_step, trial_count_mask, sigma
    )
    if len(bin_centers_u) > 0:
        ax.plot(bin_centers_u, mean_fr_u, 'r-', lw=1, alpha=0.8, label='Unrewarded')
        ax.fill_between(bin_centers_u, mean_fr_u-sem_fr_u, mean_fr_u+sem_fr_u, color='r', alpha=0.3)
    
    # Rewarded trials (missed=False and rewarded=True)
    trials_rewarded = trials.loc[(trials.missed == False) & (trials.rewarded == True)]
    spikes_rewarded = spikes.loc[spikes.trial_id.isin(trials_rewarded.trial_id)]
    bin_centers_r, mean_fr_r, sem_fr_r = calculate_firing_rates(
        trials_rewarded, spikes_rewarded, anchor, time_step, trial_count_mask, sigma
    )
    if len(bin_centers_r) > 0:
        ax.plot(bin_centers_r, mean_fr_r, 'b-', lw=1, alpha=0.8, label='Rewarded')
        ax.fill_between(bin_centers_r, mean_fr_r-sem_fr_r, mean_fr_r+sem_fr_r, color='b', alpha=0.3)

    # Add the firing rate legend to the right of the plot
    if show_legend:
        ax.legend(bbox_to_anchor=(1, 1.05), loc='upper left')


In [None]:
anchor = k.TO_CONSUMPTION
time_step = 0.1
sigma = 5
trial_count_mask = 5

trials_histo = prepare_data_for_histogram(trials, spikes, anchor)
_, ax = plt.subplots(figsize=(10, 5))
plot_firing_rates(ax, trials_histo, spikes, anchor, time_step, sigma, trial_count_mask)

### combined

In [None]:
def save_and_close_figure(fig, title, figure_folder, regenerate=False):
    """Save figure only if it doesn't exist or regenerate is True"""
    fig_path = f'{figure_folder}/{title}.png'
    if regenerate or not os.path.exists(fig_path):
        fig.savefig(fig_path, bbox_inches='tight', dpi=300, format='png')
    fig.clf()
    plt.close(fig)

def plot_raster_histo_with_3_anchors(unit_id, events, trials, spikes, sorted_trial_id, 
                                     anchors, time_step, sigma, trial_count_mask, 
                                     save_fig=False, figure_folder=figure_folder, regenerate=False):
    fig, axes = plt.subplots(
        2, 3,
        figsize=(16, 9),  
        sharex='col',
        gridspec_kw={'height_ratios': [3, 1]}
    )
    plt.subplots_adjust(right=0.85)

    for i, anchor in enumerate(anchors):
        ax_raster = axes[0, i]
        ax_rate = axes[1, i]
        
        events_raster, spikes_raster, sorted_trial_id = prepare_data_for_raster(events, trials, spikes)
        plot_raster(ax_raster, sorted_trial_id, events_raster, spikes_raster, anchor, show_legend=(i==2))

        trials_histo = prepare_data_for_histogram(trials, spikes, anchor)
        plot_firing_rates(ax_rate, trials_histo, spikes, anchor, 
                         time_step, sigma, trial_count_mask, show_legend=(i==2))

        ax_raster.set_title(f"Aligned to {anchor}")
        if i == 0:
            ax_raster.set_ylabel("Trial #")
            ax_rate.set_ylabel("Firing Rate (Hz)")
        ax_rate.set_xlabel("Time (s)")

    plt.suptitle(unit_id)
    plt.tight_layout()

    if save_fig or regenerate:
        save_and_close_figure(fig, unit_id, figure_folder, regenerate)
    else:
        plt.show()
        plt.close(fig)

In [None]:
anchors = [k.TO_CUE_ON, k.TO_CUE_OFF, k.TO_CONSUMPTION]
time_step = 0.1
sigma = 3
trial_count_mask = 5
events, trials, units, idx = utils.get_session_data(filename, pickle_dir)
spikes = units[0]

plot_raster_histo_with_3_anchors("test", events, trials, spikes, sorted_trial_id, 
                                     anchors, time_step, sigma, trial_count_mask, 
                                     save_fig=True, figure_folder=figure_folder, regenerate=True)

# Looping

In [None]:
pickle_files = [f.name for f in pickle_dir.glob("*.pkl")]
# pickle_files = pickle_files[:2]
# filename = random.choice(pickle_files)

In [None]:
anchors = [k.TO_CUE_ON, k.TO_CUE_OFF, k.TO_CONSUMPTION]
time_step = 0.1
sigma = 3
trial_count_mask = 5
regenerate = False
save_fig = True

In [None]:
# Initialize a list to store failed unit IDs
failed_units = []
regenerate = False

for filename in tqdm(pickle_files, desc="Processing sessions"):
    events, trials, units, idx = utils.get_session_data(filename, pickle_dir)
    
    # Create progress bar for units
    units_pbar = tqdm(total=len(units), desc=f"Units in session {idx}", leave=False)
    
    # For the inner loop (units)
    for i, spikes in enumerate(units):
        unit_id = f"{idx}-unit_{i}"
        
        # Check if figure already exists
        figure_path = os.path.join(figure_folder, f"{unit_id}.png")
        if os.path.exists(figure_path) and not regenerate:
            units_pbar.set_description(f"Unit {i} (skipped)")
            units_pbar.update(1)
            continue
        
        try:
            # Only call the plotting function if we need to generate/regenerate the figure
            plot_raster_histo_with_3_anchors(unit_id, events, trials, spikes, sorted_trial_id,
                                           anchors, time_step, sigma, trial_count_mask,
                                           save_fig, figure_folder, regenerate)
            units_pbar.set_description(f"Unit {i} (processed)")
        except Exception as e:
            # Add the failed unit to our list
            failed_units.append({
                'unit_id': unit_id,
                'error': str(e)
            })
            units_pbar.set_description(f"Unit {i} (failed)")
        
        units_pbar.update(1)
    
    units_pbar.close()

# After processing all sessions, print a summary of failed units
if failed_units:
    print(f"\nTotal of {len(failed_units)} units failed:")
    for i, unit in enumerate(failed_units):
        print(f"{i+1}. {unit['unit_id']} - Error: {unit['error']}")
    
    # Optionally save the failed_units list to a file for later reference
    import json
    with open(os.path.join(figure_folder, 'failed_units.json'), 'w') as f:
        json.dump(failed_units, f, indent=4)
    print(f"Failed units list saved to {os.path.join(figure_folder, 'failed_units.json')}")
else:
    print("All units processed successfully!")