In [7]:
import os
import random
import pandas as pd
from pathlib import Path
import numpy as np
import torch

import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
from mpl_toolkits.axes_grid1 import make_axes_locatable

from scipy.ndimage import gaussian_filter1d
from scipy.stats import pointbiserialr
from sklearn.model_selection import StratifiedKFold, KFold

from sklearn.metrics import (
    accuracy_score, 
    roc_auc_score, 
    roc_curve, 
)

from side_info_decoding.utils import (
    set_seed, 
    load_data_from_pids, 
    sliding_window_over_trials
)
from side_info_decoding.reduced_rank import (
    Reduced_Rank_Model, 
    train_single_task, 
    model_eval,
    Multi_Task_Reduced_Rank_Model,
    train_multi_task
)

import pymc3 as pm
from hmmlearn import vhmm
from side_info_decoding.bmm_hmm import (
    BetaProcess, Constrained_BMM_HMM, posterior_inference
)
from side_info_decoding.viz import plot_multi_session_hmm_results, plot_bmm_hmm_results

seed = 666
set_seed(seed)

In [2]:
%matplotlib inline
plt.rc("figure", dpi=200)
SMALL_SIZE = 10
BIGGER_SIZE = 10
plt.rc('font', size=BIGGER_SIZE)
plt.rc('axes', titlesize=BIGGER_SIZE)
plt.rc('axes', labelsize=BIGGER_SIZE)
plt.rc('axes', linewidth=1)
plt.rc('xtick', labelsize=BIGGER_SIZE)
plt.rc('ytick', labelsize=BIGGER_SIZE)
plt.rc('legend', fontsize=SMALL_SIZE)
plt.rc('figure', titlesize=1)
plt.rcParams['xtick.major.size'] = 3
plt.rcParams['xtick.minor.size'] = 3
plt.rcParams['ytick.major.size'] = 3
plt.rcParams['ytick.minor.size'] = 3

#### RRR vs. multi-session RRR

In [3]:
pids = [
    "dab512bd-a02d-4c1f-8dbc-9155a163efc0",
    "febb430e-2d50-4f83-87a0-b5ffbb9a4943",
    "6fc4d73c-2071-43ec-a756-c6c6d8322c8b",
    "523f8301-4f56-4faf-ab33-a9ff11331118",
    "143dd7cf-6a47-47a1-906d-927ad7fe9117",
    '1a60a6e1-da99-4d4e-a734-39b1d4544fad',
    '0b8ea3ec-e75b-41a1-9442-64f5fbc11a5a',
    '1e176f17-d00f-49bb-87ff-26d237b525f1',
    '16799c7a-e395-435d-a4c4-a678007e1550',
     'ad714133-1e03-4d3a-8427-33fc483daf1a',
     '31f3e083-a324-4b88-b0a4-7788ec37b191',
]

In [None]:
X_dict, Y_dict = load_data_from_pids(
    pids,
    brain_region="all",
    behavior="choice",
    data_type="all_ks",
    t_before=0.5,
    t_after=1.5,
    n_t_bins=40
)

In [5]:
rank = [2, 5, 10, 15]

In [21]:
R = 5 # rank
d = 0 # half window size
n_epochs = 1000

skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=seed)

n_units = []
train_X_dict, test_X_dict, train_Y_dict, test_Y_dict = {}, {}, {}, {}
for pid in pids:
    X, Y = X_dict[pid], Y_dict[pid]
    K, C, T = X.shape
    n_units.append(C)
    X = sliding_window_over_trials(X, half_window_size=d)
    Y = sliding_window_over_trials(Y, half_window_size=d)
    X, Y = torch.tensor(X), torch.tensor(Y)
    train_X_dict.update({pid: [X[train] for train, _ in skf.split(X, Y)]})
    test_X_dict.update({pid: [X[test] for _, test in skf.split(X, Y)]})
    train_Y_dict.update({pid: [Y[train] for train, _ in skf.split(X, Y)]})
    test_Y_dict.update({pid: [Y[test] for _, test in skf.split(X, Y)]})

for fold_idx in range(5):
    
    print(f"start training fold {fold_idx+1} on {len(pids)} sessions ..")
    train_X_lst = [train_X_dict[pid][fold_idx] for pid in pids]
    test_X_lst = [test_X_dict[pid][fold_idx] for pid in pids]
    train_Y_lst = [train_Y_dict[pid][fold_idx] for pid in pids]
    test_Y_lst = [test_Y_dict[pid][fold_idx] for pid in pids]
    
    multi_task_rrm = Multi_Task_Reduced_Rank_Model(
        n_tasks=len(pids),
        n_units=n_units, 
        n_t_bins=T, 
        rank=R, 
        half_window_size=d
    )

    # training
    multi_task_rrm, train_losses = train_multi_task(
        model=multi_task_rrm,
        train_dataset=(train_X_lst, train_Y_lst),
        test_dataset=(test_X_lst, test_Y_lst),
        loss_function=torch.nn.BCELoss(),
        learning_rate=1e-2,
        weight_decay=1e-1,
        n_epochs=n_epochs,
    )

    # eval
    test_U, test_V, test_metrics, _ = model_eval(
        multi_task_rrm, 
        train_dataset=(train_X_lst, train_Y_lst),
        test_dataset=(test_X_lst, test_Y_lst),
        behavior="choice"
    )
    
    save_path = Path("./sensitivity_rank") / "choice" / f"rank_{R}"
    os.makedirs(save_path, exist_ok=True) 
    np.save(save_path/f"fold_{fold_idx}.npy", test_metrics)

start training fold 1 on 11 sessions ..
Epoch [100/1000], Loss: 21.68406470180929
Epoch [200/1000], Loss: 1.0371153225414398
Epoch [300/1000], Loss: 0.37614249174877545
Epoch [400/1000], Loss: 0.2145229218276752
Epoch [500/1000], Loss: 0.18490575001629478
Epoch [600/1000], Loss: 0.18095661123248744
Epoch [700/1000], Loss: 0.1811074051736613
Epoch [800/1000], Loss: 0.18071352847883032
Epoch [900/1000], Loss: 0.18237798755952406
Epoch [1000/1000], Loss: 0.18397621909760975
task 0 train accuracy: 0.993 auc: 1.000
task 0 test accuracy: 0.878 auc: 0.941
task 1 train accuracy: 0.993 auc: 1.000
task 1 test accuracy: 0.957 auc: 0.996
task 2 train accuracy: 0.959 auc: 0.988
task 2 test accuracy: 0.814 auc: 0.902
task 3 train accuracy: 0.994 auc: 0.999
task 3 test accuracy: 0.900 auc: 0.971
task 4 train accuracy: 0.926 auc: 0.977
task 4 test accuracy: 0.785 auc: 0.846
task 5 train accuracy: 0.992 auc: 1.000
task 5 test accuracy: 0.933 auc: 0.976
task 6 train accuracy: 0.993 auc: 1.000
task 6 tes