In [1]:
import os
import datajoint as dj
dj.config['database.host'] = os.environ['DJ_HOST']
dj.config['database.user'] = os.environ['DJ_USERNAME']
dj.config['database.password'] = os.environ['DJ_PASSWORD']
dj.config['enable_python_native_blobs'] = True
dj.config['display.limit'] = 200
        
name = 'mvi'
os.environ["DJ_SCHEMA_NAME"] = f"metrics_{name}"
dj.config["nnfabrik.schema_name"] = os.environ["DJ_SCHEMA_NAME"]

In [2]:
import re
import torch
import numpy as np
import pickle 
import json
import pandas as pd
pd.set_option('display.max_columns', 500)
pd.set_option('display.max_rows', 10)
import matplotlib as mpl
mpl.rcParams["figure.facecolor"] = 'w'
mpl.rcParams["axes.facecolor"] = 'w'
mpl.rcParams["savefig.facecolor"] = 'w'
mpl.rcParams["figure.dpi"] = 100
mpl.rcParams["figure.figsize"] = (3, 3)
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import spearmanr, pearsonr

from nnsysident.training.trainers import standard_trainer
from nnsysident.models.models import SE2DFullGaussian2d_Poisson, SE2DFullGaussian2d_ZIG, SE2DFullGaussian2d_ZIL
from nnsysident.models.ensemble_models import Ensemble
from nnsysident.utility.data_helpers import extract_data_key
from nnsysident.datasets.mouse_loaders import static_loaders

from neuralpredictors.measures.zero_inflated_losses import ZIGLoss, ZILLoss
from neuralpredictors.measures import corr

from dataport.bcm.static import PreprocessedMouseData

random_seed = 27121992
device = 'cuda'

Connecting konstantin@134.76.19.44:3306


---

In [3]:
# Load model architecture configurations
with open('group233_model_configs.pkl', 'rb') as handle:
    zhiwei_configs = pickle.load(handle)
    
modulator_kwargs = zhiwei_configs["mod_key"]
shifter_kwargs = zhiwei_configs["shift_key"]
modulator_kwargs["bias"] = False
shifter_kwargs["bias"] = False

In [4]:
datasets =  [{'animal_id': 26614,
              'session': 1,
              'scan_idx': 16,
              'scan_purpose': 'imagenet'},
             {'animal_id': 26614,
              'session': 2,
              'scan_idx': 17,
              'scan_purpose': 'dei_control_pair'},
             {'animal_id': 26726,
              'session': 6,
              'scan_idx': 11,
              'scan_purpose': 'imagenet'},
             {'animal_id': 26726,
              'session': 7,
              'scan_idx': 13,
              'scan_purpose': 'dei_control_pair'},
             {'animal_id': 26942,
              'session': 1,
              'scan_idx': 11,
              'scan_purpose': 'imagenet'},
             {'animal_id': 26942,
              'session': 2,
              'scan_idx': 8,
              'scan_purpose': 'dei_control_pair'},
             {'animal_id': 27468,
              'session': 3,
              'scan_idx': 12,
              'scan_purpose': 'imagenet'},
             {'animal_id': 27468,
              'session': 4,
              'scan_idx': 7,
              'scan_purpose': 'dei_control_pair'}]

## Imagenet Data

In [5]:
imagenet_key = datasets[0]
assert imagenet_key["scan_purpose"] == "imagenet"
paths = ["./data/static{}-{}-{}-GrayImageNet-7bed7f7379d99271be5d144e5e59a8e7.zip".format(imagenet_key["animal_id"], imagenet_key["session"], imagenet_key["scan_idx"])]
img_data_key = extract_data_key(paths[0])

dataset_config = {'paths': paths, 
                  'batch_size': 64, 
                  'seed': random_seed,
                  'loader_outputs': ["images", "responses", "pupil_center", "behavior"],
                  # 'loader_outputs': ["images", "responses"],
                  'normalize': True,
                  'exclude': ["images"],
                  'subtract_behavior_mean': True
                  }
    
img_dataloaders = static_loaders(**dataset_config)
img_dataset = img_dataloaders["test"][img_data_key].dataset

---

## Models

### ZIG

In [None]:
# for random_seed in np.arange(5):
#     loc = np.exp(-10)

#     zig_model_config = {
#         "layers": 4,
#         "hidden_channels": 64,
#         "feature_reg_weight": 0.78,
#         "init_mu_range": 0.55,
#         "init_sigma": 0.4,
#         'grid_mean_predictor': {'type': 'cortex',
#                                   'input_dimensions': 2,
#                                   'hidden_layers': 0,
#                                   'hidden_features': 0,
#                                   'final_tanh': False},
#         'zero_thresholds': {img_data_key: loc},

#         "input_kern": 15,
#         "gamma_input": 1,
#         "hidden_kern": 13,
#         "depth_separable": True,
#         "k_image_dependent": True,
#         "modulator_kwargs": modulator_kwargs,
#         "shifter_kwargs": shifter_kwargs,
#     }


#     zig_model = SE2DFullGaussian2d_ZIG().build_model(img_dataloaders, random_seed, **zig_model_config)
#     zig_model.to(device);

#     score, output, state_dict = standard_trainer(zig_model,
#                                                img_dataloaders,
#                                                random_seed, 
#                                                loss_function="ZIGLoss", 
#                                                stop_function="get_loss", 
#                                                track_training=False, 
#                                                maximize=False)
#     zig_model.eval();
#     torch.save(state_dict, "ZIG_statedict" + img_data_key + f"-seed{random_seed}" + ".inshallah")

In [None]:
import torch

torch.__version__

In [17]:
loc = np.exp(-10)

zig_model_config = {
    "layers": 4,
    "hidden_channels": 64,
    "feature_reg_weight": 0.78,
    "init_mu_range": 0.55,
    "init_sigma": 0.4,
    'grid_mean_predictor': {'type': 'cortex',
                              'input_dimensions': 2,
                              'hidden_layers': 0,
                              'hidden_features': 0,
                              'final_tanh': False},
    'zero_thresholds': {img_data_key: loc},

    "input_kern": 15,
    "gamma_input": 1,
    "hidden_kern": 13,
    "depth_separable": True,
    "k_image_dependent": True,
    "modulator_kwargs": modulator_kwargs,
    "shifter_kwargs": shifter_kwargs,
}


zig_model = SE2DFullGaussian2d_ZIG().build_model(img_dataloaders, random_seed, **zig_model_config)
zig_model.to(device);

In [18]:
score, output, state_dict = standard_trainer(zig_model,
                                           img_dataloaders,
                                           random_seed, 
                                           loss_function="ZIGLoss", 
                                           stop_function="get_loss", 
                                           track_training=True, 
                                           maximize=False)
zig_model.eval();
# torch.save(state_dict, "ZIG_statedict" + img_data_key + ".inshallah")


# zig_model.load_state_dict(torch.load("ZIG_statedict" + img_data_key + ".inshallah"))
# zig_model.eval();

val_correlation -0.0041950606
val_loss -12894278.0
train_loss -116609970.0


Epoch 1: 100% 71/71 [00:03<00:00, 21.92it/s]


val_correlation 0.04720706
val_loss -16144846.0
train_loss -145906240.0


Epoch 2: 100% 71/71 [00:03<00:00, 22.00it/s]


val_correlation 0.060895365
val_loss -16157542.0
train_loss -146057570.0


Epoch 3: 100% 71/71 [00:03<00:00, 22.11it/s]


val_correlation 0.069919236
val_loss -16161837.0
train_loss -146125170.0


Epoch 4: 100% 71/71 [00:03<00:00, 22.16it/s]


val_correlation 0.08081096
val_loss -16170463.0
train_loss -146192220.0


Epoch 5: 100% 71/71 [00:03<00:00, 21.96it/s]


val_correlation 0.08646358
val_loss -16150144.0
train_loss -146039440.0


Epoch 6: 100% 71/71 [00:03<00:00, 20.77it/s]


val_correlation 0.093453616
val_loss -16170310.0
train_loss -146275900.0


Epoch 7: 100% 71/71 [00:03<00:00, 21.79it/s]


val_correlation 0.10304318
val_loss -16173260.0
train_loss -146338300.0


Epoch 8: 100% 71/71 [00:03<00:00, 21.68it/s]


val_correlation 0.11010283
val_loss -16181480.0
train_loss -146410600.0


Epoch 9: 100% 71/71 [00:03<00:00, 21.90it/s]


val_correlation 0.120379664
val_loss -16165231.0
train_loss -146284180.0


Epoch 10: 100% 71/71 [00:03<00:00, 22.38it/s]


val_correlation 0.13069174
val_loss -16211837.0
train_loss -146761200.0


Epoch 11: 100% 71/71 [00:03<00:00, 22.42it/s]


val_correlation 0.14418732
val_loss -16209196.0
train_loss -146848290.0


Epoch 12: 100% 71/71 [00:03<00:00, 22.48it/s]


val_correlation 0.15466835
val_loss -16226073.0
train_loss -147049140.0


Epoch 13: 100% 71/71 [00:03<00:00, 22.02it/s]


val_correlation 0.16332681
val_loss -16235629.0
train_loss -147150140.0


Epoch 14: 100% 71/71 [00:03<00:00, 22.05it/s]


val_correlation 0.16676965
val_loss -16238948.0
train_loss -147191500.0


Epoch 15: 100% 71/71 [00:03<00:00, 22.15it/s]


val_correlation 0.17165507
val_loss -16224696.0
train_loss -147233920.0


Epoch 16: 100% 71/71 [00:03<00:00, 22.13it/s]


val_correlation 0.1802882
val_loss -16240865.0
train_loss -147344700.0


Epoch 17: 100% 71/71 [00:03<00:00, 21.15it/s]


val_correlation 0.18126956
val_loss -16251524.0
train_loss -147427280.0


Epoch 18: 100% 71/71 [00:03<00:00, 22.30it/s]


val_correlation 0.18283702
val_loss -16234316.0
train_loss -147340910.0


Epoch 19: 100% 71/71 [00:03<00:00, 22.31it/s]


val_correlation 0.18650793
val_loss -16256026.0
train_loss -147615570.0


Epoch 20: 100% 71/71 [00:03<00:00, 22.56it/s]


val_correlation 0.18997131
val_loss -16262447.0
train_loss -147762660.0


Epoch 21: 100% 71/71 [00:03<00:00, 22.33it/s]


val_correlation 0.19346991
val_loss -16271892.0
train_loss -147853150.0


Epoch 22: 100% 71/71 [00:03<00:00, 22.14it/s]


val_correlation 0.1951822
val_loss -16277024.0
train_loss -147944160.0


Epoch 23: 100% 71/71 [00:03<00:00, 22.40it/s]


val_correlation 0.19258626
val_loss -16280436.0
train_loss -148093390.0


Epoch 24: 100% 71/71 [00:03<00:00, 22.13it/s]


val_correlation 0.20129001
val_loss -16288058.0
train_loss -148142600.0


Epoch 25: 100% 71/71 [00:03<00:00, 22.29it/s]


val_correlation 0.19838548
val_loss -16279951.0
train_loss -148187440.0


Epoch 26: 100% 71/71 [00:03<00:00, 22.34it/s]


val_correlation 0.19857763
val_loss -16284840.0
train_loss -148262740.0


Epoch 27: 100% 71/71 [00:03<00:00, 22.30it/s]


val_correlation 0.19976783
val_loss -16287000.0
train_loss -148315540.0


Epoch 28: 100% 71/71 [00:03<00:00, 22.07it/s]


val_correlation 0.19177033
val_loss -16280662.0
train_loss -148165600.0


Epoch 29: 100% 71/71 [00:03<00:00, 21.55it/s]


val_correlation 0.19778611
val_loss -16280282.0
train_loss -148357700.0


Epoch 30: 100% 71/71 [00:03<00:00, 21.60it/s]


val_correlation 0.1997643
val_loss -16290236.0
train_loss -148406530.0


Epoch 31: 100% 71/71 [00:03<00:00, 21.71it/s]


val_correlation 0.19730262
val_loss -16288220.0
train_loss -148381660.0


Epoch 32: 100% 71/71 [00:03<00:00, 21.85it/s]


val_correlation 0.19704758
val_loss -16279913.0
train_loss -148436460.0


Epoch 33: 100% 71/71 [00:03<00:00, 21.40it/s]


val_correlation 0.20100723
val_loss -16277782.0
train_loss -148531570.0


Epoch 34: 100% 71/71 [00:03<00:00, 21.52it/s]


val_correlation 0.19903117
val_loss -16293164.0
train_loss -148516540.0


Epoch 35: 100% 71/71 [00:03<00:00, 22.06it/s]


val_correlation 0.1997643
val_loss -16290236.0
train_loss -148406540.0


Epoch 36: 100% 71/71 [00:03<00:00, 22.18it/s]


Epoch 00036: reducing learning rate of group 0 to 1.5000e-03.
val_correlation 0.19557859
val_loss -16261744.0
train_loss -148322290.0


Epoch 37: 100% 71/71 [00:03<00:00, 21.77it/s]


val_correlation 0.20350087
val_loss -16297288.0
train_loss -148729360.0


Epoch 38: 100% 71/71 [00:03<00:00, 21.69it/s]


val_correlation 0.20367672
val_loss -16290236.0
train_loss -148751800.0


Epoch 39: 100% 71/71 [00:03<00:00, 21.80it/s]


val_correlation 0.20595136
val_loss -16280965.0
train_loss -148891950.0


Epoch 40: 100% 71/71 [00:03<00:00, 22.06it/s]


val_correlation 0.1997643
val_loss -16290236.0
train_loss -148406530.0


Epoch 41: 100% 71/71 [00:03<00:00, 21.78it/s]


val_correlation 0.20527136
val_loss -16286613.0
train_loss -148714180.0


Epoch 42: 100% 71/71 [00:03<00:00, 22.18it/s]


Epoch 00042: reducing learning rate of group 0 to 4.5000e-04.
val_correlation 0.20396133
val_loss -16281092.0
train_loss -148746660.0


Epoch 43: 100% 71/71 [00:03<00:00, 20.99it/s]


val_correlation 0.20413546
val_loss -16290758.0
train_loss -148827120.0


Epoch 44: 100% 71/71 [00:03<00:00, 21.74it/s]


val_correlation 0.20604365
val_loss -16287606.0
train_loss -148848880.0


Epoch 45: 100% 71/71 [00:03<00:00, 21.54it/s]


val_correlation: 0.29963976

zig_val_loss: -16377972.0

zig_train_loss: -150074690.0

### Poisson

In [None]:
# for random_seed in np.arange(5):

#     poisson_model_config = {
#         "layers": 4,
#         "hidden_channels": 64,
#         "gamma_readout": 0.78,
#         "init_mu_range": 0.55,
#         "init_sigma": 0.4,
#         'grid_mean_predictor': {'type': 'cortex',
#                                   'input_dimensions': 2,
#                                   'hidden_layers': 0,
#                                   'hidden_features': 0,
#                                   'final_tanh': False},
#         "input_kern": 15,
#         "gamma_input": 1,
#         "hidden_kern": 13,
#         "depth_separable": True,
#         "modulator_kwargs": modulator_kwargs,
#         "shifter_kwargs": shifter_kwargs,  
#     }


#     poisson_model = SE2DFullGaussian2d_Poisson().build_model(img_dataloaders, random_seed, **poisson_model_config)
#     poisson_model.to(device);

#     score, output, state_dict = standard_trainer(poisson_model,
#                                                  img_dataloaders,
#                                                  random_seed,
#                                                  loss_function="PoissonLoss",
#                                                  track_training=False, )
#     poisson_model.eval();
#     torch.save(state_dict, "Poisson_statedict" + img_data_key + f"-seed{random_seed} + ".inshallah"")

In [19]:
poisson_model_config = {
    "layers": 4,
    "hidden_channels": 64,
    "gamma_readout": 0.78,
    "init_mu_range": 0.55,
    "init_sigma": 0.4,
    'grid_mean_predictor': {'type': 'cortex',
                              'input_dimensions': 2,
                              'hidden_layers': 0,
                              'hidden_features': 0,
                              'final_tanh': False},
    "input_kern": 15,
    "gamma_input": 1,
    "hidden_kern": 13,
    "depth_separable": True,
    "modulator_kwargs": modulator_kwargs,
    "shifter_kwargs": shifter_kwargs,  
}

poisson_model = SE2DFullGaussian2d_Poisson().build_model(img_dataloaders, random_seed, **poisson_model_config)
poisson_model.to(device);

In [20]:
score, output, state_dict = standard_trainer(poisson_model,
                                             img_dataloaders,
                                             random_seed,
                                             loss_function="PoissonLoss",
                                             track_training=True, )
poisson_model.eval();
# torch.save(state_dict, "Poisson_statedict" + img_data_key + ".inshallah")


# poisson_model.load_state_dict(torch.load("Poisson_statedict" + img_data_key + ".inshallah"))
# poisson_model.eval();

val_correlation -0.0021692142
val_loss 4523065.0
train_loss 40846644.0


Epoch 1: 100% 71/71 [00:02<00:00, 24.93it/s]


val_correlation 0.06089262
val_loss 2475356.2
train_loss 22528398.0


Epoch 2: 100% 71/71 [00:02<00:00, 25.34it/s]


val_correlation 0.072875775
val_loss 2464938.5
train_loss 22345140.0


Epoch 3: 100% 71/71 [00:02<00:00, 25.62it/s]


val_correlation 0.09017557
val_loss 2454380.5
train_loss 22239252.0


Epoch 4: 100% 71/71 [00:02<00:00, 25.63it/s]


val_correlation 0.11448877
val_loss 2422110.5
train_loss 21965224.0


Epoch 5: 100% 71/71 [00:02<00:00, 25.58it/s]


val_correlation 0.12299193
val_loss 2448570.2
train_loss 22124994.0


Epoch 6: 100% 71/71 [00:02<00:00, 25.59it/s]


val_correlation 0.14777163
val_loss 2380002.5
train_loss 21473530.0


Epoch 7: 100% 71/71 [00:02<00:00, 25.15it/s]


val_correlation 0.17114505
val_loss 2359384.2
train_loss 21275482.0


Epoch 8: 100% 71/71 [00:02<00:00, 25.59it/s]


val_correlation 0.19159639
val_loss 2326209.5
train_loss 20952694.0


Epoch 9: 100% 71/71 [00:02<00:00, 25.45it/s]


val_correlation 0.21444872
val_loss 2291545.5
train_loss 20566532.0


Epoch 10: 100% 71/71 [00:02<00:00, 25.75it/s]


val_correlation 0.226796
val_loss 2275843.5
train_loss 20373322.0


Epoch 11: 100% 71/71 [00:02<00:00, 25.70it/s]


val_correlation 0.2436114
val_loss 2247989.5
train_loss 20040616.0


Epoch 12: 100% 71/71 [00:02<00:00, 25.67it/s]


val_correlation 0.25637776
val_loss 2229100.5
train_loss 19792282.0


Epoch 13: 100% 71/71 [00:02<00:00, 25.53it/s]


val_correlation 0.2629218
val_loss 2219726.8
train_loss 19673228.0


Epoch 14: 100% 71/71 [00:02<00:00, 24.94it/s]


val_correlation 0.27408484
val_loss 2198571.5
train_loss 19447516.0


Epoch 15: 100% 71/71 [00:02<00:00, 25.66it/s]


val_correlation 0.2732739
val_loss 2213462.0
train_loss 19465218.0


Epoch 16: 100% 71/71 [00:02<00:00, 25.78it/s]


val_correlation 0.28695077
val_loss 2172455.5
train_loss 19131114.0


Epoch 17: 100% 71/71 [00:02<00:00, 25.51it/s]


val_correlation 0.2864941
val_loss 2186573.8
train_loss 19253208.0


Epoch 18: 100% 71/71 [00:02<00:00, 25.47it/s]


val_correlation 0.2904621
val_loss 2174763.8
train_loss 19102808.0


Epoch 19: 100% 71/71 [00:02<00:00, 25.62it/s]


val_correlation 0.29799125
val_loss 2159633.2
train_loss 18860118.0


Epoch 20: 100% 71/71 [00:02<00:00, 25.40it/s]


val_correlation 0.2989345
val_loss 2159043.5
train_loss 18787244.0


Epoch 21: 100% 71/71 [00:02<00:00, 25.53it/s]


val_correlation 0.29993096
val_loss 2159457.0
train_loss 18811408.0


Epoch 22: 100% 71/71 [00:02<00:00, 25.63it/s]


val_correlation 0.30322784
val_loss 2151338.5
train_loss 18679836.0


Epoch 23: 100% 71/71 [00:02<00:00, 25.89it/s]


val_correlation 0.306771
val_loss 2142980.5
train_loss 18578364.0


Epoch 24: 100% 71/71 [00:02<00:00, 25.75it/s]


val_correlation 0.30526882
val_loss 2150338.5
train_loss 18587190.0


Epoch 25: 100% 71/71 [00:02<00:00, 25.83it/s]


val_correlation 0.30340692
val_loss 2157931.5
train_loss 18582934.0


Epoch 26: 100% 71/71 [00:02<00:00, 25.76it/s]


val_correlation 0.30748805
val_loss 2150942.2
train_loss 18526000.0


Epoch 27: 100% 71/71 [00:02<00:00, 25.13it/s]


val_correlation 0.30622983
val_loss 2145603.8
train_loss 18473752.0


Epoch 28: 100% 71/71 [00:02<00:00, 25.60it/s]


val_correlation 0.3083354
val_loss 2149118.8
train_loss 18657004.0


Epoch 29: 100% 71/71 [00:02<00:00, 25.51it/s]


val_correlation 0.31233382
val_loss 2133266.5
train_loss 18354588.0


Epoch 30: 100% 71/71 [00:02<00:00, 25.84it/s]


val_correlation 0.3139222
val_loss 2135365.5
train_loss 18359218.0


Epoch 31: 100% 71/71 [00:02<00:00, 23.97it/s]


val_correlation 0.31518897
val_loss 2131033.2
train_loss 18290440.0


Epoch 32: 100% 71/71 [00:02<00:00, 25.54it/s]


val_correlation 0.31230223
val_loss 2139440.8
train_loss 18348586.0


Epoch 33: 100% 71/71 [00:02<00:00, 25.18it/s]


val_correlation 0.31495655
val_loss 2139164.5
train_loss 18296958.0


Epoch 34: 100% 71/71 [00:02<00:00, 25.66it/s]


val_correlation 0.318402
val_loss 2121950.0
train_loss 18197576.0


Epoch 35: 100% 71/71 [00:02<00:00, 24.25it/s]


val_correlation 0.31611818
val_loss 2131259.0
train_loss 18119764.0


Epoch 36: 100% 71/71 [00:02<00:00, 24.18it/s]


val_correlation 0.31726065
val_loss 2126163.0
train_loss 18124930.0


Epoch 37: 100% 71/71 [00:02<00:00, 25.53it/s]


val_correlation 0.3171519
val_loss 2134908.5
train_loss 18133332.0


Epoch 38: 100% 71/71 [00:02<00:00, 25.47it/s]


val_correlation 0.31794217
val_loss 2123299.0
train_loss 18097894.0


Epoch 39: 100% 71/71 [00:02<00:00, 25.63it/s]


val_correlation 0.318402
val_loss 2121950.0
train_loss 18197576.0


Epoch 40: 100% 71/71 [00:02<00:00, 25.42it/s]


Epoch 00040: reducing learning rate of group 0 to 1.5000e-03.
val_correlation 0.31376648
val_loss 2137814.5
train_loss 18282880.0


Epoch 41: 100% 71/71 [00:02<00:00, 25.59it/s]


val_correlation 0.3244521
val_loss 2109876.0
train_loss 17895174.0


Epoch 42: 100% 71/71 [00:02<00:00, 25.57it/s]


val_correlation 0.31940612
val_loss 2129886.5
train_loss 17994616.0


Epoch 43: 100% 71/71 [00:02<00:00, 25.64it/s]


val_correlation 0.32657978
val_loss 2108896.2
train_loss 17903320.0


Epoch 44: 100% 71/71 [00:02<00:00, 25.76it/s]


val_correlation 0.3215325
val_loss 2124483.0
train_loss 17973250.0


Epoch 45: 100% 71/71 [00:02<00:00, 25.82it/s]


val_correlation 0.32570863
val_loss 2106250.2
train_loss 17837070.0


Epoch 46: 100% 71/71 [00:02<00:00, 24.99it/s]


val_correlation 0.3239352
val_loss 2114610.5
train_loss 17776134.0


Epoch 47: 100% 71/71 [00:02<00:00, 25.52it/s]


val_correlation 0.32291067
val_loss 2115973.2
train_loss 17809742.0


Epoch 48: 100% 71/71 [00:02<00:00, 25.55it/s]


val_correlation 0.32657978
val_loss 2108896.2
train_loss 17903318.0


Epoch 49: 100% 71/71 [00:02<00:00, 25.57it/s]


Epoch 00049: reducing learning rate of group 0 to 4.5000e-04.
val_correlation 0.32356533
val_loss 2115160.5
train_loss 17838936.0


Epoch 50: 100% 71/71 [00:02<00:00, 25.66it/s]


val_correlation 0.32605797
val_loss 2104366.8
train_loss 17783328.0


Epoch 51: 100% 71/71 [00:02<00:00, 25.58it/s]


val_correlation 0.32780504
val_loss 2100218.0
train_loss 17795326.0


Epoch 52: 100% 71/71 [00:02<00:00, 25.71it/s]


val_correlation 0.3246255
val_loss 2114629.5
train_loss 17759352.0


Epoch 53: 100% 71/71 [00:02<00:00, 25.36it/s]


val_correlation 0.32496527
val_loss 2110148.5
train_loss 17792904.0


Epoch 54: 100% 71/71 [00:02<00:00, 25.45it/s]


val_correlation 0.32512647
val_loss 2109753.5
train_loss 17774310.0


Epoch 55: 100% 71/71 [00:02<00:00, 25.53it/s]


val_correlation 0.32743704
val_loss 2104880.0
train_loss 17724062.0


Epoch 56: 100% 71/71 [00:02<00:00, 25.55it/s]


In [21]:
output

{'val_correlation': array([-0.00216921,  0.06089262,  0.07287578,  0.09017557,  0.11448877,
         0.12299193,  0.14777163,  0.17114505,  0.19159639,  0.21444872,
         0.226796  ,  0.2436114 ,  0.25637776,  0.2629218 ,  0.27408484,
         0.2732739 ,  0.28695077,  0.2864941 ,  0.2904621 ,  0.29799125,
         0.2989345 ,  0.29993096,  0.30322784,  0.306771  ,  0.30526882,
         0.30340692,  0.30748805,  0.30622983,  0.3083354 ,  0.31233382,
         0.3139222 ,  0.31518897,  0.31230223,  0.31495655,  0.318402  ,
         0.31611818,  0.31726065,  0.3171519 ,  0.31794217,  0.318402  ,
         0.31376648,  0.3244521 ,  0.31940612,  0.32657978,  0.3215325 ,
         0.32570863,  0.3239352 ,  0.32291067,  0.32657978,  0.32356533,
         0.32605797,  0.32780504,  0.3246255 ,  0.32496527,  0.32512647,
         0.32743704], dtype=float32),
 'val_loss': array([4523065. , 2475356.2, 2464938.5, 2454380.5, 2422110.5, 2448570.2,
        2380002.5, 2359384.2, 2326209.5, 2291545.5, 22

In [22]:
score

0.32780504

val_correlation 0.32643822

val_loss 2107371.0

train_loss 17724226.0

---

# Compare DEIs/MEIs

### DEI data

In [None]:
idx = np.array([(dat["animal_id"] == imagenet_key["animal_id"]) & (dat["session"] != imagenet_key["session"]) & (dat["scan_idx"] != imagenet_key["scan_idx"]) for dat in datasets])
dei_key = np.array(datasets)[idx].item()

assert dei_key["scan_purpose"] == "dei_control_pair"
paths = ["./data/static{}-{}-{}-GrayImageNetDEIInfo-7bed7f7379d99271be5d144e5e59a8e7.zip".format(dei_key["animal_id"], dei_key["session"], dei_key["scan_idx"])]
dei_data_key = extract_data_key(paths[0])

dataset_config = {'paths': paths, 
                  'batch_size': 64, 
                  'seed': random_seed,
                  'return_test_sampler': True,
                  'tier': "test",
                  'loader_outputs': ["images", 'responses', 'trial_idx', "dei_unit_ids", "dei_src_unit_ids", "dei_mean_distances"],
                  'normalize': True,
                  'exclude': ["images", "trial_idx", "dei_unit_ids", "dei_src_unit_ids", "dei_mean_distances"],
                  'subtract_behavior_mean': True}

dei_dataloaders = static_loaders(**dataset_config)

dei_dataset = dei_dataloaders["test"][dei_data_key].dataset

In [None]:
images, responses, trial_idxs, dei_unit_ids, dei_src_unit_ids, dei_mean_distances = [], [], [], [], [], []
for image, response, trial_idx, dei_unit_id, dei_src_unit_id, dei_mean_distance in dei_dataloaders["test"][dei_data_key]:
    if (len(response) == 20) & (torch.unique(dei_mean_distance <= 10)):
        images.append(image)
        responses.append(response)
        trial_idxs.append(trial_idx)
        dei_unit_ids.append(dei_unit_id)
        dei_src_unit_ids.append(dei_src_unit_id)
        dei_mean_distances.append(dei_mean_distance)
images = torch.stack(images)
responses = torch.stack(responses)
trial_idxs = torch.stack(trial_idxs).cpu().data.numpy()
dei_unit_ids = torch.stack(dei_unit_ids).cpu().data.numpy()
dei_src_unit_ids = torch.stack(dei_src_unit_ids).cpu().data.numpy()
dei_mean_distances = torch.stack(dei_mean_distances).cpu().data.numpy()

In [None]:
# Get possible unit ids (in the source-dataset frame)
possible_src_unit_ids = np.unique(dei_src_unit_ids, axis=1).squeeze()

# Sort according to mean distances (increasing)
src_sort_idx = np.argsort(np.unique(dei_mean_distances, axis=1).squeeze())
possible_src_unit_ids = possible_src_unit_ids[src_sort_idx]

# Remove duplicates (from several DEIs/MEI)
_, idx = np.unique(possible_src_unit_ids, return_index=True)
possible_src_unit_ids = possible_src_unit_ids[np.sort(idx)]

In [None]:
zig_model = Ensemble(zig_se2d_fullgaussian2d, zig_model_config, img_dataloaders, "ZIG_statedict" + img_data_key, np.arange(5), device=device)
poisson_model = Ensemble(poisson_se2d_fullgaussian2d, poisson_model_config, img_dataloaders, "Poisson_statedict" + img_data_key, np.arange(5), device=device)

# zig_model.shifter = None
# poisson_model.shifter = None
# poisson_model.modulator = None

In [None]:
zig_means = np.full((3, len(possible_src_unit_ids)), np.nan)
zig_variances = np.full((3, len(possible_src_unit_ids)), np.nan)
poisson_means = np.full((3, len(possible_src_unit_ids)), np.nan)
poisson_variances = np.full((3, len(possible_src_unit_ids)), np.nan)
real_resp_means = np.full((3, len(possible_src_unit_ids)), np.nan)
real_resp_vars = np.full((3, len(possible_src_unit_ids)), np.nan)
imgs = np.full((3, len(possible_src_unit_ids), 1, 36, 64), np.nan)
for i, possible_src_unit_id in enumerate(possible_src_unit_ids):
    image_idx = np.unique(np.where(dei_src_unit_ids == possible_src_unit_id)[0])

    dei_neuron_id = np.unique(dei_unit_ids[image_idx]).item()
    src_neuron_id = np.unique(dei_src_unit_ids[image_idx]).item()
    src_neuron_idx = np.where(img_dataset.neurons.unit_ids == src_neuron_id)[0].item()
    dei_neuron_idx = np.where(dei_dataset.neurons.unit_ids == dei_neuron_id)[0].item()

    img = torch.unique(images[image_idx], dim=1).squeeze(1)
    mei_dei_idx = [0, 1, 2]
    if len(image_idx) != 3:
        idx = np.where(np.isin(dei_dataset.trial_info.trial_idx, trial_idxs[image_idx]))[0]
        if 'mask_fixed_mei' in np.unique(dei_dataset.trial_info.frame_image_class[idx]):
            mei_dei_idx = [0, 1] if len(image_idx) == 2 else [0]
        else:
            mei_dei_idx = [1, 2] if len(image_idx) == 2 else [1]
    imgs[mei_dei_idx, i, :, :, :] = img.cpu().data
    
    # TODO: Keep this line?
#     img = torch.stack([((im - im.mean()) / (im.std())) for im in img.squeeze()])[:, None]

    behavior = torch.zeros((img.shape[0], 3)).to(device)
    pupil_center = torch.zeros((img.shape[0], 2)).to(device)
    
    zig_means_ = zig_model.predict_mean(img, data_key=img_data_key, behavior=behavior, pupil_center=pupil_center).cpu().data.numpy()
    zig_variances_ = zig_model.predict_variance(img, data_key=img_data_key, behavior=behavior, pupil_center=pupil_center).cpu().data.numpy()
    poisson_means_ = poisson_model.predict_mean(img, data_key=img_data_key, behavior=behavior, pupil_center=pupil_center).cpu().data.numpy()
    poisson_variances_ = poisson_model.predict_variance(img, data_key=img_data_key, behavior=behavior, pupil_center=pupil_center).cpu().data.numpy()

    zig_means[mei_dei_idx, i] = zig_means_[:, src_neuron_idx]
    zig_variances[mei_dei_idx, i] = zig_variances_[:, src_neuron_idx]
    poisson_means[mei_dei_idx, i] = poisson_means_[:, src_neuron_idx]
    poisson_variances[mei_dei_idx, i] = poisson_variances_[:, src_neuron_idx]
    
    real_resp_means[mei_dei_idx, i] = np.mean(responses[image_idx].cpu().data.numpy(), axis=1)[:, dei_neuron_idx]
    real_resp_vars[mei_dei_idx, i] = np.var(responses[image_idx].cpu().data.numpy(), axis=1)[:, dei_neuron_idx]
keep_idx = ~np.isnan(zig_means).any(axis=0)

### Compare Zhiwei Model with Konstantin model

In [None]:
with open(r"group233_mei_dei_resps.pkl", "rb") as input_file:
    e = pickle.load(input_file).T

In [None]:
fig, axes = plt.subplots(4, 1, figsize=(15, 10), dpi=150, sharex=True, sharey=True)
fontsize = 10


y_zhiwei = e / e[0, :]
y_zig = zig_means / zig_means[0, :]
y_zig = y_zig[:, keep_idx]
y_poisson = poisson_means / poisson_means[0, :]
y_poisson = y_poisson[:, keep_idx]
y_real = real_resp_means / real_resp_means[0, :]
y_real = y_real[:, keep_idx]
x = np.arange(y_real.shape[1])

# Zhiwei
for i in range(3):
    axes[0].plot(x, y_zhiwei[i,:], ls="", marker="x")

# ZIG
for i in range(3):
    axes[1].plot(x, y_zig[i,:], ls="", marker="x")
    
# Poisson
for i in range(3):
    axes[2].plot(x, y_poisson[i,:], ls="", marker="x")
    axes[2].ticklabel_format(useOffset=False)
    
# Real data
for i, label in enumerate(["MEI", "DEI1", "DEI2"]):
    axes[3].plot(x, y_real[i,:], ls="", marker="x", label=label)
    
    
axes[0].set_title("Zhiwei model", fontsize=fontsize*1.3)
axes[1].set_title("ZIG model", fontsize=fontsize*1.3)
axes[2].set_title("Poisson model", fontsize=fontsize*1.3)
axes[3].set_title("Real data  (averaged over 20 repeats)", fontsize=fontsize*1.3)
axes[3].set_xlabel("neurons", fontsize=fontsize*1.2)
axes[3].set_ylabel(r"$\frac{resp}{resp(MEI)}$", fontsize=fontsize*1.2)

axes[0].set(ylim=[0, 1.5])

# axes[3].legend(bbox_to_anchor=(0.15, 1., 0, 0), frameon=True, fontsize=fontsize*.8)
axes[3].legend(bbox_to_anchor=(0.1, 1.1, 0, 0), frameon=True, fontsize=fontsize*.8)

sns.despine(trim=True)
# fig.savefig("Zhiwei_Model_Comparison" + ".png", bbox_inches="tight", transparent=False)

In [None]:
fig, ax = plt.subplots()
ax.scatter(y_zhiwei[0], y_poisson[0])
ax.scatter(y_zhiwei[0], y_zig[0])
ax.scatter(y_zhiwei[0], y_real[0])
ax.plot([ax.get_xlim()[0], ax.get_xlim()[1]], [ax.get_ylim()[0], ax.get_ylim()[1]])

#### Rank correlation

In [None]:
zig_zhiwei, poisson_zhiwei, zig_real, poisson_real, zhiwei_real = [], [], [], [], []
for i in range(e.shape[-1]):
    zig_zhiwei_, p = spearmanr(e[:, i], zig_means[:, keep_idx][:, i], axis=0)
    poisson_zhiwei_, p = spearmanr(e[:, i], poisson_means[:, keep_idx][:, i], axis=0)
    zig_real_, p = spearmanr(real_resp_means[:, keep_idx][:, i], zig_means[:, keep_idx][:, i], axis=0)
    poisson_real_, p = spearmanr(real_resp_means[:, keep_idx][:, i], poisson_means[:, keep_idx][:, i], axis=0)
    zhiwei_real_, p = spearmanr(real_resp_means[:, keep_idx][:, i], e[:, i], axis=0)
    
    zig_zhiwei.append(zig_zhiwei_)
    poisson_zhiwei.append(poisson_zhiwei_)
    zig_real.append(zig_real_)
    poisson_real.append(poisson_real_)
    zhiwei_real.append(zhiwei_real_)

print("ZIG     |  Zhiwei: {:.5f}".format(np.mean(zig_zhiwei)))
print("Poisson |  Zhiwei: {:.5f}".format(np.mean(poisson_zhiwei)))
print("ZIG     |  Real:   {:.5f}".format(np.mean(zig_real)))
print("Poisson |  Real:   {:.5f}".format(np.mean(poisson_real)))
print("Zhiwei  |  Real:   {:.5f}".format(np.mean(zhiwei_real)))

#### Full Correlation

In [None]:
zig_zhiwei = np.mean(corr(e, zig_means[:, keep_idx], axis=0))
poisson_zhiwei = np.mean(corr(e, poisson_means[:, keep_idx], axis=0))
zig_real = np.mean(corr(real_resp_means[:, keep_idx], zig_means[:, keep_idx], axis=0))
poisson_real = np.mean(corr(real_resp_means[:, keep_idx], poisson_means[:, keep_idx], axis=0))
zhiwei_real = np.mean(corr(real_resp_means[:, keep_idx], e, axis=0))

print("ZIG     |  Zhiwei: {:.5f}".format(np.mean(zig_zhiwei)))
print("Poisson |  Zhiwei: {:.5f}".format(np.mean(poisson_zhiwei)))
print("ZIG     |  Real:   {:.5f}".format(np.mean(zig_real)))
print("Poisson |  Real:   {:.5f}".format(np.mean(poisson_real)))
print("Zhiwei  |  Real:   {:.5f}".format(np.mean(zhiwei_real)))

___

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(9, 3), dpi=150)
color = "navy"
ec="white"
fontsize = 15

axes[0].scatter(real_resp_means.flatten(), real_resp_vars.flatten(), color=color, ec=ec)
axes[0].set_xlabel("Means", fontsize=fontsize)
axes[0].set_ylabel("Variances", fontsize=fontsize)
axes[0].set_title("Real Data", fontsize=fontsize*1.3)

axes[1].scatter(zig_means.flatten(), zig_variances.flatten(), color=color, ec=ec)
axes[1].set_title("ZIG", fontsize=fontsize*1.3)
axes[1].plot([axes[1].get_xlim()[0], axes[1].get_xlim()[1]], [axes[1].get_xlim()[0], axes[1].get_xlim()[1]], ls="--", color="grey", label="Poisson")
# axes[1].set(xlim=[0, 15], ylim=[0, 400])
axes[1].legend(frameon=False, bbox_to_anchor=[.4,.8,0,0])
sns.despine(trim=True)

# fig.savefig("mean_variance_comparison" + ".png", bbox_inches="tight", transparent=False)

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(9, 3), dpi=150)
color = "navy"
ec="white"
fontsize = 15

axes[0].plot(real_resp_means, real_resp_vars, marker="", lw=1)
axes[0].set_xlabel("Means", fontsize=fontsize)
axes[0].set_ylabel("Variances", fontsize=fontsize)
axes[0].set_title("Real Data", fontsize=fontsize*1.3)

axes[1].plot(zig_means, zig_variances, marker="", lw=1)
axes[1].set_title("ZIG", fontsize=fontsize*1.3)
axes[1].plot([axes[1].get_xlim()[0], axes[1].get_xlim()[1]], [axes[1].get_xlim()[0], axes[1].get_xlim()[1]], ls="--", color="grey", label="Poisson")
# axes[1].set(xlim=[0, 15], ylim=[0, 400])
axes[1].legend(frameon=False, bbox_to_anchor=[.4,.8,0,0])
sns.despine(trim=True)

# fig.savefig("mean_variance_comparison" + ".png", bbox_inches="tight", transparent=False)