In [1]:
import os
import sys
import random
import h5py
import pandas as pd
from pathlib import Path
import numpy as np
import torch

import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split, KFold, StratifiedKFold

from side_info_decoding.utils import (
    set_seed, 
    load_data_from_pids, 
    sliding_window_over_trials
)

from one.api import ONE

seed = 666
set_seed(seed)



In [2]:
# setup
regions = ["LP", "GRN"]
n_sess = 2
out_path = Path("/mnt/3TB/yizi/cached_ibl_data")
one = ONE(base_url="https://openalyx.internationalbrainlab.org", mode='remote')

In [20]:
# download and cache data

for roi_idx, roi in enumerate(regions):
    
    print("=================")
    print(f"Downloading data in region {roi} ..")
    
    pids = one.search_insertions(atlas_acronym=[roi], query_type='remote')
    pids = list(pids)[:n_sess]
    
    # load choice
    neural_dict, choice_dict = load_data_from_pids(
        pids,
        brain_region=roi.lower(),
        behavior="choice",
        data_type="all_ks",
        n_t_bins = 40,
    )
    available_pids = list(neural_dict.keys())
    
    # load contrast
    _, contrast_dict = load_data_from_pids(
        pids,
        brain_region=roi.lower(),
        behavior="contrast",
        data_type="good_ks",
        n_t_bins = 40,
    )

    print("=================")
    print(f"Downloaded {len(available_pids)} PIDs in region {roi} ..")
    
    for _, pid in enumerate(available_pids):
        xs, ys = neural_dict[pid], choice_dict[pid]
        n_trials, n_units, n_t_bins = xs.shape
        if n_units < 5:
            continue
        xs = sliding_window_over_trials(xs, half_window_size=0).squeeze()
        ys = sliding_window_over_trials(ys, half_window_size=0).squeeze()
        xs, ys = torch.tensor(xs), torch.tensor(ys)
        
        contrast_dict[pid] = np.nan_to_num(contrast_dict[pid], 0)
        contrast_dict[pid].T[0] *= -1
        contrast = contrast_dict[pid].sum(1)
        
        contrast_mask_dict = {}
        for lvl in np.unique(np.abs(contrast)):
            contrast_mask_dict.update(
                {lvl: np.argwhere(contrast == lvl).flatten()}
            )
            
        path = out_path/roi
        if not os.path.exists(path):
            os.makedirs(path)
            
        data_dict = {}
        data_dict.update({'neural': xs})
        data_dict.update({'choice': ys})
        data_dict.update({'contrast': contrast})
        data_dict.update({'contrast_mask': contrast_mask_dict})
        data_dict.update({'meta':
            {"n_trials": n_trials, "n_units": n_units, "n_t_bins": n_t_bins}
        })
        xs_per_lvl, ys_per_lvl = {}, {}
        for lvl in np.unique(np.abs(contrast)):
            try:
                xs_per_lvl.update({lvl: xs[contrast_mask_dict[lvl]]})
                ys_per_lvl.update({lvl: ys[contrast_mask_dict[lvl]]})
            except:
                continue
        data_dict.update({'neural_contrast': xs_per_lvl})
        data_dict.update({'choice_contrast': ys_per_lvl})
        np.save(path/f"pid_{pid}.npy", data_dict)
        
    print("=================")
    print(f"Successfully cached all data!")

Downloading data in region LP ..




pulling data from ibl database ..
eid: ebce500b-c530-47de-8cb1-963c552703ea
pid: 8c732bf2-639d-496c-bf82-464bc9c2d54b
number of trials found: 470
found 470 trials from 13.74 to 5761.52 sec.
found 139 Kilosort units in region lp


Compute spike count: 100%|███████████████████| 470/470 [00:01<00:00, 348.18it/s]


pulling data from ibl database ..
eid: 15b69921-d471-4ded-8814-2adad954bcd8
pid: 7a620688-66cb-44d3-b79b-ccac1c8ba23e
number of trials found: 715
found 715 trials from 28.03 to 3547.82 sec.
found 47 Kilosort units in region lp


Compute spike count: 100%|██████████████████| 715/715 [00:00<00:00, 2813.38it/s]


pulling data from ibl database ..
eid: ebce500b-c530-47de-8cb1-963c552703ea
pid: 8c732bf2-639d-496c-bf82-464bc9c2d54b
number of trials found: 470
found 470 trials from 13.74 to 5761.52 sec.
found 34 good units in region lp


Compute spike count: 100%|██████████████████| 470/470 [00:00<00:00, 1690.27it/s]


pulling data from ibl database ..
eid: 15b69921-d471-4ded-8814-2adad954bcd8
pid: 7a620688-66cb-44d3-b79b-ccac1c8ba23e
number of trials found: 715
found 715 trials from 28.03 to 3547.82 sec.
found 0 good units in region lp


Compute spike count: 100%|█████████████████| 715/715 [00:00<00:00, 81293.78it/s]

Downloaded 2 PIDs in region LP ..
Successfully cached all data!
Downloading data in region GRN ..





pulling data from ibl database ..
eid: c958919c-2e75-435d-845d-5b62190b520e
pid: cc72fdb7-92e8-47e6-9cea-94f27c0da2d8
number of trials found: 705
found 705 trials from 79.14 to 3939.10 sec.
found 261 Kilosort units in region grn


Compute spike count: 100%|███████████████████| 705/705 [00:06<00:00, 103.83it/s]


pulling data from ibl database ..
eid: 32d27583-56aa-4510-bc03-669036edad20
pid: 2e720cee-05cc-440e-a24b-13794b1ac01d
number of trials found: 682
found 682 trials from 28.94 to 3431.07 sec.
found 81 Kilosort units in region grn


Compute spike count: 100%|███████████████████| 682/682 [00:02<00:00, 292.03it/s]


pulling data from ibl database ..
eid: c958919c-2e75-435d-845d-5b62190b520e
pid: cc72fdb7-92e8-47e6-9cea-94f27c0da2d8
number of trials found: 705
found 705 trials from 79.14 to 3939.10 sec.
found 12 good units in region grn


Compute spike count: 100%|██████████████████| 705/705 [00:00<00:00, 1487.99it/s]


pulling data from ibl database ..
eid: 32d27583-56aa-4510-bc03-669036edad20
pid: 2e720cee-05cc-440e-a24b-13794b1ac01d
number of trials found: 682
found 682 trials from 28.94 to 3431.07 sec.
found 6 good units in region grn


Compute spike count: 100%|██████████████████| 682/682 [00:00<00:00, 1503.17it/s]


Downloaded 2 PIDs in region GRN ..
Successfully cached all data!


In [21]:
# setup
regions = ["LP", "GRN"]
n_sess = 2
in_path = Path("/mnt/3TB/yizi/cached_ibl_data")

In [22]:
# run hierarchical RRR model

for roi_idx, roi in enumerate(regions):
    
    f_names = os.listdir(in_path/roi)
    pids = [f_name.split("_")[1].split(".")[0] for f_name in f_names]
    
    print("=================")
    print(f"Loading {len(pids)} PIDs in region {roi}:")
    for pid in pids:
        print(pid)
    
    data_dict = np.load(in_path/roi/f"pid_{pid}.npy", allow_pickle=True)
    
    

Loading 2 PIDs in region LP:
7a620688-66cb-44d3-b79b-ccac1c8ba23e
8c732bf2-639d-496c-bf82-464bc9c2d54b
Loading 2 PIDs in region GRN:
cc72fdb7-92e8-47e6-9cea-94f27c0da2d8
2e720cee-05cc-440e-a24b-13794b1ac01d
