In [1]:
import sys
sys.path.append('../')

In [2]:
import os
import torch
from mei.methods import gradient_ascent
from mei.modules import ConstrainedOutputModel
import urllib
import datajoint as dj

import matplotlib.pyplot as plt
from nnfabrik.builder import get_model


In [3]:
# only need to run once at the beginning
!pip3 install git+https://github.com/sinzlab/nnvision.git@model_builder

Collecting git+https://github.com/sinzlab/nnvision.git@model_builder
  Cloning https://github.com/sinzlab/nnvision.git (to revision model_builder) to c:\users\leon\appdata\local\temp\pip-req-build-eg4t1umx
  Resolved https://github.com/sinzlab/nnvision.git to commit 8218786fa1ef313ead3068af74f09c803701a06b
  Preparing metadata (setup.py): started
  Preparing metadata (setup.py): finished with status 'done'
Collecting neuralpredictors@ git+https://github.com/KonstantinWilleke/neuralpredictors.git@transformer_readout (from nnvision==0.1)
  Cloning https://github.com/KonstantinWilleke/neuralpredictors.git (to revision transformer_readout) to c:\users\leon\appdata\local\temp\pip-install-lnpjuv9h\neuralpredictors_d076acec423442d682ed8cd51a2a2000
  Resolved https://github.com/KonstantinWilleke/neuralpredictors.git to commit 5243b69ac2ff34d9aac33aeefdce0b4697da21ff
  Installing build dependencies: started
  Installing build dependencies: finished with status 'done'
  Getting requirements to b

  Running command git clone --filter=blob:none --quiet https://github.com/sinzlab/nnvision.git 'C:\Users\Leon\AppData\Local\Temp\pip-req-build-eg4t1umx'
  Running command git checkout -b model_builder --track origin/model_builder
  branch 'model_builder' set up to track 'origin/model_builder'.
  Switched to a new branch 'model_builder'
  Running command git clone --filter=blob:none --quiet https://github.com/KonstantinWilleke/neuralpredictors.git 'C:\Users\Leon\AppData\Local\Temp\pip-install-lnpjuv9h\neuralpredictors_d076acec423442d682ed8cd51a2a2000'
  Running command git checkout -b transformer_readout --track origin/transformer_readout
  branch 'transformer_readout' set up to track 'origin/transformer_readout'.
  Switched to a new branch 'transformer_readout'
  Running command git clone --filter=blob:none --quiet https://github.com/sinzlab/mei.git 'C:\Users\Leon\AppData\Local\Temp\pip-install-lnpjuv9h\mei_26a97558260f4f12ad1a1491c4fdd2e3'
  Running command git checkout -b inception_l

### Download monkey V4 example model files

In [8]:

# download example model
os.makedirs('./pretrained_models', exist_ok=True)
url = 'https://github.com/sinzlab/nnvision/raw/model_builder/nnvision/data/model_weights/v4_task_driven/task_driven_ensemble_model_01.pth.tar'
filepath = './pretrained_models/task_driven_ensemble_model_01.pth.tar'
urllib.request.urlretrieve(url, filepath)

('./pretrained_models/task_driven_ensemble_model_01.pth.tar',
 <http.client.HTTPMessage at 0x1a105c0eb50>)

In [None]:
import urllib.request

url = "https://example.com/file.txt"

In [None]:
# some hackity hack to supress datajoint promts (ignore)
dj.config["database.host"] = "{host_address}"
dj.config["database.user"] = "{user}"
dj.config["database.password"] = "{password}"

In [None]:
# method for loading example model
def get_v4_model():
    model_fn = 'nnvision.models.ptrmodels.task_core_gauss_readout'
    model_config =  {'input_channels': 1,
      'model_name': 'resnet50_l2_eps0_1',
      'layer_name': 'layer3.0',
      'pretrained': False,
      'bias': False,
      'final_batchnorm': True,
      'final_nonlinearity': True,
      'momentum': 0.1,
      'fine_tune': False,
      'init_mu_range': 0.4,
      'init_sigma_range': 0.6,
      'readout_bias': True,
      'gamma_readout': 3.0,
      'gauss_type': 'isotropic',
      'elu_offset': -1,
                     }
    
    data_info = {
        "all_sessions": {
            "input_dimensions": torch.Size([64, 1, 100, 100]),
            "input_channels": 1,
            "output_dimension": 1244,
            "img_mean": 124.54466,
            "img_std": 70.28,
        },
    }
    
    filename = os.path.join('./pretrained_models/task_driven_ensemble_model_01.pth.tar')
    state_dict = torch.load(filename, map_location='cpu')
    
    v4_data_task_sota = get_model(
        model_fn, model_config, seed=10, data_info=data_info, state_dict=state_dict,
    )
    
    return v4_data_task_sota, data_info



# Load the model


In [None]:
model, data_info = get_v4_model()

# Create single cell model
Here you select the index of the cell for which you'd like to optimize the MEI

In [None]:
cell_index = 403

In [None]:
single_cell_model = ConstrainedOutputModel(model, cell_index)

# Configurations
Configurations for the MEI optimization algorithm

In [None]:
seed = 8

# learning rate
lr = 1

# number of iterations for which the MEI is optimized
num_iterations = 1000 

# interval with which the objective will be evaluated
interval = 10

# shape of the MEI
input_shape = (1,) + data_info["all_sessions"]['input_dimensions'][1:] 

device='cpu'

In [None]:
method_config = dict(
    initial=dict(path="mei.initial.RandomNormal"),
    optimizer=dict(path="torch.optim.SGD", kwargs=dict(lr=lr)),
    stopper=dict(path="mei.stoppers.NumIterations", kwargs=dict(num_iterations=num_iterations)),
    objectives=[dict(path="mei.objectives.EvaluationObjective", kwargs=dict(interval=interval))],
    device=device,
)

# Optimize MEI

In [None]:
mei, final_evaluation, tracker_log = gradient_ascent(model=single_cell_model, config=method_config, seed=seed, shape=input_shape)   

# Visualize MEI

In [None]:
print(mei.shape)
print(torch.max(torch.abs(mei)))
plt.imshow(mei[0,0], cmap='gray')

In [None]:
plt.plot(
    tracker_log["mei.objectives.EvaluationObjective"]["times"],
    tracker_log["mei.objectives.EvaluationObjective"]["values"],
)
plt.gca().set_xlabel("# iteration")
plt.gca().set_ylabel("evaluation")

# Constraints 
So far, we have not introduced contraints into MEI optimization. 
This results in implausible MEIs with high pixel values and high contrast and a forever increasing activation. We need to constrain the overall norm of the MEI and also clip the range of pixel values to the range of pixel values which were in the training data. 
The norm is then established empirically such, that ideally, value clipping is not necessary.

In [None]:
p = 2
norm_value = 10
max_pixel_value = 2.2606
min_pixel_value = -1.75047

method_config = dict(
    initial=dict(path="mei.initial.RandomNormal"),
    optimizer=dict(path="torch.optim.SGD", kwargs=dict(lr=lr)),
    stopper=dict(path="mei.stoppers.NumIterations", kwargs=dict(num_iterations=num_iterations)),
    objectives=[dict(path="mei.objectives.EvaluationObjective", kwargs=dict(interval=interval))],
    postprocessing=dict(path="mei.postprocessing.PNormConstraintAndClip", kwargs=dict(p=p, norm_value=norm_value, max_pixel_value=max_pixel_value, min_pixel_value=min_pixel_value)),
    device=device,
)

In [None]:
mei, final_evaluation, tracker_log = gradient_ascent(model=single_cell_model, config=method_config, seed=seed,
                                                     shape=input_shape)   

In [None]:
print(mei.shape)
max_value = torch.max(torch.abs(mei))
print(max_value)
plt.imshow(mei[0,0], cmap='gray', vmin=min_pixel_value, vmax=max_pixel_value)
plt.colorbar()
plt.show()


In [None]:
plt.plot(
    tracker_log["mei.objectives.EvaluationObjective"]["times"],
    tracker_log["mei.objectives.EvaluationObjective"]["values"],
)
plt.gca().set_xlabel("# iteration")
plt.gca().set_ylabel("evaluation")