# 3rd Place Sensorium Submission

**IdV_ENS** team -> Institut de la Vision (Sorbonne University) and Ecole Normale Superieure, Paris

Simone Azeglio¹ ², Ulisse Ferrari¹, Peter Neri², Olivier Marre¹

¹ Institut de la Vision, Sorbonne University, Paris 

² Ecole Normale Superieure, Paris


### Imports

In [1]:
import torch
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns

import warnings
warnings.filterwarnings('ignore')

from nnfabrik.builder import get_data, get_model, get_trainer

from sensorium.utility import submission

import os

In [2]:
dataset_name = '26872-17-20'

Install library for Scattering Networks

In [3]:
!pip install kymatio -q

## Model 1 - VOneCNN 256

dataset configuration

In [4]:
basepath = "../data/"

# as filenames, we'll select all 7 datasets
filenames = [os.path.join(basepath, file) for file in os.listdir(basepath) if "static" in file ]

dataset_fn = 'sensorium.datasets.static_loaders'
dataset_config = {'paths': filenames,
                 'normalize': True,
                 'scattering': False, 
                 'include_behavior': False,
                 'include_eye_position': False,
                 'batch_size': 128,
                 'scale':1.,
                 }

dataloaders = get_data(dataset_fn, dataset_config)

Instantiate VOne CNN Model

In [5]:
model_fn = 'sensorium.models.my_models.VOneNet_core_full_gauss_readout'

model_config = {'pad_input': False,
  'stack': -1,
  'layers': 4,
  'input_kern': 9,
  'gamma_input': 6.3831,
  'gamma_readout': 0.0076,
  'hidden_kern': 7,
  'hidden_channels': 64,
  'depth_separable': True,
  'grid_mean_predictor': {'type': 'cortex',
   'input_dimensions': 2,
   'hidden_layers': 1,
   'hidden_features': 30,
   'final_tanh': True},
  'init_sigma': 0.1,
  'init_mu_range': 0.3,
  'gauss_type': 'full',
  'shifter': False,
}

model = get_model(model_fn=model_fn,
                  model_config=model_config,
                  dataloaders=dataloaders,
                  seed=42)

Neuronal distributions gabor parameters
Model:  VOneNet


### Trainer

In [6]:
trainer_fn = "sensorium.training.standard_trainer"

trainer_config = {'max_iter': 300,
                 'verbose': True,
                  'track_training': True,
                 'lr_decay_steps': 4,
                 'avg_loss': False, # True,
                 'lr_init': 0.009,
                 }

trainer = get_trainer(trainer_fn=trainer_fn, 
                     trainer_config=trainer_config)

### Model training

In [7]:
validation_score, trainer_output, state_dict = trainer(model, dataloaders, seed=42)

correlation -0.0028102435
poisson_loss 30768218.0


Epoch 1: 100%|██████████| 252/252 [02:18<00:00,  1.82it/s]


correlation 0.095125355
poisson_loss 18595728.0


Epoch 2: 100%|██████████| 252/252 [01:51<00:00,  2.26it/s]


correlation 0.12528986
poisson_loss 18273534.0


Epoch 3: 100%|██████████| 252/252 [01:51<00:00,  2.26it/s]


correlation 0.165501
poisson_loss 17873586.0


Epoch 4: 100%|██████████| 252/252 [01:51<00:00,  2.26it/s]


correlation 0.18419854
poisson_loss 17644584.0


Epoch 5: 100%|██████████| 252/252 [01:51<00:00,  2.25it/s]


correlation 0.20507738
poisson_loss 17422724.0


Epoch 6: 100%|██████████| 252/252 [01:51<00:00,  2.25it/s]


correlation 0.2319045
poisson_loss 17101710.0


Epoch 7: 100%|██████████| 252/252 [01:51<00:00,  2.25it/s]


correlation 0.24631763
poisson_loss 16935878.0


Epoch 8: 100%|██████████| 252/252 [01:51<00:00,  2.25it/s]


correlation 0.25776964
poisson_loss 16805838.0


Epoch 9: 100%|██████████| 252/252 [01:51<00:00,  2.25it/s]


correlation 0.2668028
poisson_loss 16703888.0


Epoch 10: 100%|██████████| 252/252 [01:51<00:00,  2.25it/s]


correlation 0.27713114
poisson_loss 16570633.0


Epoch 11: 100%|██████████| 252/252 [01:51<00:00,  2.25it/s]


correlation 0.28187335
poisson_loss 16515988.0


Epoch 12: 100%|██████████| 252/252 [01:51<00:00,  2.25it/s]


correlation 0.28708482
poisson_loss 16465317.0


Epoch 13: 100%|██████████| 252/252 [01:51<00:00,  2.26it/s]


correlation 0.29126644
poisson_loss 16410776.0


Epoch 14: 100%|██████████| 252/252 [01:51<00:00,  2.26it/s]


correlation 0.2949069
poisson_loss 16363112.0


Epoch 15: 100%|██████████| 252/252 [01:53<00:00,  2.21it/s]


correlation 0.298493
poisson_loss 16318696.0


Epoch 16: 100%|██████████| 252/252 [01:52<00:00,  2.25it/s]


correlation 0.30354127
poisson_loss 16250538.0


Epoch 17: 100%|██████████| 252/252 [01:52<00:00,  2.25it/s]


correlation 0.30468026
poisson_loss 16249510.0


Epoch 18: 100%|██████████| 252/252 [01:52<00:00,  2.25it/s]


correlation 0.30578336
poisson_loss 16238503.0


Epoch 19: 100%|██████████| 252/252 [01:52<00:00,  2.25it/s]


correlation 0.30829525
poisson_loss 16202678.0


Epoch 20: 100%|██████████| 252/252 [01:52<00:00,  2.25it/s]


correlation 0.31237358
poisson_loss 16146324.0


Epoch 21: 100%|██████████| 252/252 [01:52<00:00,  2.24it/s]


correlation 0.3129938
poisson_loss 16141642.0


Epoch 22: 100%|██████████| 252/252 [01:51<00:00,  2.25it/s]


correlation 0.31409112
poisson_loss 16130255.0


Epoch 23: 100%|██████████| 252/252 [01:51<00:00,  2.25it/s]


correlation 0.31575835
poisson_loss 16113662.0


Epoch 24: 100%|██████████| 252/252 [01:51<00:00,  2.25it/s]


correlation 0.31682444
poisson_loss 16096655.0


Epoch 25: 100%|██████████| 252/252 [01:51<00:00,  2.26it/s]


correlation 0.31668517
poisson_loss 16102952.0


Epoch 26: 100%|██████████| 252/252 [01:51<00:00,  2.26it/s]


correlation 0.31755045
poisson_loss 16087669.0


Epoch 27: 100%|██████████| 252/252 [01:51<00:00,  2.25it/s]


correlation 0.3192532
poisson_loss 16073632.0


Epoch 28: 100%|██████████| 252/252 [01:52<00:00,  2.25it/s]


correlation 0.32013738
poisson_loss 16058739.0


Epoch 29: 100%|██████████| 252/252 [01:52<00:00,  2.25it/s]


correlation 0.32044056
poisson_loss 16057000.0


Epoch 30: 100%|██████████| 252/252 [01:52<00:00,  2.25it/s]


correlation 0.32122472
poisson_loss 16053758.0


Epoch 31: 100%|██████████| 252/252 [01:52<00:00,  2.24it/s]


correlation 0.3207503
poisson_loss 16048904.0


Epoch 32: 100%|██████████| 252/252 [01:52<00:00,  2.25it/s]


correlation 0.32215717
poisson_loss 16031880.0


Epoch 33: 100%|██████████| 252/252 [01:52<00:00,  2.25it/s]


correlation 0.32250175
poisson_loss 16032848.0


Epoch 34: 100%|██████████| 252/252 [01:51<00:00,  2.25it/s]


correlation 0.32165733
poisson_loss 16039315.0


Epoch 35: 100%|██████████| 252/252 [01:52<00:00,  2.25it/s]


correlation 0.3239366
poisson_loss 16006722.0


Epoch 36: 100%|██████████| 252/252 [01:52<00:00,  2.24it/s]


correlation 0.32304704
poisson_loss 16024409.0


Epoch 37: 100%|██████████| 252/252 [01:52<00:00,  2.25it/s]


correlation 0.32417846
poisson_loss 16012318.0


Epoch 38: 100%|██████████| 252/252 [01:52<00:00,  2.25it/s]


correlation 0.32389566
poisson_loss 16016049.0


Epoch 39: 100%|██████████| 252/252 [01:52<00:00,  2.24it/s]


correlation 0.3242978
poisson_loss 16005207.0


Epoch 40: 100%|██████████| 252/252 [01:52<00:00,  2.24it/s]


correlation 0.32510084
poisson_loss 15998535.0


Epoch 41: 100%|██████████| 252/252 [01:52<00:00,  2.25it/s]


correlation 0.32487085
poisson_loss 16006638.0


Epoch 42: 100%|██████████| 252/252 [01:52<00:00,  2.24it/s]


correlation 0.3253169
poisson_loss 15996773.0


Epoch 43: 100%|██████████| 252/252 [01:52<00:00,  2.25it/s]


correlation 0.3250339
poisson_loss 16002287.0


Epoch 44: 100%|██████████| 252/252 [01:52<00:00,  2.25it/s]


correlation 0.325992
poisson_loss 15984578.0


Epoch 45: 100%|██████████| 252/252 [01:51<00:00,  2.25it/s]


correlation 0.3251028
poisson_loss 16005615.0


Epoch 46: 100%|██████████| 252/252 [01:52<00:00,  2.25it/s]


correlation 0.3263185
poisson_loss 15980937.0


Epoch 47: 100%|██████████| 252/252 [01:52<00:00,  2.24it/s]


correlation 0.32498175
poisson_loss 16001614.0


Epoch 48: 100%|██████████| 252/252 [01:52<00:00,  2.24it/s]


correlation 0.3265427
poisson_loss 15982713.0


Epoch 49: 100%|██████████| 252/252 [01:52<00:00,  2.24it/s]


correlation 0.32708755
poisson_loss 15973575.0


Epoch 50: 100%|██████████| 252/252 [01:53<00:00,  2.22it/s]


correlation 0.32742816
poisson_loss 15969666.0


Epoch 51: 100%|██████████| 252/252 [01:52<00:00,  2.23it/s]


correlation 0.32700533
poisson_loss 15978140.0


Epoch 52: 100%|██████████| 252/252 [01:52<00:00,  2.23it/s]


correlation 0.32784617
poisson_loss 15964206.0


Epoch 53: 100%|██████████| 252/252 [01:52<00:00,  2.23it/s]


correlation 0.3266869
poisson_loss 15991417.0


Epoch 54: 100%|██████████| 252/252 [01:52<00:00,  2.23it/s]


correlation 0.32702228
poisson_loss 15977679.0


Epoch 55: 100%|██████████| 252/252 [01:52<00:00,  2.25it/s]


correlation 0.32666644
poisson_loss 15986166.0


Epoch 56: 100%|██████████| 252/252 [01:52<00:00,  2.25it/s]


correlation 0.3276581
poisson_loss 15967320.0


Epoch 57: 100%|██████████| 252/252 [01:52<00:00,  2.25it/s]


correlation 0.3279334
poisson_loss 15963182.0


Epoch 58: 100%|██████████| 252/252 [01:51<00:00,  2.26it/s]


Epoch    58: reducing learning rate of group 0 to 2.7000e-03.
correlation 0.3277286
poisson_loss 15962851.0


Epoch 59: 100%|██████████| 252/252 [01:51<00:00,  2.25it/s]


correlation 0.3344736
poisson_loss 15885944.0


Epoch 60: 100%|██████████| 252/252 [01:51<00:00,  2.26it/s]


correlation 0.33509338
poisson_loss 15884366.0


Epoch 61: 100%|██████████| 252/252 [01:51<00:00,  2.25it/s]


correlation 0.33543995
poisson_loss 15876829.0


Epoch 62: 100%|██████████| 252/252 [01:51<00:00,  2.26it/s]


correlation 0.33518472
poisson_loss 15882553.0


Epoch 63: 100%|██████████| 252/252 [01:51<00:00,  2.25it/s]


correlation 0.334811
poisson_loss 15880056.0


Epoch 64: 100%|██████████| 252/252 [01:52<00:00,  2.25it/s]


correlation 0.33484128
poisson_loss 15879692.0


Epoch 65: 100%|██████████| 252/252 [01:51<00:00,  2.25it/s]


correlation 0.33524987
poisson_loss 15883565.0


Epoch 66: 100%|██████████| 252/252 [01:51<00:00,  2.25it/s]


correlation 0.3356766
poisson_loss 15874933.0


Epoch 67: 100%|██████████| 252/252 [01:51<00:00,  2.25it/s]


correlation 0.3352277
poisson_loss 15881721.0


Epoch 68: 100%|██████████| 252/252 [01:51<00:00,  2.25it/s]


correlation 0.33570352
poisson_loss 15879596.0


Epoch 69: 100%|██████████| 252/252 [01:51<00:00,  2.25it/s]


correlation 0.33492386
poisson_loss 15890686.0


Epoch 70: 100%|██████████| 252/252 [01:52<00:00,  2.25it/s]


correlation 0.33563107
poisson_loss 15877027.0


Epoch 71: 100%|██████████| 252/252 [01:51<00:00,  2.25it/s]


correlation 0.33572796
poisson_loss 15878197.0


Epoch 72: 100%|██████████| 252/252 [01:51<00:00,  2.25it/s]


Epoch    72: reducing learning rate of group 0 to 8.1000e-04.
correlation 0.33495852
poisson_loss 15882894.0


Epoch 73: 100%|██████████| 252/252 [01:52<00:00,  2.25it/s]


correlation 0.33727974
poisson_loss 15856123.0


Epoch 74: 100%|██████████| 252/252 [01:51<00:00,  2.25it/s]


correlation 0.33725935
poisson_loss 15860860.0


Epoch 75: 100%|██████████| 252/252 [01:51<00:00,  2.25it/s]


correlation 0.33746076
poisson_loss 15859466.0


Epoch 76: 100%|██████████| 252/252 [01:51<00:00,  2.25it/s]


correlation 0.3373697
poisson_loss 15856730.0


Epoch 77: 100%|██████████| 252/252 [01:51<00:00,  2.26it/s]


correlation 0.33731776
poisson_loss 15857089.0


Epoch 78: 100%|██████████| 252/252 [01:51<00:00,  2.25it/s]


correlation 0.33705705
poisson_loss 15856511.0


Epoch 79: 100%|██████████| 252/252 [01:51<00:00,  2.26it/s]


Epoch    79: reducing learning rate of group 0 to 2.4300e-04.
correlation 0.33701867
poisson_loss 15861851.0


Epoch 80: 100%|██████████| 252/252 [01:51<00:00,  2.26it/s]


correlation 0.33768296
poisson_loss 15851664.0


Epoch 81: 100%|██████████| 252/252 [01:51<00:00,  2.26it/s]


correlation 0.3377451
poisson_loss 15852438.0


Epoch 82: 100%|██████████| 252/252 [01:51<00:00,  2.26it/s]


correlation 0.3378777
poisson_loss 15854055.0


Epoch 83: 100%|██████████| 252/252 [01:51<00:00,  2.26it/s]


In [8]:
torch.save(model.state_dict(), '../model_tutorial/model_checkpoints/generalization_vone_256_model.pth')

### Load Checkpoints

In [9]:
model.load_state_dict(torch.load("../model_tutorial/model_checkpoints/generalization_vone_256_model.pth"));

In [10]:
model.eval();

### Submission File

In [11]:
submission.generate_submission_file(trained_model=model, 
                                    dataloaders=dataloaders,
                                    data_key=dataset_name,
                                    path="./submission_files/VOne256/",
                                    device="cuda")

Submission file saved for tier: live_test. Saved in: ./submission_files/VOne256/submission_file_live_test.csv
Submission file saved for tier: final_test. Saved in: ./submission_files/VOne256/submission_file_final_test.csv


In [12]:
del model 

# Model 2 - VOne SE CNN 256

In [13]:
model_se_vone_fn = 'sensorium.models.my_models.VOne_SE_core_full_gauss_readout'

model_se_vone_config = {'pad_input': False,
  'stack': -1,
  'layers': 4,
  'input_kern': 9,
  'gamma_input': 6.3831,
  'gamma_readout': 0.0076,
  'hidden_kern': 7,
  'hidden_channels': 64,
  'depth_separable': True,
  'grid_mean_predictor': {'type': 'cortex',
   'input_dimensions': 2,
   'hidden_layers': 1,
   'hidden_features': 30,
   'final_tanh': True},
  'init_sigma': 0.1,
  'init_mu_range': 0.3,
  'gauss_type': 'full',
  'shifter': False,
}

model_se_vone = get_model(model_fn=model_se_vone_fn,
                  model_config=model_se_vone_config,
                  dataloaders=dataloaders,
                  seed=42)

Neuronal distributions gabor parameters
Model:  VOneNet


In [14]:
trainer_se_vone_fn = "sensorium.training.standard_trainer"

trainer_se_vone_config = {'max_iter': 300,
                 'verbose': True,
                  'track_training': True,
                 'lr_decay_steps': 4,
                 'avg_loss': False, 
                 'lr_init': 0.009,
                 }

trainer_se_vone = get_trainer(trainer_fn=trainer_se_vone_fn, 
                     trainer_config=trainer_se_vone_config)

### Model Training 

In [15]:
validation_score_se_vone, trainer_output_se_vone, state_dict_se_vone= trainer(model_se_vone, dataloaders, seed=42)

correlation -0.0028102435
poisson_loss 30768218.0


Epoch 1: 100%|██████████| 252/252 [01:51<00:00,  2.26it/s]


correlation 0.0951371
poisson_loss 18595772.0


Epoch 2: 100%|██████████| 252/252 [01:51<00:00,  2.25it/s]


correlation 0.12523207
poisson_loss 18273604.0


Epoch 3: 100%|██████████| 252/252 [01:51<00:00,  2.26it/s]


correlation 0.16545868
poisson_loss 17875590.0


Epoch 4: 100%|██████████| 252/252 [01:51<00:00,  2.25it/s]


correlation 0.18431029
poisson_loss 17643640.0


Epoch 5: 100%|██████████| 252/252 [01:51<00:00,  2.25it/s]


correlation 0.20501798
poisson_loss 17423592.0


Epoch 6: 100%|██████████| 252/252 [01:51<00:00,  2.25it/s]


correlation 0.2320032
poisson_loss 17100398.0


Epoch 7: 100%|██████████| 252/252 [01:52<00:00,  2.25it/s]


correlation 0.2462325
poisson_loss 16936682.0


Epoch 8: 100%|██████████| 252/252 [01:52<00:00,  2.23it/s]


correlation 0.25787485
poisson_loss 16803580.0


Epoch 9: 100%|██████████| 252/252 [01:52<00:00,  2.24it/s]


correlation 0.26697433
poisson_loss 16701562.0


Epoch 10: 100%|██████████| 252/252 [01:51<00:00,  2.25it/s]


correlation 0.27709466
poisson_loss 16571046.0


Epoch 11: 100%|██████████| 252/252 [01:51<00:00,  2.25it/s]


correlation 0.28193116
poisson_loss 16515180.0


Epoch 12: 100%|██████████| 252/252 [01:51<00:00,  2.25it/s]


correlation 0.2869856
poisson_loss 16466742.0


Epoch 13: 100%|██████████| 252/252 [01:51<00:00,  2.25it/s]


correlation 0.29134557
poisson_loss 16410084.0


Epoch 14: 100%|██████████| 252/252 [01:51<00:00,  2.25it/s]


correlation 0.29506606
poisson_loss 16360467.0


Epoch 15: 100%|██████████| 252/252 [01:52<00:00,  2.25it/s]


correlation 0.2986134
poisson_loss 16318095.0


Epoch 16: 100%|██████████| 252/252 [01:51<00:00,  2.25it/s]


correlation 0.30351874
poisson_loss 16250094.0


Epoch 17: 100%|██████████| 252/252 [01:51<00:00,  2.25it/s]


correlation 0.30471864
poisson_loss 16249858.0


Epoch 18: 100%|██████████| 252/252 [01:52<00:00,  2.25it/s]


correlation 0.30577907
poisson_loss 16239219.0


Epoch 19: 100%|██████████| 252/252 [01:51<00:00,  2.26it/s]


correlation 0.30824903
poisson_loss 16203084.0


Epoch 20: 100%|██████████| 252/252 [01:51<00:00,  2.26it/s]


correlation 0.3124421
poisson_loss 16145865.0


Epoch 21: 100%|██████████| 252/252 [01:51<00:00,  2.25it/s]


correlation 0.31291357
poisson_loss 16143104.0


Epoch 22: 100%|██████████| 252/252 [01:52<00:00,  2.24it/s]


correlation 0.31412524
poisson_loss 16129080.0


Epoch 23: 100%|██████████| 252/252 [01:52<00:00,  2.25it/s]


correlation 0.3158773
poisson_loss 16111715.0


Epoch 24: 100%|██████████| 252/252 [01:51<00:00,  2.25it/s]


correlation 0.31692082
poisson_loss 16095086.0


Epoch 25: 100%|██████████| 252/252 [01:52<00:00,  2.24it/s]


correlation 0.31676725
poisson_loss 16103095.0


Epoch 26: 100%|██████████| 252/252 [01:52<00:00,  2.25it/s]


correlation 0.31772205
poisson_loss 16086119.0


Epoch 27: 100%|██████████| 252/252 [01:51<00:00,  2.25it/s]


correlation 0.3193344
poisson_loss 16072790.0


Epoch 28: 100%|██████████| 252/252 [01:51<00:00,  2.25it/s]


correlation 0.3202154
poisson_loss 16058058.0


Epoch 29: 100%|██████████| 252/252 [01:51<00:00,  2.25it/s]


correlation 0.3203649
poisson_loss 16057933.0


Epoch 30: 100%|██████████| 252/252 [01:51<00:00,  2.26it/s]


correlation 0.32139212
poisson_loss 16051613.0


Epoch 31: 100%|██████████| 252/252 [01:51<00:00,  2.26it/s]


correlation 0.32073522
poisson_loss 16049119.0


Epoch 32: 100%|██████████| 252/252 [01:52<00:00,  2.25it/s]


correlation 0.32224125
poisson_loss 16030712.0


Epoch 33: 100%|██████████| 252/252 [01:52<00:00,  2.24it/s]


correlation 0.32250893
poisson_loss 16033365.0


Epoch 34: 100%|██████████| 252/252 [01:52<00:00,  2.25it/s]


correlation 0.32163554
poisson_loss 16039430.0


Epoch 35: 100%|██████████| 252/252 [01:51<00:00,  2.26it/s]


correlation 0.3239246
poisson_loss 16006536.0


Epoch 36: 100%|██████████| 252/252 [01:51<00:00,  2.25it/s]


correlation 0.32301778
poisson_loss 16024837.0


Epoch 37: 100%|██████████| 252/252 [01:51<00:00,  2.25it/s]


correlation 0.3242318
poisson_loss 16011000.0


Epoch 38: 100%|██████████| 252/252 [01:51<00:00,  2.25it/s]


correlation 0.3239959
poisson_loss 16015448.0


Epoch 39: 100%|██████████| 252/252 [01:51<00:00,  2.25it/s]


correlation 0.32415468
poisson_loss 16006940.0


Epoch 40: 100%|██████████| 252/252 [01:51<00:00,  2.25it/s]


correlation 0.3251419
poisson_loss 15998176.0


Epoch 41: 100%|██████████| 252/252 [01:51<00:00,  2.25it/s]


correlation 0.32482696
poisson_loss 16006646.0


Epoch 42: 100%|██████████| 252/252 [01:51<00:00,  2.25it/s]


correlation 0.3254729
poisson_loss 15995165.0


Epoch 43: 100%|██████████| 252/252 [01:51<00:00,  2.26it/s]


correlation 0.32495102
poisson_loss 16004636.0


Epoch 44: 100%|██████████| 252/252 [01:51<00:00,  2.25it/s]


correlation 0.32605252
poisson_loss 15983002.0


Epoch 45: 100%|██████████| 252/252 [01:51<00:00,  2.26it/s]


correlation 0.3250852
poisson_loss 16005844.0


Epoch 46: 100%|██████████| 252/252 [01:52<00:00,  2.25it/s]


correlation 0.32641944
poisson_loss 15980580.0


Epoch 47: 100%|██████████| 252/252 [01:51<00:00,  2.25it/s]


correlation 0.32494414
poisson_loss 16002011.0


Epoch 48: 100%|██████████| 252/252 [01:52<00:00,  2.25it/s]


correlation 0.32656175
poisson_loss 15981710.0


Epoch 49: 100%|██████████| 252/252 [01:51<00:00,  2.25it/s]


correlation 0.3269574
poisson_loss 15975326.0


Epoch 50: 100%|██████████| 252/252 [01:51<00:00,  2.25it/s]


correlation 0.32748026
poisson_loss 15969009.0


Epoch 51: 100%|██████████| 252/252 [01:51<00:00,  2.25it/s]


correlation 0.32709423
poisson_loss 15976660.0


Epoch 52: 100%|██████████| 252/252 [01:51<00:00,  2.26it/s]


correlation 0.32787842
poisson_loss 15962883.0


Epoch 53: 100%|██████████| 252/252 [01:51<00:00,  2.25it/s]


correlation 0.3267963
poisson_loss 15990771.0


Epoch 54: 100%|██████████| 252/252 [01:51<00:00,  2.25it/s]


correlation 0.3269961
poisson_loss 15977605.0


Epoch 55: 100%|██████████| 252/252 [01:52<00:00,  2.25it/s]


correlation 0.32673138
poisson_loss 15984375.0


Epoch 56: 100%|██████████| 252/252 [01:51<00:00,  2.25it/s]


correlation 0.32767943
poisson_loss 15967374.0


Epoch 57: 100%|██████████| 252/252 [01:51<00:00,  2.25it/s]


correlation 0.327989
poisson_loss 15961628.0


Epoch 58: 100%|██████████| 252/252 [01:52<00:00,  2.24it/s]


Epoch    58: reducing learning rate of group 0 to 2.7000e-03.
correlation 0.32775673
poisson_loss 15962915.0


Epoch 59: 100%|██████████| 252/252 [01:51<00:00,  2.25it/s]


correlation 0.3344876
poisson_loss 15886452.0


Epoch 60: 100%|██████████| 252/252 [01:51<00:00,  2.26it/s]


correlation 0.33513543
poisson_loss 15883176.0


Epoch 61: 100%|██████████| 252/252 [01:51<00:00,  2.26it/s]


correlation 0.33545315
poisson_loss 15876017.0


Epoch 62: 100%|██████████| 252/252 [01:51<00:00,  2.25it/s]


correlation 0.33521894
poisson_loss 15881464.0


Epoch 63: 100%|██████████| 252/252 [01:51<00:00,  2.25it/s]


correlation 0.3349026
poisson_loss 15878594.0


Epoch 64: 100%|██████████| 252/252 [01:52<00:00,  2.25it/s]


correlation 0.33486256
poisson_loss 15879291.0


Epoch 65: 100%|██████████| 252/252 [01:52<00:00,  2.25it/s]


correlation 0.33531457
poisson_loss 15882805.0


Epoch 66: 100%|██████████| 252/252 [01:51<00:00,  2.25it/s]


correlation 0.33575752
poisson_loss 15873613.0


Epoch 67: 100%|██████████| 252/252 [01:52<00:00,  2.25it/s]


correlation 0.3352784
poisson_loss 15881134.0


Epoch 68: 100%|██████████| 252/252 [01:52<00:00,  2.25it/s]


correlation 0.335698
poisson_loss 15879374.0


Epoch 69: 100%|██████████| 252/252 [01:52<00:00,  2.25it/s]


correlation 0.3350131
poisson_loss 15888704.0


Epoch 70: 100%|██████████| 252/252 [01:52<00:00,  2.25it/s]


correlation 0.3357192
poisson_loss 15875375.0


Epoch 71: 100%|██████████| 252/252 [01:51<00:00,  2.25it/s]


correlation 0.33581546
poisson_loss 15876562.0


Epoch 72: 100%|██████████| 252/252 [01:52<00:00,  2.25it/s]


Epoch    72: reducing learning rate of group 0 to 8.1000e-04.
correlation 0.33503148
poisson_loss 15881602.0


Epoch 73: 100%|██████████| 252/252 [01:52<00:00,  2.25it/s]


correlation 0.3373322
poisson_loss 15854926.0


Epoch 74: 100%|██████████| 252/252 [01:52<00:00,  2.25it/s]


correlation 0.33732587
poisson_loss 15859550.0


Epoch 75: 100%|██████████| 252/252 [01:51<00:00,  2.25it/s]


correlation 0.33752754
poisson_loss 15858068.0


Epoch 76: 100%|██████████| 252/252 [01:51<00:00,  2.26it/s]


correlation 0.33747217
poisson_loss 15855225.0


Epoch 77: 100%|██████████| 252/252 [01:51<00:00,  2.25it/s]


correlation 0.33739132
poisson_loss 15855756.0


Epoch 78: 100%|██████████| 252/252 [01:52<00:00,  2.24it/s]


correlation 0.33711803
poisson_loss 15855454.0


Epoch 79: 100%|██████████| 252/252 [01:51<00:00,  2.25it/s]


Epoch    79: reducing learning rate of group 0 to 2.4300e-04.
correlation 0.33707926
poisson_loss 15860184.0


Epoch 80: 100%|██████████| 252/252 [01:52<00:00,  2.25it/s]


correlation 0.33776987
poisson_loss 15850498.0


Epoch 81: 100%|██████████| 252/252 [01:51<00:00,  2.25it/s]


correlation 0.33779463
poisson_loss 15851759.0


Epoch 82: 100%|██████████| 252/252 [01:52<00:00,  2.25it/s]


correlation 0.3379483
poisson_loss 15852816.0


Epoch 83: 100%|██████████| 252/252 [01:51<00:00,  2.25it/s]


In [17]:
torch.save(model_se_vone.state_dict(), '../model_tutorial/model_checkpoints/generalization_VOne_256_se_model.pth')

### Load Model 

In [18]:
model_se_vone.load_state_dict(torch.load("../model_tutorial/model_checkpoints/generalization_VOne_256_se_model.pth"));

In [19]:
model_se_vone.eval();

### Submission File

In [20]:
submission.generate_submission_file(trained_model=model_se_vone, 
                                    dataloaders=dataloaders,
                                    data_key=dataset_name,
                                    path="./submission_files/VOne256SE/",
                                    device="cuda")

Submission file saved for tier: live_test. Saved in: ./submission_files/VOne256SE/submission_file_live_test.csv
Submission file saved for tier: final_test. Saved in: ./submission_files/VOne256SE/submission_file_final_test.csv


In [21]:
del model_se_vone

# Model 3 - Scattering CNN

In [22]:
dataset_config_scat = {'paths': filenames,
                 'normalize': True,
                 'scattering': True, 
                 'include_behavior': False,
                 'include_eye_position': False,
                 'batch_size': 128,
                 'scale':1.,
                 }

dataloaders_scat = get_data(dataset_fn, dataset_config_scat)

In [23]:
model_fn_scat = 'sensorium.models.stacked_core_full_gauss_readout'

model_config_scat = {'pad_input': False,
  'stack': -1,
  'layers': 4,
  'input_kern': 9,
  'gamma_input': 6.3831,
  'gamma_readout': 0.0076,
  'hidden_kern': 7,
  'hidden_channels': 64,
  'depth_separable': True,
  'grid_mean_predictor': {'type': 'cortex',
   'input_dimensions': 2,
   'hidden_layers': 1,
   'hidden_features': 30,
   'final_tanh': True},
  'init_sigma': 0.1,
  'init_mu_range': 0.3,
  'gauss_type': 'full',
  'shifter': False,
}

model_scat = get_model(model_fn=model_fn_scat,
                  model_config=model_config_scat,
                  dataloaders=dataloaders_scat,
                  seed=42)

In [24]:
trainer_fn_scat = "sensorium.training.standard_trainer"

trainer_config_scat = {'max_iter': 300,
                 'verbose': True,
                  'track_training': True,
                 'lr_decay_steps': 4,
                 'avg_loss': False, 
                 'lr_init': 0.009,
                 }

trainer_scat = get_trainer(trainer_fn=trainer_fn_scat, 
                     trainer_config=trainer_config_scat)

### Model training

In [25]:
validation_score_scat, trainer_output_scat, state_dict_scat = trainer(model_scat, dataloaders_scat, seed=42)

correlation 0.00085983006
poisson_loss 30813174.0


Epoch 1: 100%|██████████| 252/252 [06:17<00:00,  1.50s/it]


correlation 0.09728123
poisson_loss 18598136.0


Epoch 2: 100%|██████████| 252/252 [05:54<00:00,  1.41s/it]


correlation 0.1297871
poisson_loss 18219434.0


Epoch 3: 100%|██████████| 252/252 [05:52<00:00,  1.40s/it]


correlation 0.16658545
poisson_loss 17847296.0


Epoch 4: 100%|██████████| 252/252 [05:52<00:00,  1.40s/it]


correlation 0.20383841
poisson_loss 17431458.0


Epoch 5: 100%|██████████| 252/252 [05:51<00:00,  1.39s/it]


correlation 0.23381865
poisson_loss 17091154.0


Epoch 6: 100%|██████████| 252/252 [05:53<00:00,  1.40s/it]


correlation 0.25540918
poisson_loss 16831706.0


Epoch 7: 100%|██████████| 252/252 [05:53<00:00,  1.40s/it]


correlation 0.2675754
poisson_loss 16693753.0


Epoch 8: 100%|██████████| 252/252 [05:52<00:00,  1.40s/it]


correlation 0.27625263
poisson_loss 16594078.0


Epoch 9: 100%|██████████| 252/252 [05:54<00:00,  1.41s/it]


correlation 0.28308916
poisson_loss 16519270.0


Epoch 10: 100%|██████████| 252/252 [05:53<00:00,  1.40s/it]


correlation 0.28882203
poisson_loss 16446216.0


Epoch 11: 100%|██████████| 252/252 [05:50<00:00,  1.39s/it]


correlation 0.2938879
poisson_loss 16388222.0


Epoch 12: 100%|██████████| 252/252 [05:51<00:00,  1.40s/it]


correlation 0.2981759
poisson_loss 16339995.0


Epoch 13: 100%|██████████| 252/252 [05:50<00:00,  1.39s/it]


correlation 0.303831
poisson_loss 16258810.0


Epoch 14: 100%|██████████| 252/252 [05:52<00:00,  1.40s/it]


correlation 0.3033518
poisson_loss 16271953.0


Epoch 15: 100%|██████████| 252/252 [05:53<00:00,  1.40s/it]


correlation 0.3068031
poisson_loss 16221714.0


Epoch 16: 100%|██████████| 252/252 [05:51<00:00,  1.40s/it]


correlation 0.31059343
poisson_loss 16184462.0


Epoch 17: 100%|██████████| 252/252 [05:53<00:00,  1.40s/it]


correlation 0.3117701
poisson_loss 16169471.0


Epoch 18: 100%|██████████| 252/252 [05:51<00:00,  1.40s/it]


correlation 0.31166273
poisson_loss 16167576.0


Epoch 19: 100%|██████████| 252/252 [05:52<00:00,  1.40s/it]


correlation 0.31512412
poisson_loss 16119231.0


Epoch 20: 100%|██████████| 252/252 [05:54<00:00,  1.40s/it]


correlation 0.314617
poisson_loss 16132927.0


Epoch 21: 100%|██████████| 252/252 [05:55<00:00,  1.41s/it]


correlation 0.3175892
poisson_loss 16086360.0


Epoch 22: 100%|██████████| 252/252 [05:52<00:00,  1.40s/it]


correlation 0.31752238
poisson_loss 16095530.0


Epoch 23: 100%|██████████| 252/252 [05:52<00:00,  1.40s/it]


correlation 0.31881332
poisson_loss 16081442.0


Epoch 24: 100%|██████████| 252/252 [05:51<00:00,  1.40s/it]


correlation 0.31957665
poisson_loss 16068656.0


Epoch 25: 100%|██████████| 252/252 [05:53<00:00,  1.40s/it]


correlation 0.3191006
poisson_loss 16077901.0


Epoch 26: 100%|██████████| 252/252 [05:52<00:00,  1.40s/it]


correlation 0.32077387
poisson_loss 16053385.0


Epoch 27: 100%|██████████| 252/252 [05:49<00:00,  1.39s/it]


correlation 0.32118005
poisson_loss 16049121.0


Epoch 28: 100%|██████████| 252/252 [05:51<00:00,  1.40s/it]


correlation 0.3222295
poisson_loss 16043546.0


Epoch 29: 100%|██████████| 252/252 [05:50<00:00,  1.39s/it]


correlation 0.32210886
poisson_loss 16039737.0


Epoch 30: 100%|██████████| 252/252 [05:55<00:00,  1.41s/it]


correlation 0.32260114
poisson_loss 16035691.0


Epoch 31: 100%|██████████| 252/252 [05:53<00:00,  1.40s/it]


correlation 0.3220683
poisson_loss 16041454.0


Epoch 32: 100%|██████████| 252/252 [05:52<00:00,  1.40s/it]


correlation 0.32358095
poisson_loss 16023062.0


Epoch 33: 100%|██████████| 252/252 [05:52<00:00,  1.40s/it]


correlation 0.32349333
poisson_loss 16028983.0


Epoch 34: 100%|██████████| 252/252 [05:50<00:00,  1.39s/it]


correlation 0.3239486
poisson_loss 16014286.0


Epoch 35: 100%|██████████| 252/252 [05:49<00:00,  1.39s/it]


correlation 0.3251492
poisson_loss 16008658.0


Epoch 36: 100%|██████████| 252/252 [05:50<00:00,  1.39s/it]


correlation 0.32405576
poisson_loss 16024255.0


Epoch 37: 100%|██████████| 252/252 [05:50<00:00,  1.39s/it]


correlation 0.32485607
poisson_loss 16007946.0


Epoch 38: 100%|██████████| 252/252 [05:50<00:00,  1.39s/it]


correlation 0.3245545
poisson_loss 16017459.0


Epoch 39: 100%|██████████| 252/252 [05:51<00:00,  1.39s/it]


correlation 0.32491636
poisson_loss 16004905.0


Epoch 40: 100%|██████████| 252/252 [05:49<00:00,  1.39s/it]


correlation 0.3260271
poisson_loss 15990187.0


Epoch 41: 100%|██████████| 252/252 [05:52<00:00,  1.40s/it]


correlation 0.32550177
poisson_loss 16001273.0


Epoch 42: 100%|██████████| 252/252 [05:52<00:00,  1.40s/it]


correlation 0.32573402
poisson_loss 15997666.0


Epoch 43: 100%|██████████| 252/252 [05:51<00:00,  1.39s/it]


correlation 0.32561338
poisson_loss 15998974.0


Epoch 44: 100%|██████████| 252/252 [05:50<00:00,  1.39s/it]


correlation 0.32672736
poisson_loss 15980279.0


Epoch 45: 100%|██████████| 252/252 [05:50<00:00,  1.39s/it]


correlation 0.3272527
poisson_loss 15979179.0


Epoch 46: 100%|██████████| 252/252 [05:49<00:00,  1.39s/it]


correlation 0.32708997
poisson_loss 15976057.0


Epoch 47: 100%|██████████| 252/252 [05:48<00:00,  1.38s/it]


correlation 0.32625702
poisson_loss 15990081.0


Epoch 48: 100%|██████████| 252/252 [05:50<00:00,  1.39s/it]


correlation 0.32655644
poisson_loss 15983433.0


Epoch 49: 100%|██████████| 252/252 [05:50<00:00,  1.39s/it]


correlation 0.32747003
poisson_loss 15976660.0


Epoch 50: 100%|██████████| 252/252 [05:51<00:00,  1.39s/it]


correlation 0.32818982
poisson_loss 15963950.0


Epoch 51: 100%|██████████| 252/252 [05:47<00:00,  1.38s/it]


correlation 0.3279931
poisson_loss 15966828.0


Epoch 52: 100%|██████████| 252/252 [05:50<00:00,  1.39s/it]


correlation 0.32795203
poisson_loss 15969291.0


Epoch 53: 100%|██████████| 252/252 [05:50<00:00,  1.39s/it]


correlation 0.32677796
poisson_loss 15994247.0


Epoch 54: 100%|██████████| 252/252 [05:49<00:00,  1.39s/it]


correlation 0.32719648
poisson_loss 15980648.0


Epoch 55: 100%|██████████| 252/252 [05:48<00:00,  1.38s/it]


correlation 0.32703936
poisson_loss 15987482.0


Epoch 56: 100%|██████████| 252/252 [05:51<00:00,  1.40s/it]


correlation 0.3279681
poisson_loss 15967889.0


Epoch 57: 100%|██████████| 252/252 [05:51<00:00,  1.40s/it]


Epoch    57: reducing learning rate of group 0 to 2.7000e-03.
correlation 0.32680494
poisson_loss 15989935.0


Epoch 58: 100%|██████████| 252/252 [05:49<00:00,  1.39s/it]


correlation 0.33480477
poisson_loss 15885134.0


Epoch 59: 100%|██████████| 252/252 [05:50<00:00,  1.39s/it]


correlation 0.33557093
poisson_loss 15879250.0


Epoch 60: 100%|██████████| 252/252 [05:49<00:00,  1.39s/it]


correlation 0.33534425
poisson_loss 15889979.0


Epoch 61: 100%|██████████| 252/252 [05:52<00:00,  1.40s/it]


correlation 0.33569986
poisson_loss 15882371.0


Epoch 62: 100%|██████████| 252/252 [05:51<00:00,  1.40s/it]


correlation 0.33570316
poisson_loss 15887035.0


Epoch 63: 100%|██████████| 252/252 [05:51<00:00,  1.39s/it]


correlation 0.33533666
poisson_loss 15887216.0


Epoch 64: 100%|██████████| 252/252 [05:50<00:00,  1.39s/it]


correlation 0.33507496
poisson_loss 15888899.0


Epoch 65: 100%|██████████| 252/252 [05:50<00:00,  1.39s/it]


correlation 0.33533308
poisson_loss 15886776.0


Epoch 66: 100%|██████████| 252/252 [05:49<00:00,  1.39s/it]


correlation 0.3357461
poisson_loss 15882754.0


Epoch 67: 100%|██████████| 252/252 [05:51<00:00,  1.40s/it]


correlation 0.33566692
poisson_loss 15884101.0


Epoch 68: 100%|██████████| 252/252 [05:52<00:00,  1.40s/it]


correlation 0.33548194
poisson_loss 15892031.0


Epoch 69: 100%|██████████| 252/252 [05:49<00:00,  1.39s/it]


correlation 0.33499855
poisson_loss 15896324.0


Epoch 70: 100%|██████████| 252/252 [05:50<00:00,  1.39s/it]


correlation 0.33557913
poisson_loss 15885068.0


Epoch 71: 100%|██████████| 252/252 [05:54<00:00,  1.41s/it]


correlation 0.3359958
poisson_loss 15879960.0


Epoch 72: 100%|██████████| 252/252 [05:50<00:00,  1.39s/it]


correlation 0.33464846
poisson_loss 15902784.0


Epoch 73: 100%|██████████| 252/252 [05:52<00:00,  1.40s/it]


correlation 0.3355489
poisson_loss 15881860.0


Epoch 74: 100%|██████████| 252/252 [05:49<00:00,  1.39s/it]


correlation 0.3353317
poisson_loss 15893047.0


Epoch 75: 100%|██████████| 252/252 [05:48<00:00,  1.38s/it]


correlation 0.3348914
poisson_loss 15901380.0


Epoch 76: 100%|██████████| 252/252 [05:48<00:00,  1.38s/it]


correlation 0.33604282
poisson_loss 15879640.0


Epoch 77: 100%|██████████| 252/252 [05:49<00:00,  1.39s/it]


Epoch    77: reducing learning rate of group 0 to 8.1000e-04.
correlation 0.33501834
poisson_loss 15893337.0


Epoch 78: 100%|██████████| 252/252 [05:49<00:00,  1.39s/it]


correlation 0.33711293
poisson_loss 15867947.0


Epoch 79: 100%|██████████| 252/252 [05:48<00:00,  1.38s/it]


correlation 0.3374471
poisson_loss 15866382.0


Epoch 80: 100%|██████████| 252/252 [05:48<00:00,  1.38s/it]


correlation 0.33758056
poisson_loss 15866889.0


Epoch 81: 100%|██████████| 252/252 [05:50<00:00,  1.39s/it]


correlation 0.33737594
poisson_loss 15868422.0


Epoch 82: 100%|██████████| 252/252 [05:51<00:00,  1.40s/it]


correlation 0.3374794
poisson_loss 15865786.0


Epoch 83: 100%|██████████| 252/252 [05:51<00:00,  1.39s/it]


correlation 0.33761305
poisson_loss 15866137.0


Epoch 84: 100%|██████████| 252/252 [05:49<00:00,  1.39s/it]


correlation 0.3374613
poisson_loss 15865731.0


Epoch 85: 100%|██████████| 252/252 [05:51<00:00,  1.40s/it]


correlation 0.3379009
poisson_loss 15858704.0


Epoch 86: 100%|██████████| 252/252 [05:51<00:00,  1.39s/it]


correlation 0.33768508
poisson_loss 15864204.0


Epoch 87: 100%|██████████| 252/252 [05:50<00:00,  1.39s/it]


correlation 0.33774808
poisson_loss 15861747.0


Epoch 88: 100%|██████████| 252/252 [05:51<00:00,  1.39s/it]


correlation 0.33753893
poisson_loss 15864352.0


Epoch 89: 100%|██████████| 252/252 [05:50<00:00,  1.39s/it]


correlation 0.33765787
poisson_loss 15865966.0


Epoch 90: 100%|██████████| 252/252 [05:49<00:00,  1.39s/it]


correlation 0.337493
poisson_loss 15868681.0


Epoch 91: 100%|██████████| 252/252 [05:51<00:00,  1.40s/it]


In [26]:
torch.save(model_scat.state_dict(), '../model_tutorial/model_checkpoints/generalization_scattering_model.pth')

### Load Checkpoints

In [27]:
model_scat.load_state_dict(torch.load("../model_tutorial/model_checkpoints/generalization_scattering_model.pth"));

In [28]:
model_scat.eval();

### Submission File

In [29]:
submission.generate_submission_file(trained_model=model_scat, 
                                    dataloaders=dataloaders_scat,
                                    data_key=dataset_name,
                                    path="./submission_files/Scat/",
                                    device="cuda")

Submission file saved for tier: live_test. Saved in: ./submission_files/Scat/submission_file_live_test.csv
Submission file saved for tier: final_test. Saved in: ./submission_files/Scat/submission_file_final_test.csv


In [30]:
del model_scat

# Model 4 - Scattering SE CNN

In [31]:
dataset_config_scat_se = {'paths': filenames,
                 'normalize': True,
                 'scattering': True, 
                 'include_behavior': False,
                 'include_eye_position': False,
                 'batch_size': 128,
                 'scale':1.,
                 }

dataloaders_scat_se = get_data(dataset_fn, dataset_config_scat_se)

In [32]:
model_fn_scat_se = 'sensorium.models.my_models.se_core_full_gauss_readout'

model_config_scat_se = {'pad_input': False,
  'stack': -1,
  'layers': 4,
  'input_kern': 9,
  'gamma_input': 6.3831,
  'gamma_readout': 0.0076,
  'hidden_kern': 7,
  'hidden_channels': 64,
  'depth_separable': True,
  'grid_mean_predictor': {'type': 'cortex',
   'input_dimensions': 2,
   'hidden_layers': 1,
   'hidden_features': 30,
   'final_tanh': True},
  'init_sigma': 0.1,
  'init_mu_range': 0.3,
  'gauss_type': 'full',
  'shifter': False,
}

model_scat_se = get_model(model_fn=model_fn_scat_se,
                  model_config=model_config_scat_se,
                  dataloaders=dataloaders_scat_se,
                  seed=42)

In [33]:
trainer_fn_scat_se = "sensorium.training.standard_trainer"

trainer_config_scat_se = {'max_iter': 300,
                 'verbose': True,
                  'track_training': True,
                 'lr_decay_steps': 4,
                 'avg_loss': False, # True,
                 'lr_init': 0.009,
                 }

trainer_scat_se = get_trainer(trainer_fn=trainer_fn_scat_se, 
                     trainer_config=trainer_config_scat_se)

### Model training

In [34]:
validation_score_scat_se, trainer_output_scat_se, state_dict_scat_se = trainer(model_scat_se, dataloaders_scat_se, seed=42)

correlation 0.00085983006
poisson_loss 30813174.0


Epoch 1: 100%|██████████| 252/252 [06:28<00:00,  1.54s/it]


correlation 0.097286575
poisson_loss 18598072.0


Epoch 2: 100%|██████████| 252/252 [05:52<00:00,  1.40s/it]


correlation 0.12977888
poisson_loss 18219474.0


Epoch 3: 100%|██████████| 252/252 [05:51<00:00,  1.40s/it]


correlation 0.16657974
poisson_loss 17847466.0


Epoch 4: 100%|██████████| 252/252 [05:53<00:00,  1.40s/it]


correlation 0.20382412
poisson_loss 17431868.0


Epoch 5: 100%|██████████| 252/252 [05:54<00:00,  1.41s/it]


correlation 0.23380183
poisson_loss 17091434.0


Epoch 6: 100%|██████████| 252/252 [05:53<00:00,  1.40s/it]


correlation 0.2553136
poisson_loss 16832504.0


Epoch 7: 100%|██████████| 252/252 [05:56<00:00,  1.42s/it]


correlation 0.26750746
poisson_loss 16694542.0


Epoch 8: 100%|██████████| 252/252 [05:56<00:00,  1.42s/it]


correlation 0.27629665
poisson_loss 16592072.0


Epoch 9: 100%|██████████| 252/252 [05:55<00:00,  1.41s/it]


correlation 0.28321865
poisson_loss 16518741.0


Epoch 10: 100%|██████████| 252/252 [05:55<00:00,  1.41s/it]


correlation 0.28894228
poisson_loss 16444903.0


Epoch 11: 100%|██████████| 252/252 [05:55<00:00,  1.41s/it]


correlation 0.29403597
poisson_loss 16386545.0


Epoch 12: 100%|██████████| 252/252 [05:55<00:00,  1.41s/it]


correlation 0.29823405
poisson_loss 16338009.0


Epoch 13: 100%|██████████| 252/252 [05:56<00:00,  1.42s/it]


correlation 0.3039807
poisson_loss 16257331.0


Epoch 14: 100%|██████████| 252/252 [05:56<00:00,  1.42s/it]


correlation 0.3033345
poisson_loss 16272589.0


Epoch 15: 100%|██████████| 252/252 [05:53<00:00,  1.40s/it]


correlation 0.30672663
poisson_loss 16222448.0


Epoch 16: 100%|██████████| 252/252 [05:55<00:00,  1.41s/it]


correlation 0.3106179
poisson_loss 16183547.0


Epoch 17: 100%|██████████| 252/252 [05:55<00:00,  1.41s/it]


correlation 0.31173757
poisson_loss 16169034.0


Epoch 18: 100%|██████████| 252/252 [05:54<00:00,  1.41s/it]


correlation 0.31163013
poisson_loss 16167321.0


Epoch 19: 100%|██████████| 252/252 [05:56<00:00,  1.41s/it]


correlation 0.3151165
poisson_loss 16119925.0


Epoch 20: 100%|██████████| 252/252 [05:54<00:00,  1.41s/it]


correlation 0.3146559
poisson_loss 16132997.0


Epoch 21: 100%|██████████| 252/252 [05:57<00:00,  1.42s/it]


correlation 0.3175699
poisson_loss 16086385.0


Epoch 22: 100%|██████████| 252/252 [05:55<00:00,  1.41s/it]


correlation 0.31757832
poisson_loss 16094947.0


Epoch 23: 100%|██████████| 252/252 [05:55<00:00,  1.41s/it]


correlation 0.31882408
poisson_loss 16080866.0


Epoch 24: 100%|██████████| 252/252 [05:55<00:00,  1.41s/it]


correlation 0.3195542
poisson_loss 16068368.0


Epoch 25: 100%|██████████| 252/252 [05:53<00:00,  1.40s/it]


correlation 0.31913608
poisson_loss 16077185.0


Epoch 26: 100%|██████████| 252/252 [05:56<00:00,  1.41s/it]


correlation 0.32075688
poisson_loss 16053574.0


Epoch 27: 100%|██████████| 252/252 [05:52<00:00,  1.40s/it]


correlation 0.3211969
poisson_loss 16048994.0


Epoch 28: 100%|██████████| 252/252 [05:56<00:00,  1.41s/it]


correlation 0.32222214
poisson_loss 16043806.0


Epoch 29: 100%|██████████| 252/252 [05:54<00:00,  1.41s/it]


correlation 0.3221119
poisson_loss 16039454.0


Epoch 30: 100%|██████████| 252/252 [05:56<00:00,  1.41s/it]


correlation 0.32256988
poisson_loss 16036718.0


Epoch 31: 100%|██████████| 252/252 [05:56<00:00,  1.41s/it]


correlation 0.32200754
poisson_loss 16041189.0


Epoch 32: 100%|██████████| 252/252 [05:56<00:00,  1.42s/it]


correlation 0.32356966
poisson_loss 16023713.0


Epoch 33: 100%|██████████| 252/252 [05:54<00:00,  1.41s/it]


correlation 0.32351542
poisson_loss 16028252.0


Epoch 34: 100%|██████████| 252/252 [05:56<00:00,  1.41s/it]


correlation 0.3239274
poisson_loss 16014366.0


Epoch 35: 100%|██████████| 252/252 [05:56<00:00,  1.42s/it]


correlation 0.3250689
poisson_loss 16009318.0


Epoch 36: 100%|██████████| 252/252 [05:57<00:00,  1.42s/it]


correlation 0.32408503
poisson_loss 16024192.0


Epoch 37: 100%|██████████| 252/252 [05:56<00:00,  1.42s/it]


correlation 0.32482836
poisson_loss 16008660.0


Epoch 38: 100%|██████████| 252/252 [05:54<00:00,  1.41s/it]


correlation 0.32467407
poisson_loss 16015881.0


Epoch 39: 100%|██████████| 252/252 [05:55<00:00,  1.41s/it]


correlation 0.32493657
poisson_loss 16004441.0


Epoch 40: 100%|██████████| 252/252 [05:58<00:00,  1.42s/it]


correlation 0.3260387
poisson_loss 15989672.0


Epoch 41: 100%|██████████| 252/252 [05:57<00:00,  1.42s/it]


correlation 0.32549632
poisson_loss 16000856.0


Epoch 42: 100%|██████████| 252/252 [05:55<00:00,  1.41s/it]


correlation 0.32573292
poisson_loss 15997647.0


Epoch 43: 100%|██████████| 252/252 [05:55<00:00,  1.41s/it]


correlation 0.32557648
poisson_loss 15999494.0


Epoch 44: 100%|██████████| 252/252 [05:55<00:00,  1.41s/it]


correlation 0.32673493
poisson_loss 15980187.0


Epoch 45: 100%|██████████| 252/252 [05:55<00:00,  1.41s/it]


correlation 0.3272479
poisson_loss 15979106.0


Epoch 46: 100%|██████████| 252/252 [05:55<00:00,  1.41s/it]


correlation 0.3270534
poisson_loss 15976383.0


Epoch 47: 100%|██████████| 252/252 [05:55<00:00,  1.41s/it]


correlation 0.32623434
poisson_loss 15990376.0


Epoch 48: 100%|██████████| 252/252 [05:54<00:00,  1.41s/it]


correlation 0.3265371
poisson_loss 15983715.0


Epoch 49: 100%|██████████| 252/252 [05:57<00:00,  1.42s/it]


correlation 0.32745764
poisson_loss 15976761.0


Epoch 50: 100%|██████████| 252/252 [05:52<00:00,  1.40s/it]


correlation 0.32815558
poisson_loss 15964143.0


Epoch 51: 100%|██████████| 252/252 [05:54<00:00,  1.41s/it]


correlation 0.32798266
poisson_loss 15966494.0


Epoch 52: 100%|██████████| 252/252 [05:55<00:00,  1.41s/it]


correlation 0.32794032
poisson_loss 15969496.0


Epoch 53: 100%|██████████| 252/252 [05:53<00:00,  1.40s/it]


correlation 0.32678807
poisson_loss 15994118.0


Epoch 54: 100%|██████████| 252/252 [05:54<00:00,  1.41s/it]


correlation 0.32719263
poisson_loss 15981308.0


Epoch 55: 100%|██████████| 252/252 [05:54<00:00,  1.41s/it]


correlation 0.3270488
poisson_loss 15987544.0


Epoch 56: 100%|██████████| 252/252 [05:55<00:00,  1.41s/it]


correlation 0.32794464
poisson_loss 15968053.0


Epoch 57: 100%|██████████| 252/252 [05:55<00:00,  1.41s/it]


Epoch    57: reducing learning rate of group 0 to 2.7000e-03.
correlation 0.32680902
poisson_loss 15990380.0


Epoch 58: 100%|██████████| 252/252 [05:55<00:00,  1.41s/it]


correlation 0.33481172
poisson_loss 15885173.0


Epoch 59: 100%|██████████| 252/252 [05:56<00:00,  1.42s/it]


correlation 0.33555377
poisson_loss 15879314.0


Epoch 60: 100%|██████████| 252/252 [05:55<00:00,  1.41s/it]


correlation 0.33532402
poisson_loss 15890117.0


Epoch 61: 100%|██████████| 252/252 [05:54<00:00,  1.41s/it]


correlation 0.3356719
poisson_loss 15882520.0


Epoch 62: 100%|██████████| 252/252 [05:54<00:00,  1.41s/it]


correlation 0.33568838
poisson_loss 15887148.0


Epoch 63: 100%|██████████| 252/252 [05:54<00:00,  1.41s/it]


correlation 0.33531332
poisson_loss 15887493.0


Epoch 64: 100%|██████████| 252/252 [05:54<00:00,  1.41s/it]


correlation 0.3350486
poisson_loss 15889130.0


Epoch 65: 100%|██████████| 252/252 [05:52<00:00,  1.40s/it]


correlation 0.33532465
poisson_loss 15886816.0


Epoch 66: 100%|██████████| 252/252 [05:52<00:00,  1.40s/it]


correlation 0.3357396
poisson_loss 15882778.0


Epoch 67: 100%|██████████| 252/252 [05:53<00:00,  1.40s/it]


correlation 0.3356522
poisson_loss 15884178.0


Epoch 68: 100%|██████████| 252/252 [05:54<00:00,  1.41s/it]


correlation 0.33547828
poisson_loss 15892118.0


Epoch 69: 100%|██████████| 252/252 [05:51<00:00,  1.40s/it]


correlation 0.33498493
poisson_loss 15896264.0


Epoch 70: 100%|██████████| 252/252 [05:53<00:00,  1.40s/it]


correlation 0.33556294
poisson_loss 15885095.0


Epoch 71: 100%|██████████| 252/252 [05:55<00:00,  1.41s/it]


correlation 0.33598867
poisson_loss 15880067.0


Epoch 72: 100%|██████████| 252/252 [05:55<00:00,  1.41s/it]


correlation 0.33465827
poisson_loss 15902642.0


Epoch 73: 100%|██████████| 252/252 [05:52<00:00,  1.40s/it]


correlation 0.33554336
poisson_loss 15881978.0


Epoch 74: 100%|██████████| 252/252 [05:52<00:00,  1.40s/it]


correlation 0.33532113
poisson_loss 15893142.0


Epoch 75: 100%|██████████| 252/252 [05:56<00:00,  1.41s/it]


correlation 0.33487743
poisson_loss 15901497.0


Epoch 76: 100%|██████████| 252/252 [05:53<00:00,  1.40s/it]


correlation 0.3360291
poisson_loss 15879690.0


Epoch 77: 100%|██████████| 252/252 [05:53<00:00,  1.40s/it]


Epoch    77: reducing learning rate of group 0 to 8.1000e-04.
correlation 0.33499277
poisson_loss 15893588.0


Epoch 78: 100%|██████████| 252/252 [05:55<00:00,  1.41s/it]


correlation 0.33710456
poisson_loss 15868024.0


Epoch 79: 100%|██████████| 252/252 [05:54<00:00,  1.41s/it]


correlation 0.3374441
poisson_loss 15866470.0


Epoch 80: 100%|██████████| 252/252 [05:52<00:00,  1.40s/it]


correlation 0.3375779
poisson_loss 15866993.0


Epoch 81: 100%|██████████| 252/252 [05:52<00:00,  1.40s/it]


correlation 0.337363
poisson_loss 15868542.0


Epoch 82: 100%|██████████| 252/252 [05:54<00:00,  1.40s/it]


correlation 0.3374622
poisson_loss 15865881.0


Epoch 83: 100%|██████████| 252/252 [05:54<00:00,  1.41s/it]


correlation 0.3376005
poisson_loss 15866276.0


Epoch 84: 100%|██████████| 252/252 [05:55<00:00,  1.41s/it]


correlation 0.33745152
poisson_loss 15865886.0


Epoch 85: 100%|██████████| 252/252 [05:52<00:00,  1.40s/it]


correlation 0.33788666
poisson_loss 15858890.0


Epoch 86: 100%|██████████| 252/252 [05:54<00:00,  1.41s/it]


correlation 0.33767277
poisson_loss 15864285.0


Epoch 87: 100%|██████████| 252/252 [05:53<00:00,  1.40s/it]


correlation 0.33773184
poisson_loss 15861880.0


Epoch 88: 100%|██████████| 252/252 [05:53<00:00,  1.40s/it]


correlation 0.33752874
poisson_loss 15864459.0


Epoch 89: 100%|██████████| 252/252 [05:56<00:00,  1.41s/it]


correlation 0.3376406
poisson_loss 15866151.0


Epoch 90: 100%|██████████| 252/252 [05:54<00:00,  1.41s/it]


correlation 0.33748215
poisson_loss 15868708.0


Epoch 91: 100%|██████████| 252/252 [05:55<00:00,  1.41s/it]


In [35]:
torch.save(model_scat_se.state_dict(), '../model_tutorial/model_checkpoints/generalization_se_scattering_model.pth')

### Load Checkpoints

In [36]:
model_scat_se.load_state_dict(torch.load("../model_tutorial/model_checkpoints/generalization_se_scattering_model.pth"));

In [37]:
model_scat_se.eval();

### Submission File

In [38]:
submission.generate_submission_file(trained_model=model_scat_se, 
                                    dataloaders=dataloaders_scat_se,
                                    data_key=dataset_name,
                                    path="./submission_files/ScatSE/",
                                    device="cuda")

Submission file saved for tier: live_test. Saved in: ./submission_files/ScatSE/submission_file_live_test.csv
Submission file saved for tier: final_test. Saved in: ./submission_files/ScatSE/submission_file_final_test.csv


In [39]:
del model_scat_se

# Let's Ensemble!

In [40]:
import pandas as pd
from tqdm import tqdm

vone = pd.read_csv('./submission_files/VOne256/submission_file_live_test.csv')
scat = pd.read_csv('./submission_files/Scat/submission_file_live_test.csv')
se_scat = pd.read_csv('./submission_files/ScatSE/submission_file_live_test.csv')
vone_se = pd.read_csv('./submission_files/VOne256SE/submission_file_live_test.csv')

vone_final = pd.read_csv('./submission_files/VOne256/submission_file_final_test.csv')
scat_final = pd.read_csv('./submission_files/Scat/submission_file_final_test.csv')
se_scat_final = pd.read_csv('./submission_files/ScatSE/submission_file_final_test.csv')
vone_se_final = pd.read_csv('./submission_files/VOne256SE/submission_file_final_test.csv')


In [41]:
for i in tqdm(range(0, vone.shape[0])):
    vone["prediction"][i] = [float(x) for x in vone.prediction[i][1:-1].split(', ')] 
    scat["prediction"][i] = [float(x) for x in scat.prediction[i][1:-1].split(', ')] 
    se_scat["prediction"][i] = [float(x) for x in se_scat.prediction[i][1:-1].split(', ')] 
    vone_se["prediction"][i] = [float(x) for x in vone_se.prediction[i][1:-1].split(', ')] 

100%|██████████| 990/990 [00:08<00:00, 111.84it/s]


In [42]:
for i in tqdm(range(0, vone_final.shape[0])):
    vone_final["prediction"][i] = [float(x) for x in vone_final.prediction[i][1:-1].split(', ')] 
    scat_final["prediction"][i] = [float(x) for x in scat_final.prediction[i][1:-1].split(', ')] 
    se_scat_final["prediction"][i] = [float(x) for x in se_scat_final.prediction[i][1:-1].split(', ')] 
    vone_se_final["prediction"][i] = [float(x) for x in vone_se_final.prediction[i][1:-1].split(', ')] 
    
    

100%|██████████| 995/995 [00:08<00:00, 110.76it/s]


In [43]:
for i in tqdm(range(0, vone.shape[0])):
    vone["prediction"][i] = str(list((np.array(vone.prediction[i]) + 
                                        np.array(scat.prediction[i]) + 
                                        np.array(se_scat.prediction[i]) +
                                        np.array(vone_se.prediction[i])) / 4))
    
    

100%|██████████| 990/990 [00:04<00:00, 226.72it/s]


In [44]:
for i in tqdm(range(0, vone_final.shape[0])):
    vone_final["prediction"][i] = str(list((np.array(vone_final.prediction[i]) 
                                              + np.array(scat_final.prediction[i])
                                             + np.array(se_scat_final.prediction[i])
                                             + np.array(vone_se_final.prediction[i]))/4))
    

100%|██████████| 995/995 [00:04<00:00, 223.36it/s]


### Submission Files

In [48]:
vone.to_csv('./submission_files/ensembling_live_test_256_scat_se_256_scat.csv', index = False)

In [49]:
vone_final.to_csv('./submission_files/ensembling_final_test_256_scat_se_256_scat.csv', index = False)