In [1]:
import sys
from experanto.datasets import ChunkDataset
from torch.utils.data import DataLoader
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from collections import OrderedDict

In [2]:
root_folder = '../data/allen_data'
sampling_rate = 8  # Timestamps generated by this do not match the real ones, why do this? Is this for output data?
chunk_size = 32

In [3]:
train_dataset = ChunkDataset(root_folder=f'{root_folder}/experiment_951980471', global_sampling_rate=sampling_rate,
            global_chunk_size=chunk_size,
            modality_config = 
            {'screen': {
                'sampling_rate': None,
                'chunk_size': None,
                'valid_condition': {
                    'tier': 'train',
                    'stim_type': 'stimulus.Frame', #include both images and videos
                    'stim_type': 'stimulus.Clip'
                },
                'offset': 0,
                'sample_stride': 4,
                # necessary for the allen dataset since there are blanks after every stimuli because else no valid times are found
                'include_blanks': True, 
                'transforms': {
                    'ToTensor': {
                        '_target_': 'torchvision.transforms.ToTensor'
                    },
                    'Normalize': {
                        '_target_': 'torchvision.transforms.Normalize',
                        'mean': 80.0,
                        'std': 60.0
                    },
                    'Resize': {
                        '_target_': 'torchvision.transforms.Resize',
                        'size': [144, 256]
                    },
                    'CenterCrop': {
                        '_target_': 'torchvision.transforms.CenterCrop',
                        'size': 144
                    }
                },
                'interpolation': {}
            },
            'responses': {
                'sampling_rate': None,
                'chunk_size': None,
                'offset': 0.1,
                'transforms': {
                    'standardize': True
                },
                'interpolation': {
                    'interpolation_mode': 'nearest_neighbor'
                }
            },
        })



In [4]:
val_dataset = ChunkDataset(root_folder=f'{root_folder}/experiment_951980473', global_sampling_rate=sampling_rate,
            global_chunk_size=chunk_size,
            modality_config = 
            {'screen': {
                'sampling_rate': None,
                'chunk_size': None,
                'valid_condition': {
                    'tier': 'val',
                    'stim_type': 'stimulus.Frame', #include both images and videos
                    'stim_type': 'stimulus.Clip'
                },
                'offset': 0,
                'sample_stride': 4,
                # necessary for the allen dataset since there are blanks after every stimuli because else no valid times are found
                'include_blanks': True, 
                'transforms': {
                    'ToTensor': {
                        '_target_': 'torchvision.transforms.ToTensor'
                    },
                    'Normalize': {
                        '_target_': 'torchvision.transforms.Normalize',
                        'mean': 80.0,
                        'std': 60.0
                    },
                    'Resize': {
                        '_target_': 'torchvision.transforms.Resize',
                        'size': [144, 256]
                    },
                    'CenterCrop': {
                        '_target_': 'torchvision.transforms.CenterCrop',
                        'size': 144
                    }
                },
                'interpolation': {}
            },
            'responses': {
                'sampling_rate': None,
                'chunk_size': None,
                'offset': 0.1,
                'transforms': {
                    'standardize': True
                },
                'interpolation': {
                    'interpolation_mode': 'nearest_neighbor'
                }
            },
        })

In [5]:
batch_size = 15
data_loaders = OrderedDict()
m = 'allen_data'

data_loaders['train'] = OrderedDict()
data_loaders['oracle'] = OrderedDict()
data_loaders['train'][m] =  DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
data_loaders['oracle'][m] = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

In [6]:
sys.path.append('/src/sensorium_2023/')
import torch
from sensorium.datasets.mouse_video_loaders import mouse_video_loader
from sensorium.utility.scores import get_correlations
from nnfabrik.builder import get_trainer
from sensorium.models.make_model import make_video_model
from nnfabrik.utility.nn_helpers import set_random_seed
seed = 42
set_random_seed(seed)

In [7]:
factorised_3D_core_dict = dict(
    input_channels=3, # With this we can use both rgb and greyscale,  addapt this for me?
    hidden_channels=[32, 64, 128],
    spatial_input_kernel=(11,11),
    temporal_input_kernel=11,
    spatial_hidden_kernel=(5,5),
    temporal_hidden_kernel=5,
    stride=1,
    layers=3,
    gamma_input_spatial=10,
    gamma_input_temporal=0.01, 
    bias=True, 
    hidden_nonlinearities='elu', 
    x_shift=0, 
    y_shift=0,
    batch_norm=True, 
    laplace_padding=None,
    input_regularizer='LaplaceL2norm',
    padding=False,
    final_nonlin=True,
    momentum=0.7
)


shifter_dict=None


readout_dict = dict(
    bias=True,
    init_mu_range=0.2,
    init_sigma=1.0,
    gamma_readout=0.0,
    gauss_type='full',
    grid_mean_predictor=None,
    #grid_mean_predictor={
    #    'type': 'cortex',
    #    'input_dimensions': 2,
    #    'hidden_layers': 1,
    #    'hidden_features': 30,
    #    'final_tanh': True
    #},
    share_features=False,
    share_grid=False,
    shared_match_ids=None,
    gamma_grid_dispersion=0.0,
)

In [8]:
factorised_3d_model = make_video_model(
    data_loaders,
    seed,
    core_dict=factorised_3D_core_dict,
    core_type='3D_factorised',
    readout_dict=readout_dict.copy(),
    readout_type='gaussian',               
    use_gru=False,
    gru_dict=None,
    use_shifter=False,
    shifter_dict=shifter_dict,
    shifter_type='MLP',
    deeplake_ds=False,
)



In [9]:
factorised_3d_model

VideoFiringRateEncoder(
  (core): Factorized3dCore(
    (_input_weight_regularizer): LaplaceL2norm(
      (laplace): Laplace()
    )
    (temporal_regularizer): DepthLaplaceL21d(
      (laplace): Laplace1d()
    )
    (features): Sequential(
      (layer0): Sequential(
        (conv_spatial): Conv3d(3, 32, kernel_size=(1, 11, 11), stride=(1, 1, 1))
        (conv_temporal): Conv3d(32, 32, kernel_size=(11, 1, 1), stride=(1, 1, 1))
        (norm): BatchNorm3d(32, eps=1e-05, momentum=0.7, affine=True, track_running_stats=True)
        (nonlin): ELU(alpha=1.0)
      )
      (layer1): Sequential(
        (conv_spatial_1): Conv3d(32, 64, kernel_size=(1, 5, 5), stride=(1, 1, 1))
        (conv_temporal_1): Conv3d(64, 64, kernel_size=(5, 1, 1), stride=(1, 1, 1))
        (norm): BatchNorm3d(64, eps=1e-05, momentum=0.7, affine=True, track_running_stats=True)
        (nonlin): ELU(alpha=1.0)
      )
      (layer2): Sequential(
        (conv_spatial_2): Conv3d(64, 128, kernel_size=(1, 5, 5), stride=

In [10]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

In [13]:
trainer_fn = "sensorium.training.video_training_loop.standard_trainer"

trainer_config = {
    'dataloaders': data_loaders,  # Keep this as it is (your data loaders)
    'seed' : 111,  # Set seed for reproducibility
    'use_wandb' : False,  # Disable WandB
    'verbose': True,  # Keep verbosity for logging
    'lr_decay_steps': 1,  # One decay step (this will not matter for 1 iteration)
    'lr_init': 0.005,  # Keep the initial learning rate the same
    'device' : device,  # Keep the device (cpu or cuda) unchanged
    'detach_core' : False,  # Set False to allow gradients for the core
    'deeplake_ds' : False,  # Set to False as you're not using DeepLake
    'checkpoint_save_path': '/tmp/',  # Save checkpoints temporarily or disable saving
    'max_iter': 1,  # Set max_iter to 1 for a quick test (1 iteration)
    'batch_size': 1,  # Small batch size (1) for fast testing
}


trainer = get_trainer(trainer_fn=trainer_fn, 
                 trainer_config=trainer_config)

In [14]:
validation_score, trainer_output, state_dict = trainer(factorised_3d_model)

optim_step_count = 1


KeyboardInterrupt: 

In [None]:
%debug