### Load Model

In [1]:
import torch
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision import transforms as tf
import sys
from dorsalnet import DorsalNet, FC, interpolate_frames
sys.path.append('../code')
from VWAM.utils import SingleImageFolder, iterate_children, hook_model
from tqdm import tqdm
import os

DEVICE = 'cuda:0'
DTYPE = torch.bfloat16
MODEL_NAME = 'DorsalNet'
OPTIMIZER = 'Adam'

if MODEL_NAME == 'DorsalNet':
    model = DorsalNet(False, 32).eval().to(DEVICE).to(DTYPE)
    model.load_state_dict(torch.load('/home/matthew/Data/DorsalNet_FC/base_models/DorsalNet/pretrained.pth'))
elif MODEL_NAME == 'alexnet':
    model = torch.hub.load('pytorch/vision:v0.6.0', 'alexnet', pretrained=True).eval().to(DEVICE).to(DTYPE)
elif MODEL_NAME == 'inception_v3':
    model = torch.hub.load('pytorch/vision:v0.6.0', 'inception_v3', pretrained=True).eval().to(DEVICE).to(DTYPE)

  warn(f"Failed to load image Python extension: {e}")


<All keys matched successfully>

### Choose downsampling

In [2]:
MAX_FS = 5000
DEPTH = 1
input_size = (1, 3, 32, 112, 112)

import torch
import numpy as np
# from tqdm.notebook import tqdm

def choose_downsampling(activations, max_fs):
    num_channels = activations.shape[1]
    if activations.ndim == 4:
        max_output_dim = int((max_fs / num_channels)**(1/2))
        return torch.nn.AdaptiveMaxPool2d(max_output_dim)
    elif activations.ndim == 5:
        max_output_dim = int((max_fs / num_channels)**(1/3))
        return torch.nn.AdaptiveMaxPool3d(max_output_dim)

layers_dict = iterate_children(model, depth=DEPTH)
layers_dict = {layer_name: ds_function for layer_name, ds_function in layers_dict.items() if 'dropout' not in layer_name and 'concat' not in layer_name}
model = hook_model(model, layers_dict)
model(torch.randn(input_size).to(DEVICE).to(DTYPE))

layer_downsampling_fns = {}
for layer_name, layer_activations in model.activations.items():
    layer_activations = layer_activations
    print('**************')
    print(layer_name)
    print('old_shape:', layer_activations.shape)
    print('old # activations:', layer_activations.flatten().shape)
    layer_downsampling_fn = choose_downsampling(layer_activations, MAX_FS)
    layer_downsampling_fns[layer_name] = layer_downsampling_fn
    if layer_downsampling_fn is not None:
        layer_activations = layer_downsampling_fns[layer_name](layer_activations)
    print('new_shape:', layer_activations.shape)
    print('new # activations:', layer_activations.flatten().shape)

**************
model.conv1
old_shape: torch.Size([1, 64, 32, 56, 56])
old # activations: torch.Size([6422528])
new_shape: torch.Size([1, 64, 4, 4, 4])
new # activations: torch.Size([4096])
**************
model.s1
old_shape: torch.Size([1, 64, 32, 28, 28])
old # activations: torch.Size([1605632])
new_shape: torch.Size([1, 64, 4, 4, 4])
new # activations: torch.Size([4096])
**************
model.res0
old_shape: torch.Size([1, 32, 32, 28, 28])
old # activations: torch.Size([802816])
new_shape: torch.Size([1, 32, 5, 5, 5])
new # activations: torch.Size([4000])
**************
model.res1
old_shape: torch.Size([1, 32, 32, 28, 28])
old # activations: torch.Size([802816])
new_shape: torch.Size([1, 32, 5, 5, 5])
new # activations: torch.Size([4000])
**************
model.res2
old_shape: torch.Size([1, 32, 32, 28, 28])
old # activations: torch.Size([802816])
new_shape: torch.Size([1, 32, 5, 5, 5])
new # activations: torch.Size([4000])
**************
model.res3
old_shape: torch.Size([1, 32, 32, 28, 

### Initialize FC Layer

In [3]:
SUBJECT_ID = 'S00'
trn_brain = np.load(f'/home/matthew/Data/DorsalNet_FC/fMRI_data/{SUBJECT_ID}/NaturalMovies/trn.npy')
trn_brain = torch.tensor(np.nan_to_num(trn_brain), device=DEVICE)
n_voxels = trn_brain.shape[1]

val_brain = np.load(f'/home/matthew/Data/DorsalNet_FC/fMRI_data/{SUBJECT_ID}/NaturalMovies/val_rpts.npy')
val_brain = torch.tensor(np.nan_to_num(val_brain).mean(0), device=DEVICE)

fc = FC(n_voxels).to(DEVICE).to(DTYPE)
print(fc)

FC(
  (linear): LazyLinear(in_features=0, out_features=9853, bias=True)
)




### Train

In [4]:
from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
from torch.optim import Adam

def column_corr(A, B, dof=0):
    """Efficiently compute correlations between columns of two matrices
    
    Does NOT compute full correlation matrix btw `A` and `B`; returns a 
    vector of correlation coefficients. FKA ccMatrix."""
    zs = lambda x: (x-np.nanmean(x, axis=0))/np.nanstd(x, axis=0, ddof=dof)
    rTmp = np.nansum(zs(A)*zs(B), axis=0)
    n = A.shape[0]
    # make sure not to count nans
    nNaN = np.sum(np.logical_or(np.isnan(zs(A)), np.isnan(zs(B))), 0)
    n = n - nNaN
    r = rTmp/n
    return r


batch_sizes = {
    'NaturalMovies': 30,
    'vedb_ver01': 50,
}

preprocess = tf.Compose([
    tf.Resize(112),
    tf.ToTensor(),
])

image_augmentations = tf.Compose([
    tf.RandomCrop(112, padding=4),
    tf.RandomRotation(10),
    tf.RandomCrop(112, padding=3),
])

EXPERIMENT = 'NaturalMovies'

trn_dl = DataLoader(
    SingleImageFolder(f'/home/matthew/Data/DorsalNet_FC/stimuli/{EXPERIMENT}/images/trn', transform=preprocess),
    batch_size=batch_sizes[EXPERIMENT], 
    shuffle=False)

val_dl = DataLoader(
    SingleImageFolder(f'/home/matthew/Data/DorsalNet_FC/stimuli/{EXPERIMENT}/images/val', transform=preprocess),
    batch_size=batch_sizes[EXPERIMENT], 
    shuffle=False)

In [7]:
import wandb
wandb.login()

N_EPOCHS = 50
LR_INIT = 1e-1

run = wandb.init(
    # Set the project where this run will be logged
    project="DorsalNet_FC_Pilot",
    # Track hyperparameters and run metadata
    config={
        "model_name": MODEL_NAME
        "experiment": EXPERIMENT,
        "subject_id": SUBJECT_ID,
        "optimizer": OPTIMIZER,
        "epochs": N_EPOCHS,
        "learning_rate": LR_INIT,
})

torch.cuda.empty_cache()
if OPTIMIZER == 'SGD':
    optimizer = torch.optim.SGD(fc.parameters(), lr=LR_INIT)
elif OPTIMIZER == 'Adam':
    optimizer = torch.optim.Adam(fc.parameters(), lr=LR_INIT)

def train():
    pbar = tqdm(enumerate(trn_dl), total=len(trn_brain), desc=f"Epoch {epoch} Training")
    trn_epoch_losses = []
    for i, batch in pbar:
        optimizer.zero_grad()
        batch = interpolate_frames(batch, input_size[2])
        model.forward(image_augmentations(batch).unsqueeze(0).to(DTYPE).to(DEVICE))
        all_activations = []
        for layer_name, layer_activations in model.activations.items():
            layer_downsampling_fn = layer_downsampling_fns[layer_name]
            if layer_downsampling_fn is not None:
                layer_activations = layer_downsampling_fn(layer_activations)
            all_activations.append(layer_activations.mean(0).flatten())
            model.activations[layer_name] = 0
        fc_out = fc(torch.cat(all_activations).unsqueeze(0))
        batch_brain = (trn_brain[min(i+2, len(trn_brain)-1)] + trn_brain[min(i+3, len(trn_brain)-1)]) / 2
        loss = torch.square(fc_out[0]/1000 - batch_brain).sum().sqrt()
        loss.backward()
        optimizer.step()
        trn_epoch_losses.append(loss.item())
        pbar.set_postfix_str(f"Mean Epoch Loss: {torch.mean(torch.tensor(trn_epoch_losses)).item():.2f}")
    return trn_epoch_losses

def validate():
    with torch.no_grad():
        pbar = tqdm(enumerate(val_dl), total=len(val_brain), desc=f"Epoch {epoch} Validation")
        val_outputs = []
        val_epoch_losses = []
        for i, batch in pbar:
            batch = interpolate_frames(batch, input_size[2])
            model.forward(batch.unsqueeze(0).to(DTYPE).to(DEVICE))
            all_activations = []
            for layer_name, layer_activations in model.activations.items():
                layer_downsampling_fn = layer_downsampling_fns[layer_name]
                if layer_downsampling_fn is not None:
                    layer_activations = layer_downsampling_fn(layer_activations)
                all_activations.append(layer_activations.mean(0).flatten())
                model.activations[layer_name] = 0
            fc_out = fc(torch.cat(all_activations).unsqueeze(0))
            batch_brain = (val_brain[min(i+2, len(val_brain)-1)] + val_brain[min(i+3, len(val_brain)-1)]) / 2
            loss = torch.square(fc_out[0]/1000 - batch_brain).sum().sqrt()
            val_outputs.append(fc_out.cpu().float().numpy())
            val_epoch_losses.append(loss.item())
            pbar.set_postfix_str(f"Mean Epoch Loss: {torch.mean(torch.tensor(val_epoch_losses)).item():.2f}")
        ccs = column_corr(np.concatenate(val_outputs), val_brain.cpu().numpy())
        print(f"Mean Prediction Accuracy: {ccs.mean():.2f}")
    return val_epoch_losses, ccs
        
def log(epoch, trn_epoch_losses, val_epoch_losses, ccs):
    wandb.log({
        "epoch": epoch,
        "trn_loss": torch.mean(torch.tensor(trn_epoch_losses)).item(),
        "val_loss": torch.mean(torch.tensor(val_epoch_losses)).item(),
        "val_acc": ccs.mean(),
    })

save_dir = f'/home/matthew/Data/DorsalNet_FC/fits/{EXPERIMENT}/{SUBJECT_ID}'
os.makedirs(save_dir, exist_ok=True)



VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016668893599611087, max=1.0…

In [8]:
# epoch = -1
# validate()
for epoch in range(N_EPOCHS):
    trn_epoch_losses = train()
    val_epoch_losses, ccs = validate()
    log(epoch, trn_epoch_losses, val_epoch_losses, ccs)
torch.save(model.state_dict(), f"{save_dir}/{MODEL_NAME}_{OPTIMIZER}_{N_EPOCHS}_{LR_INIT}.pt")

Epoch 0 Training:  41%|████      | 1474/3600 [08:48<12:41,  2.79it/s, Mean Epoch Loss: 81.69]


KeyboardInterrupt: 