# Generate MEIs for one Dataset

In [None]:
%load_ext autoreload
%autoreload 2

import os
import datajoint as dj
dj.config['database.host'] = os.environ['DJ_HOST']
dj.config['database.user'] = os.environ['DJ_USER']
dj.config['database.password'] = os.environ['DJ_PASS']
dj.config['enable_python_native_blobs'] = True
dj.config['display.limit'] = 200

name = 'iclr'
dj.config['schema_name'] = f"konstantin_nnsysident_{name}"

import nnfabrik
from nnfabrik.main import *
from nnfabrik import builder

import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from nnsysident.tables.experiments import *
from nnsysident.tables.mei import *
from nnsysident.utility.measures import get_correlations

fetch_download_path = './mei_downloads'

In [None]:
def best_n_unit_ids(model, dataloaders, n, device='cuda'):
    test_correlation = get_correlations(model, dataloaders["test"], device=device, as_dict=False, per_neuron=True)
    indices = np.argsort(test_correlation)
    indices = indices[-n:]
    
    for i, (data_key, dataloader) in enumerate(dataloaders['test'].items()):
        assert i>=0, 'More than one dataset not supported (yet)'
        unit_ids = dataloader.dataset.neurons.unit_ids[indices]
    return unit_ids, indices

___

# Select Dataset

In [None]:
key = (Experiments.Restrictions() & 'experiment_name = "Real, Direct, se2d_fullgaussian2d, 20457-5-9"') & 'dataset_hash = "71c9ac7a98e066544ad88eb47ea282ec"'
key

# Creating the ensemble

#### the key has to restrict the TrainedModel to all the models that should be an ensemble. The dataset hash MUST be unique

In [None]:
(TrainedModel() & key)

# Create an ensemble

In [None]:
TrainedEnsembleModel().create_ensemble(key=key, comment='Real, Direct, se2d_fullgaussian2d, 20457-5-9')

In [None]:
TrainedEnsembleModel & key

In [None]:
# Check whether the ensemble has the correct amount of members:
TrainedEnsembleModel.Member() & key

# Populating the MEISelector table, to match the unit IDs and the units within the model

In [None]:
MEISelector().populate(key, display_progress=True)

In [None]:
# one entry in the table corresponds to one neuron. 
#he unit_id comes straight from the dataset, the unit index is the position of that unit in the model.
MEISelector() & key

# Selecting the MEIMethod

### Normalized Images (z-scored)

In [None]:
method_fn = 'mei.methods.gradient_ascent'

method_config = {'initial': {'path': 'mei.initial.RandomNormal'},
                 'optimizer': {'path': 'torch.optim.SGD', 'kwargs': {'lr': 2.0}},
                 'precondition': {'path': 'mei.legacy.ops.GaussianBlur',
                 'kwargs': {'sigma': 1}},
                 'postprocessing': {'path': 'mei.legacy.ops.ChangeNorm',
                 'kwargs': {'norm': 15}},
                 'stopper': {'path': 'mei.stoppers.NumIterations',
                 'kwargs': {'num_iterations': 500}},
                 'objectives': [{'path': 'mei.objectives.EvaluationObjective',
                 'kwargs': {'interval': 10}}],
                 'device': 'cuda'}

MEIMethod().add_method(method_fn, method_config, "normalized image, norm=15, lr=2, iter=500")

### Un-normalized Images (8bit)

In [None]:
method_fn = 'mei.methods.gradient_ascent'
sigma = 1
lr = 1e4
std = 2
mean= 111.3
n_iters=500

method_config = {'initial': {'path': 'nnsysident.meis.initial.CustomRandomNormal'},
                 'optimizer': {'path': 'torch.optim.SGD', 
                               'kwargs': {'lr': lr}},
                 'precondition': {'path': 'mei.legacy.ops.GaussianBlur',
                                  'kwargs': {'sigma': sigma}},
                 'postprocessing': {'path': 'nnsysident.meis.ops.ChangeStdClampedMean',
                                    'kwargs': {'std': std, 
                                               'x_min': 0, 
                                               'x_max': 255, 
                                               'clamped_mean': mean}},
                 'stopper': {'path': 'mei.stoppers.NumIterations',
                             'kwargs': {'num_iterations': n_iters}},
                 'objectives': [{'path': 'mei.objectives.EvaluationObjective',
                                 'kwargs': {'interval': 10}}],
                 'device': 'cuda'
}
MEIMethod().add_method(method_fn, method_config, f"8 bit image, std={std}, mean={mean}, lr={lr}, iter={n_iters}")

In [None]:
MEIMethod()

# Populate the MEI table

In [None]:
seed = 1
dataloaders, model = (TrainedModel() & key & 'seed={}'.format(seed)).load_model()

In [None]:
n = 20
mei_seed=1

dataset_hash = '71c9ac7a98e066544ad88eb47ea282ec'
ensemble_hash = '4dc2b15a95c86f907b7417d8811f54fe'
method_hash='8db856e30d03df10c2c41326ff7e5422'

# select unit IDs of neurons that MEI should be computed for
#unit_ids, indices = best_n_unit_ids(model, dataloaders, n)
#unit_ids = np.array([ 259, 1641, 1369, 2999, 2532, 1951, 5443, 3648, 3316, 1038])
unit_ids = dataloaders['train']['20457-5-9-0'].dataset.neurons.unit_ids


mei_restriction = dj.AndList(['dataset_hash = "{}"'.format(dataset_hash), 
             'ensemble_hash = "{}"'.format(ensemble_hash), 
             'method_hash = "{}"'.format(method_hash), 
             'unit_id in {}'.format(tuple(unit_ids)),
             'mei_seed = {}'.format(mei_seed)])

# Display how many MEIs would be computed
MEI.progress(mei_restriction)

In [None]:
MEI.populate(mei_restriction,
             display_progress=True, 
             order='random',
             reserve_jobs=True)

# Fetch MEIs and plot activations

In [None]:
# fetch the meis and outputs
mei_paths, output_paths, unit_IDs, score = (MEISelector * MEI() & mei_restriction).fetch("mei", "output","unit_id","score", download_path=fetch_download_path)

In [None]:
MEIs = np.stack([torch.load(path).detach().cpu().numpy().squeeze() for path in mei_paths])
evaluations = np.stack([torch.load(path)['mei.objectives.EvaluationObjective']["values"] for path in output_paths])
t_evaluations = np.stack([torch.load(path)['mei.objectives.EvaluationObjective']["times"] for path in output_paths])

In [None]:
# mei_direct = dict(unit_IDs = unit_IDs, MEIs = MEIs)
# torch.save(mei_direct, 'mei_direct_highestdiff')

# Plot Activations

In [None]:
plt.plot(t_evaluations.T, evaluations.T)
sns.despine(trim=True, offset=10)
plt.xlabel("iterations")
plt.ylabel("activation")
plt.title("Activations over iterations for all neurons");

# plot MEIs

In [None]:
fig, axes = plt.subplots(4, 5, figsize=(10,5), dpi=150)
for i, ax in enumerate(axes.flatten()):
    ax.imshow(MEIs[i], cmap="gray", vmin=MEIs.min(), vmax=MEIs.max() ) # vmin=MEIs.min(), vmax=MEIs.max() 
    ax.axis("off")
    ax.text(0.5, 2.5, str(unit_IDs[i]), va='center', fontsize=8.5, color='k')
#fig.savefig('mei.png')