In [5]:
import os
import sys
import time
import random
import pandas as pd
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
from scipy.linalg import svd

import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
from mpl_toolkits.axes_grid1 import make_axes_locatable

from sklearn.model_selection import StratifiedKFold
from sklearn.linear_model import RidgeClassifierCV
from sklearn.metrics import accuracy_score, roc_auc_score

from side_info_decoding.utils import (
    set_seed, 
    load_data_from_pids, 
    sliding_window_over_trials
)
from side_info_decoding.reduced_rank import (
    Full_Rank_Model,
    Multi_Task_Reduced_Rank_Model, 
    train_multi_task, 
    model_eval
)

from one.api import ONE

seed = 666
set_seed(seed)

In [2]:
%matplotlib inline
plt.rc("figure", dpi=100)
SMALL_SIZE = 10
BIGGER_SIZE = 15
plt.rc('font', size=BIGGER_SIZE)
plt.rc('axes', titlesize=BIGGER_SIZE)
plt.rc('axes', labelsize=BIGGER_SIZE)
plt.rc('axes', linewidth=2)
plt.rc('xtick', labelsize=BIGGER_SIZE)
plt.rc('ytick', labelsize=BIGGER_SIZE)
plt.rc('legend', fontsize=SMALL_SIZE)
plt.rc('figure', titlesize=2)
plt.rcParams['xtick.major.size'] = 10
plt.rcParams['xtick.minor.size'] = 10
plt.rcParams['ytick.major.size'] = 10
plt.rcParams['ytick.minor.size'] = 10

In [3]:
# regions = [
#     "CP", "GPe", "LSr",
#     "PO", "DG", "LP",
#     "NI", "PB", "PAG",
#     "SCm", "SNr", "IRN",
#     "SPIV", "LGv", "LIN", "MDRN",
#     "PYR", "COPY", "VAL",
#     "ORBvl", "Alv", "FRP",
#     "STN", "APr"
# ]

regions = ["LP"]

In [4]:
for re_idx, roi in enumerate(regions):
    
    print(f"{re_idx+1}/{len(regions)} regions remaining ...")
    
    try:
        bwm_session_file = "/mnt/3TB/yizi/decode-paper-brain-wide-map/decoding/bwm_cache_sessions.pqt"

        bwm_df = pd.read_parquet(bwm_session_file)

        one = ONE(base_url="https://openalyx.internationalbrainlab.org", mode='remote')
        pids_per_region = one.search_insertions(atlas_acronym=[roi], query_type='remote')
        print(f"{roi}: {len(pids_per_region)} PIDs")

        print(f"{len(pids_per_region)} regions available ...")
        pids = list(pids_per_region)[:2]

        all_svd_V = {}

        X_dict, Y_dict = load_data_from_pids(
            pids,
            brain_region=roi.lower(),
            behavior="choice",
            data_type="all_ks",
            n_t_bins = 40,
            t_before=0.5,
            t_after=1.5,
            align_time_type='stimOn_times',
        )

        # load contrast
        _, contrast_dict = load_data_from_pids(
            pids,
            brain_region=roi.lower(),
            behavior="contrast",
            data_type="good_ks",
            n_t_bins = 40,
            t_before=0.5,
            t_after=1.5,
            align_time_type='stimOn_times',
        )

        loaded_pids = list(X_dict.keys())

        contrast_level_dict = {}
        filter_trials_dict = {}
        for pid in loaded_pids:
            contrast_dict[pid] = np.nan_to_num(contrast_dict[pid], 0)
            contrast_dict[pid].T[0] *= -1
            contrast_level_dict[pid] = contrast_dict[pid].sum(1)
            filter_trials_dict[pid] = {}
            for level in np.unique(contrast_level_dict[pid]):
                filter_trials_dict[pid].update({level: np.argwhere(contrast_level_dict[pid] == level).flatten()})
            for val in np.unique(Y_dict[pid]):
                if val == 0:
                    direc = "L"
                else:
                    direc = "R"
                filter_trials_dict[pid].update({direc: np.argwhere(Y_dict[pid] == val).flatten()})

        R = 2 # rank
        d = 0 # half window size
        n_epochs = 7000            
        n_folds = 5

        train_pids, n_units = [], []
        train_X_dict, train_Y_dict = {}, {}
        for pid in loaded_pids:
            X, Y = X_dict[pid], Y_dict[pid]
            K, C, T = X.shape
            if C < 10:
                continue
            train_pids.append(pid)
            n_units.append(C)
            X = sliding_window_over_trials(X, half_window_size=d)
            Y = sliding_window_over_trials(Y, half_window_size=d)
            X, Y = torch.tensor(X), torch.tensor(Y)
            train_X_dict.update({pid: X})
            train_Y_dict.update({pid: Y})

        start_time = time.time()

#         # extract V from trials with different contrasts
#         print("Start extracting V from trials with different contrasts ...")
#         train_X_lst = [train_X_dict[pid] for pid in train_pids]
#         train_Y_lst = [train_Y_dict[pid] for pid in train_pids]

#         multi_task_rrm = Multi_Task_Reduced_Rank_Model(
#             n_tasks=len(train_pids),
#             n_units=n_units, 
#             n_t_bins=T, 
#             rank=R, 
#             half_window_size=d,
#             init_Us = None,
#             init_V = None,
#         )

#         # training on all data
#         rrm, train_losses = train_multi_task(
#             model=multi_task_rrm,
#             train_dataset=(train_X_lst, train_Y_lst),
#             test_dataset=(train_X_lst, train_Y_lst),
#             loss_function=torch.nn.BCELoss(),
#             learning_rate=1e-3,
#             weight_decay=1e-1,
#             n_epochs=n_epochs,
#         )

#         init_Us = np.array([multi_task_rrm.Us[pid_idx].detach().numpy() for pid_idx in range(len(train_pids))])
#         init_V = multi_task_rrm.V.detach().numpy()
#         Us, Vs = {}, {}
#         for pid_idx, pid in enumerate(train_pids):
#             Us.update({pid: init_Us[pid_idx]})
#             Vs.update({pid: init_V})

#         svd_W, svd_U, svd_S, svd_VT, S_mul_VT, W_reduced = [], [], [], [], [], []
#         for pid in train_pids:
#             W = np.array(Us[pid]) @ np.array(Vs[pid]).squeeze()
#             U, S, VT = svd(W)
#             svd_W.append(W)
#             svd_U.append(U[:, :R])
#             svd_S.append(S[:R])
#             svd_VT.append(VT[:R, :])
#             if len(S) == 1:
#                 S_mul_VT.append(np.diag(S) @ VT[:1, :])
#             else:
#                 S_mul_VT.append(np.diag(S[:R]) @ VT[:R, :])

#         all_svd_V.update({"all": S_mul_VT})

#         plt.figure()
#         plt.plot(np.abs(np.array(all_svd_V["all"])).mean(0).T[:,0])
#         plt.show()


#         # trials with diff choices
#         for direc in ["L", "R"]:
#             test_X_lst = [train_X_dict[pid][filter_trials_dict[pid][direc]] for pid in train_pids]
#             test_Y_lst = [train_Y_dict[pid][filter_trials_dict[pid][direc]] for pid in train_pids]
          
#             proj_lst = []
#             for pid_idx, pid in enumerate(train_pids):
#                 proj = (test_X_lst[pid_idx].squeeze().numpy().transpose(0,-1,1) @ svd_U[pid_idx]) * S_mul_VT[pid_idx].T
#                 proj_lst.append(proj.mean(0))
#             all_svd_V.update({direc: np.array(proj_lst)})

#         plt.figure()
#         for direc in ["L", "R"]:
#             plt.plot(np.abs(np.array(all_svd_V[direc])).mean(0)[:,0], label=direc)
#         plt.legend()
#         plt.show()
        
#         # trials with diff contrasts
#         for level in [-1, -.25, -.125, -.0625, .0625, .125, .25, 1.]:
#         # for level in [.0625, .125, .25, 1.]:
#             try:
#                 test_X_lst = [train_X_dict[pid][filter_trials_dict[pid][level]] for pid in train_pids]
#                 test_Y_lst = [train_Y_dict[pid][filter_trials_dict[pid][level]] for pid in train_pids]
#             except:
#                 continue

#             multi_task_rrm = Multi_Task_Reduced_Rank_Model(
#                 n_tasks=len(train_pids),
#                 n_units=n_units, 
#                 n_t_bins=T, 
#                 rank=R, 
#                 half_window_size=d,
#                 # init_Us = init_Us,
#                 # init_V = init_V
#             )

#             rrm, train_losses = train_multi_task(
#                 model=multi_task_rrm,
#                 train_dataset=(test_X_lst, test_Y_lst),
#                 test_dataset=(test_X_lst, test_Y_lst),
#                 loss_function=torch.nn.BCELoss(),
#                 learning_rate=1e-3,
#                 weight_decay=1e-1,
#                 n_epochs=n_epochs,
#             )

#             test_U, test_V, _, _ = model_eval(
#                 multi_task_rrm, 
#                 train_dataset=(test_X_lst, test_Y_lst),
#                 test_dataset=(test_X_lst, test_Y_lst),
#                 behavior="choice"
#             )

#             Us, Vs = {}, {}
#             for pid_idx, pid in enumerate(train_pids):
#                 Us.update({pid: test_U[pid_idx]})
#                 Vs.update({pid: test_V})

#             svd_W, svd_U, svd_S, svd_VT, S_mul_VT, W_reduced = [], [], [], [], [], []
#             for pid in train_pids:
#                 W = np.array(Us[pid]) @ np.array(Vs[pid]).squeeze()
#                 U, S, VT = svd(W)
#                 svd_W.append(W)
#                 svd_U.append(U[:, :R])
#                 svd_S.append(S[:R])
#                 svd_VT.append(VT[:R, :])
#                 if len(S) == 1:
#                     S_mul_VT.append(np.diag(S) @ VT[:1, :])
#                 else:
#                     S_mul_VT.append(np.diag(S[:R]) @ VT[:R, :])

#             all_svd_V.update({level: S_mul_VT})

#         # for level in [-1, -.25, -.125, -.0625, .0625, .125, .25, 1.]:
#         for level in [.0625, .125, .25, 1.]:
#             plt.figure()
#             plt.plot(np.array(all_svd_V[-level]).mean(0).T[:,0], label=level)
#             plt.plot(np.array(all_svd_V[level]).mean(0).T[:,0], label=level)
#             plt.legend()
#             plt.show()

#         np.save(f"../biorxiv_plots/results/{roi}_timescale.npy", all_svd_V)

        skf = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=seed)

        train_pids, n_units = [], []
        train_X_dict, test_X_dict, train_Y_dict, test_Y_dict = {}, {}, {}, {}
        for pid in loaded_pids:
            X, Y = X_dict[pid], Y_dict[pid]
            K, C, T = X.shape
            if C < 10:
                continue
            train_pids.append(pid)
            n_units.append(C)
            X = sliding_window_over_trials(X, half_window_size=d)
            Y = sliding_window_over_trials(Y, half_window_size=d)
            X, Y = torch.tensor(X), torch.tensor(Y)
            train_X_dict.update({pid: [X[train] for train, _ in skf.split(X, Y)]})
            test_X_dict.update({pid: [X[test] for _, test in skf.split(X, Y)]})
            train_Y_dict.update({pid: [Y[train] for train, _ in skf.split(X, Y)]})
            test_Y_dict.update({pid: [Y[test] for _, test in skf.split(X, Y)]})

        skf = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=seed)

        metrics_per_fold = []
        for fold_idx in range(n_folds):

            print(f"{fold_idx+1}/{n_folds} folds remaining ...")
            train_X_lst = [train_X_dict[pid][fold_idx] for pid in train_pids]
            test_X_lst = [test_X_dict[pid][fold_idx] for pid in train_pids]
            train_Y_lst = [train_Y_dict[pid][fold_idx] for pid in train_pids]
            test_Y_lst = [test_Y_dict[pid][fold_idx] for pid in train_pids]

            multi_task_rrm = Multi_Task_Reduced_Rank_Model(
                n_tasks=len(train_pids),
                n_units=n_units, 
                n_t_bins=T, 
                rank=R, 
                half_window_size=d
            )

            rrm, train_losses = train_multi_task(
                model=multi_task_rrm,
                train_dataset=(train_X_lst, train_Y_lst),
                test_dataset=(test_X_lst, test_Y_lst),
                loss_function=torch.nn.BCELoss(),
                learning_rate=1e-3,
                weight_decay=1e-1,
                n_epochs=n_epochs,
            )

            _, _, rrm_metrics, _ = model_eval(
                multi_task_rrm, 
                train_dataset=(train_X_lst, train_Y_lst),
                test_dataset=(test_X_lst, test_Y_lst),
                behavior="choice"
            )
            
            frm_metrics = []
            for pid_idx, pid in enumerate(train_pids):
                clf = RidgeClassifierCV(alphas=[1e-3, 1e-2, 1e-1, 1, 1e2, 1e3]).fit(
                    train_X_lst[pid_idx].squeeze().numpy().reshape((len(train_X_lst[pid_idx]), -1)), 
                    train_Y_lst[pid_idx].numpy()
                )
                test_pred = clf.predict(
                    test_X_lst[pid_idx].squeeze().numpy().reshape((len(test_X_lst[pid_idx]), -1))
                )
                frm_metrics.append(
                    [accuracy_score(test_Y_lst[pid_idx].numpy(), test_pred), 
                     roc_auc_score(test_Y_lst[pid_idx].numpy(), test_pred)]
                )
            
            metrics_per_fold.append(np.c_[rrm_metrics, frm_metrics])

        metrics_dict = {}
        for pid_idx, pid in enumerate(train_pids):
            metrics_dict.update({pid: np.mean(metrics_per_fold, 0)[pid_idx]})
        np.save(f"../biorxiv_plots/results/{roi}_metrics.npy", metrics_dict)

        end_time = time.time()
        print(f"time spent: {end_time - start_time: .3f} seconds")
        
    except Exception as e: 
        print(e)
        continue

1/1 regions remaining ...
LP: 111 PIDs
111 regions available ...
pulling data from ibl database ..
eid: ebce500b-c530-47de-8cb1-963c552703ea
pid: 8c732bf2-639d-496c-bf82-464bc9c2d54b
number of trials found: 470
found 470 trials from 13.74 to 5761.52 sec.
found 139 Kilosort units in region lp


Compute spike count: 100%|███████████████████| 470/470 [00:01<00:00, 353.23it/s]


pulling data from ibl database ..
eid: 15b69921-d471-4ded-8814-2adad954bcd8
pid: 7a620688-66cb-44d3-b79b-ccac1c8ba23e
number of trials found: 715
found 715 trials from 28.03 to 3547.82 sec.
found 47 Kilosort units in region lp


Compute spike count: 100%|██████████████████| 715/715 [00:00<00:00, 2825.40it/s]


pulling data from ibl database ..
eid: ebce500b-c530-47de-8cb1-963c552703ea
pid: 8c732bf2-639d-496c-bf82-464bc9c2d54b
number of trials found: 470
found 470 trials from 13.74 to 5761.52 sec.
found 34 good units in region lp


Compute spike count: 100%|██████████████████| 470/470 [00:00<00:00, 1674.71it/s]


pulling data from ibl database ..
eid: 15b69921-d471-4ded-8814-2adad954bcd8
pid: 7a620688-66cb-44d3-b79b-ccac1c8ba23e
number of trials found: 715
found 715 trials from 28.03 to 3547.82 sec.
found 0 good units in region lp


Compute spike count: 100%|█████████████████| 715/715 [00:00<00:00, 93939.59it/s]

1/5 folds remaining ...





Epoch [700/7000], Loss: 9.140467170355361
Epoch [1400/7000], Loss: 1.1869052034502268
Epoch [2100/7000], Loss: 0.49463793638201414
Epoch [2800/7000], Loss: 0.4287902326694561
Epoch [3500/7000], Loss: 0.4120512491609508
Epoch [4200/7000], Loss: 0.3971540389101153
Epoch [4900/7000], Loss: 0.3870029015131694
Epoch [5600/7000], Loss: 0.38272623933359967
Epoch [6300/7000], Loss: 0.38121903186876477
Epoch [7000/7000], Loss: 0.38008190209686565
task 0 train accuracy: 0.947 auc: 0.992
task 0 test accuracy: 0.670 auc: 0.706
task 1 train accuracy: 0.752 auc: 0.855
task 1 test accuracy: 0.706 auc: 0.718
Epoch [700/7000], Loss: 9.86449275978176
Epoch [1400/7000], Loss: 3.173992456499351
Epoch [2100/7000], Loss: 0.6266558590645952
Epoch [2800/7000], Loss: 0.2937889797970517
Epoch [3500/7000], Loss: 0.19406069972534917
Epoch [4200/7000], Loss: 0.15836737360349404
Epoch [4900/7000], Loss: 0.1499203515810612
Epoch [5600/7000], Loss: 0.14876895269221502
Epoch [6300/7000], Loss: 0.1486678034547296
Epoch