In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import numpy.random as npr
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec  # Import for custom grid layout
from scipy import stats
import pickle
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.utils import resample
from joblib import Parallel, delayed
import plotly.graph_objects as go
import plotly.express as px

from notebooks.imports import *
from config import dir_config, main_config, ephys_config
from src.utils import pmf_utils, plot_utils, ephys_utils


In [3]:
compiled_dir = Path(dir_config.data.compiled)
processed_dir = Path(dir_config.data.processed)

## Utils

In [None]:
def plot_PSTH(neuron_id, epochs=["target_onset", "stimulus_onset", "response_onset"]):
    _, ax = plt.subplots(1, len(epochs), figsize=(15, 5))
    ax_idx = 0
    if "target_onset" in epochs:
        ax[ax_idx].plot(np.nanmean(target_onset["biased_state"][neuron_id], axis=0), label="biased")
        ax[ax_idx].plot(np.nanmean(target_onset["unbiased_state"][neuron_id], axis=0), label="unbiased")
        ax[ax_idx].vlines(200, 0, 10, color="black", linestyle="--", linewidth=1)
        ax[ax_idx].set_title("Target Onset")
        ax_idx += 1
    if "stimulus_onset" in epochs:
        ax[ax_idx].plot(np.nanmean(stimulus_onset["biased_state"][neuron_id], axis=0), label="biased")
        ax[ax_idx].plot(np.nanmean(stimulus_onset["unbiased_state"][neuron_id], axis=0), label="unbiased")
        ax[ax_idx].vlines(100, 0, 10, color="black", linestyle="--", linewidth=1)
        ax[ax_idx].set_title("Stimulus Onset")
        ax_idx += 1
    if "response_onset" in epochs:
        ax[ax_idx].plot(np.nanmean(response_onset["biased_state"][neuron_id], axis=0), label="biased")
        ax[ax_idx].plot(np.nanmean(response_onset["unbiased_state"][neuron_id], axis=0), label="unbiased")
        ax[ax_idx].vlines(300, 0, 10, color="black", linestyle="--", linewidth=1)
        ax[ax_idx].set_title("Response Onset")
        ax_idx += 1
    ax[-1].legend(bbox_to_anchor=(1.05, 1), loc='upper left')

def get_neuron_condition_trials(sessions, trial_info):
    """
    Optimized function to get neuron condition trials with reduced function calls.
    """
    neuron_condition_dict = {}
    conditions_list = [
        ("coh_0_choice_toRF_corr", (0, 1)),
        ("coh_6_choice_toRF_corr", (0.06, 1, 1)),
        ("coh_20_choice_toRF_corr", (0.2, 1, 1)),
        ("coh_50_choice_toRF_corr", (0.5, 1, 1)),
        ("coh_0_choice_awayRF_corr", (0, 0)),
        ("coh_6_choice_awayRF_corr", (0.06, 0, 1)),
        ("coh_20_choice_awayRF_corr", (0.2, 0, 1)),
        ("coh_50_choice_awayRF_corr", (0.5, 0, 1))
    ]

    for session_id in sessions:
        neuron_ids = neuron_metadata.neuron_id[neuron_metadata.session_id == session_id].values
        conditions = {key: np.array(ephys_utils.get_trial_num(trial_info[session_id], *vals)) for key, vals in conditions_list}
        
        for neuron_id in neuron_ids:
            neuron_condition_dict[neuron_id] = conditions.copy()  # Use copy to prevent reference issues

    return neuron_condition_dict

def create_pca_matrix_by_condition(neuron_condition_dict, normalize = True):
    conditions = [
        "coh_0_choice_toRF_corr",
        "coh_6_choice_toRF_corr",
        "coh_20_choice_toRF_corr",
        "coh_50_choice_toRF_corr",
        "coh_0_choice_awayRF_corr",
        "coh_6_choice_awayRF_corr",
        "coh_20_choice_awayRF_corr",
        "coh_50_choice_awayRF_corr"
        ]
    PCA_data = {event: [] for event in ephys_config["alignment_settings_GP"].keys()}
    condition_len = {event: {} for event in ephys_config["alignment_settings_GP"].keys()}    

    for alignment in ephys_config["alignment_settings_GP"].keys():
        for condition in conditions:
            time_duration = ephys_config["alignment_settings_GP"][alignment]["end_time_ms"] - ephys_config["alignment_settings_GP"][alignment]["start_time_ms"] + 1
            all_neuron_data = np.full((len(neuron_condition_dict.keys()), time_duration), np.nan)
            for neuron_idx, neuron_id in enumerate(neuron_condition_dict.keys()):        
                trials = neuron_condition_dict[neuron_id][condition]
                condition_data = ephys_utils.get_neural_data_from_trial_num(ephys[alignment][neuron_id], trials, type="convolved_spike_trains")
                if alignment == "stimulus_onset":
                    non_nan_50_prct_timepoint = np.where(np.sum(np.isnan(condition_data), axis=0)/len(trials) > 0.5)[0]
                    if non_nan_50_prct_timepoint.size==0:
                        non_nan_50_prct_timepoint = condition_data.shape[1]
                    else:
                        non_nan_50_prct_timepoint = non_nan_50_prct_timepoint[0] - 1

                    neuron_array = condition_data[:,:non_nan_50_prct_timepoint]
                    all_neuron_data[neuron_idx, :non_nan_50_prct_timepoint] = np.nanmean(neuron_array,axis=0)
            
                elif alignment == "response_onset":
                    non_nan_50_prct_timepoint = np.where(np.sum(np.isnan(condition_data), axis=0)/len(trials) > 0.5)[0]
                    if non_nan_50_prct_timepoint.size==0:
                        non_nan_50_prct_timepoint = 0
                    else:
                        non_nan_50_prct_timepoint = non_nan_50_prct_timepoint[-1] + 1
                    neuron_array = condition_data[:,non_nan_50_prct_timepoint:]
                    all_neuron_data[neuron_idx, non_nan_50_prct_timepoint:] = np.nanmean(neuron_array,axis=0)
                elif alignment == "target_onset":
                    all_neuron_data[neuron_idx,:] = np.nanmean(condition_data,axis=0)

            
            all_non_nan_mask = ~np.any(np.isnan(all_neuron_data), axis=0)
            all_neuron_data = all_neuron_data[:, all_non_nan_mask]
            PCA_data[alignment].append(all_neuron_data)
            if all_neuron_data.shape[1] == 0:
                pass
            condition_len[alignment][condition] = all_neuron_data.shape[1]

        
        PCA_data[alignment] = np.hstack(PCA_data[alignment])

        if normalize:
            scaler = StandardScaler()
            PCA_data[alignment] = scaler.fit_transform(PCA_data[alignment].T).T

    return PCA_data, condition_len

def go_scatter(fig, x, y, z=None, name=" ", color=None, markersize=2, linestyle="solid", linewidth=2, mode="lines+markers", opacity=0.8, showlegend=False):
    """Helper function to add a 2D or 3D scatter plot to a Plotly figure."""
    if z is None:
        fig.add_trace(go.Scatter(
            x=x, y=y,
            mode=mode,
            marker=dict(
                size=markersize,
                color=color if color is not None else 'black',
                opacity=opacity
            ),
            line=dict(width=linewidth, dash=linestyle),
            opacity=opacity,
            name=name,
            showlegend=showlegend
        ))
    else:
        fig.add_trace(go.Scatter3d(
            x=x, y=y, z=z,
            mode=mode,
            marker=dict(
                size=markersize,
                color=color if color is not None else 'black',
                opacity=opacity
            ),
            line=dict(width=linewidth, dash=linestyle),
            opacity=opacity,  
            name=name,
            showlegend=showlegend
        ))

def plot_pca_projection(fig, conditions, state_values,
                        proj_biased_mean, proj_unbiased_mean,
                        timepoints_list, onset_time, axes=(0, 1)):
    """
    Function to plot PCA projections with SEM as shaded areas.
    Now supports 2D and 3D plotting by choosing the number of axes.
    """
    condition_dict = {
        "coh_0_choice_toRF_corr": {"index": 0, "color": "blue", "biased_lw": 5, "unbiased_lw": 2, "opacity": 1},
        "coh_6_choice_toRF_corr": {"index": 1, "color": "green", "biased_lw": 5, "unbiased_lw": 2, "opacity": 1},
        "coh_20_choice_toRF_corr": {"index": 2, "color": "orange", "biased_lw": 5, "unbiased_lw": 2, "opacity":1},
        "coh_50_choice_toRF_corr": {"index": 3, "color": "red", "biased_lw": 5, "unbiased_lw": 2, "opacity": 1},
        "coh_0_choice_awayRF_corr": {"index": 4, "color": "blue", "biased_lw": 5, "unbiased_lw": 2, "opacity": 0.3},
        "coh_6_choice_awayRF_corr": {"index": 5, "color": "green", "biased_lw": 5, "unbiased_lw": 2, "opacity": 0.3},
        "coh_20_choice_awayRF_corr": {"index": 6, "color": "orange", "biased_lw": 5, "unbiased_lw": 2, "opacity": 0.3},
        "coh_50_choice_awayRF_corr": {"index": 7, "color": "red", "biased_lw": 5, "unbiased_lw": 2, "opacity": 0.3},
    }

    for condition in conditions:
        condition_params = condition_dict[condition]  # Extract plotting parameters
        # condition = condition_params["index"]
        
        for state in state_values:
            name = f"{condition}_{state}"
            # Select mean and parameters
            if state == "biased_state":
                mean = proj_biased_mean[condition]
                linewidth = condition_params["biased_lw"]
            else:
                mean = proj_unbiased_mean[condition]
                linewidth = condition_params["unbiased_lw"]

            # Plot using the selected axes
            if len(axes) == 2:
                go_scatter(fig, mean[axes[0]], mean[axes[1]], 
                           name=name, color=condition_params["color"], 
                           mode="lines", linewidth=linewidth, opacity=condition_params["opacity"],showlegend=True)
            elif len(axes) == 3:
                go_scatter(fig, mean[axes[0]], mean[axes[1]], mean[axes[2]],
                           name=name, color=condition_params["color"], 
                           mode="lines", linewidth=linewidth, opacity=condition_params["opacity"],showlegend=True)

    # Highlight key time points
    for idx, timepoint in enumerate(timepoints_list):
        for condition in conditions:
            condition_params = condition_dict[condition]
            
            for state in state_values:
                marker_size = 8 if idx == 0 else 5
                if state == "biased_state":
                    if ~np.all(np.isnan(proj_biased_mean[condition][:,timepoint])):
                        x, y = proj_biased_mean[condition][axes[0], timepoint], proj_biased_mean[condition][axes[1], timepoint]
                        z = proj_biased_mean[condition][axes[2], timepoint] if len(axes) == 3 else None
                        go_scatter(fig, [x], [y], [z] if z is not None else None,
                           color="black", markersize=marker_size, mode="markers")
                else:
                    if ~np.all(np.isnan(proj_unbiased_mean[condition][:,timepoint])):
                        x, y = proj_unbiased_mean[condition][axes[0], timepoint], proj_unbiased_mean[condition][axes[1], timepoint]
                        z = proj_unbiased_mean[condition][axes[2], timepoint] if len(axes) == 3 else None 
                        go_scatter(fig, [x], [y], [z] if z is not None else None,
                           color="black", markersize=marker_size, mode="markers")

    # Mark onset time with violet dots
    for condition in conditions:
        condition_params = condition_dict[condition]
        # condition = condition_params["index"]
        
        for state in state_values:
            if state == "biased_state":
                if ~np.all(np.isnan(proj_biased_mean[condition][:,onset_time])):
                    x, y = proj_biased_mean[condition][axes[0], onset_time], proj_biased_mean[condition][axes[1], onset_time]
                    z = proj_biased_mean[condition][axes[2], onset_time] if len(axes) == 3 else None
                    go_scatter(fig, [x], [y], [z] if z is not None else None,
                       color="violet", markersize=10, mode="markers")
            else:
                if ~np.all(np.isnan(proj_unbiased_mean[condition][:,onset_time])):
                    x, y = proj_unbiased_mean[condition][axes[0], onset_time], proj_unbiased_mean[condition][axes[1], onset_time]
                    z = proj_unbiased_mean[condition][axes[2], onset_time] if len(axes) == 3 else None
                    go_scatter(fig, [x], [y], [z] if z is not None else None,
                       color="violet", markersize=10, mode="markers")
                    
    go_scatter(fig,[None],[None],[None] if len(axes) == 3 else None, name="Onset Time", color="violet", markersize=10, mode="markers", showlegend=True)
    go_scatter(fig,[None],[None],[None] if len(axes) == 3 else None, name=f"{timepoints_list[1]-timepoints_list[0]}ms spacing", color="black", markersize=8, mode="markers", showlegend=True)

def plot_pca_projection_and_sem(fig, conditions, state_values,
                        proj_biased_mean, proj_unbiased_mean,
                        proj_biased_std, proj_unbiased_std,
                        timepoints_list, onset_time, axes=(0, 1)):
    """
    Function to plot PCA projections with Mean and Std as shaded areas.
    Supports both 2D and 3D plotting.
    """
    condition_dict = {
        "coh_0_choice_toRF_corr": {"color": "blue", "opacity": 1},
        "coh_6_choice_toRF_corr": {"color": "green", "opacity": 1},
        "coh_20_choice_toRF_corr": {"color": "orange", "opacity": 1},
        "coh_50_choice_toRF_corr": {"color": "red", "opacity": 1},
        "coh_0_choice_awayRF_corr": {"color": "blue", "opacity": 0.3},
        "coh_6_choice_awayRF_corr": {"color": "green", "opacity": 0.3},
        "coh_20_choice_awayRF_corr": {"color": "orange", "opacity": 0.3},
        "coh_50_choice_awayRF_corr": {"color": "red", "opacity": 0.3},
    }
    
    for condition in conditions:
        condition_params = condition_dict[condition]
        color = condition_params["color"]
        opacity = condition_params["opacity"]
        
        for state in state_values:
            name = f"{condition}_{state}"
            
            if state == "biased_state":
                mean = proj_biased_mean[condition]
                std = proj_biased_std[condition]
            else:
                mean = proj_unbiased_mean[condition]
                std = proj_unbiased_std[condition]
            if len(axes) == 2:
                valid_mask = ~np.isnan(mean[axes[0]]) & ~np.isnan(std[axes[0]]) & ~np.isnan(mean[axes[1]]) & ~np.isnan(std[axes[1]])
                # Plot mean
                go_scatter(fig, mean[axes[0]][valid_mask], mean[axes[1]][valid_mask], 
                           name=name, color=color, mode="lines", linewidth=2, opacity=opacity, showlegend=True)
                # Plot standard deviation as shaded area
                x_valid = np.concatenate([
                    (mean[axes[0]][valid_mask] + std[axes[0]][valid_mask]*5), 
                    (mean[axes[0]][valid_mask] - std[axes[0]][valid_mask]*5)[::-1]])
                y_valid = np.concatenate([
                    (mean[axes[1]][valid_mask] + std[axes[1]][valid_mask]*5), 
                    (mean[axes[1]][valid_mask] - std[axes[1]][valid_mask]*5)[::-1]])
                fig.add_trace(go.Scatter(
                    x=x_valid, 
                    y=y_valid,
                    fill='toself',
                    fillcolor=color,
                    opacity=0.2,
                    line=dict(width=0),
                    showlegend=False,
                    name=f"{name}_std"
                ))

            elif len(axes) == 3:
                # Plot mean
                go_scatter(fig, mean[axes[0]], mean[axes[1]], mean[axes[2]],
                           name=name, color=color, mode="lines", linewidth=2, opacity=opacity, showlegend=True)
                # Plot standard deviation as lines (shading is not directly possible in 3D)
                for std_offset in [-1, 1]:
                # Offset for all three axes
                    go_scatter(fig, 
                            mean[axes[0]] + std_offset * std[axes[0]],  # X with std
                            mean[axes[1]] + std_offset * std[axes[1]],  # Y with std
                            mean[axes[2]] + std_offset * std[axes[2]],  # Z with std
                            name=f"{name}_std", color=color, mode="lines", opacity=0.2, showlegend=False)

    # Highlight key time points
    for idx, timepoint in enumerate(timepoints_list):
        for condition in conditions:
            condition_params = condition_dict[condition]
            
            for state in state_values:
                marker_size = 8 if idx == 0 else 5
                if state == "biased_state":
                    if ~np.all(np.isnan(proj_biased_mean[condition][:,timepoint])):
                        x, y = proj_biased_mean[condition][axes[0], timepoint], proj_biased_mean[condition][axes[1], timepoint]
                        z = proj_biased_mean[condition][axes[2], timepoint] if len(axes) == 3 else None
                        go_scatter(fig, [x], [y], [z] if z is not None else None,
                           color="black", markersize=marker_size, mode="markers")
                else:
                    if ~np.all(np.isnan(proj_unbiased_mean[condition][:,timepoint])):
                        x, y = proj_unbiased_mean[condition][axes[0], timepoint], proj_unbiased_mean[condition][axes[1], timepoint]
                        z = proj_unbiased_mean[condition][axes[2], timepoint] if len(axes) == 3 else None 
                        go_scatter(fig, [x], [y], [z] if z is not None else None,
                           color="black", markersize=marker_size, mode="markers")
                        
    # Mark onset time with violet dots
    for condition in conditions:
        for state in state_values:
            if state == "biased_state":
                mean = proj_biased_mean[condition]
            else:
                mean = proj_unbiased_mean[condition]
            
            if ~np.all(np.isnan(mean[:, onset_time])):
                x, y = mean[axes[0], onset_time], mean[axes[1], onset_time]
                z = mean[axes[2], onset_time] if len(axes) == 3 else None
                go_scatter(fig, [x], [y], [z] if z is not None else None,
                           color="violet", markersize=10, mode="markers")
    
    go_scatter(fig,[None],[None],[None] if len(axes) == 3 else None, name="Onset Time", color="violet", markersize=10, mode="markers", showlegend=True)
    go_scatter(fig,[None],[None],[None] if len(axes) == 3 else None, name=f"{timepoints_list[1]-timepoints_list[0]}ms spacing", color="black", markersize=8, mode="markers", showlegend=True)

def bootstrap_iteration(idx_bootstrap, neuron_condition_dict, pc_weights, normalize): 
    """
    Optimized bootstrap iteration using NumPy and vectorized operations.
    """  
    # Only resample trials, avoiding unnecessary dictionary modifications
    bootstrapped_dict = {
        neuron_id: {cond: (resample(trials, random_state=idx_bootstrap) if idx_bootstrap > 0 else trials)
                    for cond, trials in conditions.items()}
        for neuron_id, conditions in neuron_condition_dict.items()
    }

    # Generate PCA data
    PCA_data, cond_len = create_pca_matrix_by_condition(bootstrapped_dict, normalize=normalize)

    pc_proj_cond = {
        alignment: split_array_by_cond(
            pc_weights[alignment][:, :3].T @ PCA_data[alignment], 
            cond_len[alignment], 
            alignment
        )
        for alignment in ephys_config["alignment_settings_GP"]
    }

    # return idx_bootstrap, pc_proj_cond
    return pc_proj_cond

def bootstrap_PCA(sessions, trial_info, pc_weights, normalize=True, n_bootstraps=1000):
    """
    Optimized parallelized PCA bootstrapping.
    """
    pc_projection = {
        "mean": {event: None for event in ephys_config["alignment_settings_GP"]},
        "bootstrap_sem": {event: {} for event in ephys_config["alignment_settings_GP"]},
        "bootstrap": {event: [] for event in ephys_config["alignment_settings_GP"]}
    }

    # Precompute neuron condition trials once to avoid redundancy
    neuron_condition_dict = get_neuron_condition_trials(sessions, trial_info)

    # Run bootstrap iterations in parallel
    results = Parallel(n_jobs=-1, backend="loky", batch_size=10, verbose=True)(
        delayed(bootstrap_iteration)(idx, neuron_condition_dict, pc_weights, normalize)
        for idx in range(n_bootstraps + 1)
    )

    # Process results efficiently
    for alignment in ephys_config["alignment_settings_GP"]:
        first_result = results[0][alignment]
        bootstrap_results = np.array([proj[alignment] for proj in results[1:]])

        pc_projection["mean"][alignment] = first_result
        pc_projection["bootstrap"][alignment] = bootstrap_results      
        for condition in first_result.keys():
            pc_projection["bootstrap_sem"][alignment][condition] = stats.sem(
                np.array([bootstrap_results[iteration][condition] for iteration in range(len(bootstrap_results))]), axis=0, nan_policy='omit')

    return pc_projection

def split_array_by_cond(data, array_lengths,alignment):
    duration = ephys_config["alignment_settings_GP"][alignment]["end_time_ms"] - ephys_config["alignment_settings_GP"][alignment]["start_time_ms"] + 1
    split_arrays = {key: np.full((3,duration), np.nan) for key in array_lengths.keys()}
    start_idx = 0
    # print(array_lengths)
    for key, length in array_lengths.items():        
        end_idx = start_idx + length
        # print(start_idx, end_idx)

        if alignment == "response_onset":
            split_arrays[key][:,-(end_idx-start_idx):] = data[:, start_idx:end_idx]
        else:
            split_arrays[key][:,:(end_idx-start_idx)] = data[:, start_idx:end_idx]
        start_idx = end_idx
    return split_arrays


## Load Data

In [5]:
with open(Path(processed_dir, f'glm_hmm_all_trials_prior_based_initialization_final.pkl'), 'rb') as f:
    glm_hmm = pickle.load(f)

with open(Path(processed_dir, f'ephys_neuron_wise.pkl'), 'rb') as f:
    ephys = pickle.load(f)

session_metadata = pd.read_csv(Path(compiled_dir, 'sessions_metadata.csv'))
neuron_metadata = pd.read_csv(Path(compiled_dir, 'neuron_metadata.csv'))

sessions_to_remove = ["241209_GP_TZ"]
session_metadata = session_metadata[~session_metadata["session_id"].isin(sessions_to_remove)]
neuron_metadata = neuron_metadata[~neuron_metadata["session_id"].isin(sessions_to_remove)]

## Extract biased and unbiased states

In [6]:
toRF_sessions = session_metadata["session_id"][session_metadata.prior_direction == "toRF"]
awayRF_sessions = session_metadata["session_id"][session_metadata.prior_direction == "awayRF"]

In [7]:
state_occupancy = {}
for idx_session, session_id in enumerate(glm_hmm["session_wise"]["data"]):
    model = glm_hmm["session_wise"]["models"][session_id]
    choices = glm_hmm["session_wise"]["data"][session_id]["choices"].values.reshape(-1, 1)
    input = np.array(glm_hmm["session_wise"]["data"][session_id][["normalized_stimulus","bias","previous_choice","previous_target"]])
    if glm_hmm["session_wise"]["data"][session_id]["mask"] is None:
        mask = None
    else:
        mask = glm_hmm["session_wise"]["data"][session_id]["mask"]
    mask = np.ones_like(choices, dtype=bool) if mask is None else mask

    posterior_probs = model.expected_states(data=choices, input=input, mask=np.array(mask).reshape(-1,1))[0]
    biased_idx = (posterior_probs[:, 1] > 0.5) & np.array(mask)
    unbiased_idx = (posterior_probs[:, 0] > 0.5) & np.array(mask)
    state_occupancy[session_id] = {
        "biased_state_trials": glm_hmm["session_wise"]["data"][session_id]["trial_num"][biased_idx],
        "unbiased_state_trials": glm_hmm["session_wise"]["data"][session_id]["trial_num"][unbiased_idx]}


In [8]:
biased_state_trial_info = {}
unbiased_state_trial_info = {}
for session_id in session_metadata.session_id:
    biased_state_trials = state_occupancy[session_id]["biased_state_trials"]
    biased_state_trial_info[session_id] = glm_hmm["session_wise"]["data"][session_id].loc[np.isin(glm_hmm["session_wise"]["data"][session_id]["trial_num"], biased_state_trials)]
    unbiased_state_trials = state_occupancy[session_id]["unbiased_state_trials"]
    unbiased_state_trial_info[session_id] = glm_hmm["session_wise"]["data"][session_id].loc[np.isin(glm_hmm["session_wise"]["data"][session_id]["trial_num"], unbiased_state_trials)]

## Ephys!!!

### Preprocessing for PCA

In [9]:
PCA_data_toRF,_ = create_pca_matrix_by_condition(get_neuron_condition_trials(toRF_sessions, glm_hmm["session_wise"]["data"]), normalize = True)
PCA_data_awayRF,_ = create_pca_matrix_by_condition(get_neuron_condition_trials(awayRF_sessions, glm_hmm["session_wise"]["data"]), normalize = True)

### Compute PCA

In [10]:
n_components = 10
pca = PCA(n_components=n_components)
toRF_pc, toRF_pc_weights = {}, {}
awayRF_pc, awayRF_pc_weights = {}, {}
for alignment in ephys_config["alignment_settings_GP"].keys():
    toRF_pc[alignment] = pca.fit(PCA_data_toRF[alignment])
    toRF_pc_weights[alignment] = pca.fit_transform(PCA_data_toRF[alignment])
    awayRF_pc[alignment] = pca.fit(PCA_data_awayRF[alignment])
    awayRF_pc_weights[alignment] = pca.fit_transform(PCA_data_awayRF[alignment])

### Projections

In [11]:
PCA_data_prior_toRF_biased_state, cond_len_prior_to_RF_biased_state = create_pca_matrix_by_condition(get_neuron_condition_trials(toRF_sessions, biased_state_trial_info), normalize = True)
PCA_data_prior_toRF_unbiased_state,  cond_len_prior_to_RF_unbiased_state = create_pca_matrix_by_condition(get_neuron_condition_trials(toRF_sessions, unbiased_state_trial_info), normalize = True)

In [12]:
# project neural data onto first 3 PCs
pc_biased_state_conds, pc_unbiased_state_conds = {},{}
for alignment in ephys_config["alignment_settings_GP"].keys():
    pc_proj_biased_state = (toRF_pc_weights[alignment][:,:3].T @ PCA_data_prior_toRF_biased_state[alignment])
    pc_proj_unbiased_state = (toRF_pc_weights[alignment][:,:3].T @ PCA_data_prior_toRF_unbiased_state[alignment])
    # split by conditions
    pc_biased_state_conds[alignment] = split_array_by_cond(pc_proj_biased_state, cond_len_prior_to_RF_biased_state[alignment], alignment)
    pc_unbiased_state_conds[alignment] = split_array_by_cond(pc_proj_unbiased_state, cond_len_prior_to_RF_unbiased_state[alignment], alignment)

#### bootstrap

In [21]:
pc_projection_biased_state_toRF_prior = bootstrap_PCA(toRF_sessions, biased_state_trial_info, toRF_pc_weights, normalize=True)

[Parallel(n_jobs=-1)]: Using backend LokyBackend with 16 concurrent workers.
[Parallel(n_jobs=-1)]: Done 212 tasks      | elapsed: 31.2min
[Parallel(n_jobs=-1)]: Done 970 out of 1001 | elapsed: 32.3min remaining:  1.0min
[Parallel(n_jobs=-1)]: Done 1001 out of 1001 | elapsed: 32.4min finished


In [20]:
pc_projection_unbiased_state_toRF_prior = bootstrap_PCA(toRF_sessions, unbiased_state_trial_info, toRF_pc_weights, normalize=True)

[Parallel(n_jobs=-1)]: Using backend LokyBackend with 16 concurrent workers.
[Parallel(n_jobs=-1)]: Done 212 tasks      | elapsed:   55.7s
[Parallel(n_jobs=-1)]: Done 1001 out of 1001 | elapsed:  2.2min finished


In [22]:
pc_projection_biased_state_awayRF_prior = bootstrap_PCA(awayRF_sessions, biased_state_trial_info, awayRF_pc_weights, normalize=True)

[Parallel(n_jobs=-1)]: Using backend LokyBackend with 16 concurrent workers.
[Parallel(n_jobs=-1)]: Done 212 tasks      | elapsed: 32.1min
[Parallel(n_jobs=-1)]: Done 1001 out of 1001 | elapsed: 33.3min finished


In [23]:
pc_projection_unbiased_state_awayRF_prior = bootstrap_PCA(awayRF_sessions, unbiased_state_trial_info, awayRF_pc_weights, normalize=True)

[Parallel(n_jobs=-1)]: Using backend LokyBackend with 16 concurrent workers.
[Parallel(n_jobs=-1)]: Done 212 tasks      | elapsed: 31.1min
[Parallel(n_jobs=-1)]: Done 970 out of 1001 | elapsed: 32.2min remaining:  1.0min
[Parallel(n_jobs=-1)]: Done 1001 out of 1001 | elapsed: 32.2min finished


In [24]:
pc_projection = {"toRF_prior": {
                    "biased_state": pc_projection_biased_state_toRF_prior, 
                    "unbiased_state": pc_projection_unbiased_state_toRF_prior
                    },
                "awayRF_prior": {
                    "biased_state": pc_projection_biased_state_awayRF_prior,
                    "unbiased_state": pc_projection_unbiased_state_awayRF_prior
                    },
                }


import pickle
with open(Path(processed_dir, f'pc_projection.pkl'), 'wb') as f:
    pickle.dump(pc_projection, f)

## Plot neural trajectories

In [85]:
alignment = "stimulus_onset"
state_values = ["biased_state","unbiased_state"]  # States

# Define colors for each coherence level
alignment_dict = ephys_config["alignment_settings_GP"][alignment]
conditions = [
        "coh_0_choice_toRF_corr",
        # "coh_6_choice_toRF_corr",
        # "coh_20_choice_toRF_corr",
        "coh_50_choice_toRF_corr",
        # "coh_0_choice_awayRF_corr",
        # "coh_6_choice_awayRF_corr",
        # "coh_20_choice_awayRF_corr",
        # "coh_50_choice_awayRF_corr"
        ]

if alignment == "response_onset":
    timepoints_list = np.arange(0, alignment_dict["end_time_ms"] - alignment_dict["start_time_ms"] + 1, 50)
else:
    timepoints_list=np.arange(0, alignment_dict["end_time_ms"] - alignment_dict["start_time_ms"] + 1, 100)

# Create a figure
fig = go.Figure()
# plot_pca_projection(fig, conditions, state_values,
#                     pc_biased_state_conds[alignment], pc_unbiased_state_conds[alignment],
#                     timepoints_list=timepoints_list, onset_time = - alignment_dict["start_time_ms"],axes=(0,2))
plot_pca_projection(fig, conditions, state_values,
                    pc_projection_biased_state_toRF_prior["mean"][alignment], pc_projection_unbiased_state_toRF_prior["mean"][alignment],
                    timepoints_list=timepoints_list, onset_time = - alignment_dict["start_time_ms"],axes=(0,1,2))

# Set plot labels
fig.update_layout(
    width = 1600,
    height= 1200,
    scene=dict(
        xaxis_title="PC1",
        yaxis_title="PC2",
        zaxis_title="PC3"
    ),
    title="Interactive 3D PCA Projection"
)

fig.show()

### SEM plotting

In [93]:
alignment = "stimulus_onset"
state_values = ["biased_state","unbiased_state"]  # States

# Define colors for each coherence level
alignment_dict = ephys_config["alignment_settings_GP"][alignment]
conditions = [
        "coh_0_choice_toRF_corr",
        # "coh_6_choice_toRF_corr",
        # "coh_20_choice_toRF_corr",
        "coh_50_choice_toRF_corr",
        # "coh_0_choice_awayRF_corr",
        # "coh_6_choice_awayRF_corr",
        # "coh_20_choice_awayRF_corr",
        # "coh_50_choice_awayRF_corr"
        ]

if alignment == "response_onset":
    timepoints_list = np.arange(0, alignment_dict["end_time_ms"] - alignment_dict["start_time_ms"] + 1, 50)
else:
    timepoints_list=np.arange(0, alignment_dict["end_time_ms"] - alignment_dict["start_time_ms"] + 1, 100)

# Create a figure
fig = go.Figure()
plot_pca_projection_and_sem(fig, conditions, state_values,
                    pc_projection_biased_state_toRF_prior["mean"][alignment], pc_projection_unbiased_state_toRF_prior["mean"][alignment], pc_projection_biased_state_toRF_prior["bootstrap_sem"][alignment], pc_projection_unbiased_state_toRF_prior["bootstrap_sem"][alignment],
                    timepoints_list=timepoints_list, onset_time = - alignment_dict["start_time_ms"], axes=(1,2))
# Set plot labels
fig.update_layout(
    width = 1600,
    height= 1200,
    scene=dict(
        xaxis_title="PC1",
        yaxis_title="PC2",
        zaxis_title="PC3"
    ),
    title="Interactive 3D PCA Projection"
)

fig.show()
