In [1]:
import os
import datajoint as dj
dj.config['database.host'] = os.environ['DJ_HOST']
dj.config['database.user'] = os.environ['DJ_USERNAME']
dj.config['database.password'] = os.environ['DJ_PASSWORD']
dj.config['enable_python_native_blobs'] = True
dj.config['display.limit'] = 200
        
name = 'mvi'
os.environ["DJ_SCHEMA_NAME"] = f"metrics_{name}"
dj.config["nnfabrik.schema_name"] = os.environ["DJ_SCHEMA_NAME"]

In [2]:
import re
import torch
import numpy as np
import pickle 
import json
import pandas as pd
pd.set_option('display.max_columns', 500)
pd.set_option('display.max_rows', 10)
import matplotlib as mpl
mpl.rcParams["figure.facecolor"] = 'w'
mpl.rcParams["axes.facecolor"] = 'w'
mpl.rcParams["savefig.facecolor"] = 'w'
mpl.rcParams["figure.dpi"] = 100
mpl.rcParams["figure.figsize"] = (3, 3)
import matplotlib.pyplot as plt
import seaborn as sns

from neuralmetrics.utils import extract_data_key
from neuralmetrics.models.direct import ZIG
from neuralmetrics.training.losses import ZIGLoss
from neuralmetrics.training.trainers import nn_zig_trainer
from neuralmetrics.datasets import static_loaders
from neuralmetrics.models.neuralnet import zig_se2d_fullgaussian2d
from neuralpredictors.measures import corr

from dataport.bcm.static import PreprocessedMouseData

random_seed = 27121992
device = 'cuda'

Connecting konstantin@134.76.19.44:3306


---

In [3]:
datasets =  [{'animal_id': 26614,
              'session': 1,
              'scan_idx': 16,
              'scan_purpose': 'imagenet'},
             {'animal_id': 26614,
              'session': 2,
              'scan_idx': 17,
              'scan_purpose': 'dei_control_pair'},
             {'animal_id': 26726,
              'session': 6,
              'scan_idx': 11,
              'scan_purpose': 'imagenet'},
             {'animal_id': 26726,
              'session': 7,
              'scan_idx': 13,
              'scan_purpose': 'dei_control_pair'},
             {'animal_id': 26942,
              'session': 1,
              'scan_idx': 11,
              'scan_purpose': 'imagenet'},
             {'animal_id': 26942,
              'session': 2,
              'scan_idx': 8,
              'scan_purpose': 'dei_control_pair'},
             {'animal_id': 27468,
              'session': 3,
              'scan_idx': 12,
              'scan_purpose': 'imagenet'},
             {'animal_id': 27468,
              'session': 4,
              'scan_idx': 7,
              'scan_purpose': 'dei_control_pair'}]

## Imagenet Data

In [4]:
imagenet_key = datasets[0]
assert imagenet_key["scan_purpose"] == "imagenet"
paths = ["./data/static{}-{}-{}-GrayImageNet-7bed7f7379d99271be5d144e5e59a8e7.zip".format(imagenet_key["animal_id"], imagenet_key["session"], imagenet_key["scan_idx"])]
img_data_key = extract_data_key(paths[0])

dataset_config = {'paths': paths, 
                  'batch_size': 64, 
                  'seed': random_seed,
                  'loader_outputs': ["images", "responses"],
                  'normalize': True,
                  'exclude': ["images"]
                  }
    
img_dataloaders = static_loaders(**dataset_config)
img_dataset = img_dataloaders["test"][img_data_key].dataset

---

## Model

In [5]:
loc = np.exp(-10)

model_config = {
    "layers": 4,
    "hidden_channels": 64,
    "feature_reg_weight": 0.78,
    "init_mu_range": 0.55,
    "init_sigma": 0.4,
    'grid_mean_predictor': {'type': 'cortex',
                              'input_dimensions': 2,
                              'hidden_layers': 0,
                              'hidden_features': 0,
                              'final_tanh': False},
    'zero_thresholds': {img_data_key: loc},

    "input_kern": 15,
    "gamma_input": 1,
    "hidden_kern": 13,
    "depth_separable": True,
}


model = zig_se2d_fullgaussian2d(img_dataloaders, random_seed, **model_config)
model.to(device);



---

# Train Model

In [None]:
score, output, state_dict = nn_zig_trainer(model,
                                           img_dataloaders,
                                           random_seed, 
                                           loss_function="ZIGLoss", 
                                           stop_function="get_loss", 
                                           track_training=True, 
                                           maximize=False)
model.eval();
# torch.save(state_dict, "MVI_statedict_" + img_data_key)

In [7]:
norm_image = "_unnormimage" if "images" in dataset_config["exclude"] else "_normimage"

model.load_state_dict(torch.load("MVI_statedict_" + img_data_key + norm_image))
model.eval();
print(norm_image)

_unnormimage


___

### DEI data

In [8]:
idx = np.array([(dat["animal_id"] == imagenet_key["animal_id"]) & (dat["session"] != imagenet_key["session"]) & (dat["scan_idx"] != imagenet_key["scan_idx"]) for dat in datasets])
dei_key = np.array(datasets)[idx].item()

assert dei_key["scan_purpose"] == "dei_control_pair"
paths = ["./data/static{}-{}-{}-GrayImageNetDEIInfo-7bed7f7379d99271be5d144e5e59a8e7.zip".format(dei_key["animal_id"], dei_key["session"], dei_key["scan_idx"])]
dei_data_key = extract_data_key(paths[0])

dataset_config = {'paths': paths, 
                  'batch_size': 64, 
                  'seed': random_seed,
                  'return_test_sampler': True,
                  'tier': "test",
                  'loader_outputs': ["images", 'responses', 'trial_idx', "dei_unit_ids", "dei_src_unit_ids", "dei_mean_distances"],
                  'normalize': True,
                  'exclude': ["images", "trial_idx", "dei_unit_ids", "dei_src_unit_ids", "dei_mean_distances"]}

dei_dataloaders = static_loaders(**dataset_config)

dei_dataset = dei_dataloaders["test"][dei_data_key].dataset

Returning only test sampler with repeats...


In [9]:
images, responses, trial_idxs, dei_unit_ids, dei_src_unit_ids, dei_mean_distances = [], [], [], [], [], []
for image, response, trial_idx, dei_unit_id, dei_src_unit_id, dei_mean_distance in dei_dataloaders["test"][dei_data_key]:
    if (len(response) == 20) & (torch.unique(dei_mean_distance <= 10)):
        images.append(image)
        responses.append(response)
        trial_idxs.append(trial_idx)
        dei_unit_ids.append(dei_unit_id)
        dei_src_unit_ids.append(dei_src_unit_id)
        dei_mean_distances.append(dei_mean_distance)
images = torch.stack(images)
responses = torch.stack(responses)
trial_idxs = torch.stack(trial_idxs).cpu().data.numpy()
dei_unit_ids = torch.stack(dei_unit_ids).cpu().data.numpy()
dei_src_unit_ids = torch.stack(dei_src_unit_ids).cpu().data.numpy()
dei_mean_distances = torch.stack(dei_mean_distances).cpu().data.numpy()

In [10]:
# Get possible unit ids (in the source-dataset frame)
possible_src_unit_ids = np.unique(dei_src_unit_ids, axis=1).squeeze()

# Sort according to mean distances (increasing)
src_sort_idx = np.argsort(np.unique(dei_mean_distances, axis=1).squeeze())
possible_src_unit_ids = possible_src_unit_ids[src_sort_idx]

# Remove duplicates (from several DEIs/MEI)
_, idx = np.unique(possible_src_unit_ids, return_index=True)
possible_src_unit_ids = possible_src_unit_ids[np.sort(idx)]

In [11]:
means = np.full((3, len(possible_src_unit_ids)), np.nan)
variances = np.full((3, len(possible_src_unit_ids)), np.nan)
real_resp_means = np.full((3, len(possible_src_unit_ids)), np.nan)
real_resp_vars = np.full((3, len(possible_src_unit_ids)), np.nan)
for i, possible_src_unit_id in enumerate(possible_src_unit_ids):
    image_idx = np.unique(np.where(dei_src_unit_ids == possible_src_unit_id)[0])
    
    # skip missing data
    if len(image_idx) != 3:
        continue

    dei_neuron_id = np.unique(dei_unit_ids[image_idx]).item()
    src_neuron_id = np.unique(dei_src_unit_ids[image_idx]).item()
    src_neuron_idx = np.where(img_dataset.neurons.unit_ids == src_neuron_id)[0].item()
    dei_neuron_idx = np.where(dei_dataset.neurons.unit_ids == dei_neuron_id)[0].item()

    img = torch.unique(images[image_idx], dim=1).squeeze(1)
    
    # TODO: Keep this line?
#     img = torch.stack([((im - im.mean()) / (im.std())) for im in img.squeeze()])[:, None]

    
    means_ = model.predict_mean(img, data_key=img_data_key).squeeze().cpu().data.numpy()
    variances_ = model.predict_variance(img, data_key=img_data_key).squeeze().cpu().data.numpy()

    means[:, i] = means_[:, src_neuron_idx]
    variances[:, i] = variances_[:, src_neuron_idx]
    
    real_resp_means[:, i] = np.mean(responses[image_idx].cpu().data.numpy(), axis=1)[:, dei_neuron_idx]
    real_resp_vars[:, i] = np.var(responses[image_idx].cpu().data.numpy(), axis=1)[:, dei_neuron_idx]

### Compare Zhiwei Model with Konstantin model

In [None]:
with open(r"group233_mei_dei_resps.pkl", "rb") as input_file:
    e = pickle.load(input_file).T

In [None]:
fig, axes = plt.subplots(3, 1, figsize=(15, 11), dpi=100, sharex=True)
fontsize = 15


x = np.arange(e.shape[1])
y_zhiwei = e / e[0, :]
y_konstantin = means / means[0, :]
y_konstantin = y_konstantin[:, ~np.isnan(y_konstantin[0, :])]
y_real = real_resp_means / real_resp_means[0, :]
y_real = y_real[:, ~np.isnan(y_real[0, :])]

# Zhiwei
for i in range(3):
    axes[0].plot(x, y_zhiwei[i,:], ls="", marker="x")

# Konstantin
for i in range(3):
    axes[1].plot(x, y_konstantin[i,:], ls="", marker="x")
    
# Real data
for i, label in enumerate(["MEI", "DEI1", "DEI2"]):
    axes[2].plot(x, y_real[i,:], ls="", marker="x", label=label)
    
    
axes[0].set_title("Zhiwei model", fontsize=fontsize*1.3)
axes[1].set_title("Konstantin model", fontsize=fontsize*1.3)
axes[2].set_title("Real data (averaged over 20 repeats)", fontsize=fontsize*1.3)
axes[2].set_xlabel("neurons", fontsize=fontsize)
axes[2].set_ylabel(r"$\frac{resp}{resp(MEI)}$", fontsize=fontsize)

# axes[1].set(ylim=[0, 6])

axes[2].legend(bbox_to_anchor=(0.15, 1., 0, 0), frameon=False, fontsize=fontsize*.8)
sns.despine(trim=True)
# fig.savefig("Zhiwei_Model_Comparison" + ".png", bbox_inches="tight", transparent=False)

In [None]:
from scipy.stats import spearmanr, pearsonr

In [None]:
c_ks, c_zs = [], []
for i in range(y_real.shape[-1]):
    c_k, p = spearmanr(y_real[:, i], y_konstantin[:, i], axis=0)
    c_z, p = spearmanr(y_real[:, i], y_zhiwei[:, i], axis=0)
    c_ks.append(c_k)
    c_zs.append(c_z)

In [None]:
np.mean(c_ks), np.mean(c_zs)

In [None]:
co_k = np.mean(corr(real_resp_means[:, ~np.isnan(real_resp_means[0, :])], means[:, ~np.isnan(real_resp_means[0, :])], axis=0))
co_z = np.mean(corr(real_resp_means[:, ~np.isnan(real_resp_means[0, :])], e, axis=0))

In [None]:
co_k, co_z