In [1]:
%cd ../code

import torch
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision import transforms
from dorsalnet import DorsalNet

DEVICE = 'cuda:0'
DTYPE = torch.float

model = DorsalNet(False, 32).eval().to(DEVICE).to(DTYPE)
model.load_state_dict(torch.load('/home/matthew/Data/DorsalNet_FC/base_models/DorsalNet/pretrained.pth'))

/home/matthew/Code/DorsalNet_FC/code


<All keys matched successfully>

In [2]:
import torch
import numpy as np
from tqdm.notebook import tqdm
from functools import partial
from collections import defaultdict

def iterate_children(child, parent_name='model', depth=1):
    if depth > 1:
        children_list = []
        for name, grandchild in child.named_children():
            children_list += iterate_children(grandchild, parent_name+'.'+name, depth-1)
        return children_list
    else:
        return {(parent_name+'.'+name, module) for name, module in child.named_children()}

def store_activations(activations_dict, layer_name, module, input, output):
    activations_dict[layer_name] = output

def hook_model(model, depth):
    model.activations = defaultdict(list)
    for layer_name, child in iterate_children(model, depth=depth):
        child.register_forward_hook(partial(store_activations, model.activations, layer_name))
    return model

def choose_downsampling(activations, max_fs):
    if activations.ndim == 5:
        activations = activations[0:1]
        test_range = activations.shape[-1]
        numels = np.zeros((test_range+1, test_range))
        pbar = tqdm(range(sum(range(test_range+1))))
        for k in range(1,test_range+1):
            for s in range(1,k+1):
                pbar.update(1)
                pbar.set_postfix_str(f"testing size {k}, stride {s}")
                n = (activations.shape[-1] - k) / s
                if n != int(n):
                    continue
                else:
                    pooled = torch.nn.functional.max_pool3d(activations, kernel_size=(2,k,k), stride=s)
                    if pooled.shape[-1] > 1 and pooled.numel() <= max_fs:
                        numels[k,s] = pooled.numel()
                    else:
                        continue
        best_k, best_s = np.unravel_index(np.argmax(numels, axis=None), numels.shape)
        if (best_k, best_s) == (0,0):
            return None
        else:
            return torch.nn.MaxPool3d(kernel_size=(2, best_k, best_k), stride=best_s)
    else:
        return None

In [3]:
import torchvision

interpolate_frames = torchvision.transforms.Compose([
    torchvision.ops.Permute([1,2,3,0]),
    torchvision.transforms.Resize([112,32]),
    torchvision.ops.Permute([0,3,1,2]),
])

In [4]:
MAX_FS = 1500
    
model = hook_model(model, 1)
model(torch.randn((1, 3, 32, 112, 112)).to(DEVICE).to(DTYPE))

layer_downsampling_fns = {}
for layer_name, layer_activations in model.activations.items():
    print('**************')
    print(layer_name)
    print('old_shape:', layer_activations.flatten().shape)
    layer_downsampling_fn = choose_downsampling(layer_activations, MAX_FS)
    layer_downsampling_fns[layer_name] = layer_downsampling_fn
    if layer_downsampling_fn is not None:
        layer_activations = layer_downsampling_fns[layer_name](layer_activations)
    print('new_shape:', layer_activations.flatten().shape)

**************
model.conv1
old_shape: torch.Size([6422528])


  0%|          | 0/1596 [00:00<?, ?it/s]

new_shape: torch.Size([1280])
**************
model.s1
old_shape: torch.Size([1605632])


  0%|          | 0/406 [00:00<?, ?it/s]

new_shape: torch.Size([1280])
**************
model.res0
old_shape: torch.Size([802816])


  0%|          | 0/406 [00:00<?, ?it/s]

new_shape: torch.Size([1440])
**************
model.res1
old_shape: torch.Size([802816])


  0%|          | 0/406 [00:00<?, ?it/s]

new_shape: torch.Size([1440])
**************
model.res2
old_shape: torch.Size([802816])


  0%|          | 0/406 [00:00<?, ?it/s]

new_shape: torch.Size([1440])
**************
model.res3
old_shape: torch.Size([802816])


  0%|          | 0/406 [00:00<?, ?it/s]

new_shape: torch.Size([1440])
**************
model.concat
old_shape: torch.Size([2408448])


  0%|          | 0/406 [00:00<?, ?it/s]

new_shape: torch.Size([1152])
**************
model.dropout
old_shape: torch.Size([802816])


  0%|          | 0/406 [00:00<?, ?it/s]

new_shape: torch.Size([1440])


In [5]:
import torch
import dill as pickle
import os
import numpy as np
from torchvision import datasets, transforms
from torchvision.models import inception_v3 as model_init
from torchvision.models import Inception_V3_Weights as model_weights
from collections import defaultdict
from tqdm import tqdm

model_name='dorsalnet'
DTYPE=torch.float32
iter_mode = 'children'
iter_depth = 1

save_dir = f"/home/matthew/remote_mounts/pomcloud0/students/matthew/projects/VWAM/DNNs/{model_name}/"

batch_sizes = {
    'LHimages': 1,
    'NaturalMovies': 30,
    'vedb_ver01': 50,
    'BiomotionPilot06': 48,
}

In [6]:
# preprocess = torchvision.transforms.Compose([
#     torchvision.transforms.Resize(112),
#     torchvision.transforms.ToTensor(),
#     # tf.Normalize(mean=[0.43216, 0.394666, 0.37645], std=[0.22803, 0.22145, 0.216989]),
# ])

# for experiment in ['LHimages', 'NaturalMovies', 'vedb_ver01'][2:]:
#     print('****', experiment, '****')
#     images_dir = f'/hdd01/stimuli/{experiment}/ImageFolder'
#     for split in ['trn', 'val']:
#         dataset = datasets.ImageFolder(images_dir+'/'+split, transform=preprocess)
#         dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_sizes[experiment], shuffle=False)
#         activations_dict = defaultdict(list)
#         for images, labels in tqdm(dataloader):
#             images = interpolate_frames(images).unsqueeze(0)
#             model(images.to(DEVICE));
#             layer_names = list(model.activations.keys())
#             for layer_name in layer_names:
#                 layer_activations = model.activations[layer_name].detach().cpu()
#                 del model.activations[layer_name]
#                 layer_downsampling_fn = layer_downsampling_fns[layer_name]
#                 if not isinstance(layer_downsampling_fn, type(None)):
#                     layer_activations = layer_downsampling_fn(layer_activations)
#                 # if experiment != 'LHimages':
#                 #     layer_activations = torch.mean(layer_activations, 0).unsqueeze(0)
#                 activations_dict[layer_name].append(layer_activations.numpy())
#         activations_dict = {name: np.concatenate(outputs, 0) for name, outputs in activations_dict.items()}
#         if not os.path.exists(f'{save_dir}/activations/{experiment}'):
#             os.makedirs(f'{save_dir}/activations/{experiment}')
#         np.savez(f'{save_dir}/activations/{experiment}/{split}_activations.npz', **activations_dict)
#         activations_concatenated = np.nan_to_num(np.concatenate([value.reshape(len(value), -1) for value in list(activations_dict.values())], 1).astype(np.float))
#         np.save(f'{save_dir}/activations/{experiment}/{split}_activations.npy', activations_concatenated)

## Model fitting

In [12]:
import vm_tools as vmt
import cortex as cx

experiments = ['NaturalMovies', 'vedb_ver01']

for subject_id in [f'S0{i}' for i in range(9)]:
    for experiment in experiments:
        if os.path.exists(f'/home/matthew/Data/DorsalNet_FC/fMRI_data/{subject_id}/{experiment}/'):
            fit_dir = f'/home/matthew/remote_mounts/pomcloud0/students/matthew/Projects/VWAM/regression_fits/{model_name}/{experiment}/{subject_id}'
            if os.path.exists(fit_dir+'/ridge.npz'):
                fit = np.load(fit_dir+'/ridge.npz')
            else:
                trn_a = np.load(f'/home/matthew/remote_mounts/pomcloud0/students/matthew/Projects/VWAM/DNNs/{model_name}/activations/{experiment}/trn_activations.npy')
                val_a = np.load(f'/home/matthew/remote_mounts/pomcloud0/students/matthew/Projects/VWAM/DNNs/{model_name}/activations/{experiment}/val_activations.npy')

                if experiment=='NaturalMovies' and subject_id=='S01':
                    trn_a = trn_a[:2400]
                    val_a = val_a[:180]

                if experiment != 'LHimages':
                    trn_a = vmt.utils.add_lags(trn_a)
                    val_a = vmt.utils.add_lags(val_a)
                
                if not os.path.exists(fit_dir):
                    os.makedirs(fit_dir)
                save_path = fit_dir+f"/ridge.npz"
                trn_brain = np.load(f'/home/matthew/Data/DorsalNet_FC/fMRI_data/{subject_id}/{experiment}/trn.npy')
                val_brain = np.load(f'/home/matthew/Data/DorsalNet_FC/fMRI_data/{subject_id}/{experiment}/val_rpts.npy').mean(0)
                fit = vmt.Regression.ridge_cv(trn_fs=trn_a, trn_data=trn_brain,
                                                    val_fs=val_a, val_data=val_brain,
                                                    alphas = list(np.logspace(0,20,20)), ## default range is much too low
                                                    select_by='individual_voxel_r2',
                                                    do_re_zscore_fs=False, do_re_zscore_data=False, is_verbose=False,
                                                    chunk_sz=100000,
                                                    )
                if experiment != 'LHimages':
                    fit['weights_lagged'] = fit['weights'].copy()
                    fit['last_two_lags_mean'] = np.nanmean([fit['weights_lagged'][len(fit['weights_lagged'])//3:-len(fit['weights_lagged'])//3], fit['weights_lagged'][-len(fit['weights_lagged'])//3:]], axis=0)
                    fit['weights'] = vmt.utils.avg_wts(fit['weights'].T, skipfirst=False).T
                np.savez(save_path, **fit)
            print('mean cc:', np.nanmean(fit['cc']))
            mask = np.load(f'/home/matthew/Data/DorsalNet_FC/fMRI_data/{subject_id}/{experiment}/mask.npy')
            cx.webshow(cx.Volume(fit['cc'], subject=subject_id, xfmname=experiment, mask=mask, vmin=0, vmax=1, cmap='afmhot'), title=f"{subject_id} {experiment} ccs", with_curvature=True)

mean cc: 0.27363303
Started server on port 24215


INFO:tornado.access:200 GET /mixer.html (127.0.0.1) 47.46ms
INFO:tornado.access:200 GET /mixer.html (127.0.0.1) 47.46ms
INFO:tornado.access:200 GET /resources/css/jquery-ui.min.css (127.0.0.1) 1.86ms
INFO:tornado.access:200 GET /resources/css/jquery-ui.min.css (127.0.0.1) 1.86ms
INFO:tornado.access:200 GET /resources/css/w2ui-1.4.2.min.css (127.0.0.1) 2.49ms
INFO:tornado.access:200 GET /resources/css/w2ui-1.4.2.min.css (127.0.0.1) 2.49ms
INFO:tornado.access:200 GET /resources/css/select2-4.0.3.min.css (127.0.0.1) 3.18ms
INFO:tornado.access:200 GET /resources/css/select2-4.0.3.min.css (127.0.0.1) 3.18ms
INFO:tornado.access:200 GET /resources/js/jquery-2.1.1.min.js (127.0.0.1) 7.92ms
INFO:tornado.access:200 GET /resources/js/jquery-2.1.1.min.js (127.0.0.1) 7.92ms
INFO:tornado.access:200 GET /resources/js/jquery-ui.min.js (127.0.0.1) 18.35ms
INFO:tornado.access:200 GET /resources/js/jquery-ui.min.js (127.0.0.1) 18.35ms
INFO:tornado.access:200 GET /resources/js/jquery.ddslick.min.js (127.0

INFO:tornado.access:200 GET /resources/js/facepick_worker.js (127.0.0.1) 1.52ms
INFO:tornado.access:200 GET /resources/js/facepick_worker.js (127.0.0.1) 1.52ms
INFO:tornado.access:200 GET /ctm/S01/S01_[inflated]_mg2_9_v3.svg (127.0.0.1) 16.95ms
INFO:tornado.access:200 GET /ctm/S01/S01_[inflated]_mg2_9_v3.svg (127.0.0.1) 16.95ms


Stopping server
Stopping server
Stopping server


## Activation maximization

In [8]:
import skvideo
from skvideo import io
import torchvision

model = DorsalNet(False, 32).eval().to(DEVICE).to(DTYPE)
model.load_state_dict(torch.load('/home/matthew/Data/DorsalNet_FC/base_models/DorsalNet/pretrained.pth'))

preprocess = torchvision.transforms.Compose([
    torchvision.transforms.Resize(112),
    # torchvision.transforms.ToTensor(),
    # tf.Normalize(mean=[0.43216, 0.394666, 0.37645], std=[0.22803, 0.22145, 0.216989]),
])

dl = DataLoader(ImageFolder('/home/matthew/Data/DorsalNet_FC/stimuli/NaturalMovies/images/trn', transform=preprocess), batch_size=32, shuffle=False)

invariance_transforms = transforms.Compose([
    transforms.RandomCrop((512,512), padding=5),
    # transforms.GaussianBlur(31),
    # transforms.RandomRotation([-5,5]),
    # transforms.RandomResizedCrop((500,500), scale=(.95,1.05), ratio=(1,1,1)),
    transforms.RandomCrop((512,512), padding=3),
])

lr = 1e2

for dim in range(4):
    for loc in range(0,28,8):
        fspace = torch.randn((1,3,32,512,512), device=DEVICE, dtype=torch.complex64).requires_grad_(True)
        optimizer = torch.optim.Adam([fspace], lr=lr)
        iterator = tqdm(range(100))
        for i in iterator:
            loss = 0
            optimizer.zero_grad()
            frames = torch.abs(torch.fft.ifftn(fspace.squeeze())).to(DTYPE)
            outputs = model(preprocess(invariance_transforms(frames)).unsqueeze(0))[0]
            for _ in range(dim):
                outputs = outputs.sum(0)
            loss -= outputs[loc].sum()
            iterator.set_postfix({'frames loss': loss.item(), 'mean pixel value': frames.mean().item(), 'pixel std': frames.std().item()})
            loss.backward()
            torch.nn.utils.clip_grad_norm_(fspace, 1e-4)
            optimizer.step()
        frames = torch.abs(torch.fft.ifftn(fspace))
        skvideo.io.vwrite(f"test_{dim}_{loc}.mp4", (frames.squeeze().permute(1,0,2,3).detach().cpu().numpy()*255).astype(np.uint8), inputdict={'-r':'16'})

100%|██████████| 100/100 [00:04<00:00, 20.35it/s, frames loss=-3.13e+4, mean pixel value=0.086, pixel std=0.114] 
100%|██████████| 100/100 [00:04<00:00, 20.38it/s, frames loss=-2.57e+4, mean pixel value=0.0727, pixel std=0.124]
100%|██████████| 100/100 [00:04<00:00, 20.49it/s, frames loss=0, mean pixel value=0.000177, pixel std=9.23e-5]
100%|██████████| 100/100 [00:04<00:00, 20.52it/s, frames loss=-9.43e+4, mean pixel value=0.0729, pixel std=0.125]
100%|██████████| 100/100 [00:04<00:00, 20.44it/s, frames loss=-1.42e+4, mean pixel value=0.0419, pixel std=0.136]
100%|██████████| 100/100 [00:04<00:00, 20.38it/s, frames loss=-2.16e+4, mean pixel value=0.054, pixel std=0.132] 
100%|██████████| 100/100 [00:04<00:00, 20.43it/s, frames loss=-1.82e+4, mean pixel value=0.054, pixel std=0.133] 
100%|██████████| 100/100 [00:04<00:00, 20.34it/s, frames loss=-1.56e+4, mean pixel value=0.0546, pixel std=0.132]
100%|██████████| 100/100 [00:04<00:00, 20.44it/s, frames loss=-3.26e+4, mean pixel value=0.