# MEI Demo

In [1]:
import os
import datajoint as dj
dj.config['database.host'] = os.environ['DJ_HOST']
dj.config['database.user'] = os.environ['DJ_USERNAME']
dj.config['database.password'] = os.environ['DJ_PASSWORD']
dj.config['enable_python_native_blobs'] = True
dj.config['display.limit'] = 200
        
name = 'vei'
os.environ["DJ_SCHEMA_NAME"] = f"metrics_{name}"
dj.config["nnfabrik.schema_name"] = os.environ["DJ_SCHEMA_NAME"]

In [2]:
import os

from matplotlib import pyplot as plt
from torch import load

from mei.main import TrainedEnsembleModelTemplate, CSRFV1ObjectiveTemplate, MEISeed, MEIMethod, MEITemplate
from nnfabrik.main import Dataset, my_nnfabrik
from nnsysident.tables.experiments import TrainedModel

Connecting konstantin@134.76.19.44:3306


## 1. Define Tables

In [3]:
if not "stores" in dj.config:
    dj.config["stores"] = {}
dj.config["stores"]["minio_models"] = {
    "protocol": "s3",
    "endpoint": os.environ["MINIO_ENDPOINT"],
    "bucket": "kklurzmodels",
    "location": "dj-store",
    "access_key": os.environ["MINIO_ACCESS_KEY"],
    "secret_key": os.environ["MINIO_SECRET_KEY"],
    "secure": True,
}

try:
    main = my_nnfabrik(os.environ["DJ_SCHEMA_NAME"], use_common_fabrikant=False)
except:
    raise ValueError(
        " ".join(
            [
                "No schema name has been specified.",
                "Specify it via",
                "os.environ['DJ_SCHEMA_NAME']='schema_name'",
            ]
        )
    )
for key, val in main.__dict__.items():
    locals()[key] = val

@schema
class TrainedEnsembleModel(TrainedEnsembleModelTemplate):
    dataset_table = Dataset
    trained_model_table = TrainedModel


@schema
class CSRFV1Selector(CSRFV1ObjectiveTemplate):
    dataset_table = Dataset


@schema
class MEI(MEITemplate):
    trained_model_table = TrainedEnsembleModel
    selector_table = CSRFV1Selector

In [10]:
Dataset()

dataset_fn  name of the dataset loader function,dataset_hash  hash of the configuration object,dataset_config  dataset configuration object,dataset_fabrikant  Name of the contributor that added this entry,dataset_comment  short description,dataset_ts  UTZ timestamp at time of insertion
,,,,,


In [6]:
CSRFV1Selector()

dataset_fn  name of the dataset loader function,dataset_hash  hash of the configuration object,neuron_id  unique neuron identifier,neuron_position  integer position of the neuron in the model's output,session_id  unique session identifier
,,,,


In [7]:
TrainedEnsembleModel()

dataset_fn  name of the dataset loader function,dataset_hash  hash of the configuration object,ensemble_hash  the hash of the ensemble,ensemble_comment  a short comment describing the ensemble
,,,


In [8]:
MEIMethod()

method_fn  name of the method function,method_hash  hash of the method config,method_config  method configuration object,method_ts  UTZ timestamp at time of insertion,method_comment  a short comment describing the method
,,,,


In [9]:
MEISeed()

mei_seed  MEI seed


## 2. Reset Tables For Demo

In [None]:
CSRFV1Selector().drop()
TrainedEnsembleModel().drop()
MEIMethod().drop()
MEISeed().drop()

## 3. Create Ensemble Model

In [None]:
TrainedEnsembleModel().create_ensemble(dict(dataset_fn="csrf_v1", dataset_hash="3d94500a46b792bbb480aedfc30f9753"), comment="Happy little ensemble")
TrainedEnsembleModel()

## 4. Populate Selector Table

In [None]:
CSRFV1Selector.populate()
CSRFV1Selector()

## 5. Specify MEI Method Parameters

In [None]:
method_fn = "mei.methods.gradient_ascent"
method_config = dict(
    initial=dict(path="mei.initial.RandomNormal"),
    optimizer=dict(path="torch.optim.SGD", kwargs=dict(lr=0.1)),
    stopper=dict(path="mei.stoppers.NumIterations", kwargs=dict(num_iterations=1000)),
    objectives=[dict(path="mei.objectives.EvaluationObjective", kwargs=dict(interval=10))],
    device="cuda",
)
MEIMethod().add_method(method_fn, method_config, comment="My MEI method")
MEIMethod()

## 6. Add Seed

In [None]:
MEISeed().insert1(dict(mei_seed=42))
MEISeed()

## 7. Generate MEIs

In [None]:
MEI().populate(dict(neuron_id=188), display_progress=True)
MEI()

## 8. Look at MEIs

In [None]:
neuron_id = 188
mei_path = (MEI() & dict(ensemble_hash=0, neuron_id=neuron_id)).fetch1("mei")
plt.imshow(load(mei_path).squeeze(), cmap="gray")
plt.gca().axis("off")
os.remove(mei_path)

## 9. Plot Evaluations Across Time

In [None]:
output_path = (MEI() & dict(ensemble_id=0, neuron_id=neuron_id)).fetch1("output")
output = load(output_path)
os.remove(output_path)

plt.plot(
    output["mei.objectives.EvaluationObjective"]["times"],
    output["mei.objectives.EvaluationObjective"]["values"],
)
plt.gca().set_xlabel("# iteration")
plt.gca().set_ylabel("evaluation")