In [1]:
import sys
from experanto.datasets import ChunkDataset
from torch.utils.data import DataLoader
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from collections import OrderedDict

In [2]:
root_folder = '../data/allen_data'
sampling_rate = 30  # Timestamps generated by this do not match the real ones, why do this? Is this for output data?
chunk_size = 20

In [3]:
train_dataset = ChunkDataset(root_folder=f'{root_folder}/experiment_951980471', global_sampling_rate=sampling_rate,
            global_chunk_size=chunk_size,
            modality_config = 
            {'screen': {
                'sampling_rate': None,
                'chunk_size': None,
                'valid_condition': {
                    'tier': 'train',
                    'stim_type': 'stimulus.Frame', #include both images and videos
                    'stim_type': 'stimulus.Clip'
                },
                'offset': 0,
                'sample_stride': 4,
                # necessary for the allen dataset since there are blanks after every stimuli because else no valid times are found
                'include_blanks': True, 
                'transforms': {
                    'ToTensor': {
                        '_target_': 'torchvision.transforms.ToTensor'
                    },
                    'Normalize': {
                        '_target_': 'torchvision.transforms.Normalize',
                        'mean': 80.0,
                        'std': 60.0
                    },
                    'Resize': {
                        '_target_': 'torchvision.transforms.Resize',
                        'size': [144, 256]
                    },
                    'CenterCrop': {
                        '_target_': 'torchvision.transforms.CenterCrop',
                        'size': 144
                    }
                },
                'interpolation': {}
            },
            'responses': {
                'sampling_rate': None,
                'chunk_size': None,
                'offset': 0.1,
                'transforms': {
                    'standardize': True
                },
                'interpolation': {
                    'interpolation_mode': 'nearest_neighbor'
                }
            },
            'eye_tracker': {
                'sampling_rate': None,
                'chunk_size': None,
                'offset': 0,
                'transforms': {
                    'normalize': True
                },
                'interpolation': {
                    'interpolation_mode': 'nearest_neighbor'
                }
            },
            'treadmill': {
                'sampling_rate': None,
                'chunk_size': None,
                'offset': 0,
                'transforms': {
                    'normalize': True
                },
                'interpolation': {
                    'interpolation_mode': 'nearest_neighbor'
                }
            }
        })



In [4]:
val_dataset = ChunkDataset(root_folder=f'{root_folder}/experiment_951980473', global_sampling_rate=sampling_rate,
            global_chunk_size=chunk_size,
            modality_config = 
            {'screen': {
                'sampling_rate': None,
                'chunk_size': None,
                'valid_condition': {
                    'tier': 'val',
                    'stim_type': 'stimulus.Frame', #include both images and videos
                    'stim_type': 'stimulus.Clip'
                },
                'offset': 0,
                'sample_stride': 4,
                # necessary for the allen dataset since there are blanks after every stimuli because else no valid times are found
                'include_blanks': True, 
                'transforms': {
                    'ToTensor': {
                        '_target_': 'torchvision.transforms.ToTensor'
                    },
                    'Normalize': {
                        '_target_': 'torchvision.transforms.Normalize',
                        'mean': 80.0,
                        'std': 60.0
                    },
                    'Resize': {
                        '_target_': 'torchvision.transforms.Resize',
                        'size': [144, 256]
                    },
                    'CenterCrop': {
                        '_target_': 'torchvision.transforms.CenterCrop',
                        'size': 144
                    }
                },
                'interpolation': {}
            },
            'responses': {
                'sampling_rate': None,
                'chunk_size': None,
                'offset': 0.1,
                'transforms': {
                    'standardize': True
                },
                'interpolation': {
                    'interpolation_mode': 'nearest_neighbor'
                }
            },
            'eye_tracker': {
                'sampling_rate': None,
                'chunk_size': None,
                'offset': 0,
                'transforms': {
                    'normalize': True
                },
                'interpolation': {
                    'interpolation_mode': 'nearest_neighbor'
                }
            },
            'treadmill': {
                'sampling_rate': None,
                'chunk_size': None,
                'offset': 0,
                'transforms': {
                    'normalize': True
                },
                'interpolation': {
                    'interpolation_mode': 'nearest_neighbor'
                }
            }
        })

In [5]:
batch_size = 1
data_loaders = OrderedDict()
m = 'allen_data'

data_loaders['train'] = OrderedDict()
data_loaders['oracle'] = OrderedDict()
data_loaders['train'][m] =  DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
data_loaders['oracle'][m] = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
it = next(iter(data_loaders['train'][m]))

In [6]:
for key in it.keys():
    print(f"This is shape of {key} : {it[key].shape}")

This is shape of eye_tracker : torch.Size([1, 20, 22])
This is shape of screen : torch.Size([1, 20, 1, 144, 144])
This is shape of treadmill : torch.Size([1, 20, 1])
This is shape of responses : torch.Size([1, 20, 12])
This is shape of timestamps : torch.Size([1, 20])


In [10]:
sys.path.append('/src/sensorium_2023/')
import torch
from sensorium.datasets.mouse_video_loaders import mouse_video_loader
from sensorium.utility.scores import get_correlations
from nnfabrik.builder import get_trainer
from sensorium.models.make_model import make_video_model
from nnfabrik.utility.nn_helpers import set_random_seed
seed = 42
set_random_seed(seed)

In [12]:
factorised_3D_core_dict = dict(
    input_channels=4, # increase if behaviour is used
    hidden_channels=[32, 64, 128],
    spatial_input_kernel=(11,11),
    temporal_input_kernel=11,
    spatial_hidden_kernel=(5,5),
    temporal_hidden_kernel=5,
    stride=1,
    layers=3,
    gamma_input_spatial=10,
    gamma_input_temporal=0.01, 
    bias=True, 
    hidden_nonlinearities='elu', 
    x_shift=0, 
    y_shift=0,
    batch_norm=True, 
    laplace_padding=None,
    input_regularizer='LaplaceL2norm',
    padding=False,
    final_nonlin=True,
    momentum=0.7
)


shifter_dict=None


readout_dict = dict(
    bias=True,
    init_mu_range=0.2,
    init_sigma=1.0,
    gamma_readout=0.0,
    gauss_type='full',
#     grid_mean_predictor=None,
    grid_mean_predictor={
        'type': 'cortex',
        'input_dimensions': 2,
        'hidden_layers': 1,
        'hidden_features': 30,
        'final_tanh': True
    },
    share_features=False,
    share_grid=False,
    shared_match_ids=None,
    gamma_grid_dispersion=0.0,
)

In [13]:
factorised_3d_model = make_video_model(
    data_loaders,
    seed,
    core_dict=factorised_3D_core_dict,
    core_type='3D_factorised',
    readout_dict=readout_dict.copy(),
    readout_type='gaussian',               
    use_gru=False,
    gru_dict=None,
    use_shifter=False,
    shifter_dict=shifter_dict,
    shifter_type='MLP',
    deeplake_ds=False,
)

RuntimeError: Expected 4D (unbatched) or 5D (batched) input to conv3d, but got input of size: [1, 20, 22]