In [1]:
import torch
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision import transforms as tf
import sys
sys.path.append('../code')
from dorsalnet import DorsalNet, FC, interpolate_frames
from VWAM.utils import SingleImageFolder, choose_downsampling, iterate_children, hook_model
from tqdm import tqdm

DEVICE = 'cuda:0'
DTYPE = torch.bfloat16

model = DorsalNet(False, 32).eval().to(DEVICE).to(DTYPE)
model.load_state_dict(torch.load('/home/matthew/Data/DorsalNet_FC/base_models/DorsalNet/pretrained.pth'))

<All keys matched successfully>

### Choose downsampling

In [2]:
MAX_FS = 5000
DEPTH = 1
input_shape = (1, 3, 32, 112, 112)

import torch
import numpy as np

layers_dict = iterate_children(model, depth=DEPTH)
layers_dict = {k: v for k, v in layers_dict.items() if not 'dropout' in k}
model = hook_model(model, layers_dict)
model(torch.randn(input_shape).to(DEVICE).to(DTYPE))

layer_downsampling_fns = {}
for layer_name, layer_activations in model.activations.items():
    layer_activations = layer_activations
    print('**************')
    print(layer_name)
    print('old_shape:', layer_activations.shape)
    print('old # activations:', layer_activations.flatten().shape)
    layer_downsampling_fn = choose_downsampling(layer_activations, MAX_FS)
    layer_downsampling_fns[layer_name] = layer_downsampling_fn
    if layer_downsampling_fn is not None:
        layer_activations = layer_downsampling_fns[layer_name](layer_activations)
    print('new_shape:', layer_activations.shape)
    print('new # activations:', layer_activations.flatten().shape)

**************
model.conv1
old_shape: torch.Size([1, 64, 32, 56, 56])
old # activations: torch.Size([6422528])
new_shape: torch.Size([1, 64, 4, 4, 4])
new # activations: torch.Size([4096])
**************
model.s1
old_shape: torch.Size([1, 64, 32, 28, 28])
old # activations: torch.Size([1605632])
new_shape: torch.Size([1, 64, 4, 4, 4])
new # activations: torch.Size([4096])
**************
model.res0
old_shape: torch.Size([1, 32, 32, 28, 28])
old # activations: torch.Size([802816])
new_shape: torch.Size([1, 32, 5, 5, 5])
new # activations: torch.Size([4000])
**************
model.res1
old_shape: torch.Size([1, 32, 32, 28, 28])
old # activations: torch.Size([802816])
new_shape: torch.Size([1, 32, 5, 5, 5])
new # activations: torch.Size([4000])
**************
model.res2
old_shape: torch.Size([1, 32, 32, 28, 28])
old # activations: torch.Size([802816])
new_shape: torch.Size([1, 32, 5, 5, 5])
new # activations: torch.Size([4000])
**************
model.res3
old_shape: torch.Size([1, 32, 32, 28, 

In [3]:
for key1 in model.activations.keys():
    for key2 in model.activations.keys():
        if key1 != key2:
            if model.activations[key1].shape == model.activations[key2].shape:
                if np.allclose(model.activations[key1].detach().float().cpu().numpy(), model.activations[key2].detach().float().cpu().numpy()):
                    print(key1, key2)

In [7]:
from collections import defaultdict
import torchvision
import os
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

interpolate_frames = torchvision.transforms.Compose([
    torchvision.ops.Permute([1,2,3,0]),
    torchvision.transforms.Resize([input_shape[-1],input_shape[1]]),
    torchvision.ops.Permute([0,3,1,2]),
])

preprocess = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Resize(input_shape[-1]),
    torchvision.transforms.CenterCrop(input_shape[-1]),
    # torchvision.transforms.Normalize(123.0, 75.0),
])

model_name='dorsalnet'
DTYPE=torch.float32
iter_mode = 'children'
iter_depth = 1

save_dir = f"/home/matthew/remote_mounts/pomcloud0/students/matthew/projects/activation_maximization/DNNs/{model_name}/"

batch_sizes = {
    # 'LHimages': 1,
    'NaturalMovies': 30,
    'vedb_ver01': 50,
    'BiomotionPilot06': 48,
}

model = model.to(DTYPE)

for experiment in ['NaturalMovies', 'vedb_ver01']:
    print('****', experiment, '****')
    images_dir = f'/home/matthew/Data/DorsalNet_FC/stimuli/{experiment}/images/trn'
    for split in ['trn', 'val']:
        dataloader = DataLoader(
            SingleImageFolder(f'/home/matthew/Data/DorsalNet_FC/stimuli/{experiment}/images/{split}', transform=preprocess),
                batch_size=batch_sizes[experiment], 
                shuffle=False)
        activations_dict = defaultdict(list)
        for images in tqdm(dataloader):
            images = interpolate_frames(images).unsqueeze(0).to(DTYPE).to(DEVICE)
            model(images);
            layer_names = list(model.activations.keys())
            for layer_name in layer_names:
                layer_activations = model.activations[layer_name].detach().cpu()
                del model.activations[layer_name]
                layer_downsampling_fn = layer_downsampling_fns[layer_name]
                if not isinstance(layer_downsampling_fn, type(None)):
                    layer_activations = layer_downsampling_fn(layer_activations)
                # if experiment != 'LHimages':
                #     layer_activations = torch.mean(layer_activations, 0).unsqueeze(0)
                activations_dict[layer_name].append(layer_activations.numpy())
        activations_dict = {name: np.concatenate(outputs, 0) for name, outputs in activations_dict.items()}
        if not os.path.exists(f'{save_dir}/activations/{experiment}'):
            os.makedirs(f'{save_dir}/activations/{experiment}')
        np.savez(f'{save_dir}/activations/{experiment}/{split}_activations_v2.npz', **activations_dict)
        activations_concatenated = np.nan_to_num(np.concatenate([value.reshape(len(value), -1) for value in list(activations_dict.values())], 1).astype(np.float))
        np.save(f'{save_dir}/activations/{experiment}/{split}_activations_v2.npy', activations_concatenated)

**** NaturalMovies ****


100%|██████████| 3600/3600 [18:51<00:00,  3.18it/s]
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
100%|██████████| 270/270 [01:24<00:00,  3.18it/s]


**** vedb_ver01 ****


100%|██████████| 1200/1200 [07:32<00:00,  2.65it/s]
100%|██████████| 90/90 [00:32<00:00,  2.78it/s]


## Model fitting

In [11]:
import os
import numpy as np
import vm_tools as vmt
import cortex as cx

experiments = ['NaturalMovies', 'vedb_ver01']
model_name='dorsalnet'

for subject_id in [f'S0{i}' for i in range(9)]:
    for experiment in experiments:
        if os.path.exists(f'/home/matthew/Data/DorsalNet_FC/fMRI_data/{subject_id}/{experiment}/'):
            fit_dir = f'/home/matthew/remote_mounts/pomcloud0/students/matthew/Projects/activation_maximization/regression_fits/{model_name}/{experiment}/{subject_id}'
            # if not os.path.exists(fit_dir+'/ridge_v2.npz'):
            if True:
                trn_a = np.load(f'/home/matthew/remote_mounts/pomcloud0/students/matthew/Projects/activation_maximization/DNNs/{model_name}/activations/{experiment}/trn_activations_v2.npy')
                val_a = np.load(f'/home/matthew/remote_mounts/pomcloud0/students/matthew/Projects/activation_maximization/DNNs/{model_name}/activations/{experiment}/val_activations_v2.npy')

                if experiment=='NaturalMovies' and subject_id=='S01':
                    trn_a = trn_a[:2400]
                    val_a = val_a[:180]

                if experiment != 'LHimages':
                    trn_a = vmt.utils.add_lags(trn_a)
                    val_a = vmt.utils.add_lags(val_a)
                
                if not os.path.exists(fit_dir):
                    os.makedirs(fit_dir)
                save_path = fit_dir+f"/ridge_v2.npz"
                trn_brain = np.load(f'/home/matthew/Data/DorsalNet_FC/fMRI_data/{subject_id}/{experiment}/trn.npy')
                val_brain = np.load(f'/home/matthew/Data/DorsalNet_FC/fMRI_data/{subject_id}/{experiment}/val_rpts.npy').mean(0)
                fit = vmt.Regression.ridge_cv(trn_fs=trn_a, trn_data=trn_brain,
                                                    val_fs=val_a, val_data=val_brain,
                                                    alphas = list(np.logspace(0,20,20)), ## default range is much too low
                                                    select_by='individual_voxel_r2',
                                                    do_re_zscore_fs=False, do_re_zscore_data=False, is_verbose=False,
                                                    chunk_sz=100000,
                                                    )
                if experiment != 'LHimages':
                    fit['weights_lagged'] = fit['weights'].copy()
                    fit['last_two_lags_mean'] = np.nanmean([fit['weights_lagged'][len(fit['weights_lagged'])//3:-len(fit['weights_lagged'])//3], fit['weights_lagged'][-len(fit['weights_lagged'])//3:]], axis=0)
                    fit['weights'] = vmt.utils.avg_wts(fit['weights'].T, skipfirst=False).T
                np.savez(save_path, **fit)
            fit = np.load(save_path)
            print('mean cc:', np.nanmean(fit['cc']))
            mask = np.load(f'/home/matthew/Data/DorsalNet_FC/fMRI_data/{subject_id}/{experiment}/mask.npy')
            # cx.webshow(cx.Volume(fit['cc'], subject=subject_id, xfmname=experiment, mask=mask, vmin=0, vmax=1, cmap='afmhot'), title=f"{subject_id} {experiment} ccs", with_curvature=True)

  trncc_byvox = np.nanmean(pred_by_alpha, axis=-1)


Computing SVD




mean cc: 0.2624363
Computing SVD


  zs = lambda x: (x-np.nanmean(x, axis=0))/np.nanstd(x, axis=0, ddof=dof)
  keepdims=keepdims)
  r = rTmp/n


mean cc: 0.18957913
Computing SVD
mean cc: 0.22632763
Computing SVD
mean cc: 0.3321591
Computing SVD
mean cc: 0.37056246
Computing SVD
mean cc: 0.25449616
Computing SVD
mean cc: 0.18786868
Computing SVD
mean cc: 0.2419159
Computing SVD
mean cc: 0.25973275
Computing SVD
mean cc: 0.22743388
Computing SVD
mean cc: 0.23946497


## Activation maximization

#### Just network

In [None]:
import skvideo
from skvideo import io
import torchvision

model = DorsalNet(False, 32).eval().to(DEVICE).to(DTYPE)
model.load_state_dict(torch.load('/home/matthew/Data/DorsalNet_FC/base_models/DorsalNet/pretrained.pth'))

preprocess = torchvision.transforms.Compose([
    torchvision.transforms.Resize(112),
    # torchvision.transforms.ToTensor(),
    # tf.Normalize(mean=[0.43216, 0.394666, 0.37645], std=[0.22803, 0.22145, 0.216989]),
])

dl = DataLoader(ImageFolder('/home/matthew/Data/DorsalNet_FC/stimuli/NaturalMovies/images/trn', transform=preprocess), batch_size=32, shuffle=False)

invariance_transforms = transforms.Compose([
    transforms.RandomCrop((512,512), padding=5),
    # transforms.GaussianBlur(31),
    # transforms.RandomRotation([-5,5]),
    # transforms.RandomResizedCrop((500,500), scale=(.95,1.05), ratio=(1,1,1)),
    transforms.RandomCrop((512,512), padding=3),
])

lr = 1e2

for dim in range(4):
    for loc in range(0,28,8):
        fspace = torch.randn((1,3,32,512,512), device=DEVICE, dtype=torch.complex64).requires_grad_(True)
        optimizer = torch.optim.Adam([fspace], lr=lr)
        iterator = tqdm(range(100))
        for i in iterator:
            loss = 0
            optimizer.zero_grad()
            frames = torch.abs(torch.fft.ifftn(fspace.squeeze())).to(DTYPE)
            outputs = model(preprocess(invariance_transforms(frames)).unsqueeze(0))[0]
            for _ in range(dim):
                outputs = outputs.sum(0)
            loss -= outputs[loc].sum()
            iterator.set_postfix({'frames loss': loss.item(), 'mean pixel value': frames.mean().item(), 'pixel std': frames.std().item()})
            loss.backward()
            torch.nn.utils.clip_grad_norm_(fspace, 1e-4)
            optimizer.step()
        frames = torch.abs(torch.fft.ifftn(fspace))
        skvideo.io.vwrite(f"test_{dim}_{loc}.mp4", (frames.squeeze().permute(1,0,2,3).detach().cpu().numpy()*255).astype(np.uint8), inputdict={'-r':'16'})

#### ROIs

In [None]:
import skvideo
from skvideo import io
import torchvision
import cortex as cx

experiment='vedb_ver01'
subject_id= 'S01'
roi = 'hMT'

mask = np.load(f'/home/matthew/Data/DorsalNet_FC/fMRI_data/{subject_id}/{experiment}/mask.npy')
fit_dir = f'/home/matthew/remote_mounts/pomcloud0/students/matthew/Projects/activation_maximization/regression_fits/{model_name}/{experiment}/{subject_id}'
fit = np.load(fit_dir+'/ridge.npz')
roi_mask = cx.get_roi_mask(subject_id, experiment, roi)[roi]
roi_weights = fit['weights'][:,roi_mask[mask].astype(bool)]
roi_weights = np.nanmean(roi_weights, axis=1)
roi_weights /= abs(roi_weights.sum())
roi_weights = torch.tensor(roi_weights).to(DEVICE).to(DTYPE)

In [None]:
model = DorsalNet(False, 32).eval().to(DEVICE).to(DTYPE)
model.load_state_dict(torch.load('/home/matthew/Data/DorsalNet_FC/base_models/DorsalNet/pretrained.pth'))
model = hook_model(model, 1)

preprocess = torchvision.transforms.Compose([
    torchvision.transforms.Resize(112),
    # torchvision.transforms.ToTensor(),
    # tf.Normalize(mean=[0.43216, 0.394666, 0.37645], std=[0.22803, 0.22145, 0.216989]),
])

dl = DataLoader(ImageFolder('/home/matthew/Data/DorsalNet_FC/stimuli/NaturalMovies/images/trn', transform=preprocess), batch_size=32, shuffle=False)

invariance_transforms = transforms.Compose([
    transforms.RandomCrop((512,512), padding=5),
    # transforms.GaussianBlur(31),
    # transforms.RandomRotation([-5,5]),
    # transforms.RandomResizedCrop((500,500), scale=(.95,1.05), ratio=(1,1,1)),
    transforms.RandomCrop((512,512), padding=3),
])

lr = 1e1

fspace = torch.randn((1,3,32,512,512), device=DEVICE, dtype=torch.complex64).requires_grad_(True)
optimizer = torch.optim.Adam([fspace], lr=lr)
iterator = tqdm(range(1000))
for i in iterator:
    loss = 0
    optimizer.zero_grad()
    frames = torch.abs(torch.fft.ifftn(fspace.squeeze())).to(DTYPE)
    model(preprocess(invariance_transforms(frames)).unsqueeze(0));
    all_activations = []
    for layer_name, layer_activations in model.activations.items():
        layer_downsampling_fn = layer_downsampling_fns[layer_name]
        if layer_downsampling_fn is not None:
            layer_activations = layer_downsampling_fn(layer_activations)
        all_activations.append(layer_activations.mean(0).flatten())
    all_activations = torch.cat(all_activations)
    # all_activations = torch.clip(all_activations, -1, 1)
    # loss -= all_activations@roi_weights
    loss -= -torch.nn.functional.cosine_similarity(all_activations.unsqueeze(0), roi_weights.unsqueeze(0))
    iterator.set_postfix({'frames loss': loss.item(), 'mean pixel value': frames.mean().item(), 'pixel std': frames.std().item()})
    loss.backward()
    torch.nn.utils.clip_grad_norm_(fspace, 1e-4)
    optimizer.step()
frames = torch.abs(torch.fft.ifftn(fspace))
skvideo.io.vwrite(f"test_{roi}.mp4", (frames.squeeze().permute(1,0,2,3).detach().cpu().numpy()*255).astype(np.uint8), inputdict={'-r':'16'})