# Spike Plotting

Brief 1-2 sentence description of notebook.

In [1]:
import glob
import re
import os

In [2]:
# Imports of all used packages and libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.ndimage.filters import gaussian_filter1d


  from scipy.ndimage.filters import gaussian_filter1d


## Inputs & Data

Explanation of each input and where it comes from.

In [4]:
# Inputs and Required data loading
# input varaible names are in all caps snake case
# Whenever an input changes or is used for processing 
# the vairables are all lower in snake case

TRIAL_LENGTH = 10
SAMPLING_RATE = 20000
TONE_TIMESTAMP_DF = pd.read_excel("../../data/rce_tone_timestamp.xlsx", index_col=0)
OUTPUT_DIR = r"./proc" # where data is saved should always be shown in the inputs

INPUT_DIR=glob.glob("/scratch/back_up/reward_competition_extention/proc/phy_curation/*")

FileNotFoundError: [Errno 2] No such file or directory: '../../data/rce_tone_timestamp.xlsx'

In [None]:
TRIAL_NUMBER_COL = "trial_number"

## Outputs

Describe each output that the notebook creates. 

- Is it a plot or is it data?

- How valuable is the output and why is it valuable or useful?

## Processing

Describe what is done to the data here and how inputs are manipulated to generate outputs. 

In [None]:
# As much code and as many cells as required
# includes EDA and playing with data
# GO HAM!

# Ideally functions are defined here first and then data is processed using the functions

# function names are short and in snake case all lowercase
# a function name should be unique but does not have to describe the function
# doc strings describe functions not function names

def calc_bmi(weight, height):
    """
    This is a function that calculates BMI.
    it uses height and weight...etc.
    Meghan plz show us your docsctring format here.
    """
    bmi = weight/(height**2)
    return(bmi)


In [None]:
def find_closest(target, reference_list):
    """
    Finds the closest number in a reference list to the given target number.

    Parameters:
    - target (float or int): The number for which we want to find the closest value.
    - reference_list (list of float or int): The list of numbers in which we want to search.

    Returns:
    - float or int: The closest number from the reference list to the target.
    """

    # Using the 'min' function with a custom key to determine the closest value by minimal absolute difference
    closest_value = min(reference_list, key=lambda x: abs(x - target))
    
    return closest_value

In [None]:
def find_index_in_group(group, value_column, new_column_name):
    unique_values = sorted(list(set(group[value_column])))
    group[new_column_name] = group[value_column].apply(lambda x: unique_values.index(x) + 1)
    return group

### Getting the subject IDs from the file name

- Dropping all unlabeled trials

In [None]:
all_trials_df = TONE_TIMESTAMP_DF.dropna(subset="condition").sort_values(by=["recording_file", "time_stamp_index"]).reset_index(drop=True)

- Making sure that all timestamps are integers

In [None]:
all_trials_df["time"] = all_trials_df["time"].astype(int)
all_trials_df["time_stamp_index"] = all_trials_df["time_stamp_index"].astype(int)
all_trials_df["video_frame"] = all_trials_df["video_frame"].astype(int)

In [None]:
all_trials_df.head()

- Original timestamps are based on ephys recordings at 20kHz. The LFP will be at 1kHz, so we will need to divide all the timestamps by 20

In [None]:
all_trials_df["resampled_index"] = all_trials_df["time_stamp_index"] // 20

In [None]:
all_trials_df["recording_dir"].unique()

- Getting a list of all the subjects through the recording name

In [None]:
all_trials_df["all_subjects"] = all_trials_df["recording_dir"].apply(lambda x: ["{}.{}".format(tup[0],tup[1]) for tup in re.findall(r'(\d+)-(\d+)', x.replace("_", "-"))[1:]])

In [None]:
all_trials_df["all_subjects"].head()

- Getting the current subject of the recording through the ending of the recording name file

In [None]:
all_trials_df["subject_info"].head()

In [None]:
all_trials_df["current_subject"] = all_trials_df["subject_info"].apply(lambda x: ".".join(x.replace("-","_").split("_")[:2]))

In [None]:
all_trials_df.head()

- Labeling the trial as a winner or loser if the winner matches the subject id or not

In [None]:
all_trials_df["trial_outcome"] = all_trials_df.apply(
    lambda x: "win" if str(x["condition"]).strip() == str(x["current_subject"]) 
             else ("lose" if str(x["condition"]) in x["all_subjects"] 
                   else x["condition"]), axis=1)

In [None]:
all_trials_df.head()

In [None]:
all_trials_df

In [None]:
all_trials_df["competition_closeness"] = all_trials_df["competition_closeness"].fillna("")

In [None]:
all_trials_df["competition_closeness"].unique()

In [None]:
all_trials_df = all_trials_df[~all_trials_df["competition_closeness"].str.contains("Only")]

# Reading in Phy

- Reading in a spreadsheet of all the unit classifications
    - They are divided up into good units, multi-units, and noise

In [None]:
INPUT_DIR

In [None]:
recording_to_cluster_info = {}
for recording_dir in INPUT_DIR:
    try:
        recording_basename = os.path.basename(recording_dir).strip(".rec")
        file_path = os.path.join(recording_dir, "phy", "cluster_info.tsv")
        recording_to_cluster_info[recording_basename] = pd.read_csv(file_path, sep="\t")
    except Exception as e:
        print(e)

In [None]:
recording_to_cluster_info[list(recording_to_cluster_info.keys())[0]]

- Combining all the unit info dataframes and adding the recording name

In [None]:
recording_to_cluster_info_df = pd.concat(recording_to_cluster_info, names=['recording_name']).reset_index(level=1, drop=True).reset_index()


In [None]:
recording_to_cluster_info_df.head()

- Filtering for the good units

In [None]:
good_unit_cluster_info_df = recording_to_cluster_info_df[recording_to_cluster_info_df["group"] == "good"].reset_index(drop=True)

In [None]:
good_unit_cluster_info_df.head()

In [None]:
recording_to_good_unit_ids = good_unit_cluster_info_df.groupby('recording_name')['cluster_id'].apply(list).to_dict()


- A list of all the unit IDs that each spike came from in order
    - First item is first spike, second item is second spike, etc.

In [None]:
recording_to_spike_clusters = {}
for recording_dir in INPUT_DIR:
    try:
        recording_basename = os.path.basename(recording_dir).strip(".rec")
        file_path = os.path.join(recording_dir, "phy", "spike_clusters.npy")
        recording_to_spike_clusters[recording_basename] = np.load(file_path)
    except Exception as e:
        print(e)

In [None]:
recording_to_spike_clusters[list(recording_to_spike_clusters.keys())[0]]

In [None]:
recording_to_spike_clusters[list(recording_to_spike_clusters.keys())[0]].shape

- The times that all the spikes happened

In [None]:
recording_to_spike_times = {}
for recording_dir in INPUT_DIR:
    try:
        recording_basename = os.path.basename(recording_dir).strip(".rec")
        file_path = os.path.join(recording_dir, "phy", "spike_times.npy")
        recording_to_spike_times[recording_basename] = np.load(file_path)
    except Exception as e:
        print(e)

In [None]:
recording_to_spike_times[list(recording_to_spike_times.keys())[0]]

In [None]:
recording_to_spike_times[list(recording_to_spike_times.keys())[0]].shape

### Combining everything into a dataframe

In [None]:
recording_to_spike_df = {}
for recording_dir in INPUT_DIR:
    try:
        recording_basename = os.path.basename(recording_dir).strip(".rec")
        cluster_info_path = os.path.join(recording_dir, "phy", "cluster_info.tsv")
        cluster_info_df = pd.read_csv(cluster_info_path, sep="\t")

        spike_clusters_path = os.path.join(recording_dir, "phy", "spike_clusters.npy")
        spike_clusters = np.load(spike_clusters_path)
        
        spike_times_path = os.path.join(recording_dir, "phy", "spike_times.npy")
        spike_times = np.load(spike_times_path)

        spike_df = pd.DataFrame({'spike_clusters': spike_clusters, 'spike_times': spike_times.T[0]})

        merged_df = spike_df.merge(cluster_info_df, left_on='spike_clusters', right_on='cluster_id', how="left")
        merged_df["recording_name"] = recording_basename

        merged_df["timestamp_isi"] = merged_df.groupby('spike_clusters')["spike_times"].diff()
        merged_df["current_isi"] = merged_df["timestamp_isi"] / SAMPLING_RATE
        
        if not merged_df.empty:
            recording_to_spike_df[recording_basename] = merged_df
       
    except Exception as e:
        print(e)

- Combining the spike time df for all recordings

In [None]:
all_spike_time_df = pd.concat(recording_to_spike_df.values())

In [None]:
all_spike_time_df = all_spike_time_df[all_spike_time_df["group"] == "good"].reset_index(drop=True)

In [None]:
all_spike_time_df.head()

In [None]:
all_spike_time_df.tail()

# Merging the trial information

- Adding a column that is the trial number

In [None]:


# Apply the function to each group and create the new column
all_trials_df = all_trials_df.groupby(["recording_file"]).apply(lambda x: find_index_in_group(x, "time", TRIAL_NUMBER_COL)).reset_index(drop="True")



In [None]:
all_trials_df.head()

In [None]:
all_trials_df.tail()

- Creating 10 ms time bins for each trial

In [None]:
all_trials_df["trial_chunked_ephys_timestamp"] = all_trials_df["time"].apply(lambda x: [int(x +  SAMPLING_RATE * num * 0.1) for num in range(-100,101)])

In [None]:
all_trials_df.head()

- Getting the closest trial number for each spike

In [None]:
# Getting a list of all the trials for each recording
recording_to_trials = {}
# Loop through each unique key
for key in all_trials_df['recording_file'].unique():
    # Filter the DataFrame based on the key and get the 'Value' column as a list
    recording_to_trials[key] = all_trials_df[all_trials_df['recording_file'] == key]['time'].tolist()


In [None]:
recording_to_trials[list(recording_to_trials.keys())[0]]

In [None]:
recording_to_trials.keys()

In [None]:
all_spike_time_df = all_spike_time_df[all_spike_time_df["recording_name"].isin(recording_to_trials.keys())]

In [None]:
# Calculating the timestamp of the closest tone onset for each spike 
all_spike_time_df["closest_trial"] = all_spike_time_df.apply(lambda row: find_closest(row["spike_times"], recording_to_trials[row["recording_name"]]), axis=1)

- Filtering out all spikes that are now within a 10 second range of the tone

In [None]:
all_spike_time_df = all_spike_time_df[(all_spike_time_df["spike_times"] > all_spike_time_df["closest_trial"] - 10 * SAMPLING_RATE) & (all_spike_time_df["spike_times"] < all_spike_time_df["closest_trial"] + 10 * SAMPLING_RATE)]

- Classifying each spike as being before or after the trial

In [None]:
all_spike_time_df["trial_or_baseline"] = all_spike_time_df.apply(lambda row: "trial" if row["spike_times"] >= row["closest_trial"] else "baseline", axis=1)

In [None]:
all_spike_time_df.head()

- Removing duplicate columns

In [None]:
all_trials_df = all_trials_df.drop_duplicates(subset=["recording_file", "time"], keep="first").reset_index()

- Merging the trial and spike timestamp dataframe based on shared recording and trial timestamp

In [None]:
merged_spike_trial_df = pd.merge(left=all_spike_time_df, right=all_trials_df, left_on=["recording_name", "closest_trial"], right_on=["recording_file", "time"], how="inner")

In [None]:
merged_spike_trial_df["timestamp_bin"] = merged_spike_trial_df.apply(lambda row: np.digitize(row["spike_times"], row["trial_chunked_ephys_timestamp"]) - 101, axis=1)

In [None]:
merged_spike_trial_df.head()

In [None]:
merged_spike_trial_df["timestamp_bin"].max()

In [None]:
merged_spike_trial_df["relative_time_to_tone"] = merged_spike_trial_df["spike_times"] - merged_spike_trial_df["closest_trial"]

In [None]:
merged_spike_trial_df["relative_time_to_tone"]

In [None]:
merged_spike_trial_df["recording_file"].unique()[0] 

In [None]:
from collections import defaultdict

In [None]:
total_number_of_trials_dict = all_trials_df.groupby(["recording_file", "trial_outcome"]).count()["index"].to_dict()

In [None]:
total_number_of_trials_dict

In [None]:
grouped_df = merged_spike_trial_df.groupby(["recording_file", "trial_number", "spike_clusters", "trial_outcome"])['relative_time_to_tone'].agg(list).reset_index()

# Rename the aggregated column

In [None]:
grouped_df["total_number_of_trials"] = grouped_df.apply(lambda row: total_number_of_trials_dict[(row["recording_file"], row["trial_outcome"])], axis=1)

In [None]:
grouped_df

In [None]:
example_file = grouped_df["recording_file"].unique()[1]
example_trial = grouped_df["trial_number"].unique()[3]

In [None]:
example_df = grouped_df[(grouped_df["recording_file"] == example_file) & (grouped_df["trial_number"] == example_trial)]

In [None]:
plt.eventplot(example_df["relative_time_to_tone"])

# Calculating average firing rate

In [None]:
merged_spike_trial_df

In [None]:
grouped_df = merged_spike_trial_df.groupby(["recording_file", "timestamp_bin", "spike_clusters", "trial_outcome", "fr"]).count()[["spike_times"]].reset_index()
# Rename the aggregated column

In [None]:
grouped_df["total_number_of_trials"] = grouped_df.apply(lambda row: total_number_of_trials_dict[(row["recording_file"], row["trial_outcome"])], axis=1)

In [None]:
grouped_df["spike_times"] = grouped_df["spike_times"] / grouped_df["total_number_of_trials"] / grouped_df["fr"]

In [None]:
grouped_df

In [None]:
grouped_df["spike_times"].mean()

In [None]:
pivot_df = grouped_df.pivot_table(index=['recording_file', 'spike_clusters', 'trial_outcome'], columns='timestamp_bin', values='spike_times', fill_value=0).reset_index().set_index("spike_clusters")

In [None]:
pivot_df

In [None]:
from matplotlib.colors import LogNorm, Normalize


In [None]:
for recording_file in pivot_df["recording_file"].unique():
    recording_df = pivot_df[pivot_df["recording_file"] == recording_file].copy()
    for outcome in recording_df["trial_outcome"].unique():
        outcome_df = recording_df[recording_df["trial_outcome"] == outcome].drop(columns=["recording_file", "trial_outcome"])
        
        sns.heatmap(outcome_df, annot=False, cmap='rocket_r', cbar_kws={'label': 'Firing Rate'}, norm=LogNorm())
        
        # Customizing the plot
        plt.title('Neuronal Firing Rates {} {}'.format(outcome, recording_file))
        plt.xlabel('Time Bin')
        plt.ylabel('Neuron ID')
        
        # Show the plot
        plt.show()


# Plotting PCA

In [None]:
merged_spike_trial_df["unique_neuron"] = "id" + merged_spike_trial_df["spike_clusters"].astype(str) + "_" + merged_spike_trial_df["recording_file"]

merged_spike_trial_df["time_bin_and_outcome"] = merged_spike_trial_df["timestamp_bin"].astype(str) + "_" + merged_spike_trial_df["trial_outcome"]

In [None]:
grouped_df = merged_spike_trial_df.groupby(["time_bin_and_outcome", "unique_neuron", "recording_file", "trial_outcome"]).count()[["spike_times"]].reset_index()
# Rename the aggregated column

In [None]:
grouped_df

In [None]:
grouped_df["total_number_of_trials"] = grouped_df.apply(lambda row: total_number_of_trials_dict[(row["recording_file"], row["trial_outcome"])], axis=1)

In [None]:
grouped_df["spike_times"] = grouped_df["spike_times"] / grouped_df["total_number_of_trials"]

In [None]:
grouped_df["spike_times"].mean()

In [None]:
pivot_df = grouped_df.pivot_table(index=['unique_neuron'], columns='time_bin_and_outcome', values='spike_times', fill_value=0).reset_index().set_index("unique_neuron")

In [None]:
pivot_df.head()

In [None]:
# performing preprocessing part
from sklearn.preprocessing import StandardScaler
sc = StandardScaler()

In [None]:
pivot_df.to_numpy()

In [None]:
scaled_firing_rates = sc.fit_transform(pivot_df.to_numpy())


In [None]:
# Applying PCA function on training
# and testing set of X component
from sklearn.decomposition import PCA
 


In [None]:
pca = PCA(n_components = 2)



In [None]:
pca_firing_rates = pca.fit_transform(scaled_firing_rates)


In [None]:
principal_df = pd.DataFrame(data = pca_firing_rates
             , columns = ['principal component 1', 'principal component 2'])

In [None]:
principal_df.head()

In [None]:
pc1_product = pivot_df.reset_index(drop=True).multiply(principal_df["principal component 1"], axis="index")

In [None]:
pc2_product = pivot_df.reset_index(drop=True).multiply(principal_df["principal component 2"], axis="index")

In [None]:
pc_product = pd.concat([pc1_product.mean(), pc2_product.mean()], axis=1).reset_index()

In [None]:
pc_product

In [None]:
pc_product["bin_time"] = pc_product["time_bin_and_outcome"].apply(lambda x: int(x.split("_")[0]))
pc_product["trial_type"] = pc_product["time_bin_and_outcome"].apply(lambda x: (x.split("_")[1]))

In [None]:
outcome_to_color = {"lose": "orange", "rewarded": "green", "win": "blue", "omission": "red"}

In [None]:
pc_product["color"] = pc_product["trial_type"].map(outcome_to_color)

In [None]:
pc_product

In [None]:
pc_product["bin_time"] = pc_product["bin_time"].astype(int)

all_outcome_df = []
for outcome in pc_product["trial_type"].unique():
    outcome_df = pc_product[pc_product["trial_type"] == outcome]
    outcome_df = outcome_df.sort_values(["bin_time"])
    outcome_df['0'] = outcome_df[0].rolling(10).mean()    
    outcome_df['1'] = outcome_df[1].rolling(10).mean()
    all_outcome_df.append(outcome_df)

In [None]:
pc_product = pd.concat(all_outcome_df).dropna()

In [None]:
pc_product

In [None]:
sigma = 3
divider = 5

In [None]:
# Create scatter plot
# plt.plot(pc_product[0], pc_product[1], color=pc_product["color"])

for outcome in pc_product["trial_type"].unique():
    
    outcome_df = pc_product[pc_product["trial_type"] == outcome].sort_values("bin_time")
    plt.scatter(outcome_df[0], outcome_df[1], color=outcome_to_color[outcome], alpha=0.1)
    
    smoothed_x = gaussian_filter1d(outcome_df[0][::divider], sigma=sigma)
    
    # smoothed_x = outcome_df[0].rolling(window=10).mean()[::10]
    smoothed_y = gaussian_filter1d(outcome_df[1][::divider], sigma=sigma)
    # smoothed_y = outcome_df[1].rolling(window=10).mean()[::10]

    plt.scatter(smoothed_x[0], smoothed_y[0], color=outcome_to_color[outcome], marker = 'o', s=100)
    plt.scatter(smoothed_x[len(smoothed_x)//2], smoothed_y[len(smoothed_y)//2], color=outcome_to_color[outcome], marker = '^', s=100)
    plt.scatter(smoothed_x[len(smoothed_x)-1], smoothed_y[len(smoothed_y)-1], color=outcome_to_color[outcome], marker = 's', s=100)

    
    # Create smoothed line plot
    plt.plot(smoothed_x, smoothed_y, color=outcome_to_color[outcome], ls=':', label=outcome, linewidth=3)

plt.legend()
plt.xlabel("PC1")
plt.ylabel("PC2")
plt.title("PCA projection of neural activity (Square >> Triangle >> Circle)")
plt.savefig("./pca_rce.png")
# Show plot
plt.show()

In [None]:
raise ValueError()