# Test Notebook for Training on the FileTreeSet

In [None]:
import datajoint as dj

import os
import torch
import numpy as np
import pickle 

import nnfabrik
from nnfabrik import main, builder

import nnvision

# Get Dataloader

## StaticImageSet

In [None]:
# change path here
paths = ['/data/mouse/static22564-2-13-preproc0.h5']

In [None]:
dataset_fn = 'mouse_static_loaders'
dataset_config = dict(
    paths=paths,
    batch_size=64,
    normalize=True,
)
dataloaders = builder.get_data(dataset_fn, dataset_config)

## FileTreeSet

In [None]:
# change path 
paths = ['/data/mouse/static22564-2-13-preproc0/']

In [22]:
dataset_fn = 'nnvision.datasets.mouse_static_loaders'
dataset_config = dict(
    paths=paths,
    batch_size=64,
    normalize=True,
    file_tree=True,
)
dataloaders = builder.get_data(dataset_fn, dataset_config)

# Get NNkonsti Model

In [None]:
model_fn = 'nnvision.models.se_core_gauss_readout'
model_config = {
   'pad_input': False,
   'stack': -1,
   'layers':4,
   'input_kern': 9,
   'gamma_input': 6.3831,
   'gamma_readout': 0.0076,
   'hidden_dilation': 1,
   'hidden_kern': 7,
   'hidden_channels': 64,
   'n_se_blocks': 0,
   'depth_separable': True,
}
model = builder.get_model(model_fn=model_fn, 
                          model_config=model_config, 
                          dataloaders=dataloaders, 
                          seed=1000)

# Get Trainer

In [None]:
trainer_fn = 'nnvision.training.nnvision_trainer'
trainer_config = dict(max_iter=100,
                      verbose=False, 
                      lr_decay_steps=4,
                      avg_loss=False, 
                      patience=5,
                      lr_init=.0041)
trainer = builder.get_trainer(trainer_fn, trainer_config)

# Run Training

In [None]:
score, output, model_state = trainer(model=model, dataloaders=dataloaders, seed=1000)