### Load Model

In [None]:
import torch
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision import transforms as tf
import sys
sys.path.append('../code')
from dorsalnet import DorsalNet, FC, interpolate_frames
from VWAM.utils import SingleImageFolder, iterate_children, hook_model

DEVICE = 'cuda:0'
DTYPE = torch.bfloat16

model = DorsalNet(False, 32).eval().to(DEVICE).to(DTYPE)
model.load_state_dict(torch.load('/home/matthew/Data/DorsalNet_FC/base_models/DorsalNet/pretrained.pth'))

### Choose downsampling

In [None]:
MAX_FS = 1500
DEPTH = 1
input_size = (1, 3, 32, 112, 112)

import torch
import numpy as np
from tqdm.notebook import tqdm

def choose_downsampling(activations, max_fs):
    if activations.ndim == 5:
        activations = activations[0:1]
        test_range = activations.shape[-1]
        numels = np.zeros((test_range+1, test_range))
        pbar = tqdm(range(sum(range(test_range+1))))
        for k in range(1,test_range+1):
            for s in range(1,k+1):
                pbar.update(1)
                pbar.set_postfix_str(f"testing size {k}, stride {s}")
                n = (activations.shape[-1] - k) / s
                if n != int(n):
                    continue
                else:
                    pooled = torch.nn.functional.max_pool3d(activations, kernel_size=(2,k,k), stride=s)
                    if pooled.shape[-1] > 1 and pooled.numel() <= max_fs:
                        numels[k,s] = pooled.numel()
                    else:
                        continue
        best_k, best_s = np.unravel_index(np.argmax(numels, axis=None), numels.shape)
        if (best_k, best_s) == (0,0):
            return None
        else:
            return torch.nn.MaxPool3d(kernel_size=(2, best_k, best_k), stride=best_s)
    else:
        return None

layers_dict = iterate_children(model, depth=DEPTH)
model = hook_model(model, layers_dict)
model(torch.randn(input_size).to(DEVICE).to(DTYPE))

layer_downsampling_fns = {}
for layer_name, layer_activations in model.activations.items():
    print('**************')
    print(layer_name)
    print('old_shape:', layer_activations.flatten().shape)
    layer_downsampling_fn = choose_downsampling(layer_activations, MAX_FS)
    layer_downsampling_fns[layer_name] = layer_downsampling_fn
    if layer_downsampling_fn is not None:
        layer_activations = layer_downsampling_fns[layer_name](layer_activations)
    print('new_shape:', layer_activations.flatten().shape)

### Initialize FC Layer

In [None]:
subject_id = 'S00'
trn_brain = np.load(f'/home/matthew/Data/DorsalNet_FC/fMRI_data/{subject_id}/NaturalMovies/trn.npy')
trn_brain = torch.tensor(np.nan_to_num(trn_brain), device=DEVICE)
n_voxels = trn_brain.shape[1]

val_brain = np.load(f'/home/matthew/Data/DorsalNet_FC/fMRI_data/{subject_id}/NaturalMovies/val_rpts.npy')
val_brain = torch.tensor(np.nan_to_num(val_brain).mean(0), device=DEVICE)

fc = FC(n_voxels).to(DEVICE).to(DTYPE)
print(fc)

### Train

In [None]:
from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
from torch.optim import Adam

def column_corr(A, B, dof=0):
    """Efficiently compute correlations between columns of two matrices
    
    Does NOT compute full correlation matrix btw `A` and `B`; returns a 
    vector of correlation coefficients. FKA ccMatrix."""
    zs = lambda x: (x-np.nanmean(x, axis=0))/np.nanstd(x, axis=0, ddof=dof)
    rTmp = np.nansum(zs(A)*zs(B), axis=0)
    n = A.shape[0]
    # make sure not to count nans
    nNaN = np.sum(np.logical_or(np.isnan(zs(A)), np.isnan(zs(B))), 0)
    n = n - nNaN
    r = rTmp/n
    return r

EXPERIMENT = 'NaturalMovies'

batch_sizes = {
    'NaturalMovies': 30,
    'vedb_ver01': 50,
}

preprocess = tf.Compose([
    tf.Resize(112),
    tf.ToTensor(),
])

trn_dl = DataLoader(
    SingleImageFolder(f'/home/matthew/Data/DorsalNet_FC/stimuli/{EXPERIMENT}/images/trn', transform=preprocess),
    batch_size=batch_sizes[EXPERIMENT], 
    shuffle=False)

val_dl = DataLoader(
    SingleImageFolder(f'/home/matthew/Data/DorsalNet_FC/stimuli/{EXPERIMENT}/images/val', transform=preprocess),
    batch_size=batch_sizes[EXPERIMENT], 
    shuffle=False)

In [None]:
import wandb
wandb.login()

N_EPOCHS = 50
LR_INIT = 0.1

run = wandb.init(
    # Set the project where this run will be logged
    project="DorsalNet_FC_Pilot",
    # Track hyperparameters and run metadata
    config={
        "learning_rate": LR_INIT,
        "epochs": N_EPOCHS,
        "experiment": EXPERIMENT,
})

torch.cuda.empty_cache()
optimizer = Adam(fc.parameters(), lr=LR_INIT)
for epoch in range(N_EPOCHS):
    ### Train
    pbar = tqdm(enumerate(trn_dl), total=len(trn_brain), desc=f"Epoch {epoch} Training")
    epoch_losses = []
    for i, batch in pbar:
        optimizer.zero_grad()
        batch = interpolate_frames(batch, input_size[2])
        model.forward(batch.unsqueeze(0).to(DTYPE).to(DEVICE))
        all_activations = []
        for layer_name, layer_activations in model.activations.items():
            layer_downsampling_fn = layer_downsampling_fns[layer_name]
            if layer_downsampling_fn is not None:
                layer_activations = layer_downsampling_fn(layer_activations)
            all_activations.append(layer_activations.mean(0).flatten())
            model.activations[layer_name] = 0
        fc_out = fc(torch.cat(all_activations).unsqueeze(0))
        batch_brain = (trn_brain[min(i+2, len(trn_brain)-1)] + trn_brain[min(i+3, len(trn_brain)-1)]) / 2
        loss = torch.square(fc_out[0]/1000 - batch_brain).sum().sqrt()
        loss.backward()
        optimizer.step()
        epoch_losses.append(loss.item())
        pbar.set_postfix_str(f"Mean Epoch Loss: {torch.mean(torch.tensor(epoch_losses)).item():.2f}")
    ### Evaluate
    with torch.no_grad():
        pbar = tqdm(enumerate(val_dl), total=len(val_brain), desc=f"Epoch {epoch} Evaluation")
        val_outputs = []
        for i, batch in pbar:
            batch = interpolate_frames(batch, input_size[2])
            model.forward(batch.unsqueeze(0).to(DTYPE).to(DEVICE))
            all_activations = []
            for layer_name, layer_activations in model.activations.items():
                layer_downsampling_fn = layer_downsampling_fns[layer_name]
                if layer_downsampling_fn is not None:
                    layer_activations = layer_downsampling_fn(layer_activations)
                all_activations.append(layer_activations.mean(0).flatten())
                model.activations[layer_name] = 0
            fc_out = fc(torch.cat(all_activations).unsqueeze(0))
            val_outputs.append(fc_out.cpu().float().numpy())
        ccs = column_corr(np.concatenate(val_outputs), val_brain.cpu().numpy())
        print(f"Mean Prediction Accuracy: {ccs.mean():.2f}")
    wandb.log({
        "epoch": epoch,
        "trn_loss": torch.mean(torch.tensor(epoch_losses)).item(),
        "val_acc": ccs.mean(),
    })