# Visualizing inputs that maximally activate feature maps of a convnet

## Network used: VGG16

In [0]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [0]:
!pip install -q fastai==0.7.0 torchtext==0.2.3

In [0]:
from fastai.conv_learner import *
from cv2 import resize
import matplotlib.gridspec as gridspec
from math import ceil
from IPython.display import HTML
from pdb import set_trace
# from scipy import ndimage

In [0]:
class SaveFeatures():
    def __init__(self, module):
        self.hook = module.register_forward_hook(self.hook_fn)
    def hook_fn(self, module, input, output):
        self.features = output
    def close(self):
        self.hook.remove()

In [0]:
class FilterVisualizer():
    def __init__(self):
        self.model = vgg16(True).cuda().eval()
        set_trainable(self.model, False)

    def visualize(self, sz, layer, filter, upscaling_steps=12, upscaling_factor=1.2, lr=0.1, opt_steps=20, blur=None, save=False, print_losses=False):

        img = (np.random.random((sz, sz, 3)) * 20 + 128.)/255.
#         img = np.random.uniform(0, 1, size=(sz, sz, 3)).astype(np.float32)
#         median_filter_size = 4 if sz < 100 else 8
#         img = scipy.ndimage.filters.median_filter(img, [median_filter_size,median_filter_size,1])
    
        layer = children(self.model)[layer]
        activations = SaveFeatures(layer)  # register hook

        for i in range(upscaling_steps):  # scale the image up upscaling_steps times
            train_tfms, val_tfms = tfms_from_model(vgg16, sz)
            img_var = V(val_tfms(img)[None], requires_grad=True)  # convert image to Variable that requires grad
            optimizer = torch.optim.Adam([img_var], lr=lr, weight_decay=1e-6)
            if i > upscaling_steps/2:
                opt_steps_ = int(opt_steps*1.3)
            else:
                opt_steps_ = opt_steps
            for n in range(opt_steps_):  # optimize pixel values for opt_steps times
                optimizer.zero_grad()
                self.model(img_var)
                loss = -activations.features[0, filter].mean()
                if print_losses:
                    if i%3==0 and n%5==0:
                        print(f'{i} - {n} - {float(loss)}')
                loss.backward()
                optimizer.step()
            img = val_tfms.denorm(np.rollaxis(to_np(img_var.data),1,4))[0]
            self.output = img
            sz = int(upscaling_factor * sz)  # calculate new image size
            img = cv2.resize(img, (sz, sz), interpolation = cv2.INTER_CUBIC)  # scale image up
            if blur is not None: img = cv2.blur(img,(blur,blur))  # blur image to reduce high frequency patterns
        if save:
            self.save(layer, filter)
        activations.close()
        return np.clip(self.output, 0, 1)
    
    def most_activated(self, image, layer, limit_top=None):

        train_tfms, val_tfms = tfms_from_model(vgg16, 224)
        transformed = val_tfms(image)
#         set_trace()
        layer = children(self.model)[layer]
        activations = SaveFeatures(layer)  # register hook
        self.model(V(transformed)[None]);
        #         set_trace()
        print(activations.features.shape)
        #         set_trace()
        mean_act = [activations.features[0,i].mean().data.cpu().numpy()[0] for i in range(activations.features.shape[1])]
        activations.close()
        return mean_act
        
    def get_transformed_img(self,img,sz):
        train_tfms, val_tfms = tfms_from_model(resnet34, sz)
        return val_tfms.denorm(np.rollaxis(to_np(val_tfms(img)[None]),1,4))[0]

In [21]:
FV = FilterVisualizer()

Downloading: "https://download.pytorch.org/models/vgg16_bn-6c64b313.pth" to /root/.torch/models/vgg16_bn-6c64b313.pth
100%|██████████| 553507836/553507836 [00:05<00:00, 101382015.63it/s]


In [0]:
FV.model

Sequential(
  (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)
  (2): ReLU(inplace)
  (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)
  (5): ReLU(inplace)
  (6): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1), ceil_mode=False)
  (7): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (8): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True)
  (9): ReLU(inplace)
  (10): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (11): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True)
  (12): ReLU(inplace)
  (13): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1), ceil_mode=False)
  (14): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (15): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True)
  (16): ReLU(inplace)
  (17): Conv2d(256, 256, ke

In [0]:
import requests
import base64
import pprint

In [0]:
with open('img_at.txt') as f:
    ACCESS_TOKEN = f.read().strip()

In [0]:
def upload_to_imgur(file_name,post_title,album_hash):
    url = 'https://api.imgur.com/3/image'
    fh = open(file_name, 'rb');
    payload = {'image': base64.b64encode(fh.read()),
              'album':album_hash,
              'type':'base64',
              'title':post_title,
              'looping':False
              }
    files = {}
    headers = {
      'Authorization': f'Bearer {ACCESS_TOKEN}'
    }
    response = requests.request('POST', url, headers = headers, data = payload, files = files, allow_redirects=False)
#     print(response.json()['data']['link'])
    return response.json()['data']['link']

In [0]:
def plot_reconstructions(imgs,layer_idx,filters,n_cols=3, cell_size=4, save_fig=False,album_hash=None):
    n_rows = ceil((len(imgs))/n_cols)

    fig,axes = plt.subplots(n_rows,n_cols, figsize=(cell_size*n_cols,cell_size*n_rows))
          
    for i,ax in enumerate(axes.flat):
        ax.grid(False)
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)

        if i>=len(filters):
            pass

        ax.set_title(f'fmap {filters[i]}')

        ax.imshow(imgs[i])
    fig.suptitle(f'VGG16 Layer: {layer_idx} : {children(FV.model)[layer_idx]}', fontsize="x-large",y=1.0)
    plt.tight_layout()
    plt.subplots_adjust(top=0.9)
    if save_fig:
        plt.savefig(f'vgg16_layer_{layer_idx}_fmaps_{"_".join([str(f) for f in filters])}.png')
        link = upload_to_imgur(f'vgg16_layer_{layer_idx}_fmaps_{"_".join([str(f) for f in filters])}.png',
                        f'vgg16_layer_{layer_idx}_fmaps_{"_".join([str(f) for f in filters])}',album_hash)
        plt.close()
        return link
    else:
        plt.show()
        return None
    

In [0]:
def reconstructions(layer_idx, filters,
                    init_size=56, upscaling_steps=12, 
                    upscaling_factor=1.2, 
                    opt_steps=20, blur=5,
                    lr=1e-1,print_losses=False,
                    n_cols=3, cell_size=4,
                    save_fig=False,
                    album_hash=None):
    
    if save_fig:
        assert not album_hash is None
    imgs = []
    for filter in filters:
        imgs.append(FV.visualize(init_size,layer_idx, filter, 
                                 upscaling_steps=upscaling_steps, 
                                 upscaling_factor=upscaling_factor, 
                                 opt_steps=opt_steps, blur=blur,
                                 lr=lr,print_losses=print_losses))
        
    return plot_reconstructions(imgs,layer_idx,filters,
                         n_cols=n_cols,cell_size=cell_size,
                         save_fig=save_fig,album_hash=album_hash)

In [0]:
# l_idxs = [40]
# dd = 0
# for l in l_idxs:
#     start = 198
#     n_items_per_plot = 6
#     end =400
#     count = start
#     while count+n_items_per_plot<end:
#         dd+=1
#         print(f'reconstructions({l},list(range({count},{count+n_items_per_plot})),save_fig=True)')
#         reconstructions(l,list(range(count,count+n_items_per_plot)),save_fig=True)
#         count+=n_items_per_plot
# print(dd)

In [58]:
url = reconstructions(layer_idx=7,filters=list(range(0,3)),
                save_fig=True,album_hash="A5KpYOz")
display(HTML(f"<img src={url} />"))

In [60]:
url = reconstructions(layer_idx=10,filters=list(range(24,30)),
                save_fig=True,album_hash="A5KpYOz")
display(HTML(f"<img src={url} />"))

In [62]:
url = reconstructions(layer_idx=14,filters=list(range(0,6)),
                save_fig=True,album_hash="A5KpYOz")
display(HTML(f"<img src={url} />"))

In [64]:
url = reconstructions(layer_idx=14,filters=list(range(30,36)),
                save_fig=True,album_hash="A5KpYOz")
display(HTML(f"<img src={url} />"))

In [66]:
url = reconstructions(layer_idx=24,filters=list(range(66,72)),
                save_fig=True,album_hash="A5KpYOz")
display(HTML(f"<img src={url} />"))

In [68]:
url = reconstructions(layer_idx=24,filters=list(range(84,90)),
                save_fig=True,album_hash="A5KpYOz")
display(HTML(f"<img src={url} />"))

In [70]:
url = reconstructions(layer_idx=24,filters=list(range(204,210)),
                save_fig=True,album_hash="SsbtjRZ")
display(HTML(f"<img src={url} />"))

More images from earlier layers can be found [here](https://imgur.com/a/A5KpYOz).

In [49]:
url = reconstructions(layer_idx=40,filters=list(range(0,6)),
                save_fig=True,album_hash="SsbtjRZ")
display(HTML(f"<img src={url} />"))

In [51]:
url = reconstructions(layer_idx=40,filters=list(range(6,12)),
                save_fig=True,album_hash="SsbtjRZ")
display(HTML(f"<img src={url} />"))

In [55]:
url = reconstructions(layer_idx=40,filters=list(range(12,18)),
                save_fig=True,album_hash="SsbtjRZ")
display(HTML(f"<img src={url} />"))

Results for all feature maps from layer `40` can be found [here](https://imgur.com/a/SsbtjRZ).