### Extract spike timeseries aligned to specific event for every unit

In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2
from notebooks.imports import *
import scipy.io
from config import dir_config, ephys_config
compiled_dir = Path(dir_config.data.compiled)
processed_dir = Path(dir_config.data.processed)

In [4]:
import numpy as np
import pandas as pd
from pathlib import Path
import scipy.io
from scipy.ndimage import gaussian_filter1d

# Function to align and convolve spike trains
def get_aligned_spike_trains(cluster_spike_time, timestamps, trial_info, alignment_settings, alignment_buffer, sampling_rate=30, sigma=10):
    result = {
        event: {
            'spike_trains': None,
            'convolved_spike_trains': None
        }
        for event in alignment_settings.keys()
    }

    # Prepare arrays for all alignment events
    for event_name in alignment_settings.keys():
        n_trials = len(trial_info)
        duration = alignment_settings[event_name]['end_time_ms'] - alignment_settings[event_name]['start_time_ms'] + 2 * alignment_buffer + 1

        spike_trains = np.zeros((n_trials, duration), dtype=np.float32)
        convolved_spike_trains = np.zeros((n_trials, duration), dtype=np.float32)

        # Iterate through trials
        for idx_trial, trial_num in enumerate(trial_info.index):
            if np.isnan(trial_info.reaction_time[trial_num]):
                spike_trains[idx_trial,:] *= np.nan
                convolved_spike_trains[idx_trial,:] *= np.nan
                continue

            aligned_event_time = timestamps.loc[trial_num, event_name]
            start_timestamp = aligned_event_time + (alignment_settings[event_name]['start_time_ms'] - alignment_buffer) * sampling_rate
            end_timestamp = aligned_event_time + (alignment_settings[event_name]['end_time_ms'] + alignment_buffer) * sampling_rate

            # Filter spike times from start_timestamp to end_timestamp
            temp_spike_times = cluster_spike_time[
                (cluster_spike_time >= start_timestamp) & (cluster_spike_time <= end_timestamp)
            ] - start_timestamp

            spike_idx = np.ceil(temp_spike_times / sampling_rate).astype(int)
            spike_trains[idx_trial, spike_idx] = 1
            # Convolve spike trains
            convolved_spike_trains[idx_trial,:] = gaussian_filter1d(spike_trains[idx_trial,:], sigma=sigma, truncate=3)

            # Handle special case: exclude spikes after 50ms pre-saccade
            if event_name == 'stimulus_onset':
                response_onset_time = timestamps.loc[trial_num, 'response_onset']
                if (end_timestamp-alignment_buffer*30) > response_onset_time - 50*sampling_rate:
                    pre_saccade_idx = np.ceil(response_onset_time / sampling_rate - 50).astype(int)
                    spike_trains[idx_trial, pre_saccade_idx:] = np.nan
                    convolved_spike_trains[idx_trial,pre_saccade_idx:] = np.nan

            # Handle special case: exclude spikes before stimulus onset
            elif event_name == 'response_onset':
                stim_onset_time = timestamps.loc[trial_num, 'stimulus_onset']
                stim_on_idx = max(0, np.ceil((stim_onset_time - start_timestamp) / sampling_rate + alignment_buffer).astype(int))
                spike_trains[idx_trial, :stim_on_idx] = np.nan
                convolved_spike_trains[idx_trial, :stim_on_idx] = np.nan


        # Trim buffers
        spike_trains = spike_trains[:, alignment_buffer:-alignment_buffer]
        convolved_spike_trains = convolved_spike_trains[:, alignment_buffer:-alignment_buffer] * 1000

        # Store results
        result[event_name]['spike_trains'] = spike_trains
        result[event_name]['convolved_spike_trains'] = convolved_spike_trains

    return result



# Load neuron metadata
neuron_metadata = pd.read_csv(Path(compiled_dir, "neuron_metadata.csv"), index_col=None)
ephys_neuron_wise = {event: {} for event in ephys_config.alignment_settings_GP.keys()}

# Main loop for each neuron
for neuron in neuron_metadata.neuron_id:
    session_name = neuron_metadata.session_id[neuron - 1]
    cluster_id = neuron_metadata.cluster[neuron - 1]

    # Load required data
    timestamps_path = Path(compiled_dir, session_name, f"{session_name}_timestamps.csv")
    trial_info_path = Path(compiled_dir, session_name, f"{session_name}_trial.csv")
    spike_times_path = Path(compiled_dir, session_name, "spike_times.npy")
    spike_clusters_path = Path(compiled_dir, session_name, "spike_clusters.npy")
    spike_times_mat_path = Path(compiled_dir, session_name, "spike_times.mat")
    spike_clusters_mat_path = Path(compiled_dir, session_name, "spike_clusters.mat")

    if not (timestamps_path.is_file() and trial_info_path.is_file()):
        print(f"Missing files for session: {session_name}")
        continue

    timestamps = pd.read_csv(timestamps_path, index_col=None)
    trial_info = pd.read_csv(trial_info_path, index_col=None)

    # Load spike data
    if spike_times_path.is_file() and spike_clusters_path.is_file():
        spike_times = np.load(spike_times_path)
        spike_clusters = np.load(spike_clusters_path)
    elif spike_times_mat_path.is_file() and spike_clusters_mat_path.is_file():
        spike_times = scipy.io.loadmat(spike_times_mat_path)['spike_times'].ravel()
        spike_clusters = scipy.io.loadmat(spike_clusters_mat_path)['spike_clusters'].ravel()
    else:
        print(f"Spike times and clusters not found in {session_name} for neuron {neuron}")
        continue

    # Filter spike times for the current cluster
    cluster_spike_time = spike_times[spike_clusters == cluster_id]
    GP_trial_info = trial_info[(trial_info.task_type == 1)]# & (~np.isnan(trial_info.reaction_time))]

    # Get aligned and convolved spike trains
    results = get_aligned_spike_trains(
        cluster_spike_time, timestamps, GP_trial_info,
        ephys_config.alignment_settings_GP, ephys_config.alignment_buffer
    )

    # Save results
    for event_name in ephys_config.alignment_settings_GP.keys():
        ephys_neuron_wise[event_name][neuron]={
            "spike_trains": results[event_name]['spike_trains'],
            "convolved_spike_trains": results[event_name]['convolved_spike_trains'],
            "trial_number": GP_trial_info.trial_number
        }

In [12]:
import pickle
with open(Path(processed_dir, "ephys_neuron_wise.pkl"), 'wb') as handle:
    pickle.dump(ephys_neuron_wise, handle, protocol=pickle.HIGHEST_PROTOCOL)