In [1]:
import sys
import numpy as np
import pandas as pd

from one.api import ONE
from brainbox.io.one import SessionLoader

from brainwidemap.bwm_loading import load_good_units, load_all_units, load_trials_and_mask, merge_probes
from brainwidemap.decoding.functions.decoding import fit_eid
from brainwidemap.decoding.functions.process_targets import load_behavior
from brainwidemap.decoding.settings_for_BWM_figure.settings_choice import params
from brainwidemap.decoding.settings_for_BWM_figure.settings_choice import RESULTS_DIR
# from brainwidemap.decoding.settings_for_BWM_figure.settings_wheel_speed import params
# from brainwidemap.decoding.settings_for_BWM_figure.settings_wheel_speed import RESULTS_DIR



In [2]:
# Establish paths and filenames
params['behfit_path'] = RESULTS_DIR.joinpath('decoding', 'results', 'behavioral')
params['behfit_path'].mkdir(parents=True, exist_ok=True)
params['neuralfit_path'] = RESULTS_DIR.joinpath('decoding', 'results', 'neural')
params['neuralfit_path'].mkdir(parents=True, exist_ok=True)
params['add_to_saving_path'] = (f"_binsize={1000 * params['binsize']}_lags={params['n_bins_lag']}_"
                                f"mergedProbes_{params['merged_probes']}")
imposter_file = RESULTS_DIR.joinpath('decoding', f"imposterSessions_{params['target']}.pqt")
bwm_session_file = RESULTS_DIR.joinpath('decoding', 'bwm_cache_sessions.pqt')

In [3]:
params["binsize"] = 0.1

In [4]:
params["binsize"]

0.1

In [5]:
# Load ONE and bwm dataframe of sessions
one = ONE(base_url="https://openalyx.internationalbrainlab.org", mode='remote')
bwm_df = pd.read_parquet(bwm_session_file)

In [6]:
# Feature to run a subset of BWM dataset filtering by subjects.
# To use this, add subject names to the end of the line that calls this script in 03_slurm*.sh.
# See 03_slurm*.sh for an examples which is commented out or read the `03_*` section of the README.
if len(sys.argv) > 2:
    print('using a subset of bwm dataset')
    #mysubs = [sys.argv[i] for i in range(2, len(sys.argv))]
    #bwm_df = bwm_df[bwm_df["subject"].isin(mysubs)]
    myeids = [sys.argv[i] for i in range(2, len(sys.argv))]
    # bwm_df = bwm_df[bwm_df["eid"].isin(myeids)]

using a subset of bwm dataset


In [7]:
pid = '3675290c-8134-4598-b924-83edb7940269'

[eid, pname] = one.pid2eid(pid)

In [8]:
idx = 1

In [9]:
if params['merged_probes']:
    eid = bwm_df['eid'].unique()[idx]
    tmp_df = bwm_df.set_index(['eid', 'subject']).xs(eid, level='eid')
    subject = tmp_df.index[0]
    pids = tmp_df['pid'].to_list()  # Select all probes of this session
    probe_names = tmp_df['probe_name'].to_list()
    print(f"Running merged probes for session eid: {eid}")
else:
    eid = bwm_df.iloc[idx]['eid']
    subject = bwm_df.iloc[idx]['subject']
    pid = bwm_df.iloc[idx]['pid']
    probe_name = bwm_df.iloc[idx]['probe_name']
    print(f"Running probe pid: {pid}")

Running merged probes for session eid: 56956777-dca5-468c-87cb-78150432cc57


In [10]:
params['n_pseudo'] = 0

In [11]:
# load trials df
sess_loader = SessionLoader(one, eid)
sess_loader.load_trials()

# create mask
trials_df, trials_mask = load_trials_and_mask(
    one=one, eid=eid, sess_loader=sess_loader, min_rt=params['min_rt'], max_rt=params['max_rt'],
    min_trial_len=params['min_len'], max_trial_len=params['max_len'],
    exclude_nochoice=True, exclude_unbiased=params['exclude_unbiased_trials'])
_, trials_mask_without_minrt = load_trials_and_mask(
    one=one, eid=eid, sess_loader=sess_loader, min_rt=None, max_rt=params['max_rt'],
    min_trial_len=params['min_len'], max_trial_len=params['max_len'],
    exclude_nochoice=True, exclude_unbiased=params['exclude_unbiased_trials'])
_, trials_mask_without_maxrt = load_trials_and_mask(
    one=one, eid=eid, sess_loader=sess_loader, min_rt=params['min_rt'], max_rt=None,
    min_trial_len=params['min_len'], max_trial_len=params['max_len'],
    exclude_nochoice=True, exclude_unbiased=params['exclude_unbiased_trials'])
_, trials_mask_withonly_nochoice = load_trials_and_mask(
    one=one, eid=eid, sess_loader=sess_loader, min_rt=None, max_rt=None,
    min_trial_len=None, max_trial_len=None,
    exclude_nochoice=True, exclude_unbiased=False)

params['trials_mask_diagnostics'] = [trials_mask,
                                     trials_mask_without_minrt,
                                     trials_mask_without_maxrt,
                                     trials_mask_withonly_nochoice]

# load target data if necessary (will probably put this into a function eventually)
if params['target'] in ['wheel-vel', 'wheel-speed', 'l-whisker-me', 'r-whisker-me']:
    # load target data
    dlc_dict = load_behavior(params['target'], sess_loader)
    # load imposter sessions
    params['imposter_df'] = pd.read_parquet(imposter_file) if params['n_pseudo'] > 0 else None
else:
    dlc_dict = None
    params['imposter_df'] = None

# Load spike sorting data
if params['merged_probes']:
    clusters_list = []
    spikes_list = []
    for pid, probe_name in zip(pids, probe_names):
        # tmp_spikes, tmp_clusters = load_good_units(one, pid, eid=eid, pname=probe_name)
        tmp_spikes, tmp_clusters = load_all_units(one, pid, eid=eid, pname=probe_name)
        tmp_clusters['pid'] = pid
        spikes_list.append(tmp_spikes)
        clusters_list.append(tmp_clusters)
    spikes, clusters = merge_probes(spikes_list, clusters_list)
else:
    # spikes, clusters = load_good_units(one, pid, eid=eid, pname=probe_name)
    spikes, clusters = load_all_units(one, pid, eid=eid, pname=probe_name)

# Put everything into the input format fit_eid still expects at this point
neural_dict = {
    'spk_times': spikes['times'],
    'spk_clu': spikes['clusters'],
    'clu_regions': clusters['acronym'],
    'clu_qc': {k: np.asarray(v) for k, v in clusters.to_dict('list').items()},
    'clu_df': clusters
}

metadata = {
    'subject': subject,
    'eid': eid,
    'probe_name': probe_name
}

In [13]:
kwargs = params

In [14]:
pseudo_ids = -np.ones(1).astype('int64')

In [15]:
import os
import pandas as pd
from sklearn import linear_model as sklm
from sklearn.metrics import accuracy_score, balanced_accuracy_score, r2_score
from sklearn.model_selection import KFold, train_test_split
from tqdm import tqdm
from behavior_models.utils import format_data as format_data_mut
from behavior_models.utils import format_input as format_input_mut

from iblatlas.regions import BrainRegions

from brainwidemap.decoding.functions.balancedweightings import balanced_weighting
from brainwidemap.decoding.functions.process_inputs import build_predictor_matrix
from brainwidemap.decoding.functions.process_inputs import select_ephys_regions
from brainwidemap.decoding.functions.process_inputs import preprocess_ephys
from brainwidemap.decoding.functions.process_targets import compute_beh_target
from brainwidemap.decoding.functions.process_targets import compute_target_mask
from brainwidemap.decoding.functions.process_targets import transform_data_for_decoding
from brainwidemap.decoding.functions.process_targets import logisticreg_criteria
from brainwidemap.decoding.functions.process_targets import get_target_data_per_trial_wrapper
from brainwidemap.decoding.functions.utils import save_region_results
from brainwidemap.decoding.functions.utils import get_save_path
from brainwidemap.decoding.functions.nulldistributions import generate_null_distribution_session
from brainwidemap.decoding.functions.process_targets import check_bhv_fit_exists
from brainwidemap.decoding.functions.process_targets import optimal_Bayesian

from brainwidemap.decoding.functions.decoding import decode_cv

In [16]:
print(f'Working on eid: {metadata["eid"]}')
filenames = []  # this will contain paths to saved decoding results for this eid

if kwargs['use_imposter_session'] and not kwargs['stitching_for_imposter_session']:
    trials_df = trials_df[:int(kwargs['max_number_trials_when_no_stitching_for_imposter_session'])]

if 0 in pseudo_ids:
    raise ValueError(
        'pseudo id can be -1 (actual session) or strictly greater than 0 (pseudo session)')

if not np.all(np.sort(pseudo_ids) == pseudo_ids):
    raise ValueError('pseudo_ids must be sorted')

if kwargs['model'] == optimal_Bayesian and np.any(trials_df.probabilityLeft.values[:90] != 0.5):
    raise ValueError(
        'The optimal Bayesian model assumes 90 unbiased trials at the beginning of the '
        'session, which is not the case here.')

Working on eid: 56956777-dca5-468c-87cb-78150432cc57


In [17]:
# check if is trained
eids_train = (
    [metadata['eid']] if 'eids_train' not in metadata.keys() else metadata['eids_train'])
if 'eids_train' not in metadata.keys():
    metadata['eids_train'] = eids_train
elif metadata['eids_train'] != eids_train:
    raise ValueError(
        'eids_train are not supported yet. If you do not understand this error, '
        'just take out the eids_train key in the metadata to solve it')

In [18]:
if isinstance(kwargs['model'], str):
    import pickle
    from braindelphi.params import INTER_INDIVIDUAL_PATH
    inter_individual = pickle.load(open(INTER_INDIVIDUAL_PATH.joinpath(kwargs['model']), 'rb'))
    if metadata['eid'] not in inter_individual.keys():
        logging.exception('no inter individual model found')
        print(filenames)
    inter_indiv_model_specifications = inter_individual[metadata['eid']]
    print('winning interindividual model is %s' % inter_indiv_model_specifications['model_name'])
    if inter_indiv_model_specifications['model_name'] not in kwargs['modeldispatcher'].values():
        logging.exception('winning inter individual model is LeftKernel or RightKernel')
        print(filenames)
    kwargs['model'] = {v: k for k, v in kwargs['modeldispatcher'].items()}[inter_indiv_model_specifications['model_name']]
    kwargs['model_parameters'] = inter_indiv_model_specifications['model_parameters']
else:
    kwargs['model_parameters'] = None
    # train model if not trained already
    if kwargs['model'] != optimal_Bayesian and kwargs['model'] is not None:
        side, stim, act, _ = format_data_mut(trials_df)
        stimuli, actions, stim_side = format_input_mut([stim], [act], [side])
        behmodel = kwargs['model'](
            kwargs['behfit_path'], np.array(metadata['eids_train']), metadata['subject'],
            actions, stimuli, trials_df, stim_side, single_zeta=True)
        istrained, _ = check_bhv_fit_exists(
            metadata['subject'], kwargs['model'], metadata['eids_train'],
            kwargs['behfit_path'], modeldispatcher=kwargs['modeldispatcher'], single_zeta=True)
        if not istrained:
            behmodel.load_or_train(remove_old=False)

if kwargs['balanced_weight'] and kwargs['balanced_continuous_target']:
    raise NotImplementedError("see tag `decoding_biasCWnull` for a previous implementation.")
else:
    target_distribution = None

[36m2023-11-06 18:16:59.878 INFO     [base_models.py:  289]   results found and loaded from /mnt/3TB/yizi/decode-paper-brain-wide-map/decoding/results/behavioral/NYU-11/model_actKernel_single_zeta/train_56956777.pkl[0m


In [19]:
# get target values
if kwargs['target'] in ['pLeft', 'signcont', 'strengthcont', 'choice', 'feedback']:
    target_vals_list, target_vals_to_mask = compute_beh_target(
        trials_df, metadata, return_raw=True, **kwargs)
    target_mask = compute_target_mask(
        target_vals_to_mask, kwargs['exclude_trials_within_values'])

else:
    if dlc_dict is None or dlc_dict['times'] is None or dlc_dict['values'] is None:
        raise ValueError('dlc_dict does not contain any data')
    _, target_vals_list, target_mask = get_target_data_per_trial_wrapper(
        target_times=dlc_dict['times'],
        target_vals=dlc_dict['values'],
        trials_df=trials_df,
        align_event=kwargs['align_time'],
        align_interval=kwargs['time_window'],
        binsize=kwargs['binsize'])

In [20]:
mask = trials_mask & target_mask

if sum(mask) <= kwargs['min_behav_trials']:
    msg = 'session contains %i trials, below the threshold of %i' % (
        sum(mask), kwargs['min_behav_trials'])
    logging.exception(msg)
    print(filenames)

In [21]:
# select brain regions from beryl atlas to loop over
brainreg = BrainRegions()
beryl_reg = brainreg.acronym2acronym(neural_dict['clu_regions'], mapping='Beryl')
regions = (
    [[k] for k in np.unique(beryl_reg)] if kwargs['single_region'] else [np.unique(beryl_reg)])

In [22]:
kwargs['n_runs'] = 1

trial_len = kwargs['time_window'][1] - kwargs['time_window'][0]
binsize = kwargs.get('binsize', trial_len)
    
# kwargs['n_bins_lag'] = int(trial_len // binsize)
kwargs['n_bins_lag'] = 0

In [23]:
pseudo_id = -1

In [24]:
pseudo_ids

array([-1])

In [26]:
region_results = {}
for region in tqdm(regions, desc='Region: ', leave=False):

    # pull spikes from this region out of the neural data
    reg_clu_ids = select_ephys_regions(neural_dict, beryl_reg, region, **kwargs)

    # skip region if there are not enough units
    n_units = len(reg_clu_ids)
    if n_units < kwargs['min_units']:
        continue

    # bin spikes from this region for each trial
    msub_binned, cl_inds_used = preprocess_ephys(reg_clu_ids, neural_dict, trials_df, **kwargs)
    cl_uuids_used = list(neural_dict['clu_df'].iloc[cl_inds_used]['uuids'])

    # make design matrix
    bins_per_trial = msub_binned[0].shape[0]
    Xs = (
        msub_binned if bins_per_trial == 1
        else [build_predictor_matrix(s, kwargs['n_bins_lag']) for s in msub_binned]
    )
    
    control_mask = mask
    save_predictions = kwargs['save_predictions']

    # original session
    ys_wmask = [target_vals_list[m] for m in np.squeeze(np.where(mask))]
    Xs_wmask = [Xs[m] for m in np.squeeze(np.where(mask))]
    
    fit_results = []
    for i_run in range(kwargs['n_runs']):

        # set seed for reproducibility
        rng_seed = i_run

        fit_result = decode_cv(
            ys=ys_wmask,
            Xs=Xs_wmask,
            estimator=kwargs['estimator'],
            use_openturns=kwargs['use_openturns'],
            target_distribution=target_distribution,
            balanced_continuous_target=kwargs['balanced_continuous_target'],
            estimator_kwargs=kwargs['estimator_kwargs'],
            hyperparam_grid=kwargs['hyperparam_grid'],
            save_binned=kwargs['save_binned'] if pseudo_id == -1 else False,
            save_predictions=save_predictions,
            shuffle=kwargs['shuffle'],
            balanced_weight=kwargs['balanced_weight'],
            rng_seed=rng_seed,
        )
        fit_result['mask'] = mask
        fit_result['mask_trials_and_targets'] = [trials_mask, target_mask]
        fit_result['mask_diagnostics'] = kwargs['trials_mask_diagnostics']
        fit_result['df'] = trials_df if pseudo_id == -1 else controlsess_df
        fit_result['pseudo_id'] = pseudo_id
        fit_result['run_id'] = i_run
        fit_result['cluster_uuids'] = cl_uuids_used
        fit_results.append(fit_result)
        
    region_results.update({region[0]: fit_results[0]['acc_test_full']})

        
print(f'Finished eid: {metadata["eid"]}')

                                                                                

Finished eid: 56956777-dca5-468c-87cb-78150432cc57




In [27]:
# no time binning
region_results

{'BMA': 0.6244897959183674,
 'CA1': 0.6,
 'CA2': 0.5428571428571428,
 'CA3': 0.5918367346938775,
 'CEA': 0.6040816326530613,
 'COAp': 0.8081632653061225,
 'GPe': 0.7020408163265306,
 'IA': 0.6204081632653061,
 'LGd': 0.7510204081632653,
 'PA': 0.46530612244897956,
 'SSp-bfd': 0.6,
 'SSp-tr': 0.6285714285714286,
 'VISa': 0.5755102040816327,
 'VPM': 0.5755102040816327,
 'root': 0.7510204081632653,
 'void': 0.6204081632653061}

In [40]:
# 0.02 time bin
region_results

{'BMA': 0.5142857142857142,
 'CA1': 0.5714285714285714,
 'CA2': 0.5755102040816327,
 'CA3': 0.5387755102040817,
 'CEA': 0.5387755102040817,
 'COAp': 0.8,
 'GPe': 0.6408163265306123,
 'IA': 0.5551020408163265,
 'LGd': 0.726530612244898,
 'PA': 0.5346938775510204,
 'SSp-bfd': 0.6,
 'SSp-tr': 0.6612244897959184,
 'VISa': 0.5591836734693878,
 'VPM': 0.6040816326530613,
 'root': 0.710204081632653,
 'void': 0.6081632653061224}