In [22]:
import sys
from experanto.datasets import ChunkDataset
from torch.utils.data import DataLoader
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from collections import OrderedDict

In [23]:
root_folder = '../data/allen_data'
sampling_rate = 8  
chunk_size = 32 # since we also use video data we always use chunks of images to also consider temporal developements

In [24]:
# sample modality config for a trainingset which includes screen and response interpolation

train_dataset = ChunkDataset(root_folder=f'{root_folder}/experiment_951980471', global_sampling_rate=sampling_rate,
            global_chunk_size=chunk_size,
            modality_config = 
            {'screen': {
                'sampling_rate': None,
                'chunk_size': None,
                'valid_condition': {
                    'tier': 'train',
                    'stim_type': 'stimulus.Frame', #include both images and videos
                    'stim_type': 'stimulus.Clip'
                },
                'offset': 0,
                'sample_stride': 4,
                # necessary for the allen dataset since there are blanks after every stimuli because else no valid times are found
                'include_blanks': True, 
                'transforms': {
                    'Normalize': {
                        '_target_': 'torchvision.transforms.Normalize',
                        'mean': 80.0,
                        'std': 60.0
                    },
                    'Resize': {
                        '_target_': 'torchvision.transforms.Resize',
                        'size': [144, 256]
                    },
                    'CenterCrop': {
                        '_target_': 'torchvision.transforms.CenterCrop',
                        'size': 144
                    },
                    'greyscale': True # add this for greyscale data
                },
                'interpolation': {}
            },
            'responses': {
                'sampling_rate': None,
                'chunk_size': None,
                'offset': 0.1,
                'transforms': {
                    'standardize': True
                },
                'interpolation': {
                    'interpolation_mode': 'nearest_neighbor'
                }
            },
        })

In [25]:
val_dataset = ChunkDataset(root_folder=f'{root_folder}/experiment_951980473', global_sampling_rate=sampling_rate,
            global_chunk_size=chunk_size,
            modality_config = 
            {'screen': {
                'sampling_rate': None,
                'chunk_size': None,
                'valid_condition': {
                    'tier': 'val',
                    'stim_type': 'stimulus.Frame', #include both images and videos
                    'stim_type': 'stimulus.Clip'
                },
                'offset': 0,
                'sample_stride': 4,
                # necessary for the allen dataset since there are blanks after every stimuli because else no valid times are found
                'include_blanks': True, 
                'transforms': {
                    'ToTensor': {
                        '_target_': 'torchvision.transforms.ToTensor'
                    },
                    'Normalize': {
                        '_target_': 'torchvision.transforms.Normalize',
                        'mean': 80.0,
                        'std': 60.0
                    },
                    'Resize': {
                        '_target_': 'torchvision.transforms.Resize',
                        'size': [144, 256]
                    },
                    'CenterCrop': {
                        '_target_': 'torchvision.transforms.CenterCrop',
                        'size': 144
                    },
                    'greyscale': True # add this for greyscale data
                },
                'interpolation': {}
            },
            'responses': {
                'sampling_rate': None,
                'chunk_size': None,
                'offset': 0.1,
                'transforms': {
                    'standardize': True
                },
                'interpolation': {
                    'interpolation_mode': 'nearest_neighbor'
                }
            },
        })

In [26]:
# interpolation showcase using the dataset object
sample = train_dataset[0]

print(sample.keys())
for key in sample.keys():
    print(f'This is shape {sample[key].shape} for modality {key}')

dict_keys(['screen', 'responses', 'timestamps'])
This is shape torch.Size([1, 32, 144, 144]) for modality screen
This is shape torch.Size([32, 12]) for modality responses
This is shape torch.Size([32, 1]) for modality timestamps


In [27]:
# generating dataloaders based on the dataset objects

batch_size = 15
data_loaders = OrderedDict()

data_loaders['train'] = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
data_loaders['val'] = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

In [28]:
data_loaders

OrderedDict([('train',
              <torch.utils.data.dataloader.DataLoader at 0x7b01452a48b0>),
             ('val',
              <torch.utils.data.dataloader.DataLoader at 0x7b01452a4a00>)])

In [29]:
# interpolation showcase using the data_loaders
for batch_idx, batch_data in enumerate(data_loaders['train']):
    # batch_data is a dictionary with keys 'screen', 'responses', and 'timestamps'
    screen_data = batch_data['screen']
    responses = batch_data['responses']
    timestamps = batch_data['timestamps']
    
    # Print or inspect the batch
    print(f"Batch {batch_idx}:")
    print("Screen Data:", screen_data.shape)
    print("Responses:", responses.shape)
    print("Timestamps:", timestamps.shape)
    break

Batch 0:
Screen Data: torch.Size([15, 1, 32, 144, 144])
Responses: torch.Size([15, 32, 12])
Timestamps: torch.Size([15, 32, 1])
