In [1]:
import numpy as np
from collections import OrderedDict
import torch
import pickle
from staticnet.base import CorePlusReadout2d, Elu1
from staticnet.cores import GaussianLaplaceCore
from staticnet.readouts import SpatialTransformerPyramid2dReadout
from staticnet.shifters import MLPShifter
from staticnet.modulators import MLPModulator
from featurevis import models
import datajoint as dj
base = dj.create_virtual_module('neurostatic_base', 'neurostatic_base')

Connecting zhiwei@at-database.ad.bcm.edu:3306


In [6]:
def build_network(configs):
    Core = GaussianLaplaceCore
    Readout = SpatialTransformerPyramid2dReadout
    Shifter = MLPShifter
    Modulator = MLPModulator

    core = Core(input_channels=configs['img_shape'][1], **configs['core_key'])
    ro_in_shape = CorePlusReadout2d.get_readout_in_shape(core, configs['img_shape'])
    readout = Readout(ro_in_shape, configs['n_neurons'], **configs['ro_key'])
    shifter = Shifter(configs['n_neurons'], 2, **configs['shift_key'])
    modulator = Modulator(configs['n_neurons'], 3, **configs['mod_key'])
    model = CorePlusReadout2d(core, readout, nonlinearity=Elu1(), shifter=shifter, modulator=modulator)
    return model

def load_network(configs, state_dict):
    model = build_network(configs)
    try:
        state_dict = {k: torch.as_tensor(state_dict[k][0].copy()) for k in state_dict.dtype.names}
    except AttributeError:
        state_dict = {k: torch.as_tensor(state_dict[k].copy()) for k in state_dict.keys()}
    mod_state_dict = model.state_dict()
    for k in set(mod_state_dict) - set(state_dict):
        log.warning('Could not find paramater {} setting to initialization value'.format(repr(k)))
        state_dict[k] = mod_state_dict[k]
    model.load_state_dict(state_dict)
    return model

# Load target neurons and model parameters

In [7]:
# Neurons and images of interest
imgs = torch.load("/src/static-networks/imgs.pt")
src_n_ids = torch.load("/src/static-networks/src_n_ids.pt")
                       
# Load model architecture configurations
with open('/src/static-networks/my_notebooks/group233_model_configs.pkl', 'rb') as handle:
    configs = pickle.load(handle)
    
# Load trained model state_dict
with open('/src/static-networks/my_notebooks/group233_model_state_dict_ls.pkl', 'rb') as handle:
    state_dict_ls = pickle.load(handle)

# Build network and load trained model state_dict

In [8]:
# Load model with trained state_dict from 4 different initialization seeds
all_models = [load_network(configs, state_dict_ls[i]['model']) for i in range(len(state_dict_ls))]

# Specify mean eye position
mean_eyepos = ([0, 0])

# Create model ensemble
mean_eyepos = torch.tensor(mean_eyepos, dtype=torch.float32,
                           device='cuda').unsqueeze(0)
model_ensemble = models.Ensemble(all_models, configs['key']['readout_key'], eye_pos=mean_eyepos,
                        average_batch=False, device='cuda')

05-01-2023:08:56:34 INFO     cores.py             101:	 Ignoring input {'core_hash': '28bc2fa358337c5012278f899b5b6947'} when creating GaussianLaplaceCore
05-01-2023:08:56:34 INFO     readouts.py          146:	 Ignoring input {'ro_hash': 'a206f6da6a16ea14081062a1e2436b48', 'ro_type': 'SpatialTransformerPyramid2d'} when creating SpatialTransformerPyramid2dReadout
05-01-2023:08:56:34 INFO     shifters.py           55:	 Ignoring input {'shift_hash': '05c69a4aeaeea5e48fa8fc5e70181d67', 'shift_type': 'MLP'} when creating MLPShifter
05-01-2023:08:56:34 INFO     shifters.py           27:	 Ignoring input {} when creating MLP
05-01-2023:08:56:34 INFO     modulators.py         47:	 Ignoring input {'mod_hash': 'a757e992ae449e3057ff1d512a51bd1e', 'mod_type': 'MLP'} when creating MLPModulator
05-01-2023:08:56:34 INFO     modulators.py         18:	 Ignoring input {} when creating MLP
05-01-2023:08:56:34 INFO     cores.py             101:	 Ignoring input {'core_hash': '28bc2fa358337c5012278f899b5b694

# Get model predicted response

In [19]:
resps = []
for i in range(len(src_n_ids)):
    # Specify neuron_id and images
    neuron_id = (base.Dataset.Unit & 'group_id = 233 and unit_id = {}'.format(src_n_ids[i])).fetch1('neuron_id')
    images = imgs[i]

    # Normalize each image to have average training statistics for masked images like MEI or DEI: mean = 0, std = 0.25
    norm_images = torch.stack([((im - im.mean()) / (im.std() + 1e-9)) * 0.25 for im in images.squeeze()])[:, None]
    
    with torch.no_grad():
        resp = model_ensemble(norm_images)[:, neuron_id].cpu().numpy().squeeze()
    resps.append(resp)

resps = np.stack(resps)

In [22]:
resps / resps[:, 0:1]

array([[1.        , 0.8476606 , 0.8277773 ],
       [1.        , 0.82828283, 0.8701826 ],
       [1.        , 0.88066673, 0.85804814],
       [1.        , 0.8504938 , 0.8503026 ],
       [1.        , 0.8448976 , 0.8603661 ],
       [1.        , 0.8730932 , 0.8695517 ],
       [1.        , 0.8513239 , 0.8483376 ],
       [1.        , 0.86936873, 0.8452019 ],
       [1.        , 0.8458835 , 0.890649  ],
       [1.        , 0.84287995, 0.84245694],
       [1.        , 0.85185605, 0.86849165],
       [1.        , 0.84870005, 0.8245643 ],
       [1.        , 0.8604489 , 0.8525374 ],
       [1.        , 0.8468542 , 0.886497  ],
       [1.        , 0.843001  , 0.87426955],
       [1.        , 0.857618  , 0.8714664 ],
       [1.        , 0.8969005 , 0.8547968 ],
       [1.        , 0.8747891 , 0.8496681 ],
       [1.        , 0.85644585, 0.91385823],
       [1.        , 0.8573006 , 0.87817   ],
       [1.        , 0.84672385, 0.8551972 ],
       [1.        , 0.8296251 , 0.8783343 ],
       [1.

In [23]:
with open('/src/static-networks/my_notebooks/group233_mei_dei_resps.pkl', 'wb') as handle:
    pickle.dump(resps, handle, protocol=pickle.HIGHEST_PROTOCOL)