In [1]:
import sys
from experanto.datasets import ChunkDataset
from torch.utils.data import DataLoader
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from collections import OrderedDict

In [2]:
from omegaconf import OmegaConf, open_dict
from experanto.configs import DEFAULT_CONFIG as cfg

cfg.dataset.modality_config.screen.transforms.Resize.size = [144,144] 
cfg.dataset.modality_config.screen.interpolation.rescale_size = [144, 144]
cfg.dataset.modality_config.screen.transforms.greyscale = True
modality_cfg = cfg.dataset.modality_config

# Extract only 'screen' and 'responses' or other modalities if necessecary for single session loading
selected_modalities = OmegaConf.create({
    'screen': modality_cfg.screen,
    'responses': modality_cfg.responses
})

root_folder = '../data/allen_data'
sampling_rate = 60
chunk_size = 60 # since we also use video data we always use chunks of images to also consider temporal developements

In [3]:
# sample modality config for a trainingset which includes screen and response interpolation

train_dataset = ChunkDataset(root_folder=f'{root_folder}/experiment_951980471_train', global_sampling_rate=sampling_rate,
            global_chunk_size=chunk_size, modality_config = selected_modalities)

No metadata file found at ../data/allen_data/experiment_951980471_train/meta.json




In [4]:
val_dataset = ChunkDataset(root_folder=f'{root_folder}/experiment_951980473_val', global_sampling_rate=sampling_rate,
            global_chunk_size=chunk_size, modality_config = selected_modalities)

No metadata file found at ../data/allen_data/experiment_951980473_val/meta.json


In [5]:
# interpolation showcase using the dataset object
sample = train_dataset[100]

print(sample.keys())
for key in sample.keys():
    print(f'This is shape {sample[key].shape} for modality {key}')

dict_keys(['screen', 'responses'])
This is shape torch.Size([1, 60, 144, 144]) for modality screen
This is shape torch.Size([60, 12]) for modality responses


In [6]:
# generating dataloaders based on the dataset objects

batch_size = 50
data_loaders = OrderedDict()

data_loaders['train'] = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
data_loaders['val'] = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

In [7]:
data_loaders

OrderedDict([('train',
              <torch.utils.data.dataloader.DataLoader at 0x78f30686cfa0>),
             ('val',
              <torch.utils.data.dataloader.DataLoader at 0x78f30686ca60>)])

In [8]:
# interpolation showcase using the data_loaders
for batch_idx, batch_data in enumerate(data_loaders['train']):
    # batch_data is a} dictionary with keys 'screen', 'responses', and 'timestamps'
    screen_data = batch_data['screen']
    responses = batch_data['responses']
    
    # Print or inspect the batch
    print(f"Batch {batch_idx}:")
    print("Screen Data:", screen_data.shape)
    print("Responses:", responses.shape)
    break

Batch 0:
Screen Data: torch.Size([50, 1, 60, 144, 144])
Responses: torch.Size([50, 60, 12])
