# MEI Demo

In [None]:
import datajoint as dj

dj.config["enable_python_native_blobs"] = True
dj.config["schema_name"] = "nnfabrik_tutorial"

schema = dj.schema("nnfabrik_tutorial")

In [None]:
import os

from matplotlib import pyplot as plt
from torch import load

from mei.main import TrainedEnsembleModelTemplate, CSRFV1SelectorTemplate, MEISeed, MEIMethod, MEITemplate
from nnfabrik.templates import TrainedModelBase
from nnfabrik.main import Dataset

## 1. Define Tables

In [None]:
@schema
class TrainedModel(TrainedModelBase):
    pass


@schema
class TrainedEnsembleModel(TrainedEnsembleModelTemplate):
    dataset_table = Dataset
    trained_model_table = TrainedModel


@schema
class CSRFV1Selector(CSRFV1SelectorTemplate):
    dataset_table = Dataset


@schema
class MEI(MEITemplate):
    trained_model_table = TrainedEnsembleModel
    selector_table = CSRFV1Selector

## 2. Reset Tables For Demo

In [None]:
CSRFV1Selector().drop()
TrainedEnsembleModel().drop()
MEIMethod().drop()
MEISeed().drop()

## 3. Create Ensemble Model

In [None]:
TrainedEnsembleModel().create_ensemble(dict(dataset_fn="csrf_v1", dataset_hash="3d94500a46b792bbb480aedfc30f9753"), comment="Happy little ensemble")
TrainedEnsembleModel()

## 4. Populate Selector Table

In [None]:
CSRFV1Selector.populate()
CSRFV1Selector()

## 5. Specify MEI Method Parameters

In [None]:
method_fn = "mei.methods.gradient_ascent"
method_config = dict(
    initial=dict(path="mei.initial.RandomNormal"),
    optimizer=dict(path="torch.optim.SGD", kwargs=dict(lr=0.1)),
    stopper=dict(path="mei.stoppers.NumIterations", kwargs=dict(num_iterations=1000)),
    objectives=[dict(path="mei.objectives.EvaluationObjective", kwargs=dict(interval=10))],
    device="cuda",
)
MEIMethod().add_method(method_fn, method_config, comment="My MEI method")
MEIMethod()

## 6. Add Seed

In [None]:
MEISeed().insert1(dict(mei_seed=42))
MEISeed()

## 7. Generate MEIs

In [None]:
MEI().populate(dict(neuron_id=188), display_progress=True)
MEI()

## 8. Look at MEIs

In [None]:
neuron_id = 188
mei_path = (MEI() & dict(ensemble_hash=0, neuron_id=neuron_id)).fetch1("mei")
plt.imshow(load(mei_path).squeeze(), cmap="gray")
plt.gca().axis("off")
os.remove(mei_path)

## 9. Plot Evaluations Across Time

In [None]:
output_path = (MEI() & dict(ensemble_id=0, neuron_id=neuron_id)).fetch1("output")
output = load(output_path)
os.remove(output_path)

plt.plot(
    output["mei.objectives.EvaluationObjective"]["times"],
    output["mei.objectives.EvaluationObjective"]["values"],
)
plt.gca().set_xlabel("# iteration")
plt.gca().set_ylabel("evaluation")