In [7]:
%load_ext autoreload
%autoreload 2

import sys 
import logging
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "4"

import logging

import hydra
import lightning as L
import torch
from model import (
    BhvrDecoder,
    ContextManager,
    Decoder,
    Encoder,
    MaeMaskManager,
    SpikesPatchifier,
    SslDecoder,
)
from transforms import FilterUnit, Ndt2Tokenizer
from train import DataModule
from lightning.pytorch.utilities import CombinedLoader
from omegaconf import OmegaConf, open_dict
from torch import optim
from torchmetrics import R2Score
from brainsets.taxonomy import decoder_registry

from train import TrainWrapper, set_callbacks

log = logging.getLogger(__name__)

def load_cfg():
    sys.argv = [sys.argv[0]]
    cfg = OmegaConf.load("./configs/train.yaml")
    cfg.data_ssl = OmegaConf.load("./configs/data_ssl/odoherty.yaml")
    cfg.data_superv = OmegaConf.load("./configs/data_superv/odoherty.yaml")
    del cfg.defaults
    return cfg

cfg = load_cfg()
cfg.wandb.enable = False


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
L.seed_everything(cfg.seed)

if cfg.fast_dev_run:
    cfg.wandb.enable = False
    cfg.num_workers = 0

with open_dict(cfg):
    # Adjust batch size for multi-gpu
    num_gpus = torch.cuda.device_count()
    cfg.batch_size_per_gpu = cfg.batch_size // num_gpus
    cfg.superv_batch_size = cfg.superv_batch_size or cfg.batch_size
    cfg.superv_batch_size_per_gpu = cfg.superv_batch_size // num_gpus
    log.info(f"Number of GPUs: {num_gpus}")
    log.info(f"Batch size per GPU: {cfg.batch_size_per_gpu}")
    log.info(f"Superv batch size per GPU: {cfg.superv_batch_size_per_gpu}")

dim = cfg.model.dim

# Mask manager (for MAE SSL)
mae_mask_manager = None
if cfg.is_ssl:
    mae_mask_manager = MaeMaskManager(cfg.mask_ratio)

# context manager
ctx_manager = ContextManager(dim)

# Spikes patchifier
spikes_patchifier = SpikesPatchifier(dim, cfg.patch_size)

# Model = Encoder + Decoder
encoder = Encoder(
    dim=dim,
    max_time_patches=cfg.model.max_time_patches,
    max_space_patches=cfg.model.max_space_patches,
    **cfg.model.encoder,
)

if cfg.is_ssl:
    decoder = SslDecoder(
        dim=dim,
        max_time_patches=cfg.model.max_time_patches,
        max_space_patches=cfg.model.max_space_patches,
        patch_size=cfg.patch_size,
        **cfg.model.predictor,
    )
else:
    decoder = BhvrDecoder(
        dim=dim,
        max_time_patches=cfg.model.max_time_patches,
        max_space_patches=cfg.model.max_space_patches,
        bin_time=cfg.bin_time,
        **cfg.model.bhv_decoder,
    )

# Train wrapper
train_wrapper = TrainWrapper(
    cfg, mae_mask_manager, ctx_manager, spikes_patchifier, encoder, decoder
)

# Tokenizer
ctx_tokenizer = ctx_manager.get_ctx_tokenizer()
tokenizer = Ndt2Tokenizer(
    ctx_time=cfg.ctx_time,
    bin_time=cfg.bin_time,
    patch_size=cfg.patch_size,
    pad_val=cfg.pad_val,
    decoder_registry=decoder_registry,
    mask_ratio=cfg.mask_ratio,
    ctx_tokenizer=ctx_tokenizer,
    inc_behavior=not cfg.is_ssl,
    inc_mask=cfg.is_ssl,
)

# set up data module
data_module = DataModule(cfg, tokenizer, cfg.is_ssl)
data_module.setup()

# register context
ctx_manager.init_vocab(data_module.get_ctx_vocab(ctx_manager.keys))

L.seed_everything(cfg.seed)

# Callbacks
callbacks = set_callbacks(cfg)

# Set up trainer
# trainer = L.Trainer(
#     logger=wandb_logger,
#     default_root_dir=cfg.log_dir,
#     check_val_every_n_epoch=cfg.eval_epochs,
#     max_epochs=cfg.epochs,
#     log_every_n_steps=cfg.log_every_n_steps,
#     callbacks=callbacks,
#     accelerator="gpu",
#     precision=cfg.precision,
#     fast_dev_run=cfg.fast_dev_run,
#     num_sanity_val_steps=cfg.num_sanity_val_steps,
#     strategy="ddp_find_unused_parameters_true",
# )

Seed set to 0
Seed set to 0
Seed set to 0


{'session': ['odoherty_sabes_nonhuman_2017_v2/indy_20160411_01', 'odoherty_sabes_nonhuman_2017_v2/indy_20160411_02', 'odoherty_sabes_nonhuman_2017_v2/indy_20160418_01', 'odoherty_sabes_nonhuman_2017_v2/indy_20160419_01', 'odoherty_sabes_nonhuman_2017_v2/indy_20160420_01', 'odoherty_sabes_nonhuman_2017_v2/indy_20160426_01', 'odoherty_sabes_nonhuman_2017_v2/indy_20160622_01', 'odoherty_sabes_nonhuman_2017_v2/indy_20160624_03', 'odoherty_sabes_nonhuman_2017_v2/indy_20160630_01', 'odoherty_sabes_nonhuman_2017_v2/indy_20160915_01', 'odoherty_sabes_nonhuman_2017_v2/indy_20160916_01', 'odoherty_sabes_nonhuman_2017_v2/indy_20160921_01', 'odoherty_sabes_nonhuman_2017_v2/indy_20160927_04', 'odoherty_sabes_nonhuman_2017_v2/indy_20160927_06', 'odoherty_sabes_nonhuman_2017_v2/indy_20160930_02', 'odoherty_sabes_nonhuman_2017_v2/indy_20160930_05', 'odoherty_sabes_nonhuman_2017_v2/indy_20161006_02', 'odoherty_sabes_nonhuman_2017_v2/indy_20161007_02', 'odoherty_sabes_nonhuman_2017_v2/indy_20161011_03',

In [3]:
for batch in train_loader:
    break

train_wrapper.training_step(batch[0], 0)

/nethome/aandre8/torch_brain/venv/lib/python3.9/site-packages/lightning/pytorch/core/module.py:447: You are trying to `self.log()` but the `self.trainer` reference is not registered on the model yet. This is most likely because the model hasn't been passed to the `Trainer`


tensor(1.2680, grad_fn=<AddBackward0>)

In [None]:
# for n, a in train_wrapper.bhv_decoder.named_parameters():
#     print(n)

In [None]:
# for batch in val_loader:
#     break

# batch_superv = batch[0]['superv']

# batch_superv["bhvr_vel"].sum()

In [None]:
batch_superv["spike_tokens"].shape

torch.Size([2, 150, 32, 1])