In [1]:
from datasets import (
    load_dataset,
    load_from_disk,
    concatenate_datasets,
    load_dataset_builder,
)
from utils.dataset_utils import get_user_datasets, load_ibl_dataset, split_both_dataset
from accelerate import Accelerator
from loader.make_loader import make_loader
from utils.utils import set_seed, dummy_load
from utils.config_utils import config_from_kwargs, update_config
from utils.dataset_utils import get_data_from_h5
from models.ndt1 import NDT1
from models.stpatch import STPatch
from torch.optim.lr_scheduler import OneCycleLR
import torch
import numpy as np
import os
from trainer.make import make_trainer
import threading
from loader.dataset import build_dataloader
import json


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# load config
kwargs = {"model": "include:src/configs/ndt1_stitching_prompting.yaml"}


config = config_from_kwargs(kwargs)
config = update_config("src/configs/ndt1_stitching_prompting.yaml", config)
config = update_config("src/configs/ssl_sessions_trainer.yaml", config)

# set seed for reproducibility
set_seed(config.seed)

with open('/user/turishcheva/u14642/IBL_MtM_model/src/configs/config.json', 'r') as file:
    loader_config = json.load(file)

print('Create Dataloader.')
train_dataloader, val_dataloader = build_dataloader(loader_config)
print('Dataloader Created')

meta_data = {"num_neurons": [], "num_sessions": 0, "eids": []}
for key, v in train_dataloader.loaders.items():
    meta_data["num_neurons"].append(next(iter(v))['responses'].shape[-1])
    meta_data["num_sessions"] += 1
    meta_data["eids"].append(key)

num_sessions = len(meta_data["eids"])

seed set to 42
Create Dataloader.
No metadata file found at /mnt/vast-react/projects/neural_foundation_model/dynamic29623-4-9-Video-full/meta.json




No metadata file found at /mnt/vast-react/projects/neural_foundation_model/dynamic29156-11-10-Video-full/meta.json
No metadata file found at /mnt/vast-react/projects/neural_foundation_model/dynamic29647-19-8-Video-full/meta.json
No metadata file found at /mnt/vast-react/projects/neural_foundation_model/dynamic29228-2-10-Video-full/meta.json
No metadata file found at /mnt/vast-react/projects/neural_foundation_model/dynamic29755-2-8-Video-full/meta.json
No metadata file found at /mnt/vast-react/projects/neural_foundation_model/dynamic29234-6-9-Video-full/meta.json
No metadata file found at /mnt/vast-react/projects/neural_foundation_model/dynamic29513-3-5-Video-full/meta.json
No metadata file found at /mnt/vast-react/projects/neural_foundation_model/dynamic29514-2-9-Video-full/meta.json
No metadata file found at /mnt/vast-react/projects/neural_foundation_model/dynamic29515-10-12-Video-full/meta.json
No metadata file found at /mnt/vast-react/projects/neural_foundation_model/dynamic29712-5-

In [4]:
model = torch.load('/user/turishcheva/u14642/IBL_MtM_model/model_best.pt', weights_only=False, map_location=torch.device('cpu'))['model']

In [5]:
model

NDT1(
  (encoder): NeuralEncoder(
    (masker): Masker()
    (stitcher): NeuralStitcher(
      (stitcher_dict): ModuleDict(
        (7671): Linear(in_features=7671, out_features=668, bias=True)
        (7495): Linear(in_features=7495, out_features=668, bias=True)
        (8122): Linear(in_features=8122, out_features=668, bias=True)
        (8202): Linear(in_features=8202, out_features=668, bias=True)
        (7440): Linear(in_features=7440, out_features=668, bias=True)
        (7908): Linear(in_features=7908, out_features=668, bias=True)
        (7863): Linear(in_features=7863, out_features=668, bias=True)
        (8285): Linear(in_features=8285, out_features=668, bias=True)
        (7939): Linear(in_features=7939, out_features=668, bias=True)
        (7928): Linear(in_features=7928, out_features=668, bias=True)
      )
    )
    (embedder): NeuralEmbeddingLayer(
      (embed_spikes): Linear(in_features=668, out_features=1336, bias=True)
      (projection): Linear(in_features=1336, out

In [6]:
model.stitching

True

In [10]:
# TODO - should be True on a retrained model!
model.encoder.embedder.use_session 

False

In [None]:
# zero is masked out, 1 is kept
mask = ... ?

In [None]:
model.eval()
# TODO - do I really need this line?
masking_mode = 'neuron' if model.use_prompt else model.encoder.masker.mode
model.encoder.mask = False



B, T, S = batch[1]['responses'].shape

batched_mask = mask.repeat(B,1)

# NDT1Output(
#     loss=loss,
#     n_examples=n_examples,
#     preds=outputs,
#     targets=targets
# )

outputs = model(
    (batched_mask * batch[1]['responses']).to(torch.float32).to(device), # https://github.com/colehurwitz/IBL_MtM_model/blob/main/src/utils/eval_utils.py#L1109
    time_attn_mask=torch.ones(B, T).to(torch.int64).to(device),
    space_attn_mask=torch.ones(B, S).to(torch.int64).to(device),
    spikes_timestamps=torch.arange(T).to(torch.int64).repeat(B,1).to(device), 
    spikes_spacestamps=torch.arange(S).to(torch.int64).repeat(B,1).to(device), 
    targets = float('nan')*torch.ones(B,1).to(torch.int64),
    neuron_regions=[['V1']*B]*S,
    eval_mask= (1- batched_mask).to(device), # https://github.com/colehurwitz/IBL_MtM_model/blob/main/src/utils/eval_utils.py#L1111
    masking_mode = masking_mode, # TODO - double check this
    num_neuron= S ,
    eid=batch[0] # session key
)
preds = outputs.preds
targets = outputs.targets

### some helpful copypaste of code for now

In [None]:
# mask = torch.ones(spike_data.shape).to(torch.int64).to(spike_data.device)
# masked elements - are zeros?
# {"spikes": spike_data_masked, "heldout_idxs": hd, "eval_mask": 1-mask}

In [None]:
# https://github.com/colehurwitz/IBL_MtM_model/blob/main/src/models/ndt1.py#L653C3-L672C1
#   def forward(
#         self, 
#         spikes:           torch.FloatTensor,  # (bs, seq_len, n_channels)
#         time_attn_mask:      torch.LongTensor,   # (bs, seq_len)
#         space_attn_mask:      torch.LongTensor,   # (bs, seq_len)
#         spikes_timestamps: torch.LongTensor,   # (bs, seq_len)
#         spikes_spacestamps: torch.LongTensor,   # (bs, seq_len)
#         targets:          Optional[torch.FloatTensor] = None,  # (bs, tar_len)
#         spikes_lengths:   Optional[torch.LongTensor] = None,   # (bs) 
#         targets_lengths:  Optional[torch.LongTensor] = None,   # (bs)
#         block_idx:        Optional[torch.LongTensor] = None,   # (bs)
#         date_idx:         Optional[torch.LongTensor] = None,   # (bs)
#         neuron_regions:   Optional[torch.LongTensor] = None,   # (bs, n_channels)
#         masking_mode:     Optional[str] = None,
#         spike_augmentation: Optional[bool] = False,
#         eval_mask:        Optional[torch.LongTensor] = None,
#         num_neuron:       Optional[torch.LongTensor] = None,
#         eid:              Optional[str] = None,
#     ) -> NDT1Output:  


In [None]:
# def _forward_model_outputs_experanto(self, batch, masking_mode):
#         B, T, S = batch[1]['responses'].shape
#         return self.model(
#             batch[1]['responses'].to(torch.float32).to(self.accelerator.device), 
#             time_attn_mask=torch.ones(B, T).to(torch.int64).to(self.accelerator.device),
#             space_attn_mask=torch.ones(B, S).to(torch.int64).to(self.accelerator.device),
#             spikes_timestamps=torch.arange(T).to(torch.int64).repeat(B,1).to(self.accelerator.device),
#             spikes_spacestamps=torch.arange(S).to(torch.int64).repeat(B,1).to(self.accelerator.device),
#             targets = float('nan')*torch.ones(B,1).to(torch.int64),
#             neuron_regions=[['V1']*B]*S,
#             masking_mode=masking_mode,
#             spike_augmentation=self.config.data.spike_augmentation,
#             num_neuron=S,
#             eid='test-test-test'  # each batch consists of data from the same eid
#         )

In [None]:
# https://github.com/colehurwitz/IBL_MtM_model/blob/main/src/utils/eval_utils.py#L213-L245
# if counter <= tot_num_neurons:
#     mask_result = heldout_mask(
#         batch['spikes_data'].clone(),
#         mode='manual',
#         heldout_idxs=np.array([n_i+i])
#     )
#     mask_spikes_lst.append(mask_result['spikes'])
#     eval_mask_lst.append(mask_result['eval_mask'])
#     gt_spikes_lst.append(gt_spike_data)
#     time_attn_mask_lst.append(batch['time_attn_mask'])
#     space_attn_mask_lst.append(batch['space_attn_mask'])
#     spikes_timestamps_lst.append(batch['spikes_timestamps'])
#     spikes_spacestamps_lst.append(batch['spikes_spacestamps'])
#     targets_lst.append(batch['target'])
#     neuron_regions_lst.append(batch['neuron_regions'])
# else:
#     break

# masking_mode = 'neuron' if model.use_prompt else model.encoder.masker.mode
# model.encoder.mask = False

# outputs = model(
#     torch.cat(mask_spikes_lst, 0),
#     time_attn_mask=torch.cat(time_attn_mask_lst, 0),
#     space_attn_mask=torch.cat(space_attn_mask_lst, 0),
#     spikes_timestamps=torch.cat(spikes_timestamps_lst, 0), 
#     spikes_spacestamps=torch.cat(spikes_spacestamps_lst, 0), 
#     targets = torch.cat(targets_lst, 0),
#     neuron_regions=np.stack(neuron_regions_lst),
#     eval_mask=torch.cat(eval_mask_lst, 0),
#     masking_mode = masking_mode,
#     num_neuron=batch['spikes_data'].shape[2],
#     eid=batch['eid'][0]  # each batch consists of data from the same eid
# )