In [1]:
import os
import sys
from datetime import date
import random
import pandas as pd
from pathlib import Path
import numpy as np

from one.api import ONE

import matplotlib.pyplot as plt
from scipy.linalg import svd
from sklearn.model_selection import StratifiedKFold

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchmetrics.functional import accuracy
from torchmetrics import AUROC
import lightning as L 
from lightning.pytorch.utilities import CombinedLoader

from side_info_decoding.utils import (
    set_seed, 
    load_data_from_pids, 
    sliding_window_over_trials
)

seed = 666
set_seed(seed)

  warn(f"Failed to load image Python extension: {e}")


In [2]:
from iblatlas.atlas import AllenAtlas
ba = AllenAtlas()
regions = np.unique(ba.regions.acronym[ba.regions.level == 7])
print(len(regions))

324


In [4]:
# setup
regions = ["LP", "GRN"]
n_sess = 10
out_path = Path("/mnt/3TB/yizi/cached_ibl_data")
one = ONE(base_url="https://openalyx.internationalbrainlab.org", mode='remote')

In [8]:
with open("../biorxiv_plots/regions.txt") as file:
    regions = [line.rstrip() for line in file]
print(len(regions))

76


In [11]:
# download and cache data

for roi_idx, roi in enumerate(regions):
    
    print("=================")
    print(f"Downloading data in region {roi} ..")
    
    pids = one.search_insertions(atlas_acronym=[roi], query_type='remote')
    pids = list(pids)[:n_sess]
    
    # load choice
    neural_dict, choice_dict = load_data_from_pids(
        pids,
        brain_region=roi.lower(),
        behavior="choice",
        data_type="all_ks",
        n_t_bins = 40,
    )
    available_pids = list(neural_dict.keys())
    
    # load contrast
    _, contrast_dict = load_data_from_pids(
        pids,
        brain_region=roi.lower(),
        behavior="contrast",
        data_type="good_ks",
        n_t_bins = 40,
    )

    print("=================")
    print(f"Downloaded {len(available_pids)} PIDs in region {roi} ..")
    
    for _, pid in enumerate(available_pids):
        xs, ys = neural_dict[pid], choice_dict[pid]
        n_trials, n_units, n_t_bins = xs.shape
        if n_units < 5:
            continue
        xs = sliding_window_over_trials(xs, half_window_size=0).squeeze()
        ys = sliding_window_over_trials(ys, half_window_size=0).squeeze()
        xs, ys = torch.tensor(xs), torch.tensor(ys)
        
        contrast_dict[pid] = np.nan_to_num(contrast_dict[pid], 0)
        contrast_dict[pid].T[0] *= -1
        contrast = contrast_dict[pid].sum(1)
        
        # INCORPORATE CHANGE HERE!
        contrast_mask_dict = {}
        for lvl in np.unique(np.abs(contrast)):
            contrast_mask_dict.update(
                {lvl: np.argwhere(np.abs(contrast) == lvl).flatten()}
            )
            
        path = out_path/roi
        if not os.path.exists(path):
            os.makedirs(path)
            
        data_dict = {}
        data_dict.update({'contrast': contrast})
        data_dict.update({'contrast_mask': contrast_mask_dict})
        data_dict.update({'meta':
            {"n_trials": n_trials, "n_units": n_units, "n_t_bins": n_t_bins}
        })
        xs_per_lvl, ys_per_lvl = {}, {}
        xs_per_lvl.update({"all": xs})
        ys_per_lvl.update({"all": ys})
        for lvl in np.unique(np.abs(contrast)):
            try:
                xs_per_lvl.update({lvl: xs[contrast_mask_dict[lvl]]})
                ys_per_lvl.update({lvl: ys[contrast_mask_dict[lvl]]})
            except:
                continue
        data_dict.update({'neural_contrast': xs_per_lvl})
        data_dict.update({'choice_contrast': ys_per_lvl})
        np.save(path/f"pid_{pid}.npy", data_dict)
        
    print("=================")
    print(f"Successfully cached all data!")

Downloading data in region LP ..
pulling data from ibl database ..
eid: ebce500b-c530-47de-8cb1-963c552703ea
pid: 8c732bf2-639d-496c-bf82-464bc9c2d54b
number of trials found: 470
found 470 trials from 13.74 to 5761.52 sec.
found 139 Kilosort units in region lp


Compute spike count: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 470/470 [00:01<00:00, 345.37it/s]


pulling data from ibl database ..
eid: 15b69921-d471-4ded-8814-2adad954bcd8
pid: 7a620688-66cb-44d3-b79b-ccac1c8ba23e
number of trials found: 715
found 715 trials from 28.03 to 3547.82 sec.
found 47 Kilosort units in region lp


Compute spike count: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 715/715 [00:00<00:00, 2769.69it/s]


pulling data from ibl database ..
eid: caa5dddc-9290-4e27-9f5e-575ba3598614
pid: d0046384-16ea-4f69-bae9-165e8d0aeacf
number of trials found: 358
found 358 trials from 108.84 to 2048.28 sec.
found 27 Kilosort units in region lp


Compute spike count: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 358/358 [00:00<00:00, 16097.18it/s]
  data_norm[:,t*n_units:(t+1)*n_units] = (data_norm[:,t*n_units:(t+1)*n_units] - mean_per_trial) / std_per_trial


pulling data from ibl database ..
eid: 642c97ea-fe89-4ec9-8629-5e492ea4019d
pid: b72b22c2-6e9d-4604-9910-20c0e1a467d7
number of trials found: 432
found 432 trials from 49.33 to 2759.22 sec.
found 22 Kilosort units in region lp


Compute spike count: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 432/432 [00:00<00:00, 10992.38it/s]


pulling data from ibl database ..
eid: f4eb56a4-8bf8-4bbc-a8f3-6e6535134bad
pid: bef05a5c-68c3-4513-87c7-b3151c88da8e
number of trials found: 489
found 489 trials from 101.15 to 3448.76 sec.
found 51 Kilosort units in region lp


Compute spike count: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 489/489 [00:00<00:00, 5664.61it/s]


pulling data from ibl database ..
eid: 37ac03f1-9831-4a30-90fc-a59e635b98bd
pid: 8b31b4bd-003e-4816-a3bf-2df4cc3558f8
number of trials found: 432
found 432 trials from 65.86 to 2896.15 sec.
found 96 Kilosort units in region lp


Compute spike count: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 432/432 [00:00<00:00, 1605.63it/s]


pulling data from ibl database ..
eid: 94dabed1-741c-4ddd-a6b7-70561e27b750
pid: ec2fbc3e-cb2b-48cb-a521-3a6ca15e244c
number of trials found: 552
found 552 trials from 126.96 to 2975.72 sec.
found 46 Kilosort units in region lp


Compute spike count: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 552/552 [00:00<00:00, 1582.83it/s]


pulling data from ibl database ..
eid: 49250fba-801c-4867-a0a7-a1e19538cb61
pid: a6b71993-165b-4c43-845c-c062fe7d7a11
number of trials found: 592
found 592 trials from 60.40 to 2998.35 sec.
found 33 Kilosort units in region lp


Compute spike count: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 592/592 [00:00<00:00, 4300.67it/s]


pulling data from ibl database ..
eid: a34b4013-414b-42ed-9318-e93fbbc71e7b
pid: 22f26d69-0b30-450e-9618-ee801b720e0a
number of trials found: 643
found 643 trials from 58.07 to 3454.00 sec.
found 43 Kilosort units in region lp


Compute spike count: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 643/643 [00:00<00:00, 2470.49it/s]


pulling data from ibl database ..
eid: 2038e95d-64d4-4ecb-83d0-1308d3c598f8
pid: 1a924329-65aa-465d-b201-c2dd898aebd0
number of trials found: 481
found 481 trials from 61.03 to 2940.03 sec.
found 98 Kilosort units in region lp


Compute spike count: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 481/481 [00:00<00:00, 776.25it/s]


pulling data from ibl database ..
eid: ebce500b-c530-47de-8cb1-963c552703ea
pid: 8c732bf2-639d-496c-bf82-464bc9c2d54b
number of trials found: 470
found 470 trials from 13.74 to 5761.52 sec.
found 34 good units in region lp


Compute spike count: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 470/470 [00:00<00:00, 1632.12it/s]


pulling data from ibl database ..
eid: 15b69921-d471-4ded-8814-2adad954bcd8
pid: 7a620688-66cb-44d3-b79b-ccac1c8ba23e
number of trials found: 715
found 715 trials from 28.03 to 3547.82 sec.
found 0 good units in region lp


Compute spike count: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 715/715 [00:00<00:00, 86639.14it/s]


pulling data from ibl database ..
eid: caa5dddc-9290-4e27-9f5e-575ba3598614
pid: d0046384-16ea-4f69-bae9-165e8d0aeacf
number of trials found: 358
found 358 trials from 108.84 to 2048.28 sec.
found 2 good units in region lp


Compute spike count: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 358/358 [00:00<00:00, 24264.90it/s]
  data_norm[:,t*n_units:(t+1)*n_units] = (data_norm[:,t*n_units:(t+1)*n_units] - mean_per_trial) / std_per_trial


pulling data from ibl database ..
eid: 642c97ea-fe89-4ec9-8629-5e492ea4019d
pid: b72b22c2-6e9d-4604-9910-20c0e1a467d7
number of trials found: 432
found 432 trials from 49.33 to 2759.22 sec.
found 1 good units in region lp


Compute spike count: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 432/432 [00:00<00:00, 57580.38it/s]
  data_norm[:,t*n_units:(t+1)*n_units] = (data_norm[:,t*n_units:(t+1)*n_units] - mean_per_trial) / std_per_trial


pulling data from ibl database ..
eid: f4eb56a4-8bf8-4bbc-a8f3-6e6535134bad
pid: bef05a5c-68c3-4513-87c7-b3151c88da8e
number of trials found: 489
found 489 trials from 101.15 to 3448.76 sec.
found 2 good units in region lp


Compute spike count: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 489/489 [00:00<00:00, 83289.94it/s]
  data_norm[:,t*n_units:(t+1)*n_units] = (data_norm[:,t*n_units:(t+1)*n_units] - mean_per_trial) / std_per_trial


pulling data from ibl database ..
eid: 37ac03f1-9831-4a30-90fc-a59e635b98bd
pid: 8b31b4bd-003e-4816-a3bf-2df4cc3558f8
number of trials found: 432
found 432 trials from 65.86 to 2896.15 sec.
found 0 good units in region lp


Compute spike count: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 432/432 [00:00<00:00, 86360.96it/s]


pulling data from ibl database ..
eid: 94dabed1-741c-4ddd-a6b7-70561e27b750
pid: ec2fbc3e-cb2b-48cb-a521-3a6ca15e244c
number of trials found: 552
found 552 trials from 126.96 to 2975.72 sec.
found 0 good units in region lp


Compute spike count: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 552/552 [00:00<00:00, 43591.95it/s]


pulling data from ibl database ..
eid: 49250fba-801c-4867-a0a7-a1e19538cb61
pid: a6b71993-165b-4c43-845c-c062fe7d7a11
number of trials found: 592
found 592 trials from 60.40 to 2998.35 sec.
found 0 good units in region lp


Compute spike count: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 592/592 [00:00<00:00, 88128.77it/s]


pulling data from ibl database ..
eid: a34b4013-414b-42ed-9318-e93fbbc71e7b
pid: 22f26d69-0b30-450e-9618-ee801b720e0a
number of trials found: 643
found 643 trials from 58.07 to 3454.00 sec.
found 2 good units in region lp


Compute spike count: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 643/643 [00:00<00:00, 21218.02it/s]
  data_norm[:,t*n_units:(t+1)*n_units] = (data_norm[:,t*n_units:(t+1)*n_units] - mean_per_trial) / std_per_trial


pulling data from ibl database ..
eid: 2038e95d-64d4-4ecb-83d0-1308d3c598f8
pid: 1a924329-65aa-465d-b201-c2dd898aebd0
number of trials found: 481
found 481 trials from 61.03 to 2940.03 sec.
found 9 good units in region lp


Compute spike count: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 481/481 [00:00<00:00, 4599.74it/s]


Downloaded 10 PIDs in region LP ..
Successfully cached all data!
Downloading data in region GRN ..
pulling data from ibl database ..
eid: c958919c-2e75-435d-845d-5b62190b520e
pid: cc72fdb7-92e8-47e6-9cea-94f27c0da2d8
number of trials found: 705
found 705 trials from 79.14 to 3939.10 sec.
found 261 Kilosort units in region grn


Compute spike count: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 705/705 [00:06<00:00, 105.32it/s]


pulling data from ibl database ..
eid: 32d27583-56aa-4510-bc03-669036edad20
pid: 2e720cee-05cc-440e-a24b-13794b1ac01d
number of trials found: 682
found 682 trials from 28.94 to 3431.07 sec.
found 81 Kilosort units in region grn


Compute spike count: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 682/682 [00:02<00:00, 288.51it/s]


pulling data from ibl database ..
eid: 7cec9792-b8f9-4878-be7e-f08103dc0323
pid: e17db2b6-b778-4e2a-845c-c4d040b0c875
number of trials found: 542
found 542 trials from 307.48 to 2989.17 sec.
found 20 Kilosort units in region grn


Compute spike count: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 542/542 [00:00<00:00, 1990.81it/s]


pulling data from ibl database ..
eid: aec5d3cc-4bb2-4349-80a9-0395b76f04e2
pid: 7332e6cf-9847-4aca-b2e3-d864989dd0fb
number of trials found: 581
found 581 trials from 29.78 to 2802.68 sec.
found 261 Kilosort units in region grn


Compute spike count: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 581/581 [00:05<00:00, 105.04it/s]


pulling data from ibl database ..
eid: 25d1920e-a2af-4b6c-9f2e-fc6c65576544
pid: c0e59477-43f0-4441-9f81-3a55ddad9dad
number of trials found: 358
found 358 trials from 9.42 to 2224.60 sec.
found 99 Kilosort units in region grn


Compute spike count: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 358/358 [00:00<00:00, 360.82it/s]


pulling data from ibl database ..
eid: 571d3ffe-54a5-473d-a265-5dc373eb7efc
pid: aecd7612-b5c5-4ad2-9e76-e5b783387e47
number of trials found: 359
found 359 trials from 14.38 to 2766.52 sec.
found 74 Kilosort units in region grn


Compute spike count: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 359/359 [00:00<00:00, 834.69it/s]


pulling data from ibl database ..
eid: 75b6b132-d998-4fba-8482-961418ac957d
pid: 6a098711-5423-4072-8909-7cff0e2d4531
number of trials found: 403
found 403 trials from 49.05 to 3074.29 sec.
found 71 Kilosort units in region grn


Compute spike count: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 403/403 [00:00<00:00, 2521.15it/s]


pulling data from ibl database ..
eid: 746d1902-fa59-4cab-b0aa-013be36060d5
pid: 39883ded-f5a2-4f4f-a98e-fb138eb8433e
number of trials found: 561
found 561 trials from 39.62 to 4058.39 sec.
found 132 Kilosort units in region grn


Compute spike count: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 561/561 [00:02<00:00, 207.72it/s]


pulling data from ibl database ..
eid: 671c7ea7-6726-4fbe-adeb-f89c2c8e489b
pid: 04c9890f-2276-4c20-854f-305ff5c9b6cf
number of trials found: 700
found 700 trials from 109.27 to 3491.16 sec.
found 248 Kilosort units in region grn


Compute spike count: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 700/700 [00:07<00:00, 96.39it/s]


pulling data from ibl database ..
eid: eebacd5a-7dcd-4ba6-9dff-ec2a4d2f19e0
pid: df6012d0-d921-4d0a-af2a-2a91030d0f42
number of trials found: 554
found 554 trials from 56.17 to 2721.08 sec.
found 106 Kilosort units in region grn


Compute spike count: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 554/554 [00:03<00:00, 183.39it/s]


pulling data from ibl database ..
eid: c958919c-2e75-435d-845d-5b62190b520e
pid: cc72fdb7-92e8-47e6-9cea-94f27c0da2d8
number of trials found: 705
found 705 trials from 79.14 to 3939.10 sec.
found 12 good units in region grn


Compute spike count: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 705/705 [00:00<00:00, 1482.35it/s]


pulling data from ibl database ..
eid: 32d27583-56aa-4510-bc03-669036edad20
pid: 2e720cee-05cc-440e-a24b-13794b1ac01d
number of trials found: 682
found 682 trials from 28.94 to 3431.07 sec.
found 6 good units in region grn


Compute spike count: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 682/682 [00:00<00:00, 1481.14it/s]


pulling data from ibl database ..
eid: 7cec9792-b8f9-4878-be7e-f08103dc0323
pid: e17db2b6-b778-4e2a-845c-c4d040b0c875
number of trials found: 542
found 542 trials from 307.48 to 2989.17 sec.
found 0 good units in region grn


Compute spike count: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 542/542 [00:00<00:00, 88815.16it/s]


pulling data from ibl database ..
eid: aec5d3cc-4bb2-4349-80a9-0395b76f04e2
pid: 7332e6cf-9847-4aca-b2e3-d864989dd0fb
number of trials found: 581
found 581 trials from 29.78 to 2802.68 sec.
found 16 good units in region grn


Compute spike count: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 581/581 [00:00<00:00, 845.08it/s]


pulling data from ibl database ..
eid: 25d1920e-a2af-4b6c-9f2e-fc6c65576544
pid: c0e59477-43f0-4441-9f81-3a55ddad9dad
number of trials found: 358
found 358 trials from 9.42 to 2224.60 sec.
found 21 good units in region grn


Compute spike count: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 358/358 [00:00<00:00, 1365.74it/s]


pulling data from ibl database ..
eid: 571d3ffe-54a5-473d-a265-5dc373eb7efc
pid: aecd7612-b5c5-4ad2-9e76-e5b783387e47
number of trials found: 359
found 359 trials from 14.38 to 2766.52 sec.
found 10 good units in region grn


Compute spike count: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 359/359 [00:00<00:00, 5521.63it/s]


pulling data from ibl database ..
eid: 75b6b132-d998-4fba-8482-961418ac957d
pid: 6a098711-5423-4072-8909-7cff0e2d4531
number of trials found: 403
found 403 trials from 49.05 to 3074.29 sec.
found 6 good units in region grn


Compute spike count: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 403/403 [00:00<00:00, 16809.75it/s]


pulling data from ibl database ..
eid: 746d1902-fa59-4cab-b0aa-013be36060d5
pid: 39883ded-f5a2-4f4f-a98e-fb138eb8433e
number of trials found: 561
found 561 trials from 39.62 to 4058.39 sec.
found 9 good units in region grn


Compute spike count: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 561/561 [00:00<00:00, 2333.04it/s]


pulling data from ibl database ..
eid: 671c7ea7-6726-4fbe-adeb-f89c2c8e489b
pid: 04c9890f-2276-4c20-854f-305ff5c9b6cf
number of trials found: 700
found 700 trials from 109.27 to 3491.16 sec.
found 24 good units in region grn


Compute spike count: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 700/700 [00:01<00:00, 587.54it/s]


pulling data from ibl database ..
eid: eebacd5a-7dcd-4ba6-9dff-ec2a4d2f19e0
pid: df6012d0-d921-4d0a-af2a-2a91030d0f42
number of trials found: 554
found 554 trials from 56.17 to 2721.08 sec.
found 20 good units in region grn


Compute spike count: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 554/554 [00:00<00:00, 783.15it/s]


Downloaded 10 PIDs in region GRN ..
Successfully cached all data!


In [79]:
class SessionDataset:
    def __init__(self, dataset, roi_idx, pid_idx, **kargs):
        self.xs, self.ys = dataset
        self.n_trials, self.n_units, _ = self.xs.shape
        self.roi_idx = roi_idx
        self.pid_idx = pid_idx
        
    def __len__(self):
        return self.n_trials
    
    def __getitem__(self, index):
        return torch.tensor(self.xs[index]).to(DEVICE), torch.tensor(self.ys[index]).to(DEVICE), self.roi_idx, self.pid_idx
    
def dataloader(datasets, roi_idxs, pid_idxs, batch_size=32): 
    loaders = []
    for i, dataset in enumerate(datasets):
        sess_dataset = SessionDataset(dataset, roi_idxs[i], pid_idxs[i])
        loaders.append(DataLoader(
            sess_dataset, batch_size = batch_size
        ))
    return loaders

class Hier_Reduced_Rank_Model(nn.Module):
    def __init__(
        self, 
        n_roi,
        n_units, 
        n_t_bin, 
        rank_V,
        rank_B
    ):
        super(Hier_Reduced_Rank_Model, self).__init__()
        
        self.n_roi = n_roi
        self.n_sess = len(n_units)
        self.n_units = n_units
        self.n_t_bin = n_t_bin
        self.rank_V = rank_V
        self.rank_B = rank_B
        
        self.Us = nn.ParameterList(
            [nn.Parameter(torch.randn(self.n_units[i], self.rank_V)) for i in range(self.n_sess)]
        )
        self.A = nn.Parameter(torch.randn(self.n_roi, self.rank_V, self.rank_B)) 
        self.B = nn.Parameter(
            torch.randn(self.rank_B, self.n_t_bin)
        ) 
        self.intercepts = nn.ParameterList(
            [nn.Parameter(torch.randn(1,)) for i in range(self.n_sess)]
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, datasets):
        pred_lst, gt_lst = [], []
        for dataset in datasets:
            X, Y, roi_idx, sess_idx = dataset
            roi_idx = torch.unique(roi_idx)
            sess_idx = torch.unique(sess_idx)
            n_trials, n_units, n_t_bins = X.shape
            self.Vs = torch.einsum("ijk,kt->ijt", self.A, self.B)
            Beta = torch.einsum("cr,rt->ct", self.Us[sess_idx], self.Vs[roi_idx].squeeze()).to(DEVICE)
            out = torch.einsum("ct,kct->k", Beta, X)
            out += self.intercepts[sess_idx].to(DEVICE) * torch.ones(n_trials).to(DEVICE)
            out = self.sigmoid(out)
            pred_lst.append(out)
            gt_lst.append(Y)
        return pred_lst, gt_lst
    
class LitHierRRR(L.LightningModule):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def training_step(self, batch):
        losses = 0
        pred_lst, gt_lst = self.model(batch)
        for i in range(len(batch)):
            losses += nn.BCELoss()(pred_lst[i], gt_lst[i])
        loss = losses / len(batch)
        self.log("loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        return loss
    
    def validation_step(self, batch):
        accs, aucs = self._shared_eval_step(batch)
        metrics = {"val_acc": np.mean(accs), "val_auc": np.mean(aucs)}
        self.log_dict(metrics, on_step=False, on_epoch=True, prog_bar=True)
    
    def test_step(self, batch):
        accs, aucs = self._shared_eval_step(batch)
        for i in range(len(batch)):
            print(f"session {i} test_acc {accs[i]} test_auc {aucs[i]}")

    def _shared_eval_step(self, batch):
        pred_lst, gt_lst = self.model(batch)
        accs, aucs = [], []
        for i in range(len(batch)):
            auroc = AUROC(task="binary")
            acc = accuracy(pred_lst[i], gt_lst[i], task="binary")
            auc = auroc(pred_lst[i], gt_lst[i])
            accs.append(acc)
            aucs.append(auc)
        return accs, aucs
        
    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-2, weight_decay=1e-3)
        return optimizer

In [50]:
# setup
regions = ["LP", "GRN"]
n_rank_V = 2
n_rank_B = 2
n_epochs = 1000
in_path = Path("/mnt/3TB/yizi/cached_ibl_data")

In [None]:
# prep data for fitting hier-RRR on trials with diff contrast

res_dict = {}
for lvl in ["all", .0625, .125, .25, 1.]:
    
    print("=================")
    print(f"Started training on trials with contrast {lvl} ..")

    lst_datasets, lst_units, lst_regions, lst_sessions, lst_region_names, lst_pids = [], [], [], [], [], []

    pid_idx = 0
    for roi_idx, roi in enumerate(regions):

        f_names = os.listdir(in_path/roi)
        pids = [f_name.split("_")[1].split(".")[0] for f_name in f_names]

        print("=================")
        print(f"Loading {len(pids)} PIDs in region {roi}:")
        for pid in pids:
            print(pid)

        data_dict = np.load(in_path/roi/f"pid_{pid}.npy", allow_pickle=True).item()

        for _, pid in enumerate(pids):
            xs = data_dict["neural_contrast"][lvl]
            ys = data_dict["choice_contrast"][lvl]
            lst_datasets.append((xs, ys))
            lst_units.append(data_dict["meta"]["n_units"])
            lst_regions.append(roi_idx)
            lst_region_names.append(roi)
            lst_sessions.append(pid_idx)
            lst_pids.append(pid)
            pid_idx += 1 

    train_loaders = dataloader(lst_datasets, lst_regions, lst_sessions, batch_size=128)
    train_loaders = CombinedLoader(train_loaders, mode="min_size")

    hier_rrr = Hier_Reduced_Rank_Model(
        n_roi = len(regions),
        n_units = lst_units, 
        n_t_bin = data_dict["meta"]["n_t_bins"], 
        rank_V = n_rank_V,
        rank_B = n_rank_B
    )

    lit_hier_rrr = LitHierRRR(hier_rrr)
    trainer = L.Trainer(max_epochs=n_epochs)
    trainer.fit(model=lit_hier_rrr, 
                train_dataloaders=train_loaders)

    Us = [hier_rrr.Us[pid_idx].detach().numpy() for pid_idx in lst_sessions]
    Vs = hier_rrr.Vs.detach().numpy()

    svd_Vs = []
    for pid_idx in lst_sessions:
        roi_idx = lst_regions[pid_idx]
        W = Us[pid_idx] @ Vs[roi_idx]
        U, S, V = svd(W)
        svd_Vs.append(np.diag(S[:n_rank_V]) @ V[:n_rank_V, :])
    svd_Vs = np.array(svd_Vs)
    
    fig, axes = plt.subplots(len(regions), 1, figsize=(5, 2*len(regions)))
    for i, roi_idx in enumerate(np.unique(lst_regions)):
        mask = np.array(lst_regions) == roi_idx
        axes[i].plot(np.abs(svd_Vs[mask].mean(0)[0]))
        axes[i].set_title(f"{regions[roi_idx]} (contrast = {lvl})")
    plt.tight_layout()
    plt.show()

    res_dict.update({lvl: {}})
    res_dict[lvl].update({"pid_idxs": lst_sessions})
    res_dict[lvl].update({"regions_idxs": lst_regions})
    res_dict[lvl].update({"region_names": lst_region_names})
    res_dict[lvl].update({"pids": lst_pids})
    res_dict[lvl].update({"svd_Vs": svd_Vs})
    
    print("=================")
    print(f"Finished training on trials with contrast {lvl} ..")

np.save(in_path/f"res_{date.today()}.npy", res_dict)

In [82]:
n_rank_V = 5
n_rank_B = 5
n_epochs = 1000

In [83]:
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# DEVICE = torch.device("cpu")
print(DEVICE)

cuda:0


In [84]:
# prep data for x-val

n_folds = 5
skf = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=seed)

xval_dict = {}
for fold_idx in range(n_folds):
    
    print("=================")
    print(f"Started training on {fold_idx+1} / {n_folds} folds ..")
    
    train_datasets, test_datasets = [], []
    lst_units, lst_regions, lst_sessions, lst_region_names, lst_pids = [], [], [], [], []

    pid_idx = 0
    for roi_idx, roi in enumerate(regions):

        f_names = os.listdir(in_path/roi)
        pids = [f_name.split("_")[1].split(".")[0] for f_name in f_names]

        print("=================")
        print(f"Loading {len(pids)} PIDs in region {roi}:")
        for pid in pids:
            print(pid)

        data_dict = np.load(in_path/roi/f"pid_{pid}.npy", allow_pickle=True).item()

        for _, pid in enumerate(pids):
            xs = data_dict["neural_contrast"]["all"]
            ys = data_dict["choice_contrast"]["all"]
            for counter, (train, test) in enumerate(skf.split(xs, ys)):
                if counter == fold_idx:
                    train_xs, test_xs = xs[train], xs[test]
                    train_ys, test_ys = ys[train], ys[test]
            train_datasets.append((train_xs, train_ys))
            test_datasets.append((test_xs, test_ys))
            lst_units.append(data_dict["meta"]["n_units"])
            lst_regions.append(roi_idx)
            lst_region_names.append(roi)
            lst_sessions.append(pid_idx)
            lst_pids.append(pid)
            pid_idx += 1
            
    train_loaders = dataloader(train_datasets, lst_regions, lst_sessions, batch_size=128)
    test_loaders = dataloader(test_datasets, lst_regions, lst_sessions, batch_size=128)
    train_loaders = CombinedLoader(train_loaders, mode="min_size")
    test_loaders = CombinedLoader(test_loaders, mode="min_size")

    hier_rrr = Hier_Reduced_Rank_Model(
        n_roi = len(regions),
        n_units = lst_units, 
        n_t_bin = data_dict["meta"]["n_t_bins"], 
        rank_V = n_rank_V,
        rank_B = n_rank_B
    ).to(DEVICE)

    lit_hier_rrr = LitHierRRR(hier_rrr)
    trainer = L.Trainer(max_epochs=n_epochs)
    trainer.fit(model=lit_hier_rrr, 
                train_dataloaders=train_loaders)
    
    accs_per_batch, aucs_per_batch = [], []
    for batch in test_loaders:
        accs, aucs = [], []
        pred_lst, gt_lst = hier_rrr(batch[0])
        for i in range(len(batch[0])):
            auroc = AUROC(task="binary")
            accs.append(accuracy(pred_lst[i], gt_lst[i], task="binary").item())
            aucs.append(auroc(pred_lst[i], gt_lst[i]).item())
        accs_per_batch.append(accs)
        aucs_per_batch.append(aucs)
    test_accs = np.mean(accs_per_batch, 0)
    test_aucs = np.mean(aucs_per_batch, 0)
    print("Accuracy: ", test_accs)
    print("AUC: ", test_aucs)
    
    xval_dict.update({fold_idx: {}})
    xval_dict[fold_idx].update({"accs": test_accs})
    xval_dict[fold_idx].update({"aucs": test_aucs})
    xval_dict[fold_idx].update({"pid_idxs": lst_sessions})
    xval_dict[fold_idx].update({"regions_idxs": lst_regions})
    xval_dict[fold_idx].update({"region_names": lst_region_names})
    xval_dict[fold_idx].update({"pids": lst_pids})
    
    print("=================")
    print(f"Finished training on {fold_idx+1} / {n_folds} folds ..")

np.save(in_path/f"xval_rankB_{n_rank_V}_rankV_{n_rank_V}.npy", xval_dict)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/home/yizi/anaconda3/envs/clusterless/lib/python3.8/site-packages/lightning/pytorch/trainer/configuration_validator.py:74: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type                    | Params
--------------------------------------------------
0 | model | Hier_Reduced_Rank_Model | 2.9 K 
--------------------------------------------------
2.9 K     Trainable params
0         Non-trainable params
2.9 K     Total params
0.012     Total estimated model params size (MB)


Started training on 1 / 5 folds ..
Loading 10 PIDs in region LP:
a6b71993-165b-4c43-845c-c062fe7d7a11
ec2fbc3e-cb2b-48cb-a521-3a6ca15e244c
bef05a5c-68c3-4513-87c7-b3151c88da8e
7a620688-66cb-44d3-b79b-ccac1c8ba23e
d0046384-16ea-4f69-bae9-165e8d0aeacf
8c732bf2-639d-496c-bf82-464bc9c2d54b
b72b22c2-6e9d-4604-9910-20c0e1a467d7
8b31b4bd-003e-4816-a3bf-2df4cc3558f8
1a924329-65aa-465d-b201-c2dd898aebd0
22f26d69-0b30-450e-9618-ee801b720e0a
Loading 10 PIDs in region GRN:
e17db2b6-b778-4e2a-845c-c4d040b0c875
c0e59477-43f0-4441-9f81-3a55ddad9dad
df6012d0-d921-4d0a-af2a-2a91030d0f42
cc72fdb7-92e8-47e6-9cea-94f27c0da2d8
7332e6cf-9847-4aca-b2e3-d864989dd0fb
04c9890f-2276-4c20-854f-305ff5c9b6cf
6a098711-5423-4072-8909-7cff0e2d4531
39883ded-f5a2-4f4f-a98e-fb138eb8433e
aecd7612-b5c5-4ad2-9e76-e5b783387e47
2e720cee-05cc-440e-a24b-13794b1ac01d


/home/yizi/anaconda3/envs/clusterless/lib/python3.8/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.
/home/yizi/anaconda3/envs/clusterless/lib/python3.8/site-packages/lightning/pytorch/loops/fit_loop.py:293: The number of training batches (5) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Training: |                                                                                                   …

  return torch.tensor(self.xs[index]).to(DEVICE), torch.tensor(self.ys[index]).to(DEVICE), self.roi_idx, self.pid_idx
`Trainer.fit` stopped: `max_epochs=1000` reached.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/home/yizi/anaconda3/envs/clusterless/lib/python3.8/site-packages/lightning/pytorch/trainer/configuration_validator.py:74: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type                    | Params
--------------------------------------------------
0 | model | Hier_Reduced_Rank_Model | 2.9 K 
--------------------------------------------------
2.9 K     Trainable params
0         Non-trainable params
2.9 K     Total params
0.012     Total estimated model params size (MB)


Accuracy:  [0.30859375 0.30859375 0.30859375 0.30859375 0.30859375 0.30859375
 0.30859375 0.30859375 0.30859375 0.30859375 0.8858507  0.86631945
 0.88975695 0.88975695 0.88975695 0.88975695 0.8936632  0.88975695
 0.8936632  0.88975695]
AUC:  [0.31666667 0.31626506 0.31639893 0.31639893 0.3165328  0.3165328
 0.3165328  0.31666667 0.31639893 0.31639893 0.93871684 0.93345626
 0.93675941 0.93663707 0.93627005 0.93663707 0.93773812 0.93639239
 0.93688175 0.93675941]
Finished training on 1 / 5 folds ..
Started training on 2 / 5 folds ..
Loading 10 PIDs in region LP:
a6b71993-165b-4c43-845c-c062fe7d7a11
ec2fbc3e-cb2b-48cb-a521-3a6ca15e244c
bef05a5c-68c3-4513-87c7-b3151c88da8e
7a620688-66cb-44d3-b79b-ccac1c8ba23e
d0046384-16ea-4f69-bae9-165e8d0aeacf
8c732bf2-639d-496c-bf82-464bc9c2d54b
b72b22c2-6e9d-4604-9910-20c0e1a467d7
8b31b4bd-003e-4816-a3bf-2df4cc3558f8
1a924329-65aa-465d-b201-c2dd898aebd0
22f26d69-0b30-450e-9618-ee801b720e0a
Loading 10 PIDs in region GRN:
e17db2b6-b778-4e2a-845c-c4d040b0

/home/yizi/anaconda3/envs/clusterless/lib/python3.8/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.
/home/yizi/anaconda3/envs/clusterless/lib/python3.8/site-packages/lightning/pytorch/loops/fit_loop.py:293: The number of training batches (5) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Training: |                                                                                                   …

  return torch.tensor(self.xs[index]).to(DEVICE), torch.tensor(self.ys[index]).to(DEVICE), self.roi_idx, self.pid_idx
`Trainer.fit` stopped: `max_epochs=1000` reached.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/home/yizi/anaconda3/envs/clusterless/lib/python3.8/site-packages/lightning/pytorch/trainer/configuration_validator.py:74: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type                    | Params
--------------------------------------------------
0 | model | Hier_Reduced_Rank_Model | 2.9 K 
--------------------------------------------------
2.9 K     Trainable params
0         Non-trainable params
2.9 K     Total params
0.012     Total estimated model params size (MB)


Accuracy:  [0.81640625 0.8125     0.8125     0.8125     0.8125     0.81640625
 0.8125     0.8125     0.8125     0.80859375 0.8624132  0.8624132
 0.87413195 0.8780382  0.86631945 0.87413195 0.8624132  0.8780382
 0.86631945 0.8702257 ]
AUC:  [0.34318664 0.34066808 0.34835631 0.35113998 0.34053552 0.34769353
 0.34345175 0.34225875 0.34080064 0.34212619 0.97285679 0.97261115
 0.97408499 0.97420781 0.97408499 0.97408499 0.97310243 0.97322525
 0.97359371 0.97371653]
Finished training on 2 / 5 folds ..
Started training on 3 / 5 folds ..
Loading 10 PIDs in region LP:
a6b71993-165b-4c43-845c-c062fe7d7a11
ec2fbc3e-cb2b-48cb-a521-3a6ca15e244c
bef05a5c-68c3-4513-87c7-b3151c88da8e
7a620688-66cb-44d3-b79b-ccac1c8ba23e
d0046384-16ea-4f69-bae9-165e8d0aeacf
8c732bf2-639d-496c-bf82-464bc9c2d54b
b72b22c2-6e9d-4604-9910-20c0e1a467d7
8b31b4bd-003e-4816-a3bf-2df4cc3558f8
1a924329-65aa-465d-b201-c2dd898aebd0
22f26d69-0b30-450e-9618-ee801b720e0a
Loading 10 PIDs in region GRN:
e17db2b6-b778-4e2a-845c-c4d040b0c

/home/yizi/anaconda3/envs/clusterless/lib/python3.8/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.
/home/yizi/anaconda3/envs/clusterless/lib/python3.8/site-packages/lightning/pytorch/loops/fit_loop.py:293: The number of training batches (5) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Training: |                                                                                                   …

  return torch.tensor(self.xs[index]).to(DEVICE), torch.tensor(self.ys[index]).to(DEVICE), self.roi_idx, self.pid_idx
`Trainer.fit` stopped: `max_epochs=1000` reached.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/home/yizi/anaconda3/envs/clusterless/lib/python3.8/site-packages/lightning/pytorch/trainer/configuration_validator.py:74: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type                    | Params
--------------------------------------------------
0 | model | Hier_Reduced_Rank_Model | 2.9 K 
--------------------------------------------------
2.9 K     Trainable params
0         Non-trainable params
2.9 K     Total params
0.012     Total estimated model params size (MB)


Accuracy:  [0.85546875 0.85546875 0.85546875 0.85546875 0.85546875 0.85546875
 0.85546875 0.85546875 0.85546875 0.85546875 0.94140625 0.94140625
 0.94140625 0.94140625 0.94140625 0.94140625 0.94140625 0.94140625
 0.94140625 0.94140625]
AUC:  [0.33801697 0.33801697 0.33801697 0.33801697 0.33801697 0.33801697
 0.33801697 0.33801697 0.33801697 0.33801697 0.98563007 0.98550725
 0.98563007 0.98526161 0.98513879 0.98599853 0.98501597 0.98563007
 0.98501597 0.98501597]
Finished training on 3 / 5 folds ..
Started training on 4 / 5 folds ..
Loading 10 PIDs in region LP:
a6b71993-165b-4c43-845c-c062fe7d7a11
ec2fbc3e-cb2b-48cb-a521-3a6ca15e244c
bef05a5c-68c3-4513-87c7-b3151c88da8e
7a620688-66cb-44d3-b79b-ccac1c8ba23e
d0046384-16ea-4f69-bae9-165e8d0aeacf
8c732bf2-639d-496c-bf82-464bc9c2d54b
b72b22c2-6e9d-4604-9910-20c0e1a467d7
8b31b4bd-003e-4816-a3bf-2df4cc3558f8
1a924329-65aa-465d-b201-c2dd898aebd0
22f26d69-0b30-450e-9618-ee801b720e0a
Loading 10 PIDs in region GRN:
e17db2b6-b778-4e2a-845c-c4d040b

/home/yizi/anaconda3/envs/clusterless/lib/python3.8/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.
/home/yizi/anaconda3/envs/clusterless/lib/python3.8/site-packages/lightning/pytorch/loops/fit_loop.py:293: The number of training batches (5) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Training: |                                                                                                   …

  return torch.tensor(self.xs[index]).to(DEVICE), torch.tensor(self.ys[index]).to(DEVICE), self.roi_idx, self.pid_idx
`Trainer.fit` stopped: `max_epochs=1000` reached.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/home/yizi/anaconda3/envs/clusterless/lib/python3.8/site-packages/lightning/pytorch/trainer/configuration_validator.py:74: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type                    | Params
--------------------------------------------------
0 | model | Hier_Reduced_Rank_Model | 2.9 K 
--------------------------------------------------
2.9 K     Trainable params
0         Non-trainable params
2.9 K     Total params
0.012     Total estimated model params size (MB)


Accuracy:  [0.71875   0.71875   0.71875   0.71875   0.71875   0.71875   0.71875
 0.71875   0.71875   0.71875   0.90625   0.90625   0.9140625 0.90625
 0.90625   0.921875  0.90625   0.9140625 0.9140625 0.9140625]
AUC:  [0.80116649 0.80037116 0.80196182 0.80169671 0.80169671 0.79851538
 0.8014316  0.80169671 0.79957582 0.80196182 0.95962809 0.95987277
 0.95987277 0.95889405 0.95962809 0.96011745 0.95938341 0.95962809
 0.95987277 0.96011745]
Finished training on 4 / 5 folds ..
Started training on 5 / 5 folds ..
Loading 10 PIDs in region LP:
a6b71993-165b-4c43-845c-c062fe7d7a11
ec2fbc3e-cb2b-48cb-a521-3a6ca15e244c
bef05a5c-68c3-4513-87c7-b3151c88da8e
7a620688-66cb-44d3-b79b-ccac1c8ba23e
d0046384-16ea-4f69-bae9-165e8d0aeacf
8c732bf2-639d-496c-bf82-464bc9c2d54b
b72b22c2-6e9d-4604-9910-20c0e1a467d7
8b31b4bd-003e-4816-a3bf-2df4cc3558f8
1a924329-65aa-465d-b201-c2dd898aebd0
22f26d69-0b30-450e-9618-ee801b720e0a
Loading 10 PIDs in region GRN:
e17db2b6-b778-4e2a-845c-c4d040b0c875
c0e59477-43f0-4441-

/home/yizi/anaconda3/envs/clusterless/lib/python3.8/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.
/home/yizi/anaconda3/envs/clusterless/lib/python3.8/site-packages/lightning/pytorch/loops/fit_loop.py:293: The number of training batches (5) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Training: |                                                                                                   …

  return torch.tensor(self.xs[index]).to(DEVICE), torch.tensor(self.ys[index]).to(DEVICE), self.roi_idx, self.pid_idx
`Trainer.fit` stopped: `max_epochs=1000` reached.


Accuracy:  [0.71875   0.71875   0.7265625 0.71875   0.7421875 0.7421875 0.7265625
 0.7109375 0.71875   0.734375  0.9296875 0.9296875 0.9296875 0.9296875
 0.9296875 0.9296875 0.9296875 0.9296875 0.9296875 0.9296875]
AUC:  [0.74151644 0.74257688 0.74178155 0.74019088 0.74231177 0.74284199
 0.7431071  0.7431071  0.74363733 0.74204666 0.96617647 0.96617647
 0.96617647 0.96617647 0.96593137 0.96593137 0.96593137 0.96593137
 0.96593137 0.96593137]
Finished training on 5 / 5 folds ..


In [None]:
# V = 5, B = 5
df = pd.DataFrame(xval_dict[0]).iloc[:,:2] / 5
for fold in range(1,5):
    df += pd.DataFrame(xval_dict[fold]).iloc[:,:2] / 5
df

In [85]:
# V = 2, B = 10
df = pd.DataFrame(xval_dict[0]).iloc[:,:2] / 5
for fold in range(1,5):
    df += pd.DataFrame(xval_dict[fold]).iloc[:,:2] / 5
df

Unnamed: 0,accs,aucs
0,0.683594,0.508111
1,0.682813,0.50758
2,0.684375,0.509303
3,0.682813,0.509489
4,0.6875,0.507819
5,0.688281,0.50872
6,0.684375,0.508508
7,0.68125,0.508349
8,0.682813,0.507686
9,0.685156,0.50811


In [81]:
# V = 2, B = 5
df = pd.DataFrame(xval_dict[0]).iloc[:,:2] / 5
for fold in range(1,5):
    df += pd.DataFrame(xval_dict[fold]).iloc[:,:2] / 5
df

Unnamed: 0,accs,aucs
0,0.696875,0.528527
1,0.696875,0.528421
2,0.696875,0.528554
3,0.696875,0.528554
4,0.696875,0.528474
5,0.696875,0.528633
6,0.694531,0.527563
7,0.696875,0.528528
8,0.696875,0.52858
9,0.696875,0.528421


In [76]:
# V = 2, B = 2
df = pd.DataFrame(xval_dict[0]).iloc[:,:2] / 5
for fold in range(1,5):
    df += pd.DataFrame(xval_dict[fold]).iloc[:,:2] / 5
df

Unnamed: 0,accs,aucs
0,0.682813,0.502315
1,0.682813,0.502288
2,0.682813,0.501693
3,0.682031,0.502103
4,0.678906,0.501462
5,0.68125,0.500805
6,0.68125,0.499776
7,0.68125,0.501623
8,0.678125,0.49867
9,0.684375,0.502397
