# Demo Notebook how to run models on static mouse datasets

In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2 

In [2]:
import os
import datajoint as dj
dj.config['database.host'] = os.environ['DJ_HOST']
dj.config['database.user'] = os.environ['DJ_USER']
dj.config['database.password'] = os.environ['DJ_PASS']
dj.config['enable_python_native_blobs'] = True

name = "test"
dj.config['schema_name'] = f"konstantin_nnsysident_{name}"

In [3]:
import torch
import numpy as np
import pickle 
import pandas as pd
from collections import OrderedDict, Iterable

import nnfabrik
from nnfabrik.main import *
from nnfabrik import builder

from nnsysident.tables.experiments import *
from nnsysident.datasets.mouse_loaders import static_shared_loaders
from nnsysident.datasets.mouse_loaders import static_loaders
from nnsysident.datasets.mouse_loaders import static_loader

  from collections import OrderedDict, Iterable


Connecting konstantin@sinzlab.chlkmukhxp6i.eu-central-1.rds.amazonaws.com:3306


# Get Dataloader

In [None]:
# change path here
paths = ['data/static22564-2-12-preproc0.zip',
                     'data/static22564-2-13-preproc0.zip',
                     'data/static22564-3-8-preproc0.zip',
                     'data/static22564-3-12-preproc0.zip']

#paths = ['data/static22564-2-12-preproc0.zip']
dataset_fn = 'nnsysident.datasets.mouse_loaders.static_shared_loaders'
dataset_config = dict(
    paths=paths,
    batch_size=64,
    seed=1,
    image_n=50,
    image_base_seed=1,
    multi_match_n=972,
    multi_match_base_seed=1,
    exclude_multi_match_n = 3625,

)
dataloaders = builder.get_data(dataset_fn, dataset_config)

# Get Model

### Spatial Transformer

In [None]:
# model_fn = 'nnsysident.models.models.se2d_spatialxfeaturelinear'
# model_config = {
#    'pad_input': False,
#    'stack': -1,
#    'layers':4,
#    'input_kern': 9,
#    'gamma_input': 20,
#    'gamma_readout': 0.005,
#    'hidden_dilation': 1,
#    'hidden_kern': 7,
#    'hidden_channels': 64,
#     'init_noise': 1e-3,
#     'depth_separable': True,
    
# }
# model = builder.get_model(model_fn=model_fn, model_config=model_config, dataloaders=dataloaders, seed=1000)

### The New gaussian readout: change gauss_type for the different modes

In [None]:
model_fn = 'nnsysident.models.models.taskdriven_fullgaussian2d'

model_config = {'tl_model_name': 'vgg16',
   'layers': 8,
   'init_mu_range': 0.55,
   'init_sigma': 0.4,
   'share_features': False,
   'grid_mean_predictor': {'type': 'cortex',
    'input_dimensions': 2,
    'hidden_layers': 0,
    'hidden_features': 0,
    'final_tanh': False},
   'gamma_readout': 4.622488854650272}


model = builder.get_model(model_fn=model_fn, model_config=model_config, dataloaders=dataloaders, seed=1)

In [None]:
model

In [None]:
model_fn = 'nnsysident.models.models.se2d_fullgaussian2d'

model_config = {'share_features': True,
 'init_mu_range': 0.55,
 'init_sigma': 0.4,
 'input_kern': 15,
 'hidden_kern': 13,
 'gamma_input': 1.0,
 'gamma_readout': 2.117604964706911,
 'grid_mean_predictor': {'type': 'cortex',
  'input_dimensions': 2,
  'hidden_layers': 0,
  'hidden_features': 0,
  'final_tanh': False}}


model = builder.get_model(model_fn=model_fn, model_config=model_config, dataloaders=dataloaders, seed=1)

In [None]:
torch.mm(model.readout['22564-3-12-0'].mu_transform[0].weight[:, 0].view(1, -1),
         model.readout['22564-3-12-0'].mu_transform[0].weight[:, 1].view(1, -1).T).abs()

In [None]:
core_dict = OrderedDict([(k, v) for k, v in torch.load('inshallahmodel.mojo').items() if k[0:5] == 'core.'])
model.load_state_dict(core_dict, strict=False)

In [None]:
model.readout.regularizer('22564-2-12-0')

In [None]:
model.readout.regularizer('22564-2-12-0')

In [None]:
model.readout.gamma_readout

# Get Trainer

In [None]:
trainer_fn = 'nnsysident.training.trainers.standard_trainer'
trainer_config = dict(detach_core=True, track_training=True)
trainer = builder.get_trainer(trainer_fn, trainer_config)

# Run Training

In [None]:
score, output, model_state = trainer(model=model, dataloaders=dataloaders, seed=1)

In [None]:
# grid = model.readout['22564-2-12-0'].mu.squeeze().cpu().data.numpy()
for data_key in ['22564-3-8-0', '22564-3-12-0', '22564-2-13-0', '22564-2-12-0']:
   # model.readout[data_key]

    
    grid = model.readout[data_key].mu.squeeze().detach().cpu().numpy()
    plt.scatter(*grid.T)
    plt.xlim([-1,1])
    plt.ylim([-1,1])
    plt.show()

In [None]:
plt.scatter(*model.readout['22564-3-12-0'].source_grid.T.cpu().numpy())

In [None]:
# grid = model.readout['22564-2-12-0'].mu.squeeze().cpu().data.numpy()
grid = model.readout['22564-3-12-0'].mu.squeeze().detach().cpu().numpy()

In [None]:
import matplotlib.pyplot as plt
plt.scatter(*grid.T)

In [None]:
plt.scatter(*dataloaders['train']['22564-2-12-0'].dataset.neurons.[:,:2].T)

In [None]:
torch.mm(model.readout['22564-3-12-0'].mu_transform[0].weight[:, 0].view(1, -1),
         model.readout['22564-3-12-0'].mu_transform[0].weight[:, 1].view(1, -1).T).abs()

In [None]:
model.readout['22564-3-12-0'].mu_transform[0].weight

In [None]:
model.readout['22564-3-12-0'].mu_transform[0].weight

In [None]:
torch.norm(model.readout['22564-3-12-0'].mu_transform[0].weight, dim=0, p=2)

In [None]:
torch.abs(torch.norm(model.readout['22564-3-12-0'].mu_transform[0].weight[:, 0], p=2) - 
          torch.norm(model.readout['22564-3-12-0'].mu_transform[0].weight[:, 1], p=2))

In [None]:
aspect = agg_fn(torch.abs(torch.norm(self.trans_mat[:, :, 0]**2, dim=1) -
                          torch.norm(self.trans_mat[:, :, 1]**2, dim=1)))

In [None]:
l.weight

In [None]:
model.core.features.layer0.conv.weight.device

In [None]:
model.readout['22564-3-12-0'].source_grid.shape

In [None]:
sg = model.readout['22564-3-12-0'].source_grid.cpu().data.numpy()

In [None]:
plt.scatter(*sg.T)

In [None]:
x, y = sg.max(axis=0) - sg.min(axis=0)
x/y

In [None]:
sg