In [None]:
import numpy as np
import numpy.random as npr
import pandas as pd
import matplotlib.pyplot as plt
from scipy import stats

from notebooks.imports import *
from config import dir_config, main_config
from src.utils import pmf_utils, plot_utils
import pickle


In [None]:
compiled_dir = Path(dir_config.data.compiled)
processed_dir = Path(dir_config.data.processed)

In [None]:
with open(Path(processed_dir, 'glm_hmm_all_trials.pkl'), 'rb') as f:
    glm_hmm = pickle.load(f)

session_metadata = pd.read_csv(Path(compiled_dir,'sessions_metadata.csv'))

### Best state

In [None]:
for n_state in range(glm_hmm["session_wise"]['test_ll'].shape[1]):
    print(f"Test likelihood for {n_state+2} states: {np.mean(glm_hmm['session_wise']['test_ll'][:,n_state,:])}")
best_state = np.argmax(np.mean(glm_hmm["session_wise"]['test_ll'],axis=(0,2))) 
print(f"Best state is {best_state+2}")

### Best fold for each session

In [None]:
best_fold_session_wise = []
for session in range(glm_hmm["session_wise"]['test_ll'].shape[0]):
    best_fold_session_wise.append(np.argmax(glm_hmm["session_wise"]["test_ll"][session,best_state,:]))

### model verification 


#### utils

In [None]:

def psychometric_fit(model, choices, prob_choice_hat, stimulus, ax, color, label, n_sample = 10):
    
    data = {
        "signed_coherence": np.array(stimulus)*100,
        "choice": choices
    }
    x_data, y_data, _, x_model, y_model = pmf_utils.get_psychometric_data(data)
    
    x_model_hat, y_model_hat = np.full((n_sample,len(x_model)),np.nan), np.full((n_sample,len(y_model)),np.nan)

    x_model_hat, y_model_hat = np.full((n_sample,len(x_model)),np.nan), np.full((n_sample,len(y_model)),np.nan)  
    for idx_sample in range(n_sample):
        data_fitted = {
            "signed_coherence": np.array(stimulus)*100,
            "choice": npr.binomial(1, prob_choice_hat)
        }
        x_data, y_data_hat, _, x_model_hat[idx_sample,:], y_model_hat[idx_sample,:] = pmf_utils.get_psychometric_data(data_fitted)

    ax.plot(x_data, y_data,'o', color=color)
    ax.plot(x_model, y_model, color=color, label=label)
    ax.plot(np.mean(x_model_hat, axis=0), np.mean(y_model_hat, axis=0), color = color, linestyle='--')
    ax.fill_between(np.mean(x_model_hat, axis=0), 
                    np.mean(y_model_hat, axis=0) - np.std(y_model_hat, axis=0), 
                    np.mean(y_model_hat, axis=0) + np.std(y_model_hat, axis=0),
                    color=color, alpha=0.3)

    ax.set_xlim(min(x_data),max(x_data))
    ax.set_xlabel('Coherence')
    ax.set_ylabel('choices toRF')
    ax.set_title('Psychometric fits',fontsize=15)
    ax.legend()
    # plt.show()


In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec  # Import for custom grid layout

def plot_model_fits(model, choices, input, stimulus, mask, n_states, session_name, task_switch):
    transition_matrix = model.transitions.params
    transition_matrix = np.exp(transition_matrix)[0]
    weights = -model.observations.params
    posterior_probs = model.expected_states(data=choices, input=input, mask=np.array(mask).reshape(-1,1))[0]

    fig = plt.figure(figsize=(14, 4))  # Wider figure for better spacing
    gs = gridspec.GridSpec(1, 4, width_ratios=[1, 1, 1, 2])  # Merge last two subplots

    cols = ['#ff7f00', '#4daf4a', '#377eb8']  # Add more colors if needed for higher n_states
    plt.suptitle(session_name)
    # ----  First Subplot: Psychometric Curves ----
    ax1 = plt.subplot(gs[0])

    weighted_sum = np.sum(weights * input[None, :, :], axis=-1).T
    sigmoid_output = 1 / (1 + np.exp(-weighted_sum)) 
    prob_choice_hat = np.sum(sigmoid_output * posterior_probs, axis=1, keepdims=True)
    
    psychometric_fit(model, choices[mask][:task_switch,:], prob_choice_hat[mask][:task_switch,:], stimulus[mask][:task_switch], ax1, '#377eb8', label="Equal")
    psychometric_fit(model, choices[mask][task_switch:,:], prob_choice_hat[mask][task_switch:,:], stimulus[mask][task_switch:], ax1, '#974810', label="Unequal")
    
    # ---- Second Subplot: GLM Weights ----
    ax2 = plt.subplot(gs[1])
    for k in range(n_states):
        ax2.plot(np.arange(input.shape[1]), weights[k][0], marker='o',
                color=cols[k], linestyle='-', lw=1.5, label=f"State {k+1}")

    ax2.tick_params(axis='y', labelsize=10)
    ax2.set_ylabel("GLM weight", fontsize=15)
    ax2.set_xlabel("covariate", fontsize=15)
    ax2.set_xticks(range(input.shape[1]))
    ax2.set_xticklabels(['stimulus', 'bias', 'previous choice', 'previous target'], fontsize=12, rotation=15)
    ax2.axhline(y=0, color="k", alpha=0.5, ls="--")
    ax2.legend()
    ax2.set_title("GLM weights", fontsize=15)

    # ---- Third Subplot: Transition Matrix ----
    ax3 = plt.subplot(gs[2])
    im = ax3.imshow(transition_matrix, vmin=-0.8, vmax=1, cmap='bone')
    for i in range(transition_matrix.shape[0]):
        for j in range(transition_matrix.shape[1]):
            ax3.text(j, i, str(np.around(transition_matrix[i, j], decimals=2)), 
                    ha="center", va="center", color="k", fontsize=12)

    ax3.set_xlim(-0.5, n_states - 0.5)
    ax3.set_ylim(n_states - 0.5, -0.5)
    ax3.set_xticks(range(n_states))
    ax3.set_yticks(range(n_states))
    ax3.set_xlabel("state t+1", fontsize=15)
    ax3.set_ylabel("state t", fontsize=15)
    ax3.set_title("Generative transition matrix", fontsize=15)

    # ---- Fourth (Merged) Subplot: Posterior Probabilities ----
    ax4 = plt.subplot(gs[3])  # Merged across two columns
    for k in range(n_states):
        ax4.plot(posterior_probs[mask, k], label=f"State {k + 1}", lw=2, color=cols[k])

    ax4.set_ylim(-0.01, 1.01)
    ax4.set_yticks([0, 0.5, 1])
    ax4.tick_params(axis='y', labelsize=10)
    ax4.set_xlabel("trial #", fontsize=15)
    ax4.set_ylabel("p(state)", fontsize=15)
    ax4.axvline(x=task_switch, color="k", alpha=0.5, ls="--")
    ax4.legend()
    ax4.set_title("Posterior Probabilities", fontsize=15)


    plt.tight_layout()
    plt.show()


#### GLM weights, transition matrix, p(state)

### session-wise

In [None]:
# best_state=1
for idx_session, session in enumerate(session_metadata['session_id']):
    
    model = glm_hmm["session_wise"]['models'][idx_session][best_state+2][best_fold_session_wise[idx_session]]
    choices = glm_hmm["session_wise"]["data"][session]["choices"].values.reshape(-1, 1)
    input = np.array(glm_hmm["session_wise"]["data"][session][["normalized_stimulus","bias","previous_choice","previous_target"]])


    stimulus = glm_hmm["session_wise"]["data"][session]["stimulus"]

    
    if glm_hmm["session_wise"]["data"][session]["mask"] is None:
        mask = None
    else:
        mask = glm_hmm["session_wise"]["data"][session]["mask"]
    mask = np.ones_like(choices, dtype=bool) if mask is None else mask

    prob_toRF = glm_hmm["session_wise"]["data"][session]["prob_toRF"]
    prob_toRF = prob_toRF[mask]
    task_switch = np.where((prob_toRF != 50) & ~np.isnan(prob_toRF))[0][0]
    plot_model_fits(model, choices, input, stimulus, mask, best_state+2, session, task_switch)

    


### data recovery (train and test)