# Demo notebook for datasets

In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2 

In [None]:
import numpy as np
import pickle 
import pandas as pd
pd.set_option('display.max_columns', 500)
pd.set_option('display.max_rows', 10)
import matplotlib.pyplot as plt
import matplotlib
import seaborn as sns

from neuralpredictors.data.datasets import FileTreeDataset, StaticImageSet

There are 2 major dataset classes, the FileFileTreeDataset and the StaticImageSet. However, the StaticImageSet is hardly used in the lab any more and this tutorial will introduce the FileTreeDataset only.  
First, instantiate your dataset object with the path to the data folder and all the parts of the dataset that you want to include (currently possible: images, reponses, behavior and pupil center.)

In [None]:
path = 'data/static22085-2-20-preproc0'
dataset = FileTreeDataset(path, 'images', 'responses', 'behavior', 'pupil_center')

In [None]:
print('Number of neurons: {}'.format(dataset.n_neurons))
print('Image shape: {}'.format(dataset.img_shape))
print('Data keys: {}'.format(dataset.data_keys))

___

Tab completion does currently not work with this dataset object. Use the dir() function to find your way through the dataset object...

In [None]:
dir(dataset)

## Look at first datapoint

In [None]:
first_data_point = dataset[0]

In [None]:
first_data_point

In [None]:
image, response, behavior, pupil_center = first_data_point

### Image

In [None]:
fig, ax = plt.subplots(1, 1, dpi=100)
ax.imshow(image.squeeze(), cmap='gray') # Squeeze to get rid of batch dimension
ax.set_title('First image in dataset');

### Response

In [None]:
fig, ax = plt.subplots(1, 1, dpi=100)
ax.plot(response, color='navy')
ax.set(title='Responses to first image', xlabel='neuron index (NOT ID)', ylabel='response')
sns.despine(trim=True)

### Behavior and pupil center

The behavior data consists of 3 behavioral variables (in this order): 
- **pupil dilation**  
- **derivative of pupil dilation**  
- **running speed**

In [None]:
print('Behavioral variables during first image-response pair: \n\n' + 
      'pupil dilation: {}\n'.format(behavior[0]) + 
      'derivative of pupil dilation: {}\n'.format(behavior[1]) + 
      'running speed: {}'.format(behavior[2]))

___

The pupil center is not a "behavioral" variable and therefore an extra data key:

In [None]:
print('Pupil center: {}'.format(pupil_center))

### Access info about neurons

The info about neurons can be found in `dataset.neurons`. Every neuron has its unique ID (`dataset.neurons.unit_ids`) which should be used to refer to it. The IDs are NOT necessarily consecutive because certain neurons might have been discarded during data-preprocessing.

In [None]:
print(f'number of neurons: {dataset.n_neurons} \nhighest unit_id: {dataset.neurons.unit_ids.max()}')

___

There is various info about the neurons in the dataset (run `dir(dataset.neurons)` for more). For example the brain areas that the neurons are located in:

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(15, 5), dpi=150)
ax.scatter(dataset.neurons.unit_ids, dataset.neurons.area, color='navy', marker='|')
ax.set(title='Brain areas of neurons', xlabel='neuron ID', ylabel='area')
sns.despine(trim=True)

### Access info about trials of images

The info about trials can be found in `dataset.trial_info`. Every trial has its unique index (`dataset.trial_info.trial_idx`) which should be used to refer to it. The indices are NOT necessarily consecutive because certain trials might have been discarded during data-preprocessing.

In [None]:
print(f'number of trials: {len(dataset.trial_info.trial_idx)} \nhighest trial index: {dataset.trial_info.trial_idx.max()}')

___

There is various info about the trials in the dataset (run `dir(dataset.trial_info)` for more). For example the tier (train, validation, test) that the trial should be used for:

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(15, 5), dpi=150)
ax.scatter(dataset.trial_info.trial_idx, dataset.trial_info.tiers, color='navy', marker='|')
ax.set(title='Tiers of trial indices', xlabel='trial index', ylabel='tier')
sns.despine(trim=True)

Note that in the **test tier**, the images are repeated for 10 trials: (Sometimes less if a trial had to be removed during pre-processing)

In [None]:
from collections import Counter

test_image_ids = dataset.trial_info.frame_image_id[dataset.trial_info.tiers == 'test']
counter = Counter(test_image_ids)
print('image_id | count')
print('----------------')
for key, value in zip(counter.keys(), counter.values()):
    print(str(key) + '     |   ' + str(value))

# Transforms

You can apply transforms on your data, for example to normalize the responses or to select only specific neurons, etc...

In [None]:
from neuralpredictors.data.transforms import NeuroNormalizer, Subsample

neurons_in_V1_indices = np.where(np.isin(dataset.neurons.area, 'V1'))[0]

# Note that the order of transforms is important! 
transforms = [NeuroNormalizer(dataset, exclude=['behavior', 'pupil_center']), Subsample(neurons_in_V1_indices)]

Look at the mean of the first image and the subselection of neurons with and without transforms. The normailzation is done across all images, not on each image individually.

In [None]:
dataset.transforms = []
image, response, behavior, pupil_center = dataset[0]
print('Without transforms:')
print('Mean of first image: {}'.format(np.mean(image)))
print('Number of neurons: {}'.format(len(response)))

In [None]:
dataset.transforms = transforms
image, response, behavior, pupil_center = dataset[0]
print('With transforms:')
print('Mean of first image: {}'.format(np.mean(image)))
print('Number of neurons: {}'.format(len(response)))

# Dataloader for training models

In [None]:
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler

batch_size = 64
tier = 'train'

trial_indices = np.where(dataset.trial_info.tiers == tier)[0]
sampler = SubsetRandomSampler(trial_indices)
dataloader = DataLoader(dataset, sampler=sampler, batch_size=batch_size)

For more elaborate dataloaders, look for example at https://github.com/sinzlab/Lurz_2020_code/blob/main/lurz2020/datasets/mouse_loaders.py