# MEI Demo

In [None]:
import datajoint as dj

dj.config["enable_python_native_blobs"] = True
dj.config["schema_name"] = "nnfabrik_tutorial"

schema = dj.schema("nnfabrik_tutorial")

In [None]:
import os

from matplotlib import pyplot as plt
from torch import load

from featurevis.main import TrainedEnsembleModelTemplate, CSRFV1SelectorTemplate, MEIMethod, MEITemplate
from nnfabrik.template import TrainedModelBase
from nnfabrik.main import Dataset

## 1. Define Tables

In [None]:
@schema
class TrainedModel(TrainedModelBase):
    pass


@schema
class TrainedEnsembleModel(TrainedEnsembleModelTemplate):
    dataset_table = Dataset
    trained_model_table = TrainedModel


@schema
class CSRFV1Selector(CSRFV1SelectorTemplate):
    dataset_table = Dataset


@schema
class MEI(MEITemplate):
    trained_model_table = TrainedEnsembleModel
    selector_table = CSRFV1Selector

## 2. Reset Tables For Demo

In [None]:
CSRFV1Selector().drop()
TrainedEnsembleModel().drop()
(MEIMethod() & "method_id = 0").delete()

## 3. Create Ensemble Model

In [None]:
TrainedEnsembleModel.create_ensemble?

In [None]:
TrainedEnsembleModel().create_ensemble(dict(dataset_fn = "csrf_v1"))

## 4. Populate Selector Table

In [None]:
CSRFV1Selector.populate()
CSRFV1Selector()

## 5. Specify MEI Method Parameters

In [None]:
MEIMethod().insert1(dict(method_id=0))
MEIMethod()

## 6. Generate MEIs

In [None]:
MEI().populate(display_progress=True)
MEI()

## 7. Look at MEIs

In [None]:
neuron_id = 381
mei_path = (MEI() & dict(ensemble_id=0, neuron_id=neuron_id)).fetch1("mei")
plt.imshow(load(mei_path), cmap="gray")
plt.gca().axis("off")
os.remove(mei_path)