In [1]:
import os
import sys
from tqdm import tqdm
import numpy as np
import pandas as pd

from one.api import ONE
from brainbox.io.one import SessionLoader
from iblatlas.regions import BrainRegions

from sklearn import linear_model as sklm
from sklearn.metrics import accuracy_score, balanced_accuracy_score, r2_score
from sklearn.model_selection import KFold, train_test_split
from behavior_models.utils import format_data as format_data_mut
from behavior_models.utils import format_input as format_input_mut

from brainwidemap.bwm_loading import load_good_units, load_trials_and_mask, merge_probes
from brainwidemap.decoding.functions.process_targets import load_behavior
from brainwidemap.decoding.settings_for_BWM_figure.settings_choice import params
from brainwidemap.decoding.settings_for_BWM_figure.settings_choice import RESULTS_DIR

from brainwidemap.decoding.functions.balancedweightings import balanced_weighting
from brainwidemap.decoding.functions.process_inputs import (
    build_predictor_matrix,
    select_ephys_regions,
    preprocess_ephys
)
from brainwidemap.decoding.functions.process_targets import (
    compute_beh_target,
    compute_target_mask,
    transform_data_for_decoding,
    logisticreg_criteria,
    get_target_data_per_trial_wrapper,
    check_bhv_fit_exists,
    optimal_Bayesian
)
from brainwidemap.decoding.functions.utils import save_region_results, get_save_path
from brainwidemap.decoding.functions.nulldistributions import generate_null_distribution_session
from brainwidemap.decoding.functions.decoding import decode_cv



### load data

In [2]:
params['behfit_path'] = RESULTS_DIR.joinpath('decoding', 'results', 'behavioral')
params['behfit_path'].mkdir(parents=True, exist_ok=True)
params['neuralfit_path'] = RESULTS_DIR.joinpath('decoding', 'results', 'neural')
params['neuralfit_path'].mkdir(parents=True, exist_ok=True)
params['add_to_saving_path'] = (f"_binsize={1000 * params['binsize']}_lags={params['n_bins_lag']}_"
                                f"mergedProbes_{params['merged_probes']}")
imposter_file = RESULTS_DIR.joinpath('decoding', f"imposterSessions_{params['target']}.pqt")
bwm_session_file = RESULTS_DIR.joinpath('decoding', 'bwm_cache_sessions.pqt')

In [3]:
# params["align_time"] = "stimOn_times"
print(params["align_time"])

firstMovement_times


In [4]:
# params["time_window"] = (0., 1.5)
print(params["time_window"])

(-0.1, 0.0)


In [5]:
params["binsize"] = 0.02 #0.1
print(params["binsize"])
params['n_pseudo'] = 0

0.02


In [6]:
one = ONE(base_url="https://openalyx.internationalbrainlab.org", mode='remote')
bwm_df = pd.read_parquet(bwm_session_file)

In [7]:
idx = 1

if params['merged_probes']:
    eid = bwm_df['eid'].unique()[idx]
    tmp_df = bwm_df.set_index(['eid', 'subject']).xs(eid, level='eid')
    subject = tmp_df.index[0]
    pids = tmp_df['pid'].to_list()  # Select all probes of this session
    probe_names = tmp_df['probe_name'].to_list()
    print(f"Running merged probes for session eid: {eid}")
else:
    eid = bwm_df.iloc[idx]['eid']
    subject = bwm_df.iloc[idx]['subject']
    pid = bwm_df.iloc[idx]['pid']
    probe_name = bwm_df.iloc[idx]['probe_name']
    print(f"Running probe pid: {pid}")
    
sess_loader = SessionLoader(one, eid)
sess_loader.load_trials()

trials_df, trials_mask = load_trials_and_mask(
    one=one, eid=eid, sess_loader=sess_loader, min_rt=params['min_rt'], max_rt=params['max_rt'],
    min_trial_len=params['min_len'], max_trial_len=params['max_len'],
    exclude_nochoice=True, exclude_unbiased=params['exclude_unbiased_trials'])
_, trials_mask_without_minrt = load_trials_and_mask(
    one=one, eid=eid, sess_loader=sess_loader, min_rt=None, max_rt=params['max_rt'],
    min_trial_len=params['min_len'], max_trial_len=params['max_len'],
    exclude_nochoice=True, exclude_unbiased=params['exclude_unbiased_trials'])
_, trials_mask_without_maxrt = load_trials_and_mask(
    one=one, eid=eid, sess_loader=sess_loader, min_rt=params['min_rt'], max_rt=None,
    min_trial_len=params['min_len'], max_trial_len=params['max_len'],
    exclude_nochoice=True, exclude_unbiased=params['exclude_unbiased_trials'])
_, trials_mask_withonly_nochoice = load_trials_and_mask(
    one=one, eid=eid, sess_loader=sess_loader, min_rt=None, max_rt=None,
    min_trial_len=None, max_trial_len=None,
    exclude_nochoice=True, exclude_unbiased=False)

params['trials_mask_diagnostics'] = [trials_mask,
                                     trials_mask_without_minrt,
                                     trials_mask_without_maxrt,
                                     trials_mask_withonly_nochoice]

if params['target'] in ['wheel-vel', 'wheel-speed', 'l-whisker-me', 'r-whisker-me']:
    # load target data
    dlc_dict = load_behavior(params['target'], sess_loader)
    # load imposter sessions
    params['imposter_df'] = pd.read_parquet(imposter_file) if params['n_pseudo'] > 0 else None
else:
    dlc_dict = None
    params['imposter_df'] = None

if params['merged_probes']:
    clusters_list = []
    spikes_list = []
    for pid, probe_name in zip(pids, probe_names):
        tmp_spikes, tmp_clusters = load_good_units(one, pid, eid=eid, pname=probe_name)
        tmp_clusters['pid'] = pid
        spikes_list.append(tmp_spikes)
        clusters_list.append(tmp_clusters)
    spikes, clusters = merge_probes(spikes_list, clusters_list)
else:
    spikes, clusters = load_good_units(one, pid, eid=eid, pname=probe_name)

neural_dict = {
    'spk_times': spikes['times'],
    'spk_clu': spikes['clusters'],
    'clu_regions': clusters['acronym'],
    'clu_qc': {k: np.asarray(v) for k, v in clusters.to_dict('list').items()},
    'clu_df': clusters
}

metadata = {
    'subject': subject,
    'eid': eid,
    'probe_name': probe_name
}

In [8]:
kwargs = params
pseudo_id = -1
pseudo_ids = -np.ones(1).astype('int64')
kwargs['n_runs'] = 1
kwargs['n_bins_lag'] = 0

In [9]:
print(f'Working on eid: {metadata["eid"]}')
filenames = []  # this will contain paths to saved decoding results for this eid

if kwargs['use_imposter_session'] and not kwargs['stitching_for_imposter_session']:
    trials_df = trials_df[:int(kwargs['max_number_trials_when_no_stitching_for_imposter_session'])]

if 0 in pseudo_ids:
    raise ValueError(
        'pseudo id can be -1 (actual session) or strictly greater than 0 (pseudo session)')

if not np.all(np.sort(pseudo_ids) == pseudo_ids):
    raise ValueError('pseudo_ids must be sorted')

if kwargs['model'] == optimal_Bayesian and np.any(trials_df.probabilityLeft.values[:90] != 0.5):
    raise ValueError(
        'The optimal Bayesian model assumes 90 unbiased trials at the beginning of the '
        'session, which is not the case here.')
    
# check if is trained
eids_train = (
    [metadata['eid']] if 'eids_train' not in metadata.keys() else metadata['eids_train'])
if 'eids_train' not in metadata.keys():
    metadata['eids_train'] = eids_train
elif metadata['eids_train'] != eids_train:
    raise ValueError(
        'eids_train are not supported yet. If you do not understand this error, '
        'just take out the eids_train key in the metadata to solve it')
    
if isinstance(kwargs['model'], str):
    import pickle
    from braindelphi.params import INTER_INDIVIDUAL_PATH
    inter_individual = pickle.load(open(INTER_INDIVIDUAL_PATH.joinpath(kwargs['model']), 'rb'))
    if metadata['eid'] not in inter_individual.keys():
        logging.exception('no inter individual model found')
        print(filenames)
    inter_indiv_model_specifications = inter_individual[metadata['eid']]
    print('winning interindividual model is %s' % inter_indiv_model_specifications['model_name'])
    if inter_indiv_model_specifications['model_name'] not in kwargs['modeldispatcher'].values():
        logging.exception('winning inter individual model is LeftKernel or RightKernel')
        print(filenames)
    kwargs['model'] = {v: k for k, v in kwargs['modeldispatcher'].items()}[inter_indiv_model_specifications['model_name']]
    kwargs['model_parameters'] = inter_indiv_model_specifications['model_parameters']
else:
    kwargs['model_parameters'] = None
    # train model if not trained already
    if kwargs['model'] != optimal_Bayesian and kwargs['model'] is not None:
        side, stim, act, _ = format_data_mut(trials_df)
        stimuli, actions, stim_side = format_input_mut([stim], [act], [side])
        behmodel = kwargs['model'](
            kwargs['behfit_path'], np.array(metadata['eids_train']), metadata['subject'],
            actions, stimuli, trials_df, stim_side, single_zeta=True)
        istrained, _ = check_bhv_fit_exists(
            metadata['subject'], kwargs['model'], metadata['eids_train'],
            kwargs['behfit_path'], modeldispatcher=kwargs['modeldispatcher'], single_zeta=True)
        if not istrained:
            behmodel.load_or_train(remove_old=False)

if kwargs['balanced_weight'] and kwargs['balanced_continuous_target']:
    raise NotImplementedError("see tag `decoding_biasCWnull` for a previous implementation.")
else:
    target_distribution = None
    
# get target values
if kwargs['target'] in ['pLeft', 'signcont', 'strengthcont', 'choice', 'feedback']:
    target_vals_list, target_vals_to_mask = compute_beh_target(
        trials_df, metadata, return_raw=True, **kwargs)
    target_mask = compute_target_mask(
        target_vals_to_mask, kwargs['exclude_trials_within_values'])

else:
    if dlc_dict is None or dlc_dict['times'] is None or dlc_dict['values'] is None:
        raise ValueError('dlc_dict does not contain any data')
    _, target_vals_list, target_mask = get_target_data_per_trial_wrapper(
        target_times=dlc_dict['times'],
        target_vals=dlc_dict['values'],
        trials_df=trials_df,
        align_event=kwargs['align_time'],
        align_interval=kwargs['time_window'],
        binsize=kwargs['binsize'])
    
mask = trials_mask & target_mask

if sum(mask) <= kwargs['min_behav_trials']:
    msg = 'session contains %i trials, below the threshold of %i' % (
        sum(mask), kwargs['min_behav_trials'])
    logging.exception(msg)
    print(filenames)

Working on eid: 56956777-dca5-468c-87cb-78150432cc57
[36m2023-11-06 21:49:24.573 INFO     [base_models.py:  289]   results found and loaded from /mnt/3TB/yizi/decode-paper-brain-wide-map/decoding/results/behavioral/NYU-11/model_actKernel_single_zeta/train_56956777.pkl[0m


In [10]:
# select brain regions from beryl atlas to loop over
brainreg = BrainRegions()
beryl_reg = brainreg.acronym2acronym(neural_dict['clu_regions'], mapping='Beryl')
regions = (
    [[k] for k in np.unique(beryl_reg)] if kwargs['single_region'] else [np.unique(beryl_reg)])

In [11]:
region_results = {}
for region in tqdm(regions, desc='Region: ', leave=False):

    # pull spikes from this region out of the neural data
    reg_clu_ids = select_ephys_regions(neural_dict, beryl_reg, region, **kwargs)

    # skip region if there are not enough units
    n_units = len(reg_clu_ids)
    if n_units < kwargs['min_units']:
        continue

    # bin spikes from this region for each trial
    msub_binned, cl_inds_used = preprocess_ephys(reg_clu_ids, neural_dict, trials_df, **kwargs)
    cl_uuids_used = list(neural_dict['clu_df'].iloc[cl_inds_used]['uuids'])

    # make design matrix
    bins_per_trial = msub_binned[0].shape[0]
    Xs = (
        msub_binned if bins_per_trial == 1
        else [build_predictor_matrix(s, kwargs['n_bins_lag']) for s in msub_binned]
    )
    
    control_mask = mask
    save_predictions = kwargs['save_predictions']

    # original session
    ys_wmask = [target_vals_list[m] for m in np.squeeze(np.where(mask))]
    Xs_wmask = [Xs[m] for m in np.squeeze(np.where(mask))]
    
    fit_results = []
    for i_run in range(kwargs['n_runs']):

        rng_seed = i_run

        fit_result = decode_cv(
            ys=ys_wmask,
            Xs=Xs_wmask,
            estimator=kwargs['estimator'],
            use_openturns=kwargs['use_openturns'],
            target_distribution=target_distribution,
            balanced_continuous_target=kwargs['balanced_continuous_target'],
            estimator_kwargs=kwargs['estimator_kwargs'],
            hyperparam_grid=kwargs['hyperparam_grid'],
            save_binned=kwargs['save_binned'] if pseudo_id == -1 else False,
            save_predictions=save_predictions,
            shuffle=kwargs['shuffle'],
            balanced_weight=kwargs['balanced_weight'],
            rng_seed=rng_seed,
        )
        fit_result['mask'] = mask
        fit_result['mask_trials_and_targets'] = [trials_mask, target_mask]
        fit_result['mask_diagnostics'] = kwargs['trials_mask_diagnostics']
        fit_result['df'] = trials_df if pseudo_id == -1 else controlsess_df
        fit_result['pseudo_id'] = pseudo_id
        fit_result['run_id'] = i_run
        fit_result['cluster_uuids'] = cl_uuids_used
        fit_results.append(fit_result)
        
    region_results.update({region[0]: fit_results[0]['acc_test_full']})
        
print(f'Finished eid: {metadata["eid"]}')

                                                                                

Finished eid: 56956777-dca5-468c-87cb-78150432cc57




In [12]:
# no time binning
region_results

{'BMA': 0.5836734693877551,
 'CA1': 0.4489795918367347,
 'CA3': 0.5673469387755102,
 'CEA': 0.5755102040816327,
 'COAp': 0.5306122448979592,
 'GPe': 0.6612244897959184,
 'IA': 0.5877551020408164,
 'LGd': 0.7346938775510204,
 'PA': 0.5673469387755102,
 'SSp-bfd': 0.5959183673469388,
 'SSp-tr': 0.636734693877551,
 'VISa': 0.5795918367346938,
 'VPM': 0.563265306122449,
 'root': 0.6571428571428571}

In [12]:
# 0.02 time bin
region_results

{'BMA': 0.5673469387755102,
 'CA1': 0.5591836734693878,
 'CA3': 0.5469387755102041,
 'CEA': 0.5224489795918368,
 'COAp': 0.5142857142857142,
 'GPe': 0.6204081632653061,
 'IA': 0.5551020408163265,
 'LGd': 0.6693877551020408,
 'PA': 0.5755102040816327,
 'SSp-bfd': 0.6285714285714286,
 'SSp-tr': 0.5959183673469388,
 'VISa': 0.5673469387755102,
 'VPM': 0.5877551020408164,
 'root': 0.5836734693877551}

### reduced rank model

In [61]:
Xs = np.array(Xs_wmask).transpose(0,-1,1)
ys = np.array(ys_wmask).astype(float).reshape(-1,1)
_, n_units, n_t_bins = Xs.shape

In [24]:
import torch
from torch import nn
import pytorch_lightning as pl
from torch.utils.data import Dataset, DataLoader, random_split
from torch.nn import functional as F

In [141]:
class ReducedRank(nn.Module):

    def __init__(self, n_units, n_t_bins, rank):
        super().__init__()
        self.U = nn.Parameter(torch.randn(n_units, rank))
        self.V = nn.Parameter(torch.randn(rank, n_t_bins))
        self.b = nn.Parameter(torch.randn((1,)))
        self.double()

    def forward(self, x):
        batch_size = x.shape[0]
        Beta = torch.einsum("cr,rt->ct", self.U, self.V)
        out = torch.einsum("ct,kct->k", Beta, x)
        out += self.b.tile((batch_size,))
        out = out.reshape(-1,1)
        return out, self.U, self.V

class LightningReducedRank(pl.LightningModule):

    def __init__(self, backbone):
        super().__init__()
        self.backbone = backbone    

    def cross_entropy_loss(self, preds, labels):
        return F.binary_cross_entropy_with_logits(preds, labels)

    def training_step(self, batch, batch_idx):
        x, y = batch
        out, U, V = self.backbone(x)
        loss = self.cross_entropy_loss(out, y)
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        out, U, V = self.backbone(x)
        loss = self.cross_entropy_loss(out, y)
        self.log('val_loss', loss)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3, weight_decay=1e-1)
        return optimizer

In [142]:
class NeuralDataset(Dataset):
    def __init__(self, neural_data, behavioral_data):
        self.neural_data = neural_data
        self.behavioral_data = behavioral_data
        
    def __len__(self):
        return len(self.neural_data)
    
    def __getitem__(self, idx):
        X, y = self.neural_data[idx], self.behavioral_data[idx]
        X = torch.from_numpy(X)
        y = torch.from_numpy(y)
        return X, y

In [143]:
# train
backbone = ReducedRank(n_units, n_t_bins, rank=2)
model = LightningReducedRank(backbone)
trainer = pl.Trainer(devices=1, accelerator="gpu", precision="16-mixed")

Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [144]:
dataset = NeuralDataset(Xs, ys)
train_dataset, val_dataset, test_dataset = random_split(dataset, [0.8, 0.1, 0.1])

In [145]:
train_dataloader = DataLoader(train_dataset, batch_size=8, num_workers=1)
val_dataloader = DataLoader(val_dataset, batch_size=8, num_workers=1)
train_dataloader = DataLoader(train_dataset, batch_size=8, num_workers=1)
test_dataloader = DataLoader(test_dataset, batch_size=8, num_workers=1)

In [146]:
trainer.fit(model, train_dataloader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name     | Type        | Params
-----------------------------------------
0 | backbone | ReducedRank | 93    
-----------------------------------------
93        Trainable params
0         Non-trainable params
93        Total params
0.000     Total estimated model params size (MB)


Epoch 251:   0%|                               | 0/25 [00:00<?, ?it/s, v_num=12]

Exception ignored in: <function _releaseLock at 0x7fc3c094ab90>
Traceback (most recent call last):
  File "/home/yizi/anaconda3/envs/ibl_bwm/lib/python3.10/logging/__init__.py", line 228, in _releaseLock
    def _releaseLock():
KeyboardInterrupt: 


Epoch 261:  68%|██████████████▎      | 17/25 [00:00<00:00, 102.56it/s, v_num=12]

/home/yizi/anaconda3/envs/ibl_bwm/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...
