### Imports

In [1]:
import os
import torch
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns

import warnings
warnings.filterwarnings('ignore')

from nnfabrik.builder import get_data, get_model, get_trainer

### Instantiate DataLoader

In [2]:
dataset_fn = 'sensorium.datasets.static_loaders'

dataset_name = '26872-17-20'

filenames = ['./data/static26872-17-20-GrayImageNet-94c6ff995dac583098847cfecd43e7b6_original.zip',]
dataset_config = {'paths': filenames,
                 'normalize': True,
                 'include_behavior': False,
                 'include_eye_position': False,
                 'batch_size': 128,
                 'scale':.25,
                 }
dataloaders_original = get_data(dataset_fn, dataset_config)

filenames = ['./data/static26872-17-20-GrayImageNet-94c6ff995dac583098847cfecd43e7b6_hashed.zip',]
dataset_config = {'paths': filenames,
                 'normalize': True,
                 'include_behavior': False,
                 'include_eye_position': False,
                 'batch_size': 128,
                 'scale':.25,
                 }

dataloaders_hashed = get_data(dataset_fn, dataset_config)

In [3]:
dataloaders_original

OrderedDict([('train',
              OrderedDict([('26872-17-20',
                            <torch.utils.data.dataloader.DataLoader at 0x7f439fcd2070>)])),
             ('validation',
              OrderedDict([('26872-17-20',
                            <torch.utils.data.dataloader.DataLoader at 0x7f439fcd2d30>)])),
             ('test',
              OrderedDict([('26872-17-20',
                            <torch.utils.data.dataloader.DataLoader at 0x7f439fcd2ca0>)])),
             ('final_test',
              OrderedDict([('26872-17-20',
                            <torch.utils.data.dataloader.DataLoader at 0x7f439fcd2cd0>)]))])

In [4]:
dataloaders_hashed

OrderedDict([('train',
              OrderedDict([('26872-17-20',
                            <torch.utils.data.dataloader.DataLoader at 0x7f439fcd2910>)])),
             ('validation',
              OrderedDict([('26872-17-20',
                            <torch.utils.data.dataloader.DataLoader at 0x7f439fcda3d0>)])),
             ('test',
              OrderedDict([('26872-17-20',
                            <torch.utils.data.dataloader.DataLoader at 0x7f439fcda310>)])),
             ('final_test',
              OrderedDict([('26872-17-20',
                            <torch.utils.data.dataloader.DataLoader at 0x7f439fcda670>)]))])

In [5]:
dataloaders_hashed["final_test"][dataset_name].dataset.trial_info.trial_idx

array(['2dffbc474aa176b6dc957938c15d0c8b',
       '31c23973a376c90940f5f5ff2118b5d2',
       'd02e9bdc27a894e882fa0c9055c99722', ...,
       'f33ba15effa5c10e873bf3842afb46a6',
       '7f24d240521d99071c93af3917215ef7',
       '3ce3bd7d63a2c9c81983cc8e9bd02ae5'], dtype='<U32')

In [6]:
dataloaders_original["final_test"][dataset_name].dataset.trial_info.trial_idx

array([1449, 6999, 5563, ..., 1012,  551, 4400])

## Instantiate State of the Art Model (SOTA)

In [16]:
# this will remove all nonlinearities from the CNN, and creates essentially a ln model: linear core + readout, with a subsequent non-linearity

model_fn = 'sensorium.models.stacked_core_full_gauss_readout'
model_config = {'pad_input': False,
              'stack': -1,
              'layers': 3,
              'input_kern': 9,
              'gamma_input': 6.3831,
              'gamma_readout': 0.0076,
              'hidden_dilation': 1,
              'hidden_kern': 7,
              'hidden_channels': 64,
              'grid_mean_predictor': {'type': 'cortex',
              'input_dimensions': 2,
              'hidden_layers': 1,
              'hidden_features': 30,
              'final_tanh': True},
              'depth_separable': True,
              'init_sigma': 0.1,
              'init_mu_range': 0.3,
              'gauss_type': 'full',
              'linear': True
               }
model = get_model(model_fn=model_fn,
                  model_config=model_config,
                  dataloaders=dataloaders_hashed,
                  seed=42,)

In [21]:
model.load_state_dict(torch.load("./checkpoints/sensorium_ln_model.pth"));
model.eval();

---

# A) Does hash mess up the scores?

## Generate submission file and get the score

In [22]:
from sensorium import evaluate
# import the API from the competition repo
from sensorium.utility import submission

### 1. original file

In [23]:
ground_truth_filename = ['./data/static26872-17-20-GrayImageNet-94c6ff995dac583098847cfecd43e7b6_original']
ground_truth_file = './sensorium_sota/'
submission.generate_ground_truth_file(filename=ground_truth_filename,
                                      path=ground_truth_file)

In [24]:
# generate the submission file
submission.generate_submission_file(trained_model=model, 
                                    dataloaders=dataloaders_original,
                                    data_key=dataset_name,
                                    path="./sensorium_sota/",
                                    device="cuda")

Submission file saved for tier: test. Saved in: ./sensorium_sota/submission_file_test.csv
Submission file saved for tier: final_test. Saved in: ./sensorium_sota/submission_file_final_test.csv


In [25]:
ground_truth_file = './sensorium_sota/ground_truth_file_final_test.csv'
submission_file = './sensorium_sota/submission_file_final_test.csv'

In [26]:
import pandas as pd
df_original = pd.read_csv(submission_file)

In [27]:
out = evaluate(submission_file, ground_truth_file)

In [28]:
print("Results for the SOTA model:")
for metric, value in out.items():
    print(f"{metric}: {np.round(value, 3)}")

Results for the SOTA model:
Single Trial Correlation: 0.207
Average Correlation: 0.377
FEVE: 0.125


### 2. Hashed file

In [29]:
ground_truth_filename = ['./data/static26872-17-20-GrayImageNet-94c6ff995dac583098847cfecd43e7b6_hashed']
ground_truth_file = './sensorium_sota/'
submission.generate_ground_truth_file(filename=ground_truth_filename,
                                      path=ground_truth_file)

In [30]:
submission.generate_submission_file(trained_model=model, 
                                    dataloaders=dataloaders_hashed,
                                    data_key=dataset_name,
                                    path="./sensorium_sota/",
                                    device="cuda")

Submission file saved for tier: test. Saved in: ./sensorium_sota/submission_file_test.csv
Submission file saved for tier: final_test. Saved in: ./sensorium_sota/submission_file_final_test.csv


#### a) with the generated GT file

In [31]:
ground_truth_file = './sensorium_sota/ground_truth_file_final_test.csv'
submission_file = './sensorium_sota/submission_file_final_test.csv'

In [32]:
import pandas as pd
df_hashed = pd.read_csv(submission_file)

In [33]:
out = evaluate(submission_file, ground_truth_file)

In [34]:
print("Results for the SOTA model:")
for metric, value in out.items():
    print(f"{metric}: {np.round(value, 3)}")

Results for the SOTA model:
Single Trial Correlation: 0.207
Average Correlation: 0.377
FEVE: 0.125


#### a) with Konsti's GT file

In [35]:
ground_truth_file = "./sensorium_sota/sensorium_ground_truth_file_final_test.csv"
submission_file = './sensorium_sota/submission_file_final_test.csv'

In [36]:
import pandas as pd
df_hashed = pd.read_csv(submission_file)

In [37]:
out = evaluate(submission_file, ground_truth_file)

In [38]:
print("Results for the SOTA model:")
for metric, value in out.items():
    print(f"{metric}: {np.round(value, 3)}")

Results for the SOTA model:
Single Trial Correlation: 0.207
Average Correlation: 0.377
FEVE: 0.125


---

# B) Are the metrics consistent when computed with CSV files vs using the scoring functions?

In [39]:
from sensorium.utility import get_correlations, get_signal_correlations, get_fev

In [40]:
correlations = get_correlations(model, dataloaders_original["final_test"], device="cuda")
signal_correlations = get_signal_correlations(model, dataloaders_original, device="cuda", tier="final_test")
fevs = get_fev(model, dataloaders_original, device="cuda", tier="final_test")

In [41]:
metrics = {"Correlation (single trial)": correlations.mean(), 
           "Correlation (mean)": signal_correlations.mean(), 
           "FEVE": fevs.mean()}

In [42]:
print("Results for the model:")
for metric, value in metrics.items():
    print(f"{metric}: {value:.3f}")

Results for the model:
Correlation (single trial): 0.207
Correlation (mean): 0.377
FEVE: 0.124
