In [None]:
import torch
import matplotlib.pyplot as plt

import featurevis
from featurevis import models
from featurevis import ops
from featurevis import utils


device = 'cuda' if torch.cuda.is_available() else 'cpu'
if device == 'cpu': 
    print('Running models on CPU!')

First, let's get a pre-trained model for neurons

In [None]:
from staticnet_analyses import multi_mei
from staticnet_experiments import models as static_models

key = {'data_hash': '7572eed73113c993e7d1b92f83e270b4', 'group_id': 29, 
       'net_hash': '80d0d4bc112470b2ba04cd5eba048e39', 'neuron_id': 119, 
       'readout_key': 'group029-21067-9-17-0'}

# Get our f (average of four models)
train_stats = multi_mei.prepare_data(key, key['readout_key'])
_, (_, _, height, width), _, mean_behavior, mean_eyepos, _ = train_stats
model_key = {'group_id': key['group_id'], 'net_hash': key['net_hash']}
my_models = [(static_models.Model & mk).load_network() for mk in (static_models.Model & model_key).proj()]
model = models.Ensemble(my_models, key['readout_key'], eye_pos=mean_eyepos, neuron_idx=key['neuron_id'], device=device)

In [None]:
dset = train_stats[0]

In [None]:
dset.images[dset.tiers == 'train'].std()

In [None]:
train_stats

In [None]:
initial_image = torch.randn(1, 1, 36, 64, dtype=torch.float32, device=device)  # grayscale random image

## Simplest optimization

### SGD (no bells and whistles) 

In [None]:
opt_x, fevals, reg_values = featurevis.gradient_ascent(model, initial_image, step_size=5, num_iterations=1000)

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(18, 5))
axes[0].plot(fevals)
axes[1].imshow(opt_x.squeeze().detach().cpu().numpy())

### ADAM (no bells and whistles)

In [None]:
opt_x, fevals, reg_values = featurevis.gradient_ascent(model, initial_image, optim_name='Adam', step_size=0.1, num_iterations=200)

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(18, 5))
axes[0].plot(fevals)
axes[1].imshow(opt_x.squeeze().detach().cpu().numpy())

## DiCarlo (Bashivan et al., 2018)
See Sec. *Synthesized "controller" images* in p.9.
* Optimizer: SGD
* Transform: Jittering
* Regularization: Total variation
* Gradient function: Normalize the gradient (grad / norm(grad)) and clip between -1 and 1.

In [None]:
dc_transform = ops.Jitter(max_jitter=(2, 4))
dc_regularization = ops.TotalVariation(weight=0.001)
dc_gradient = utils.Compose([ops.ChangeNorm(1), ops.ClipRange(-1, 1)])

In [None]:
opt_x, fevals, reg_values = featurevis.gradient_ascent(model, initial_image, step_size=1, num_iterations=700, 
                                                       transform=dc_transform, regularization=dc_regularization, 
                                                       gradient_f=dc_gradient)

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(7*3, 4))
axes[0].plot(fevals)
axes[1].plot(reg_values)
axes[2].imshow(opt_x.squeeze().detach().cpu().numpy())

## DeepTune (Abbasi-Asl, 2018)
See Equation in p.8
* Optimizer: SGD
* Regularization: total variation and l6 norm

In [None]:
dt_regularization = utils.Combine([ops.TotalVariation(weight=0.001), ops.LpNorm(weight=1, p=6)])

In [None]:
opt_x, fevals, reg_values = featurevis.gradient_ascent(model, initial_image, step_size=1, num_iterations=500, 
                                                       regularization=dt_regularization)

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(7*3, 4))
axes[0].plot(fevals)
axes[1].plot(reg_values)
axes[2].imshow(opt_x.squeeze().detach().cpu().numpy())

## Walker et al., 2019
* Optimizer: SGD
* Gradient: Fourier smoothing, divide by mean of absolute gradient and multiply by a decaying learning rate
* Post update: Clip range and blur image with a decaying sigma

In [None]:
walker_gradient = utils.Compose([ops.FourierSmoothing(0.04), # not exactly the same as fft_smooth(precond=0.1) but close
                                 ops.DivideByMeanOfAbsolute(),
                                 ops.MultiplyBy(1/850, decay_factor=(1/850 - 1/20400) /(1-1000))])  # decays from 1/850 to 1/20400 in 1000 iterations
bias, scale = 111.28329467773438, 60.922306060791016
walker_postup = utils.Compose([ops.ClipRange(-bias / scale, (255 - bias) / scale), 
                               ops.GaussianBlur(1.5, decay_factor=(1.5 - 0.01) /(1-1000))]) # decays from 1.5 to 0.01 in 1000 iterations

In [None]:
opt_x, fevals, reg_values = featurevis.gradient_ascent(model, initial_image, step_size=1, num_iterations=1000, 
                                                       post_update=walker_postup, gradient_f=walker_gradient)

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(18, 5))
axes[0].plot(fevals)
axes[1].imshow(opt_x.squeeze().detach().cpu().numpy())

In [None]:
mei = (multi_mei.MEI & key).fetch1('mei')
plt.figure(figsize=(8, 5))
plt.imshow(mei)
plt.title('MEI from deepdraw')

## Alternative MEI generation
TODO: Finding the simplest way to generate robust MEIs
* Optimizer: SGD
* Post update: Keep std to 0.1.

In [None]:
alt_postup = ops.ChangeStd(0.1)

In [None]:
opt_x, fevals, reg_values = featurevis.gradient_ascent(model, initial_image, step_size=1, num_iterations=1000, 
                                                       post_update=alt_postup)

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(18, 5))
axes[0].plot(fevals)
axes[1].imshow(opt_x.squeeze().detach().cpu().numpy())

## LEI (least exciting image)

In [None]:
lei_model = utils.Compose([model, ops.MultiplyBy(-1)]) # negative model
lei_postup = ops.ChangeStd(0.1)

In [None]:
opt_x, fevals, reg_values = featurevis.gradient_ascent(lei_model, initial_image, step_size=1, num_iterations=1000, 
                                                        post_update=lei_postup)

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(18, 5))
axes[0].plot(fevals)
axes[1].imshow(opt_x.squeeze().detach().cpu().numpy())

## Checking evolution of MEI
Just to show `save_iters`. Something similar could be used to run early stopping (by testing intermediate MEIs in a validation model).

In [None]:
opt_x, fevals, reg_values = featurevis.gradient_ascent(model, initial_image, step_size=10, num_iterations=100, save_iters=20)

In [None]:
fig, axes = plt.subplots(1, 6, figsize=(20, 3))
axes[0].set_title('Iter {} f(x) = {:.2f}'.format(0, fevals[0]))
axes[0].imshow(initial_image.squeeze().detach().cpu().numpy())
for ax, i, one_x in zip(axes[1:], range(20, 101, 20), opt_x):
    ax.imshow(one_x.squeeze().detach().cpu().numpy())
    ax.set_title('Iter {}: f(x) = {:.2f}'.format(i, fevals[i]))

## Diverse MEIs

In [None]:
initial_batch = torch.randn(5, 1, 36, 64, dtype=torch.float32, device=device) # 5 grayscale random images
mask = None

In [None]:
# mask that contains the relevant part of the image (if available)
from staticnet_analyses import largefov
mask = torch.as_tensor((largefov.MEIMask & key).fetch1('mask'), dtype=torch.float32, device=device)

### Computing similarity in pixel space

In [None]:
div_regularization = ops.Similarity(10, mask=mask, metric='correlation')
div_postup = ops.ChangeStd(0.1)

In [None]:
opt_x, fevals, reg_values = featurevis.gradient_ascent(model, initial_batch, step_size=10, num_iterations=500, 
                                                       regularization=div_regularization, post_update=div_postup)

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(18, 5))
axes[0].plot(fevals)
axes[1].plot(reg_values)

In [None]:
fig, axes = plt.subplots(1, len(opt_x), figsize=(20, 3))
for ax, one_x in zip(axes, opt_x):
    ax.imshow(one_x.squeeze().detach().cpu().numpy())

### Computing similarity in (VGG-19) feature space

In [None]:
from torch.nn import functional as F
mini_mask = F.avg_pool2d(mask.unsqueeze(0).unsqueeze(0), kernel_size=4).squeeze() # the VGG features get downsampled twice

In [None]:
div_regularization = utils.Compose([ops.GrayscaleToRGB(), models.VGG19Core(layer=15), 
                                    ops.Similarity(0.02, mask=mini_mask, metric='euclidean')])
div_postup = ops.ChangeStd(0.1)

In [None]:
opt_x, fevals, reg_values = featurevis.gradient_ascent(model, initial_batch, step_size=10, num_iterations=500, 
                                                       regularization=div_regularization, post_update=div_postup)

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(18, 5))
axes[0].plot(fevals)
axes[1].plot(reg_values)

In [None]:
fig, axes = plt.subplots(1, len(opt_x), figsize=(20, 3))
for ax, one_x in zip(axes, opt_x):
    ax.imshow(one_x.squeeze().detach().cpu().numpy())

## Texture
### Random crops
* Optimizer: SGD
* Transform: Take a random crop from the big image

In [None]:
initial_image2 = torch.randn(1, 1, 36*2, 64*2, dtype=torch.float32, device=device)
text_transform = ops.RandomCrop(36, 64)

In [None]:
opt_x, fevals, reg_values = featurevis.gradient_ascent(model, initial_image2, step_size=5, num_iterations=5000, transform=text_transform)

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(18, 5))
axes[0].plot(fevals)
axes[1].imshow(opt_x.squeeze().detach().cpu().numpy())

### Batched random crops (a la Santiago)
* Optimizer: SGD
* Transform: Create a batch with overlapping tiles of the big image, optimize the mean activity overall

Doesn't look as nice. Also, I ran out of memory for bigger FOV.

In [None]:
initial_image2 = torch.randn(1, 1, 36 + 18, 64 + 31, dtype=torch.float32, device=device)
text_transform = ops.BatchedCrops(36, 64, step_size=5, sigma=(8, 12))

In [None]:
opt_x, fevals, reg_values = featurevis.gradient_ascent(model, initial_image2, step_size=10, num_iterations=400, transform=text_transform)

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(18, 5))
axes[0].plot(fevals)
axes[1].imshow(opt_x.squeeze().detach().cpu().numpy())

## In VGG19

Let's  download a pretrained model

In [None]:
#model = models.VGG19(layer=14, channel=13, device=device) # conv 3_1, feature map 13
model = models.VGG19(layer=40, channel=150, device=device) # conv 5_1, feature map 150

In [None]:
initial_image = torch.randn(1, 3, 128, 128, dtype=torch.float32, device=device) # 128 x 128 RGB image

### MEI

In [None]:
opt_x, fevals, reg_values = featurevis.gradient_ascent(model, initial_image, optim_name='Adam', step_size=0.1, num_iterations=1000, 
                                                       transform=ops.Jitter(3), # jitter to avoid adversarial effects
                                                       gradient_f=ops.GaussianBlur(1), # bit of blurring on gradient to avoid high freq effects
                                                       post_update=ops.ChangeStd(1)) # keep the image in a reasonable range

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(21, 7))
axes[0].plot(fevals)
im = opt_x.squeeze().transpose(0, -1).detach().cpu().numpy()
im = (im - im.min(axis=(0, 1))) / (im.max(axis=(0, 1)) - im.min(axis=(0, 1)))
axes[1].imshow(im)
axes[2].imshow(opt_x.squeeze().transpose(0, -1).detach().cpu().numpy() / 2 + 0.5)

### LEI

In [None]:
lei_model = utils.Compose([model, ops.MultiplyBy(-1)]) # find least activating features 
opt_x, fevals, reg_values = featurevis.gradient_ascent(lei_model, initial_image, optim_name='Adam', step_size=0.1, num_iterations=1000, 
                                                       transform=ops.Jitter(3), # jitter to avoid adversarial effects
                                                       gradient_f=ops.GaussianBlur(1), # bit of blurring on gradient to avoid high freq effects
                                                       post_update=ops.ChangeStd(1)) # keep the image in a reasonable range

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(21, 7))
axes[0].plot(fevals)
im = opt_x.squeeze().transpose(0, -1).detach().cpu().numpy()
im = (im - im.min(axis=(0, 1))) / (im.max(axis=(0, 1)) - im.min(axis=(0, 1)))
axes[1].imshow(im)
axes[2].imshow(opt_x.squeeze().transpose(0, -1).detach().cpu().numpy() / 3 + 0.5)