# Model Training with nnfabrik

In [2]:
%load_ext autoreload
%autoreload 2

import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import nnfabrik
from nnfabrik.builder import get_data, get_model, get_trainer
import cascade


import datajoint as dj
dj.config["database.host"] = '18.198.155.36'

dj.config["enable_python_native_blobs"] = True
dj.config['nnfabrik.schema_name'] = "nnfabrik_neural_prediction_challenge"

schema = dj.schema("nnfabrik_neural_prediction_challenge")

from nnfabrik.main import *

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Initial Addition of Tables for the Schema

In [10]:
Fabrikant().insert1(dict(fabrikant_name='kwilleke',
                         email="konstantin.willeke@gmail.com",
                         affiliation='sinzlab',
                         dj_username="kwilleke"))

In [11]:
Seed().insert([{'seed':1000}])
Seed()

seed  Random seed that is passed to the model- and dataset-builder
1000


# Add Dataset

In [4]:
# the cascade mouse dataloader takes the files from the Preprocessed mouse table and stores them on the server under /data/mouse/toliaslab/static/. like the other datasets
filenames = ['static26645-2-18-GrayImageNet-94c6ff995dac583098847cfecd43e7b6.zip', 'static26644-14-17-GrayImageNet-94c6ff995dac583098847cfecd43e7b6.zip']

dataset_fn = 'cascade.datasets.static_loaders'
dataset_config = {'paths': filenames,
                 'normalize': True,
                 'include_behavior': False,
                 'include_eye_position': True,
                 'batch_size': 128,
                 'exclude': None,
                 'file_tree': True,
                 'scale': 0.5
                 }

dataloaders = get_data(dataset_fn, dataset_config)

/data/mouse/toliaslab/static/static26645-2-18-GrayImageNet-94c6ff995dac583098847cfecd43e7b6 exists already. Not unpacking /data/mouse/toliaslab/static/static26645-2-18-GrayImageNet-94c6ff995dac583098847cfecd43e7b6.zip
/data/mouse/toliaslab/static/static26644-14-17-GrayImageNet-94c6ff995dac583098847cfecd43e7b6 exists already. Not unpacking /data/mouse/toliaslab/static/static26644-14-17-GrayImageNet-94c6ff995dac583098847cfecd43e7b6.zip


In [5]:
Dataset().add_entry(dataset_fn, dataset_config, dataset_comment="Both Test Scans, 72x128px, w/o behavior")

{'dataset_fn': 'cascade.datasets.static_loaders',
 'dataset_hash': '9483f97db6cca8f6c9c8c917eae046eb'}

In [6]:
b = next(iter(dataloaders["train"]["26644-14-17"]))

In [7]:
b.images.shape

torch.Size([128, 1, 72, 128])

# Add Model

In [8]:
model_fn = 'cascade.models.stacked_core_full_gauss_readout'
model_config = {'pad_input': False,
               'stack': -1,
               'layers': 4,
               'input_kern': 9,
               'gamma_input': 6.3831,
               'gamma_readout': 0.0076,
               'hidden_dilation': 1,
               'hidden_kern': 7,
               'hidden_channels': 64,
               'depth_separable': True,
               'init_sigma': 0.1,
               'init_mu_range': 0.3,
               'gauss_type': 'full',
               'shifter': True,
               'shift_layers': 3,
               'depth_separable': True,
               }

model = get_model(model_fn=model_fn,
                  model_config=model_config,
                  dataloaders=dataloaders,
                  seed=1,)



In [9]:
Model().add_entry(model_fn, model_config, model_comment="color mei default model")

{'model_fn': 'cascade.models.stacked_core_full_gauss_readout',
 'model_hash': 'd2664a1ba94dcd1d911c1ac6907934ec'}

# Add Trainer

In [10]:
trainer_fn = "cascade.training.standard_trainer"

trainer_config = {'max_iter': 2,
                 'verbose': False,
                 'lr_decay_steps': 4,
                 'avg_loss': False,
                 'lr_init': 0.009,
                 }

trainer = get_trainer(trainer_fn=trainer_fn, 
                     trainer_config=trainer_config)

In [11]:
Trainer().add_entry(trainer_fn, trainer_config, trainer_comment="color mei default trainer")

{'trainer_fn': 'cascade.training.standard_trainer',
 'trainer_hash': '534a2b5c34a223df3479adadd9f326ac'}

In [12]:
from nnvision.tables.from_nnfabrik import TrainedModel

In [None]:
TrainedModel().populate(display_progress=True)

  0%|          | 0/1 [00:00<?, ?it/s]

/data/mouse/toliaslab/static/static26645-2-18-GrayImageNet-94c6ff995dac583098847cfecd43e7b6 exists already. Not unpacking /data/mouse/toliaslab/static/static26645-2-18-GrayImageNet-94c6ff995dac583098847cfecd43e7b6.zip
/data/mouse/toliaslab/static/static26644-14-17-GrayImageNet-94c6ff995dac583098847cfecd43e7b6 exists already. Not unpacking /data/mouse/toliaslab/static/static26644-14-17-GrayImageNet-94c6ff995dac583098847cfecd43e7b6.zip



Epoch 1:   0%|          | 0/72 [00:00<?, ?it/s][A
Epoch 1:   1%|▏         | 1/72 [00:01<01:29,  1.26s/it][A
Epoch 1:   3%|▎         | 2/72 [00:02<01:30,  1.30s/it][A
Epoch 1:   4%|▍         | 3/72 [00:04<01:33,  1.35s/it][A
Epoch 1:   6%|▌         | 4/72 [00:05<01:33,  1.38s/it][A
Epoch 1:   7%|▋         | 5/72 [00:07<01:34,  1.41s/it][A
Epoch 1:   8%|▊         | 6/72 [00:08<01:34,  1.43s/it][A
Epoch 1:  10%|▉         | 7/72 [00:09<01:33,  1.44s/it][A
Epoch 1:  11%|█         | 8/72 [00:11<01:33,  1.45s/it][A
Epoch 1:  12%|█▎        | 9/72 [00:12<01:31,  1.46s/it][A
Epoch 1:  14%|█▍        | 10/72 [00:14<01:30,  1.47s/it][A
Epoch 1:  15%|█▌        | 11/72 [00:15<01:28,  1.46s/it][A
Epoch 1:  17%|█▋        | 12/72 [00:17<01:28,  1.47s/it][A
Epoch 1:  18%|█▊        | 13/72 [00:18<01:27,  1.48s/it][A
Epoch 1:  19%|█▉        | 14/72 [00:20<01:25,  1.47s/it][A
Epoch 1:  21%|██        | 15/72 [00:21<01:24,  1.48s/it][A
Epoch 1:  22%|██▏       | 16/72 [00:23<01:22,  1.48s/it]

[001|00/05] ---> 0.05924167484045029



Epoch 2:   0%|          | 0/72 [00:00<?, ?it/s][A
Epoch 2:   1%|▏         | 1/72 [00:00<00:57,  1.24it/s][A
Epoch 2:   3%|▎         | 2/72 [00:01<01:04,  1.08it/s][A
Epoch 2:   4%|▍         | 3/72 [00:03<01:08,  1.01it/s][A
Epoch 2:   6%|▌         | 4/72 [00:04<01:11,  1.05s/it][A
Epoch 2:   7%|▋         | 5/72 [00:05<01:11,  1.07s/it][A
Epoch 2:   8%|▊         | 6/72 [00:06<01:12,  1.10s/it][A
Epoch 2:  10%|▉         | 7/72 [00:07<01:12,  1.12s/it][A
Epoch 2:  11%|█         | 8/72 [00:08<01:12,  1.13s/it][A
Epoch 2:  12%|█▎        | 9/72 [00:10<01:11,  1.13s/it][A
Epoch 2:  14%|█▍        | 10/72 [00:11<01:10,  1.14s/it][A
Epoch 2:  15%|█▌        | 11/72 [00:12<01:09,  1.14s/it][A
Epoch 2:  17%|█▋        | 12/72 [00:13<01:09,  1.15s/it][A
Epoch 2:  18%|█▊        | 13/72 [00:14<01:08,  1.16s/it][A
Epoch 2:  19%|█▉        | 14/72 [00:15<01:06,  1.14s/it][A
Epoch 2:  21%|██        | 15/72 [00:16<01:05,  1.15s/it][A
Epoch 2:  22%|██▏       | 16/72 [00:18<01:04,  1.16s/it]

In [None]:
dataloaders, model = TrainedModel().load_model()