In [2]:
import datajoint as dj

dj.config["display.limit"] = 50
dj.config["enable_python_native_blobs"] = True

dj.config['nnfabrik.schema_name'] = "nnfabrik_v1_tuning"

import datajoint as dj
schema = dj.schema("nnfabrik_v1_tuning")

from nndichromacy.tables.from_mei import MEIMethod

Connecting pawelp@134.2.168.16:3306


In [49]:
method_fn = 'mei.methods.gradient_ascent'
method_config = {'initial': {'path': 'mei.initial.RandomNormal'},
                 'optimizer': {'path': 'torch.optim.SGD', 'kwargs': {'lr': 2.0}},
                 'precondition': {'path': 'mei.legacy.ops.GaussianBlur','kwargs': {'sigma': 1}},
                 'postprocessing': {'path': 'nndichromacy.mei.ops.ChangeNormAndClip',
                  'kwargs': {'norm': 10, 'x_min': -2.2, 'x_max': 2.62}},
                 'stopper': {'path': 'mei.stoppers.NumIterations',
                  'kwargs': {'num_iterations': 1000}},
                 'objectives': [{'path': 'mei.objectives.EvaluationObjective',
                   'kwargs': {'interval': 10}}],
                 'device': 'cuda'
                }
MEIMethod().add_method(method_fn, method_config, comment='norm=10, lr=2.9')

In [52]:
MEIMethod()

method_fn  name of the method function,method_hash  hash of the method config,method_config  method configuration object,method_ts  UTZ timestamp at time of insertion,method_comment  a short comment describing the method
mei.methods.gradient_ascent,4f49ae3e50a41587cfc0ebdc39aaf25b,=BLOB=,2020-12-01 19:05:32,"norm=6.9, lr=1.5"
mei.methods.gradient_ascent,89a56f45b906cc8036b4125cfb53da32,=BLOB=,2021-01-21 11:44:14,"norm=10, lr=2.9"
nnidentify.mei.methods.gradient_ascent,84a88459f8b43eace9a3f44dae00ea8e,=BLOB=,2021-01-04 09:54:14,"norm=6.9, lr=1.5 - with validation"


Run the cell below as a job on the compute server

In [2]:
%%writefile run_scripts/mei_run.py

import datajoint as dj

dj.config["display.limit"] = 50
dj.config["enable_python_native_blobs"] = True

dj.config['nnfabrik.schema_name'] = "nnfabrik_v1_tuning"

import datajoint as dj
schema = dj.schema("nnfabrik_v1_tuning")

from nndichromacy.tables.from_mei import MEISelector
from nndichromacy.tables.from_mei import MEIMethod
from nndichromacy.tables.from_mei import MEISeed
from nndichromacy.tables.from_mei import TrainedEnsembleModel
from nndichromacy.tables.measures import SignalToNoiseRatio

from mei import mixins

import numpy as np

@schema
class MEIMouse(mixins.MEITemplateMixin, dj.Computed):
    """MEI table template.

    To create a functional "MEI" table, create a new class that inherits from this template and decorate it with your
    preferred Datajoint schema. Next assign your trained model (or trained ensemble model) and your selector table to
    the class variables called "trained_model_table" and "selector_table". By default, the created table will point to
    the "MEIMethod" table in the Datajoint schema called "nnfabrik.main". This behavior can be changed by overwriting
    the class attribute called "method_table".
    """

    trained_model_table = TrainedEnsembleModel
    selector_table = MEISelector
    method_table = MEIMethod
    seed_table = MEISeed
    

dataset_key = {
    'dataset_hash': '72930dd1be6c229df4be82f74803262c' # Set the dataset here
}
    
np.random.seed(1000)
unit_keys = (SignalToNoiseRatio().Units() & dataset_key & 'unit_snr > 0.5').fetch('KEY')
unit_keys = [unit_keys[idx] for idx in np.random.choice(range(len(unit_keys)), size=150, replace=False)]

_unit_keys = []
for key in unit_keys:
        key['ensemble_hash'] = '8a2aaa598935fa40a3f7db6f84209d5f' # Set ensemble hash
        key['method_fn'] = 'mei.methods.gradient_ascent'
        key['method_hash'] = '89a56f45b906cc8036b4125cfb53da32'
        for seed in range(1, 6):
            key['mei_seed'] = seed

            _unit_keys.append(dict(**key))
        
MEIMouse.populate(_unit_keys, display_progress=True, reserve_jobs=True)

Connecting pawelp@134.2.168.16:3306


### Check progress

In [8]:
selection = {
    'ensemble_hash': '8a2aaa598935fa40a3f7db6f84209d5f', # Set ensemble hash here
    'method_hash': '89a56f45b906cc8036b4125cfb53da32',
}
MEIMouse & selection

method_fn  name of the method function,method_hash  hash of the method config,dataset_fn  name of the dataset loader function,dataset_hash  hash of the configuration object,ensemble_hash  the hash of the ensemble,unit_id  unique neuron identifier,data_key  unique session identifier,mei_seed  MEI seed,mei  the MEI as a tensor,score  some score depending on the used method function,output  object returned by the method function
mei.methods.gradient_ascent,89a56f45b906cc8036b4125cfb53da32,nndichromacy.datasets.static_loaders,72930dd1be6c229df4be82f74803262c,8a2aaa598935fa40a3f7db6f84209d5f,46,25133-3-11,1,=BLOB=,5.19093,=BLOB=
mei.methods.gradient_ascent,89a56f45b906cc8036b4125cfb53da32,nndichromacy.datasets.static_loaders,72930dd1be6c229df4be82f74803262c,8a2aaa598935fa40a3f7db6f84209d5f,46,25133-3-11,2,=BLOB=,5.18693,=BLOB=
mei.methods.gradient_ascent,89a56f45b906cc8036b4125cfb53da32,nndichromacy.datasets.static_loaders,72930dd1be6c229df4be82f74803262c,8a2aaa598935fa40a3f7db6f84209d5f,46,25133-3-11,3,=BLOB=,5.20806,=BLOB=
mei.methods.gradient_ascent,89a56f45b906cc8036b4125cfb53da32,nndichromacy.datasets.static_loaders,72930dd1be6c229df4be82f74803262c,8a2aaa598935fa40a3f7db6f84209d5f,46,25133-3-11,4,=BLOB=,5.20007,=BLOB=
mei.methods.gradient_ascent,89a56f45b906cc8036b4125cfb53da32,nndichromacy.datasets.static_loaders,72930dd1be6c229df4be82f74803262c,8a2aaa598935fa40a3f7db6f84209d5f,46,25133-3-11,5,=BLOB=,5.18847,=BLOB=
mei.methods.gradient_ascent,89a56f45b906cc8036b4125cfb53da32,nndichromacy.datasets.static_loaders,72930dd1be6c229df4be82f74803262c,8a2aaa598935fa40a3f7db6f84209d5f,75,25133-3-11,1,=BLOB=,7.52292,=BLOB=
mei.methods.gradient_ascent,89a56f45b906cc8036b4125cfb53da32,nndichromacy.datasets.static_loaders,72930dd1be6c229df4be82f74803262c,8a2aaa598935fa40a3f7db6f84209d5f,75,25133-3-11,2,=BLOB=,7.52766,=BLOB=
mei.methods.gradient_ascent,89a56f45b906cc8036b4125cfb53da32,nndichromacy.datasets.static_loaders,72930dd1be6c229df4be82f74803262c,8a2aaa598935fa40a3f7db6f84209d5f,75,25133-3-11,3,=BLOB=,7.52327,=BLOB=
mei.methods.gradient_ascent,89a56f45b906cc8036b4125cfb53da32,nndichromacy.datasets.static_loaders,72930dd1be6c229df4be82f74803262c,8a2aaa598935fa40a3f7db6f84209d5f,75,25133-3-11,4,=BLOB=,7.54612,=BLOB=
mei.methods.gradient_ascent,89a56f45b906cc8036b4125cfb53da32,nndichromacy.datasets.static_loaders,72930dd1be6c229df4be82f74803262c,8a2aaa598935fa40a3f7db6f84209d5f,75,25133-3-11,5,=BLOB=,7.50573,=BLOB=


### Export MEIs

In [9]:
import pickle
import torch

results = (MEIMouse & selection).fetch("KEY", "mei", as_dict=True, order_by='unit_id')
scores = (MEIMouse & selection).fetch("score", order_by='unit_id')

for key in results:
    key["mei"] = torch.load(key["mei"]).detach().cpu().numpy().squeeze() 

curated_unit_keys = []
max_seed = scores.reshape(150, 5).argmax(1)

for gabor_idx, idx in enumerate(range(0, len(results), 5)):
    curated_unit_keys.append(results[idx + max_seed[gabor_idx]])
    
with open('./results/meis.pkl', 'wb') as f:
    pickle.dump(curated_unit_keys, f)

In [71]:
curated_unit_keys

[{'method_fn': 'mei.methods.gradient_ascent',
  'method_hash': '89a56f45b906cc8036b4125cfb53da32',
  'dataset_fn': 'nndichromacy.datasets.static_loaders',
  'dataset_hash': '3390ff21bea702ae0dcdbcbfefb78e5e',
  'ensemble_hash': 'd3eee24b439f842c9962777da4a7f0c5',
  'unit_id': 5,
  'data_key': '25137-5-23',
  'mei_seed': 2,
  'mei': array([[-5.9948604e-07, -4.0229355e-07, -2.2256324e-07, ...,
           3.3707366e-07, -1.0331654e-07,  3.6791272e-07],
         [ 2.7137105e-07,  1.3702996e-07, -3.8213923e-07, ...,
          -2.5228871e-07, -5.1841425e-07,  1.6825201e-07],
         [-3.5353531e-07, -5.4066312e-08, -1.4530441e-07, ...,
          -2.1645324e-07,  1.7972405e-07, -3.2616282e-07],
         ...,
         [-3.0711769e-07, -2.9277004e-07, -2.0879477e-07, ...,
          -3.2788131e-07,  1.3167700e-07, -1.5060650e-07],
         [ 1.0541684e-07, -5.3171313e-08,  2.6871396e-07, ...,
           2.0485349e-07, -9.7617317e-08,  5.0769096e-08],
         [ 7.1022174e-07,  4.0947413e-07, -1