%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import numpy.random as npr
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec  # Import for custom grid layout
from scipy import stats
from sklearn.model_selection import KFold
import pickle
import ssm

from notebooks.imports import *
from config import dir_config, main_config
from src.utils import pmf_utils, plot_utils


In [None]:
compiled_dir = Path(dir_config.data.compiled)
processed_dir = Path(dir_config.data.processed)

## Utils

## Load Data

In [None]:
neuron_metadata = pd.read_csv(Path(compiled_dir, 'neuron_metadata.csv'))

with open(Path(processed_dir, f'glm_hmm_all_trials_final.pkl'), 'rb') as f:
    glm_hmm = pickle.load(f)
    
with open(Path(processed_dir, f'ephys_neuron_wise.pkl'), 'rb') as f:
    ephys = pickle.load(f)

## Extract biased and unbiased states

In [None]:
bias_weights = []
state_occupancy = {}
for idx_session, session_id in enumerate(glm_hmm["session_wise"]["data"]):
    model = glm_hmm["session_wise"]["models"][session_id]
    glm_weights = -np.array(model.observations.params).reshape(2, -1)
    prior_direction = 1 if glm_hmm["session_wise"]["data"][session_id]["prob_toRF"].iloc[-1] > 50 else -1
    glm_weights[:,1] = glm_weights[:,1]*prior_direction
    if prior_direction == -1:
        glm_weights = np.flip(glm_weights, axis=0)
    bias_weights.append(glm_weights[:,1])
    
    
    choices = glm_hmm["session_wise"]["data"][session]["choices"].values.reshape(-1, 1)
    input = np.array(glm_hmm["session_wise"]["data"][session][["normalized_stimulus","bias","previous_choice","previous_target"]])
    if glm_hmm["session_wise"]["data"][session]["mask"] is None:
        mask = None
    else:
        mask = glm_hmm["session_wise"]["data"][session]["mask"]
    mask = np.ones_like(choices, dtype=bool) if mask is None else mask
    
    posterior_probs = model.expected_states(data=glm_hmm["session_wise"]["data"][session_id]["choices"], input=input, mask=np.array(mask).reshape(-1,1))[0]
    
    
bias_weights = np.array(bias_weights)

In [None]:
plt.plot((bias_weights))

In [None]:
state_occupancy = {}
for idx_session, session_id in enumerate(glm_hmm["session_wise"]["data"]):
    model = glm_hmm["session_wise"]["models"][session_id]   
    choices = glm_hmm["session_wise"]["data"][session_id]["choices"].values.reshape(-1, 1)
    input = np.array(glm_hmm["session_wise"]["data"][session_id][["normalized_stimulus","bias","previous_choice","previous_target"]])
    if glm_hmm["session_wise"]["data"][session_id]["mask"] is None:
        mask = None
    else:
        mask = glm_hmm["session_wise"]["data"][session_id]["mask"]
    mask = np.ones_like(choices, dtype=bool) if mask is None else mask
    
    posterior_probs = model.expected_states(data=choices, input=input, mask=np.array(mask).reshape(-1,1))[0]
    biased_idx = (posterior_probs[:, 1] > 0.5) & np.array(mask)
    unbiased_idx = (posterior_probs[:, 0] > 0.5) & np.array(mask)
    state_occupancy[session_id] = {
        "biased_state_trials": glm_hmm["session_wise"]["data"][session_id]["trial_num"][biased_idx],
        "unbiased_state_trials": glm_hmm["session_wise"]["data"][session_id]["trial_num"][unbiased_idx]}
    

## Ephys!!!

In [None]:
ephys["target_onset"][1].keys()

In [None]:
state_occupancy["210126_GP_JP"]["biased_state_trials"]

In [None]:
def extract_neuronal_data(alignment, data_type="convolved_spike_trains"):
    neuronal_data = {
        'biased_state': {},
        'unbiased_state': {},
    }
    
    for idx, neuron in enumerate(neuron_metadata.neuron_id):
        session_id = neuron_metadata.session_id[idx]
        biased_trials = np.array(state_occupancy["210126_GP_JP"]["biased_state_trials"])
        biased_idx = np.where(np.isin(np.array(ephys[alignment][neuron]["trial_number"]), biased_trials))[0]
        unbiased_trials = np.array(state_occupancy["210126_GP_JP"]["unbiased_state_trials"])
        unbiased_idx = np.where(np.isin(np.array(ephys[alignment][neuron]["trial_number"]), unbiased_trials))[0]
        
        neuronal_data['biased_state'][neuron] = np.array(ephys[alignment][neuron][data_type][biased_idx])
        neuronal_data['unbiased_state'][neuron] = np.array(ephys[alignment][neuron][data_type][unbiased_idx])
    
    return neuronal_data


target_onset = extract_neuronal_data("target_onset")

In [None]:
target_onset = extract_neuronal_data("target_onset")
stimulus_onset = extract_neuronal_data("stimulus_onset")
response_onset = extract_neuronal_data("response_onset")

In [None]:
neuron = 170

fig, ax = plt.subplots(1, 3, figsize=(15, 5))
ax[0].plot(np.nanmean(target_onset["biased_state"][neuron], axis=0), label="biased")
ax[0].plot(np.nanmean(target_onset["unbiased_state"][neuron], axis=0), label="unbiased")
ax[0].vlines(200, 0, 10, color="black", linestyle="--", linewidth=1)
ax[1].plot(np.nanmean(stimulus_onset["biased_state"][neuron], axis=0), label="biased")
ax[1].plot(np.nanmean(stimulus_onset["unbiased_state"][neuron], axis=0), label="unbiased")
ax[1].vlines(100, 0, 10, color="black", linestyle="--", linewidth=1)
ax[2].plot(np.nanmean(response_onset["biased_state"][neuron], axis=0), label="biased")
ax[2].plot(np.nanmean(response_onset["unbiased_state"][neuron], axis=0), label="unbiased")
ax[2].vlines(300, 0, 10, color="black", linestyle="--", linewidth=1)
ax[0].set_title("Target Onset")
ax[1].set_title("Stimulus Onset")
ax[2].set_title("Response Onset")
ax[2].legend(bbox_to_anchor=(1.05, 1), loc='upper left')