# Demo Notebook how to run models on static mouse datasets

In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2 

In [None]:
import os
import datajoint as dj
dj.config['database.host'] = os.environ['DJ_HOST']
dj.config['database.user'] = os.environ['DJ_USER']
dj.config['database.password'] = os.environ['DJ_PASS']
dj.config['enable_python_native_blobs'] = True

name = "test"
dj.config['schema_name'] = f"konstantin_nnsysident_{name}"

In [None]:
import torch
import numpy as np
import pickle 
import pandas as pd
from collections import OrderedDict, Iterable
import matplotlib.pyplot as plt

import nnfabrik
from nnfabrik.main import *
from nnfabrik import builder

from nnsysident.tables.experiments import *
from nnsysident.datasets.mouse_loaders import static_shared_loaders
from nnsysident.datasets.mouse_loaders import static_loaders
from nnsysident.datasets.mouse_loaders import static_loader

# Get Dataloader

In [None]:
paths = ['data/allen-agglomerate-NaturalScenesOracle-028f7ad795.zip']

dataset_fn = 'nnsysident.datasets.mouse_loaders.mouse_allen_scene_loaders'
dataset_config = dict(
    paths=paths,
    batch_size=32,
    areas=('VISp',),
)
dataloaders = builder.get_data(dataset_fn, dataset_config)

# Get Model

### The New gaussian readout: change gauss_type for the different modes

In [None]:
# model_fn = 'nnsysident.models.models.se2d_fullgaussian2d'

# model_config = {'init_mu_range': 0.55,
#                  'init_sigma': 0.4,
#                  'input_kern': 15,
#                  'hidden_kern': 13,
#                  'gamma_input': 1.0,
#                  'grid_mean_predictor':{  'type': 'cortex',
#                                           'input_dimensions': 2,
#                                           'hidden_layers': 0,
#                                           'hidden_features': 0,
#                                           'final_tanh': False},
#                  'gamma_readout': 0.6697555139311314,
#                  'share_features': False,
#                  'share_transform': False}

# model = builder.get_model(model_fn=model_fn, model_config=model_config, dataloaders=dataloaders, seed=1)


In [None]:
model_fn = 'nnsysident.models.models.taskdriven_fullgaussian2d'

model_config = {'tl_model_name': 'vgg16',
                 'layers': 8,
                 'init_mu_range': 0.55,
                 'init_sigma': 0.4,
                 'share_features': False,
                 'grid_mean_predictor': {'type': 'cortex',
                                          'input_dimensions': 2,
                                          'hidden_layers': 0,
                                          'hidden_features': 0,
                                          'final_tanh': False},
                 'gamma_readout': 1.7925745123145327}

model = builder.get_model(model_fn=model_fn, model_config=model_config, dataloaders=dataloaders, seed=1)

In [None]:
core_dict = torch.load('d99820ef586bbebcfce36d7bc89877f3-11gaussian.pth.tar') # score: 0.329419

def get_grids(my_model):
    grids = {}
    for key, readout in my_model.readout.items():
        grid = readout.grid.squeeze().cpu().data.numpy()
        grids[key] = grid
    return grids

list_of_load = ['spatial'] #'scales', '_features' ,'mu_transform.0.weight', 'mu_transform.0.bias'
list_of_detach = ['spatial']



#model_real.load_state_dict(core_dict, strict=False)
#real_grids = get_grids(model_real)

remove=[]
keep=[]
for key in core_dict.keys():
    name = '.'.join(key.split('.')[2:])
    if key.split('.')[0] == 'readout': 
        if not np.isin(name, list_of_load):
            print('Not loading:    {}'.format(key))
            remove.append(key)
        else:
            keep.append(key)
for key in keep:
    print('Loading:  {}'.format(key))

for k in remove: del core_dict[k]
#model_start.load_state_dict(core_dict, strict=False)
#start_grids = get_grids(model_start)



model.load_state_dict(core_dict, strict=False)

for param in model.named_parameters():
    name = '.'.join(param[0].split('.')[2:])
    if param[0].split('.')[0] == 'readout':
        if np.isin(name, list_of_detach):
            print('detaching:    {}'.format(param[0]))
            param[1].requires_grad = False

# Get Trainer

In [None]:
detach_core = True

trainer_fn = 'nnsysident.training.trainers.standard_trainer'
trainer_config = dict(detach_core=detach_core, track_training=True) 
trainer = builder.get_trainer(trainer_fn, trainer_config)

# Run Training

In [None]:
score, output, model_state = trainer(model=model, dataloaders=dataloaders, seed=1)

In [None]:
for neuron in range(len(grids[0])):
    fig, axes = plt.subplots(1,2, dpi=250)
    for i, ax in enumerate(axes.flatten()):
        ax.imshow(grids[i][neuron])
        ax.set(xticklabels=[], yticklabels=[])
    plt.show()
    if i ==20:
        break

In [None]:
grids = get_grids(model)
    
from nnsysident.utility.measures import get_correlations

cors = get_correlations(model, dataloaders['test'], per_neuron=False, as_dict=True)
for key, value in cors.items():
    cors[key] = np.mean(value)

In [None]:
title = 'feat|scale: random init + learned - position: random 0.33 ortho init|0 + learned + shared'

dictionary = dict(cors=cors,
                  grids = grids,
                  real_grids = real_grids,
                  start_grids = start_grids,)
#torch.save(dictionary, title)

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

fig, axes = plt.subplots(2, 2, dpi=200, figsize=(6, 4.3))

for ind, (ax, (key, grid), (real_key, real_grid), (start_key, start_grid)) in enumerate(zip(axes.flat, grids.items(), real_grids.items(), start_grids.items())):
    assert key==real_key==start_key, 'keys do not match'
    ax.scatter(*grid.T, color='navy', s=.1, label='Learned position')
    ax.scatter(*real_grid.T, color='red', s=.1, label='"True" position')
    ax.scatter(*start_grid.T, color='black', s=.1, label='Init position', alpha=0.2)

    ax.set(xlim=(-1.1, 1.1), ylim=(-1.1, 1.1), label='Learned correlation')
    ax.text(-1,.8, round(cors[key], 3), fontsize=8.5, color='navy')
    sns.despine(top=True, ax=ax)
    
    ax.set(xticks=[-1, -.5, 0, .5, 1])
    ax.set_axisbelow(True)
    ax.grid(ls='--')
    if ind != 3:
        ax.set(xticklabels=[], yticklabels=[])
        
    ax.set_title(key, fontsize=10)
fig.suptitle(title.split('-')[0] + '\n' + title.split('-')[-1], y=1.05, fontsize=10.)

plt.tight_layout()
#plt.savefig(title + '.png', dpi=200, bbox_inches='tight')