# Improved ResNet

In [3]:
import numpy as np
import torch
from torch import nn
from torch.utils.data import TensorDataset, DataLoader
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_squared_error, explained_variance_score
from scipy.stats import pearsonr
from tqdm.auto import tqdm
import torch.nn.functional as F

from utils import load_it_data, visualize_img
import matplotlib.pyplot as plt
import numpy as np
import gdown

from sklearn.linear_model import LinearRegression
from sklearn.linear_model import Ridge
from sklearn.metrics import mean_squared_error,explained_variance_score
from sklearn.decomposition import PCA

import torch
import torchvision.models as models
from torchvision.models import ResNet50_Weights
from utils import load_it_data,visualize_img

path_to_data = '' ## Insert the folder where the data is, if you download in the same folder as this notebook then leave it blank

# stim: picture / object: object name + idx / spikes: rate of each spike per stim

stimulus_train, stimulus_val, stimulus_test, objects_train, objects_val, objects_test, spikes_train, spikes_val = load_it_data(path_to_data)


In [4]:
seed = 35
torch.manual_seed(seed)
np.random.seed(seed)

stimulus_train = torch.from_numpy(stimulus_train)
stimulus_val = torch.from_numpy(stimulus_val)

device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
hooked_layer_names = ['layer3']

display(device)

resnet50 = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1).to(device)
resnet50_random = models.resnet50(weights=None).to(device)
    
resnet50.eval() # put in inference mode. e.g. disable dropout and batch normaliation
resnet50_random.eval();

# PRE_TRAINED RESNET

module_to_layername = {id(module): name for name, module in resnet50.named_modules()}
train_pca_models = {}

def record_hook(module, input, output:torch.Tensor): # called once per entire batch
    output_flattened = output.view(output.shape[0],-1).detach().to('cpu')
    layer_name = module_to_layername[id(module)]
    if train_mode:
        pca_model = PCA(n_components=0.7).fit(output_flattened)
        train_pca_models[layer_name] = pca_model
    else:
        pca_model = train_pca_models[layer_name]
    activations_pca = pca_model.transform(output_flattened)
    activations[layer_name] = activations_pca

def shape_hook(module, input, output:torch.Tensor): 
    output_flattened = output.view(output.shape[0],-1)
    layer_name = module_to_layername[id(module)]
    print(f"Output shape of {layer_name}:\t {output_flattened.shape}")

for layer_name in hooked_layer_names:
    layer = dict(resnet50.named_children())[layer_name]
    layer.register_forward_hook(record_hook)
    layer.register_forward_hook(shape_hook)

# TRAIN
train_mode = True # to guide the hook for pre-fit PCA or not
print('TRAIN')
activations = {}
with torch.no_grad(): # don't compute gradiants for speeding things up
    resnet50(stimulus_train.to(device))

pca_activations_train = list(activations.values())
# np.savez('pca_activations_train.npz',conv1=pca_activations_train[0],layer1=pca_activations_train[1],layer2=pca_activations_train[2],layer3=pca_activations_train[3],layer4=pca_activations_train[4],avg_pool=pca_activations_train[5])
# np.savez('pca_activations_train_80p.npz', layer3=pca_activations_train[0])
np.savez('pca_activations_train_80p.npz', layer3=pca_activations_train[0])

# VALIDATION
train_mode = False # use prefit PCA in train
print('\nVALIDATION')
activations = {}
with torch.no_grad(): # don't compute gradiants for speeding things up
    resnet50(stimulus_val.to(device))

pca_activations_val = list(activations.values())
# np.savez('pca_activations_val_80p.npz',conv1=pca_activations_val[0],layer1=pca_activations_val[1],layer2=pca_activations_val[2],layer3=pca_activations_val[3],layer4=pca_activations_val[4],avg_pool=pca_activations_val[5])
np.savez('pca_activations_val_80p.npz', layer3=pca_activations_val[0])

# save PCA model
np.savez('pca_model_80p.npz',layer3=train_pca_models['layer3'])

device(type='cuda', index=1)

TRAIN
Output shape of layer3:	 torch.Size([2592, 200704])

VALIDATION
Output shape of layer3:	 torch.Size([288, 200704])


In [None]:
seed = 35
torch.manual_seed(seed)
np.random.seed(seed)

# # 1000 PCA components from layer3 only load
# pca_activations_train = np.load('pca_activations_train.npz')
# pca_activations_val   = np.load('pca_activations_val.npz')

# X_train = pca_activations_train['layer3']
# X_val   = pca_activations_val['layer3']
# y_train = spikes_train
# y_val   = spikes_val

# 80% PCA from layer3 only load
pca_activations_train = np.load('pca_activations_train_80p.npz')
pca_activations_val   = np.load('pca_activations_val_80p.npz')

X_train = pca_activations_train['layer3']
X_val   = pca_activations_val['layer3']
y_train = spikes_train
y_val   = spikes_val

# # 95% PCA from layer3 only load
# pca_activations_train = np.load('pca_activations_train_95p.npz')
# pca_activations_val   = np.load('pca_activations_val_95p.npz')

# X_train = pca_activations_train['layer3']
# X_val   = pca_activations_val['layer3']
# y_train = spikes_train
# y_val   = spikes_val

scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_val   = scaler.transform(X_val)

X_train_t = torch.tensor(X_train, dtype=torch.float32).to(device)
y_train_t = torch.tensor(y_train, dtype=torch.float32).to(device)
X_val_t   = torch.tensor(X_val,   dtype=torch.float32).to(device)
y_val_t   = torch.tensor(y_val,   dtype=torch.float32).to(device)

batch_size = 64
train_loader = DataLoader(TensorDataset(X_train_t, y_train_t),
                          batch_size=batch_size, shuffle=True)
val_loader   = DataLoader(TensorDataset(X_val_t, y_val_t),
                          batch_size=batch_size, shuffle=False)

class MLPRegressor(nn.Module):
    def __init__(self,
                 in_dim:  int,
                 out_dim: int,
                 hidden=(512, 256),
                 dropout=0.2):
        super().__init__()
        layers = []
        last = in_dim
        for h in hidden:
            layers += [nn.Linear(last, h), nn.ReLU(), nn.Dropout(dropout)]
            last = h
        layers.append(nn.Linear(last, out_dim))
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)
    
class ResidualBlock(nn.Module):
    def __init__(self, dim, dropout=0.2):
        super().__init__()
        self.fc1 = nn.Linear(dim, dim)
        self.fc2 = nn.Linear(dim, dim)
        self.drop = nn.Dropout(dropout)

    def forward(self, x):
        h = F.relu(self.fc1(x))
        h = self.drop(F.relu(self.fc2(h)))
        return x + h

        
def make_mlp(style: str, d_in: int, d_out: int):
    if style == "funnel":
        hidden = (4*d_in, 2*d_in, d_in)
    elif style == "hourglass":
        hidden = (d_in//2, d_in, d_in//2)
    elif style == "residual":
        first = 2*d_in
        blocks = 3
        return nn.Sequential(
            nn.Linear(d_in, first), nn.ReLU(),
            *[ResidualBlock(first) for _ in range(blocks)],
            nn.Linear(first, d_out)
        )
    layers = []
    last = d_in
    for h in hidden:
        layers += [nn.Linear(last, h), nn.ReLU(), nn.Dropout(0.3)]
        last = h
    layers.append(nn.Linear(last, d_out))
    return nn.Sequential(*layers)

model = MLPRegressor(in_dim=X_train.shape[1],
                     out_dim=y_train.shape[1]).to(device)

model = make_mlp("funnel", X_train.shape[1], y_train.shape[1]).to(device)

criterion = nn.MSELoss()
# optimizer = torch.optim.SGD(model.parameters(), lr=1e+1, weight_decay=1e-4)
optimizer = torch.optim.Adam(model.parameters(), lr=9.3e-5, weight_decay=9e-5, amsgrad=True, eps=1e-07)

# training
best_val_mse = float('inf')
patience      = 200
patience_ctr  = 0
num_epochs    = 500

for epoch in range(1, num_epochs + 1):
    model.train()
    running_loss = 0.0
    for xb, yb in train_loader:
        optimizer.zero_grad()
        pred = model(xb)
        loss = criterion(pred, yb)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * xb.size(0)
    train_mse = running_loss / len(train_loader.dataset)

    model.eval()
    val_running_loss = 0.0
    with torch.no_grad():
        for xb, yb in val_loader:
            pred = model(xb)
            val_running_loss += criterion(pred, yb).item() * xb.size(0)
    val_mse = val_running_loss / len(val_loader.dataset)

    print(f"Epoch {epoch:03d}: train MSE={train_mse:.5f} | val MSE={val_mse:.5f} | LR={optimizer.param_groups[0]['lr']:.5e}")

    if epoch > 200:
        for param_group in optimizer.param_groups:
            param_group['lr'] = 5e-5
    if epoch > 300:
        for param_group in optimizer.param_groups:
            param_group['lr'] = 1e-5
    if epoch > 400:
        for param_group in optimizer.param_groups:
            param_group['lr'] = 5e-6

    if val_mse < best_val_mse - 1e-9:
        best_val_mse = val_mse
        torch.save(model.state_dict(), "best_mlp.pt")
        patience_ctr = 0
    else:
        patience_ctr += 1
        if patience_ctr >= patience:
            print("Early stopping triggered.")
            break

# eval and save
model.load_state_dict(torch.load("best_mlp.pt"))
model.eval()
with torch.no_grad():
    y_pred_val = model(X_val_t).cpu().numpy()

corr = pearsonr(y_val.flatten(), y_pred_val.flatten())[0]
ev   = explained_variance_score(y_val, y_pred_val, multioutput='raw_values')
mse  = mean_squared_error(y_val, y_pred_val, multioutput='raw_values')

print("\n=== Validation metrics ===")
print(f"Pearson r (all spikes flattened): {corr:.4f}")
print(f"Explained variance (mean over neurons): {ev.mean():.4f}")
print(f"MSE (mean over neurons): {mse.mean():.4f}")


Epoch 001: train MSE=0.13722 | val MSE=0.14053 | LR=9.30000e-05
Epoch 002: train MSE=0.13465 | val MSE=0.13856 | LR=9.30000e-05
Epoch 003: train MSE=0.12555 | val MSE=0.11646 | LR=9.30000e-05
Epoch 004: train MSE=0.09541 | val MSE=0.10478 | LR=9.30000e-05
Epoch 005: train MSE=0.08262 | val MSE=0.09796 | LR=9.30000e-05
Epoch 006: train MSE=0.06967 | val MSE=0.09234 | LR=9.30000e-05
Epoch 007: train MSE=0.06070 | val MSE=0.08983 | LR=9.30000e-05
Epoch 008: train MSE=0.05370 | val MSE=0.08733 | LR=9.30000e-05
Epoch 009: train MSE=0.04858 | val MSE=0.08537 | LR=9.30000e-05
Epoch 010: train MSE=0.04447 | val MSE=0.08399 | LR=9.30000e-05
Epoch 011: train MSE=0.04137 | val MSE=0.08331 | LR=9.30000e-05
Epoch 012: train MSE=0.03856 | val MSE=0.08279 | LR=9.30000e-05
Epoch 013: train MSE=0.03628 | val MSE=0.08216 | LR=9.30000e-05
Epoch 014: train MSE=0.03447 | val MSE=0.08230 | LR=9.30000e-05
Epoch 015: train MSE=0.03301 | val MSE=0.08108 | LR=9.30000e-05
Epoch 016: train MSE=0.03167 | val MSE=0

In [15]:
# load best model
model.load_state_dict(torch.load("best_mlp.pt"))
model.eval()
with torch.no_grad():
    y_pred_val = model(X_val_t).cpu().numpy()

corr = pearsonr(y_val.flatten(), y_pred_val.flatten())[0]
ev   = explained_variance_score(y_val, y_pred_val, multioutput='raw_values')
mse  = mean_squared_error(y_val, y_pred_val, multioutput='raw_values')

print("\n=== Validation metrics ===")
print(f"Pearson r (all spikes flattened): {corr:.4f}")
print(f"Explained variance (mean over neurons): {ev.mean():.4f}")
print(f"MSE (mean over neurons): {mse.mean():.4f}")
print(f"Model size: {sum(p.numel() for p in model.parameters())/1e6:.2f} M")


=== Validation metrics ===
Pearson r (all spikes flattened): 0.7081
Explained variance (mean over neurons): 0.4281
MSE (mean over neurons): 0.0710
Model size: 60.12 M
