# Model Training with nnfabrik

In [1]:
%load_ext autoreload
%autoreload 2

import datajoint as dj
dj.config["enable_python_native_blobs"] = True

import os
os.environ['CUDA_VISIBLE_DEVICES'] = '3'
import sensorium

import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from nnfabrik.builder import get_data, get_model, get_trainer


dj.config['nnfabrik.schema_name'] = "nnfabrik_dj_tutorial"

from sensorium.schemas.nnfabrik import Fabrikant, Model, Dataset, Trainer, Seed, TrainedModel, main_nnfabrik

Connecting kwilleke@134.76.19.44:3306


In [2]:
# this is the datajoint schema
schema = main_nnfabrik.schema
schema

Schema `nnfabrik_dj_tutorial`

In [3]:
Fabrikant()

fabrikant_name  Name of the contributor that added this entry,full_name  full name of the person,email  e-mail address,affiliation  conributor's affiliation (e.g. Sinz Lab),dj_username  DataJoint username
,,,,


# Initial Addition of Tables for the Schema

In [4]:
Fabrikant().insert1(dict(fabrikant_name='kwilleke',
                         email="konstantin.willeke@gmail.com",
                         affiliation='sinzlab',
                         dj_username="kwilleke"))

In [7]:
Fabrikant()

fabrikant_name  Name of the contributor that added this entry,full_name  full name of the person,email  e-mail address,affiliation  conributor's affiliation (e.g. Sinz Lab),dj_username  DataJoint username
kwilleke,,konstantin.willeke@gmail.com,sinzlab,kwilleke


In [5]:
Seed()

seed  Random seed that is passed to the model- and dataset-builder


In [6]:
Seed().insert1(dict(seed=1000))
Seed()

seed  Random seed that is passed to the model- and dataset-builder
1000


# Add Dataset

In [8]:
# the cascade mouse dataloader takes the files from the Preprocessed mouse table and stores them on the server under /data/mouse/toliaslab/static/. like the other datasets
filenames = ['../data/static27204-5-13-GrayImageNet-94c6ff995dac583098847cfecd43e7b6.zip', ]

dataset_fn = 'sensorium.datasets.static_loaders'
dataset_config = {'paths': filenames,
                 'normalize': True,
                 'include_behavior': True,
                 'include_eye_position': True,
                 'batch_size': 128,
                 'exclude': None,
                 'file_tree': True,
                 'scale': .25,
                 'add_behavior_as_channels': False,
                 }

dataloaders = get_data(dataset_fn, dataset_config)

In [9]:
Dataset().add_entry(dataset_fn, dataset_config, dataset_comment="27204-5-13, 36x64px, w/o behavior")

{'dataset_fn': 'sensorium.datasets.static_loaders',
 'dataset_hash': '4062a9e2dce9da50014ed2d1e209e925'}

In [10]:
b = next(iter(dataloaders["train"]["27204-5-13"]))

In [11]:
b.images.shape

torch.Size([128, 1, 36, 64])

# Add Model

In [12]:
model_fn = 'sensorium.models.stacked_core_full_gauss_readout'
model_config = {'pad_input': False,
  'stack': -1,
  'layers': 4,
  'input_kern': 9,
  'gamma_input': 6.3831,
  'gamma_readout': 0.0076,
  'hidden_kern': 7,
  'hidden_channels': 64,
  'depth_separable': True,
  'grid_mean_predictor': {'type': 'cortex',
   'input_dimensions': 2,
   'hidden_layers': 1,
   'hidden_features': 30,
   'final_tanh': True},
  'init_sigma': 0.1,
  'init_mu_range': 0.3,
  'gauss_type': 'full',
  'shifter': False,
}

model = get_model(model_fn=model_fn,
                  model_config=model_config,
                  dataloaders=dataloaders,
                  seed=42,)



In [13]:
Model().add_entry(model_fn, model_config, model_comment="default model")

{'model_fn': 'sensorium.models.stacked_core_full_gauss_readout',
 'model_hash': 'c498ac31e1eba12fa5cce4c6079cdd65'}

# Add Trainer

In [14]:
# let's add a quick trainer that just runs for 10 epochs
trainer_fn = "sensorium.training.standard_trainer"

trainer_config = {'max_iter': 10,
                 'verbose': False,
                 'lr_decay_steps': 4,
                 'avg_loss': False,
                 'lr_init': 0.009,
                 }

trainer = get_trainer(trainer_fn=trainer_fn, 
                     trainer_config=trainer_config)

In [17]:
Trainer().add_entry(trainer_fn, trainer_config, trainer_comment="quick trainer, 10 epochs")

{'trainer_fn': 'sensorium.training.standard_trainer',
 'trainer_hash': '583d1d028833c8198424973f1b53e872'}

# TrainedModel Table

In [19]:
# this table will run the training for us
TrainedModel()

model_fn  name of the model function,model_hash  hash of the model configuration,dataset_fn  name of the dataset loader function,dataset_hash  hash of the configuration object,trainer_fn  name of the Trainer loader function,trainer_hash  hash of the configuration object,seed  Random seed that is passed to the model- and dataset-builder,comment  short description,score  loss,output  trainer object's output,fabrikant_name  Name of the contributor that added this entry,trainedmodel_ts  UTZ timestamp at time of insertion
,,,,,,,,,,,


# Train Model

this table depends on the 4 tables that we've created earlier

In [20]:
Model()

model_fn  name of the model function,model_hash  hash of the model configuration,model_config  model configuration to be passed into the function,model_fabrikant  Name of the contributor that added this entry,model_comment  short description,model_ts  UTZ timestamp at time of insertion
sensorium.models.stacked_core_full_gauss_readout,c498ac31e1eba12fa5cce4c6079cdd65,=BLOB=,kwilleke,default model,2023-12-21 19:11:40


In [21]:
Dataset()

dataset_fn  name of the dataset loader function,dataset_hash  hash of the configuration object,dataset_config  dataset configuration object,dataset_fabrikant  Name of the contributor that added this entry,dataset_comment  short description,dataset_ts  UTZ timestamp at time of insertion
sensorium.datasets.static_loaders,4062a9e2dce9da50014ed2d1e209e925,=BLOB=,kwilleke,"27204-5-13, 36x64px, w/o behavior",2023-12-21 19:11:34


In [22]:
Trainer()

trainer_fn  name of the Trainer loader function,trainer_hash  hash of the configuration object,trainer_config  training configuration object,trainer_fabrikant  Name of the contributor that added this entry,trainer_comment  short description,trainer_ts  UTZ timestamp at time of insertion
sensorium.training.standard_trainer,583d1d028833c8198424973f1b53e872,=BLOB=,kwilleke,"quick trainer, 10 epochs",2023-12-21 19:12:15


In [23]:
Seed()

seed  Random seed that is passed to the model- and dataset-builder
1000


## run training

In [25]:
## if we populate this table, it will run all combinations of seeds/models/datasets/trainers
TrainedModel().populate(display_progress=True)


Epoch 1:   0%|          | 0/35 [00:00<?, ?it/s][A
Epoch 1:   3%|▎         | 1/35 [00:00<00:15,  2.22it/s][A
Epoch 1:   6%|▌         | 2/35 [00:00<00:14,  2.24it/s][A
Epoch 1:   9%|▊         | 3/35 [00:01<00:14,  2.21it/s][A
Epoch 1:  11%|█▏        | 4/35 [00:01<00:14,  2.18it/s][A
Epoch 1:  14%|█▍        | 5/35 [00:02<00:13,  2.17it/s][A
Epoch 1:  17%|█▋        | 6/35 [00:02<00:13,  2.15it/s][A
Epoch 1:  20%|██        | 7/35 [00:03<00:12,  2.16it/s][A
Epoch 1:  23%|██▎       | 8/35 [00:03<00:12,  2.15it/s][A
Epoch 1:  26%|██▌       | 9/35 [00:04<00:12,  2.15it/s][A
Epoch 1:  29%|██▊       | 10/35 [00:04<00:11,  2.16it/s][A
Epoch 1:  31%|███▏      | 11/35 [00:05<00:11,  2.17it/s][A
Epoch 1:  34%|███▍      | 12/35 [00:05<00:10,  2.16it/s][A
Epoch 1:  37%|███▋      | 13/35 [00:06<00:10,  2.17it/s][A
Epoch 1:  40%|████      | 14/35 [00:06<00:09,  2.16it/s][A
Epoch 1:  43%|████▎     | 15/35 [00:06<00:09,  2.16it/s][A
Epoch 1:  46%|████▌     | 16/35 [00:07<00:08,  2.17it/s]

# Look at Result

In [26]:
TrainedModel()

model_fn  name of the model function,model_hash  hash of the model configuration,dataset_fn  name of the dataset loader function,dataset_hash  hash of the configuration object,trainer_fn  name of the Trainer loader function,trainer_hash  hash of the configuration object,seed  Random seed that is passed to the model- and dataset-builder,comment  short description,score  loss,output  trainer object's output,fabrikant_name  Name of the contributor that added this entry,trainedmodel_ts  UTZ timestamp at time of insertion
sensorium.models.stacked_core_full_gauss_readout,c498ac31e1eba12fa5cce4c6079cdd65,sensorium.datasets.static_loaders,4062a9e2dce9da50014ed2d1e209e925,sensorium.training.standard_trainer,583d1d028833c8198424973f1b53e872,1000,"quick trainer, 10 epochs.default model.27204-5-13, 36x64px, w/o behavior",0.23279,=BLOB=,kwilleke,2023-12-21 19:17:26


## Validation Correlation ("Score") = 0.23, not bad for 10 epochs

---

# Load trained model

In [27]:
dataloaders, model = (TrainedModel()).load_model()



In [29]:
# now this is our trained model
model

FiringRateEncoder(
  (core): Stacked2dCore(
    (_input_weights_regularizer): LaplaceL2norm(
      (laplace): Laplace()
    )
    (features): Sequential(
      (layer0): Sequential(
        (conv): Conv2d(1, 64, kernel_size=(9, 9), stride=(1, 1), bias=False)
        (norm): BatchNorm2d(64, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)
        (nonlin): AdaptiveELU()
      )
      (layer1): Sequential(
        (ds_conv): DepthSeparableConv2d(
          (in_depth_conv): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (spatial_conv): Conv2d(64, 64, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=64, bias=False)
          (out_depth_conv): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        )
        (norm): BatchNorm2d(64, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)
        (nonlin): AdaptiveELU()
      )
      (layer2): Sequential(
        (ds_conv): DepthSeparableConv2d(
          (in_depth_conv): 

In [30]:
model.eval().cuda();

# Now the model is ready to be used

In [36]:
new_images = torch.randn(100,1,36,64).cuda()

In [37]:
predicted_responses = model(new_images)

In [38]:
predicted_responses

tensor([[0.3794, 0.4740, 0.3717,  ..., 0.3467, 0.1474, 0.3769],
        [0.4062, 0.1574, 0.3993,  ..., 0.2929, 0.1165, 0.3619],
        [0.2652, 0.2746, 0.2872,  ..., 0.2461, 0.1658, 0.2314],
        ...,
        [0.5462, 0.1589, 0.4439,  ..., 0.3607, 0.1473, 0.4137],
        [0.6244, 0.3021, 0.4183,  ..., 0.3002, 0.1902, 0.3277],
        [0.4220, 0.1997, 0.3004,  ..., 0.3353, 0.1493, 0.4713]],
       device='cuda:0', grad_fn=<AddBackward0>)

👍