In [1]:
from datasets import (
    load_dataset,
    load_from_disk,
    concatenate_datasets,
    load_dataset_builder,
)
from utils.dataset_utils import get_user_datasets, load_ibl_dataset, split_both_dataset
from accelerate import Accelerator
from loader.make_loader import make_loader
from utils.utils import set_seed, dummy_load
from utils.config_utils import config_from_kwargs, update_config
from utils.dataset_utils import get_data_from_h5
from models.ndt1 import NDT1
from models.stpatch import STPatch
from torch.optim.lr_scheduler import OneCycleLR
import torch
import numpy as np
import os
from trainer.make import make_trainer
import threading
from loader.dataset import build_dataloader
import json


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# load config
kwargs = {"model": "include:src/configs/ndt1_stitching_prompting.yaml"}


config = config_from_kwargs(kwargs)
config = update_config("src/configs/ndt1_stitching_prompting.yaml", config)
config = update_config("src/configs/ssl_sessions_trainer.yaml", config)

# set seed for reproducibility
set_seed(config.seed)

with open('/user/turishcheva/u14642/IBL_MtM_model/src/configs/config.json', 'r') as file:
    loader_config = json.load(file)

print('Create Dataloader.')
train_dataloader, val_dataloader = build_dataloader(loader_config)
print('Dataloader Created')

meta_data = {"num_neurons": [], "num_sessions": 0, "eids": []}
for key, v in train_dataloader.loaders.items():
    meta_data["num_neurons"].append(next(iter(v))['responses'].shape[-1])
    meta_data["num_sessions"] += 1
    meta_data["eids"].append(key)

num_sessions = len(meta_data["eids"])

seed set to 42
Create Dataloader.
No metadata file found at /mnt/vast-react/projects/neural_foundation_model/dynamic29623-4-9-Video-full/meta.json




No metadata file found at /mnt/vast-react/projects/neural_foundation_model/dynamic29156-11-10-Video-full/meta.json
No metadata file found at /mnt/vast-react/projects/neural_foundation_model/dynamic29647-19-8-Video-full/meta.json
No metadata file found at /mnt/vast-react/projects/neural_foundation_model/dynamic29228-2-10-Video-full/meta.json
No metadata file found at /mnt/vast-react/projects/neural_foundation_model/dynamic29755-2-8-Video-full/meta.json
No metadata file found at /mnt/vast-react/projects/neural_foundation_model/dynamic29234-6-9-Video-full/meta.json
No metadata file found at /mnt/vast-react/projects/neural_foundation_model/dynamic29513-3-5-Video-full/meta.json
No metadata file found at /mnt/vast-react/projects/neural_foundation_model/dynamic29514-2-9-Video-full/meta.json
No metadata file found at /mnt/vast-react/projects/neural_foundation_model/dynamic29515-10-12-Video-full/meta.json
No metadata file found at /mnt/vast-react/projects/neural_foundation_model/dynamic29712-5-

In [3]:
# keys = ['dynamic29623-4-9-Video-full', 'dynamic29156-11-10-Video-full', 'dynamic29647-19-8-Video-full',
# 'dynamic29228-2-10-Video-full', 'dynamic29755-2-8-Video-full', 'dynamic29234-6-9-Video-full',
# 'dynamic29513-3-5-Video-full', 'dynamic29514-2-9-Video-full', 'dynamic29515-10-12-Video-full', 'dynamic29712-5-9-Video-full']

In [4]:
# len(keys)

In [5]:

# make log dir
log_dir = os.path.join(
    config.dirs.log_dir,
    "train",
    "num_session_{}".format(num_sessions),
    "model_{}".format(config.model.model_class),
    "method_{}".format(config.method.model_kwargs.method_name),
    "mask_{}".format(config.encoder.masker.mode),
    "stitch_{}".format(config.encoder.stitching),
)
if not os.path.exists(log_dir):
    os.makedirs(log_dir)



# # make the dataloader
# train_dataloader = make_loader(
#     train_dataset,
#     target=config.data.target,
#     load_meta=config.data.load_meta,
#     batch_size=config.training.train_batch_size,
#     pad_to_right=True,
#     pad_value=-1.0,
#     max_time_length=config.data.max_time_length,
#     max_space_length=config.data.max_space_length,
#     dataset_name=config.data.dataset_name,
#     sort_by_depth=config.data.sort_by_depth,
#     sort_by_region=config.data.sort_by_region,
#     stitching=config.encoder.stitching,
#     shuffle=True,
# )
# # /mnt/vast-react/projects/agsinz_foundation_model_brain/goirik/IBL_MtM_model/src/loader/base.py _preprocess_ibl_dataset
# # return {
# #     "spikes_data": binned_spikes_data,
# #     "time_attn_mask": time_attn_mask,
# #     "space_attn_mask": space_attn_mask,
# #     "spikes_timestamps": spikes_timestamps,
# #     "spikes_spacestamps": spikes_spacestamps,
# #     "target": target_behavior,
# #     "neuron_depths": neuron_depths, 
# #     "neuron_regions": list(neuron_regions),
# #     "eid": data['eid']
# # }

# val_dataloader = make_loader(
#     val_dataset,
#     target=config.data.target,
#     load_meta=config.data.load_meta,
#     batch_size=config.training.test_batch_size,
#     pad_to_right=True,
#     pad_value=-1.0,
#     max_time_length=config.data.max_time_length,
#     max_space_length=config.data.max_space_length,
#     dataset_name=config.data.dataset_name,
#     sort_by_depth=config.data.sort_by_depth,
#     sort_by_region=config.data.sort_by_region,
#     stitching=config.encoder.stitching,
#     shuffle=False,
# )

# Initialize the accelerator
accelerator = Accelerator()

# load model
NAME2MODEL = {"NDT1": NDT1, "STPatch": STPatch}

config = update_config(config, meta_data)
model_class = NAME2MODEL[config.model.model_class]
model = model_class(config.model, **config.method.model_kwargs, **meta_data)



Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [6]:
model

NDT1(
  (encoder): NeuralEncoder(
    (masker): Masker()
    (stitcher): NeuralStitcher(
      (stitcher_dict): ModuleDict(
        (7908): Linear(in_features=7908, out_features=668, bias=True)
        (7440): Linear(in_features=7440, out_features=668, bias=True)
        (8202): Linear(in_features=8202, out_features=668, bias=True)
        (7928): Linear(in_features=7928, out_features=668, bias=True)
        (8122): Linear(in_features=8122, out_features=668, bias=True)
        (8285): Linear(in_features=8285, out_features=668, bias=True)
        (7671): Linear(in_features=7671, out_features=668, bias=True)
        (7495): Linear(in_features=7495, out_features=668, bias=True)
        (7863): Linear(in_features=7863, out_features=668, bias=True)
        (7939): Linear(in_features=7939, out_features=668, bias=True)
      )
    )
    (embedder): NeuralEmbeddingLayer(
      (embed_spikes): Linear(in_features=668, out_features=1336, bias=True)
      (projection): Linear(in_features=1336, out

In [7]:
model = torch.load('/user/turishcheva/u14642/IBL_MtM_model/model_best.pt', weights_only=False)['model']

In [13]:
# model.load_state_dict(torch.load('/user/turishcheva/u14642/IBL_MtM_model/model_best.pt'))

In [8]:
model

NDT1(
  (encoder): NeuralEncoder(
    (masker): Masker()
    (stitcher): NeuralStitcher(
      (stitcher_dict): ModuleDict(
        (7671): Linear(in_features=7671, out_features=668, bias=True)
        (7495): Linear(in_features=7495, out_features=668, bias=True)
        (8122): Linear(in_features=8122, out_features=668, bias=True)
        (8202): Linear(in_features=8202, out_features=668, bias=True)
        (7440): Linear(in_features=7440, out_features=668, bias=True)
        (7908): Linear(in_features=7908, out_features=668, bias=True)
        (7863): Linear(in_features=7863, out_features=668, bias=True)
        (8285): Linear(in_features=8285, out_features=668, bias=True)
        (7939): Linear(in_features=7939, out_features=668, bias=True)
        (7928): Linear(in_features=7928, out_features=668, bias=True)
      )
    )
    (embedder): NeuralEmbeddingLayer(
      (embed_spikes): Linear(in_features=668, out_features=1336, bias=True)
      (projection): Linear(in_features=1336, out

In [9]:
config.encoder.masker.mode

'all'

In [10]:
model.method

'ssl'

In [11]:
model.encoder.masker.mode, model.encoder.mask

('neuron', True)

In [12]:
model.encoder

NeuralEncoder(
  (masker): Masker()
  (stitcher): NeuralStitcher(
    (stitcher_dict): ModuleDict(
      (7671): Linear(in_features=7671, out_features=668, bias=True)
      (7495): Linear(in_features=7495, out_features=668, bias=True)
      (8122): Linear(in_features=8122, out_features=668, bias=True)
      (8202): Linear(in_features=8202, out_features=668, bias=True)
      (7440): Linear(in_features=7440, out_features=668, bias=True)
      (7908): Linear(in_features=7908, out_features=668, bias=True)
      (7863): Linear(in_features=7863, out_features=668, bias=True)
      (8285): Linear(in_features=8285, out_features=668, bias=True)
      (7939): Linear(in_features=7939, out_features=668, bias=True)
      (7928): Linear(in_features=7928, out_features=668, bias=True)
    )
  )
  (embedder): NeuralEmbeddingLayer(
    (embed_spikes): Linear(in_features=668, out_features=1336, bias=True)
    (projection): Linear(in_features=1336, out_features=512, bias=True)
    (act): Softsign()
    (em

In [None]:
# masking entry point https://github.com/sensorium-competition/IBL_MtM_model/blob/10e84fcecea45e2a3e8c51797372df06e7fca40d/src/models/ndt1.py#L483-L500

In [13]:
model.encoder.context_forward, model.encoder.context_backward, model.encoder.max_F

(-1, -1, 100)

In [14]:
from src.models.ndt1 import create_context_mask

In [15]:
context_mask = create_context_mask(model.encoder.context_forward, model.encoder.context_backward, model.encoder.max_F)

In [16]:
context_mask.shape, context_mask.sum()

(torch.Size([100, 100]), tensor(10000))

In [17]:
model = accelerator.prepare(model)

optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=config.optimizer.lr,
    weight_decay=config.optimizer.wd,
    eps=config.optimizer.eps,
)
lr_scheduler = OneCycleLR(
    optimizer=optimizer,
    total_steps=config.training.num_epochs * len(train_dataloader) // config.optimizer.gradient_accumulation_steps,
    max_lr=config.optimizer.lr,
    pct_start=config.optimizer.warmup_pct,
    div_factor=config.optimizer.div_factor,
)

trainer_kwargs = {
    "log_dir": log_dir,
    "accelerator": accelerator,
    "lr_scheduler": lr_scheduler,
    "config": config,
    "stitching": config.encoder.stitching,
}

trainer = make_trainer(
    model=model,
    train_dataloader=train_dataloader,
    eval_dataloader=val_dataloader,
    optimizer=optimizer,
    **trainer_kwargs,
    **meta_data
)

In [18]:
trainer.train_dataloader, trainer.masking_mode

(<experanto.utils.LongCycler at 0x150bd6c528f0>, 'neuron')

In [19]:
import random
from tqdm import tqdm

In [20]:
model.train()
batch = next(iter(trainer.train_dataloader))

In [21]:
batch

('29623-4-9',
 {'screen': tensor([[[[[ 1.1405e+00,  1.1519e+00,  1.1912e+00,  ...,  3.5917e-01,
               3.3671e-01,  3.1952e-01],
             [ 1.2385e+00,  1.2484e+00,  1.2758e+00,  ...,  3.6289e-01,
               3.4176e-01,  3.2940e-01],
             [ 1.4237e+00,  1.3995e+00,  1.3718e+00,  ...,  3.6092e-01,
               3.3987e-01,  3.3723e-01],
             ...,
             [-7.7821e-01, -8.1287e-01, -6.0339e-01,  ..., -5.0364e-01,
              -6.3979e-01, -7.9896e-01],
             [-6.6414e-01, -6.7335e-01, -4.7492e-01,  ..., -2.4066e-01,
              -3.8683e-01, -5.9326e-01],
             [-3.3078e-01, -4.5662e-01, -4.4365e-01,  ..., -2.3831e-01,
              -2.6302e-01, -2.9906e-01]],
  
            [[ 1.1237e+00,  1.1498e+00,  1.2020e+00,  ...,  3.5126e-01,
               3.2882e-01,  3.1952e-01],
             [ 1.2079e+00,  1.2285e+00,  1.2771e+00,  ...,  3.6012e-01,
               3.3695e-01,  3.2810e-01],
             [ 1.3534e+00,  1.3540e+00,  1.3598e+0

In [22]:
if trainer.masking_mode in ["combined", "all"]:
    masking_mode = random.sample(trainer.masking_schemes, 1)[0]
    if masking_mode == 'temporal':
        trainer.model.encoder.masker.ratio = 0.3
    elif masking_mode == 'causal':
        trainer.model.encoder.masker.ratio = 0.6
    else:
        trainer.model.encoder.masker.ratio = trainer.masking_ratio
else:
    masking_mode = trainer.masking_mode

In [23]:
masking_mode

'neuron'

In [24]:
'''
def _forward_model_outputs_experanto(self, batch, masking_mode):
        B, T, S = batch[1]['responses'].shape
        return self.model(
            batch[1]['responses'].to(torch.float32).to(self.accelerator.device), 
            time_attn_mask=torch.ones(B, T).to(torch.int64).to(self.accelerator.device),
            space_attn_mask=torch.ones(B, S).to(torch.int64).to(self.accelerator.device),
            spikes_timestamps=torch.arange(T).to(torch.int64).repeat(B,1).to(self.accelerator.device),
            spikes_spacestamps=torch.arange(S).to(torch.int64).repeat(B,1).to(self.accelerator.device),
            targets = float('nan')*torch.ones(B,1).to(torch.int64),
            neuron_regions=[['V1']*B]*S,
            masking_mode=masking_mode,
            spike_augmentation=self.config.data.spike_augmentation,
            num_neuron=S,
            eid='test-test-test'  # each batch consists of data from the same eid
        ) 
'''

"\ndef _forward_model_outputs_experanto(self, batch, masking_mode):\n        B, T, S = batch[1]['responses'].shape\n        return self.model(\n            batch[1]['responses'].to(torch.float32).to(self.accelerator.device), \n            time_attn_mask=torch.ones(B, T).to(torch.int64).to(self.accelerator.device),\n            space_attn_mask=torch.ones(B, S).to(torch.int64).to(self.accelerator.device),\n            spikes_timestamps=torch.arange(T).to(torch.int64).repeat(B,1).to(self.accelerator.device),\n            spikes_spacestamps=torch.arange(S).to(torch.int64).repeat(B,1).to(self.accelerator.device),\n            targets = float('nan')*torch.ones(B,1).to(torch.int64),\n            neuron_regions=[['V1']*B]*S,\n            masking_mode=masking_mode,\n            spike_augmentation=self.config.data.spike_augmentation,\n            num_neuron=S,\n            eid='test-test-test'  # each batch consists of data from the same eid\n        ) \n"

In [25]:
trainer.accelerator.device, trainer.config.data.spike_augmentation

(device(type='cuda'), False)

In [26]:
B, T, S = batch[1]['responses'].shape

spikes = batch[1]['responses'].to(torch.float32).to(trainer.accelerator.device) 
time_attn_mask=torch.ones(B, T).to(torch.int64).to(trainer.accelerator.device)
space_attn_mask=torch.ones(B, S).to(torch.int64).to(trainer.accelerator.device)
spikes_timestamps=torch.arange(T).to(torch.int64).repeat(B,1).to(trainer.accelerator.device)
spikes_spacestamps=torch.arange(S).to(torch.int64).repeat(B,1).to(trainer.accelerator.device)
targets = float('nan')*torch.ones(B,1).to(torch.int64)
neuron_regions=[['V1']*B]*S
masking_mode=masking_mode
spike_augmentation=trainer.config.data.spike_augmentation
num_neuron=S
eid='test-test-test'

In [27]:
B, _T, N = spikes.size()

In [28]:
# outputs = self._forward_model_outputs(batch, masking_mode)
# outputs = trainer._forward_model_outputs_experanto(batch, masking_mode)

In [29]:
# spikes, targets_mask = model.encoder.masker(spikes, neuron_regions)
# targets_mask = targets_mask.to(torch.int64) & spikes_mask.unsqueeze(-1).expand(B,_T,N).to(torch.int64)

In [30]:
set(neuron_regions[0])

{'V1'}

In [31]:
spikes.shape, len(neuron_regions), set(neuron_regions[0])

(torch.Size([64, 16, 7908]), 7908, {'V1'})

In [32]:
spikes

tensor([[[6.2725e-08, 1.5484e-08, 8.6449e-09,  ..., 6.4939e-09,
          1.9378e-08, 1.3059e-08],
         [5.6876e-02, 2.1745e-07, 4.2500e-09,  ..., 7.1709e-09,
          1.5093e-07, 1.3812e-08],
         [1.3601e-07, 2.1302e+00, 2.9259e-09,  ..., 1.1829e+00,
          1.6280e-01, 2.3847e-08],
         ...,
         [3.9993e-07, 1.7355e-08, 7.1607e-09,  ..., 3.7941e-09,
          1.1166e-08, 8.5400e-08],
         [7.5033e-01, 7.4884e-02, 1.1914e-08,  ..., 3.2651e-09,
          5.3026e-09, 4.8186e-08],
         [6.6764e-08, 2.7064e-07, 4.6936e-01,  ..., 3.1239e-09,
          3.6501e-09, 1.8141e-08]],

        [[3.7631e-08, 2.7802e-01, 5.6674e-09,  ..., 7.0721e-01,
          6.2580e-10, 1.7016e-09],
         [1.3292e-07, 7.2055e-08, 1.1955e-08,  ..., 1.4574e-08,
          5.3072e-10, 1.8141e-09],
         [2.2637e-07, 2.1218e-07, 1.4143e-08,  ..., 2.2836e-09,
          4.9449e-10, 2.2268e-09],
         ...,
         [1.5687e-07, 1.5437e-08, 4.9097e-09,  ..., 7.9802e-11,
          2.954

In [33]:
spikes, targets_mask = model.encoder.masker(spikes, neuron_regions)

In [34]:
spikes.shape # same as before - [64, 16, 7908]

torch.Size([64, 16, 7908])

In [35]:
spikes # zeros on the masked positions? -> literally put zeros 

tensor([[[6.2725e-08, 0.0000e+00, 8.6449e-09,  ..., 6.4939e-09,
          1.9378e-08, 1.3059e-08],
         [5.6876e-02, 0.0000e+00, 4.2500e-09,  ..., 7.1709e-09,
          1.5093e-07, 1.3812e-08],
         [1.3601e-07, 0.0000e+00, 2.9259e-09,  ..., 1.1829e+00,
          1.6280e-01, 2.3847e-08],
         ...,
         [3.9993e-07, 0.0000e+00, 7.1607e-09,  ..., 3.7941e-09,
          1.1166e-08, 8.5400e-08],
         [7.5033e-01, 0.0000e+00, 1.1914e-08,  ..., 3.2651e-09,
          5.3026e-09, 4.8186e-08],
         [6.6764e-08, 0.0000e+00, 4.6936e-01,  ..., 3.1239e-09,
          3.6501e-09, 1.8141e-08]],

        [[3.7631e-08, 2.7802e-01, 5.6674e-09,  ..., 0.0000e+00,
          6.2580e-10, 1.7016e-09],
         [1.3292e-07, 7.2055e-08, 1.1955e-08,  ..., 0.0000e+00,
          5.3072e-10, 1.8141e-09],
         [2.2637e-07, 2.1218e-07, 1.4143e-08,  ..., 0.0000e+00,
          4.9449e-10, 2.2268e-09],
         ...,
         [1.5687e-07, 1.5437e-08, 4.9097e-09,  ..., 0.0000e+00,
          2.954

In [36]:
(spikes== 0).sum()

tensor(2438992, device='cuda:0')

In [37]:
targets_mask.sum()

tensor(2438992, device='cuda:0')

In [38]:
spikes_mask = time_attn_mask

In [39]:
targets_mask = targets_mask.to(torch.int64) & spikes_mask.unsqueeze(-1).expand(B,_T,N).to(torch.int64)

In [41]:
targets_mask.shape

torch.Size([64, 16, 7908])

In [42]:
targets_mask.sum()

tensor(2438992, device='cuda:0')

In [43]:
num_neuron

7908

In [45]:
model.encoder.stitcher

NeuralStitcher(
  (stitcher_dict): ModuleDict(
    (7671): Linear(in_features=7671, out_features=668, bias=True)
    (7495): Linear(in_features=7495, out_features=668, bias=True)
    (8122): Linear(in_features=8122, out_features=668, bias=True)
    (8202): Linear(in_features=8202, out_features=668, bias=True)
    (7440): Linear(in_features=7440, out_features=668, bias=True)
    (7908): Linear(in_features=7908, out_features=668, bias=True)
    (7863): Linear(in_features=7863, out_features=668, bias=True)
    (8285): Linear(in_features=8285, out_features=668, bias=True)
    (7939): Linear(in_features=7939, out_features=668, bias=True)
    (7928): Linear(in_features=7928, out_features=668, bias=True)
  )
)

In [47]:
spikes

tensor([[[6.2725e-08, 0.0000e+00, 8.6449e-09,  ..., 6.4939e-09,
          1.9378e-08, 1.3059e-08],
         [5.6876e-02, 0.0000e+00, 4.2500e-09,  ..., 7.1709e-09,
          1.5093e-07, 1.3812e-08],
         [1.3601e-07, 0.0000e+00, 2.9259e-09,  ..., 1.1829e+00,
          1.6280e-01, 2.3847e-08],
         ...,
         [3.9993e-07, 0.0000e+00, 7.1607e-09,  ..., 3.7941e-09,
          1.1166e-08, 8.5400e-08],
         [7.5033e-01, 0.0000e+00, 1.1914e-08,  ..., 3.2651e-09,
          5.3026e-09, 4.8186e-08],
         [6.6764e-08, 0.0000e+00, 4.6936e-01,  ..., 3.1239e-09,
          3.6501e-09, 1.8141e-08]],

        [[3.7631e-08, 2.7802e-01, 5.6674e-09,  ..., 0.0000e+00,
          6.2580e-10, 1.7016e-09],
         [1.3292e-07, 7.2055e-08, 1.1955e-08,  ..., 0.0000e+00,
          5.3072e-10, 1.8141e-09],
         [2.2637e-07, 2.1218e-07, 1.4143e-08,  ..., 0.0000e+00,
          4.9449e-10, 2.2268e-09],
         ...,
         [1.5687e-07, 1.5437e-08, 4.9097e-09,  ..., 0.0000e+00,
          2.954

In [48]:
(spikes== 0).sum()

tensor(2438992, device='cuda:0')

In [49]:
if hasattr(model.encoder, 'stitcher'):
    spikes = model.encoder.stitcher(spikes, str(num_neuron))
    print('mew')

mew


In [51]:
(spikes== 0).sum(), spikes.shape

(tensor(0, device='cuda:0'), torch.Size([64, 16, 668]))

In [56]:
block_idx = None
date_idx = None

In [57]:
x, spikes_mask, spikes_timestamp = model.encoder.embedder(
    spikes, spikes_mask, spikes_timestamps, block_idx, date_idx, targets_mask, masking_mode, eid
)


In [58]:
x.shape

torch.Size([64, 17, 512])

In [60]:
spikes_mask.shape

torch.Size([64, 17])

In [65]:
(spikes_mask == 0).sum(), spikes_mask.sum(), 64*17

(tensor(0, device='cuda:0'), tensor(1088, device='cuda:0'), 1088)

In [67]:
spikes_timestamp.shape

torch.Size([64, 17])

In [None]:
(spikes, 
time_attn_mask, 
spikes_timestamps, 
block_idx, 
date_idx, 
neuron_regions, 
masking_mode, 
eval_mask, 
num_neuron, 
eid)


spikes:           torch.FloatTensor,  # (bs, seq_len, n_channels)
spikes_mask:      torch.LongTensor,   # (bs, seq_len)
spikes_timestamp: torch.LongTensor,   # (bs, seq_len)
block_idx:        Optional[torch.LongTensor] = None,   # (bs)
date_idx:         Optional[torch.LongTensor] = None,   # (bs)
neuron_regions:   Optional[np.ndarray] = None,  # (bs, n_channels)
masking_mode:     Optional[str] = None,
eval_mask:        Optional[torch.LongTensor] = None,
num_neuron:       Optional[torch.LongTensor] = None,
eid:              Optional[str] = None,

In [None]:
# model = accelerator.prepare(model)

# optimizer = torch.optim.AdamW(
#     model.parameters(),
#     lr=config.optimizer.lr,
#     weight_decay=config.optimizer.wd,
#     eps=config.optimizer.eps,
# )
# lr_scheduler = OneCycleLR(
#     optimizer=optimizer,
#     total_steps=config.training.num_epochs * len(train_dataloader) // config.optimizer.gradient_accumulation_steps,
#     max_lr=config.optimizer.lr,
#     pct_start=config.optimizer.warmup_pct,
#     div_factor=config.optimizer.div_factor,
# )

# trainer_kwargs = {
#     "log_dir": log_dir,
#     "accelerator": accelerator,
#     "lr_scheduler": lr_scheduler,
#     "config": config,
#     "stitching": config.encoder.stitching,
# }

# trainer = make_trainer(
#     model=model,
#     train_dataloader=train_dataloader,
#     eval_dataloader=val_dataloader,
#     optimizer=optimizer,
#     **trainer_kwargs,
#     **meta_data
# )
# Shared variable to signal the dummy load to stop
stop_dummy_load = threading.Event()
if config.training.dummy:
    # This is for HPC GPU usage, to avoid the GPU being idle
    print("Running dummy load")
    # Run dummy load in a separate thread
    dummy_thread = threading.Thread(target=dummy_load, args=(stop_dummy_load,))
    dummy_thread.start()
    try:
        # train loop
        trainer.train()
    finally:
        # Signal the dummy load to stop and wait for the thread to finish
        stop_dummy_load.set()
        dummy_thread.join()
else:
    # train loop
    trainer.train()