# Visualizing inputs that maximally activate feature maps of a convnet

## Network used: ResNet-34

In [0]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [0]:
!pip install -q fastai==0.7.0 torchtext==0.2.3

In [0]:
from fastai.conv_learner import *
from cv2 import resize
import matplotlib.gridspec as gridspec
from math import ceil

In [0]:
from IPython.display import HTML

In [0]:
from pdb import set_trace

In [0]:
from scipy import ndimage

In [0]:
class SaveFeatures():
    def __init__(self, module):
        self.hook = module.register_forward_hook(self.hook_fn)
    def hook_fn(self, module, input, output):
        self.features = output
    def close(self):
        self.hook.remove()

In [0]:
class FilterVisualizer():
    def __init__(self):
        self.model = nn.Sequential(*list(resnet34(True).children())[:-2]).cuda().eval()
        set_trainable(self.model, False)

    def visualize(self, sz, layer, filter, upscaling_steps=12, upscaling_factor=1.2, lr=0.1, opt_steps=20, blur=None, save=False, print_losses=False):

        img = (np.random.random((sz, sz, 3)) * 20 + 128.)/255.
#         img = np.random.uniform(0, 1, size=(sz, sz, 3)).astype(np.float32)
#         median_filter_size = 4 if sz < 100 else 8
#         img = scipy.ndimage.filters.median_filter(img, [median_filter_size,median_filter_size,1])

        activations = SaveFeatures(layer)  # register hook

        for i in range(upscaling_steps):  # scale the image up upscaling_steps times
            train_tfms, val_tfms = tfms_from_model(resnet34, sz)
            img_var = V(val_tfms(img)[None], requires_grad=True)  # convert image to Variable that requires grad
            optimizer = torch.optim.Adam([img_var], lr=lr, weight_decay=1e-6)
            if i > upscaling_steps/2:
                opt_steps_ = int(opt_steps*1.3)
            else:
                opt_steps_ = opt_steps
            for n in range(opt_steps_):  # optimize pixel values for opt_steps times
                optimizer.zero_grad()
                self.model(img_var)
                loss = -1*activations.features[0, filter].mean()
                if print_losses:
                    if i%3==0 and n%5==0:
                        print(f'{i} - {n} - {float(loss)}')
                loss.backward()
                optimizer.step()
            img = val_tfms.denorm(np.rollaxis(to_np(img_var.data),1,4))[0]
            self.output = img
            sz = int(upscaling_factor * sz)  # calculate new image size
            img = cv2.resize(img, (sz, sz), interpolation = cv2.INTER_CUBIC)  # scale image up
            if blur is not None: img = cv2.blur(img,(blur,blur))  # blur image to reduce high frequency patterns
        activations.close()
        return np.clip(self.output, 0, 1)
    
    def get_transformed_img(self,img,sz):
        train_tfms, val_tfms = tfms_from_model(resnet34, sz)
        return val_tfms.denorm(np.rollaxis(to_np(val_tfms(img)[None]),1,4))[0]
    
    def most_activated(self, image, layer, limit_top=None):

        train_tfms, val_tfms = tfms_from_model(resnet34, 224)
        transformed = val_tfms(image)

        activations = SaveFeatures(layer)  # register hook
        self.model(V(transformed)[None]);
        
        mean_act = [activations.features[0,i].mean().data.cpu().numpy()[0] for i in range(activations.features.shape[1])]
        activations.close()
        return mean_act

In [5]:
FV = FilterVisualizer()

Downloading: "https://download.pytorch.org/models/resnet34-333f7ec4.pth" to /root/.torch/models/resnet34-333f7ec4.pth
100%|██████████| 87306240/87306240 [00:03<00:00, 23733638.02it/s]


In [0]:
FV.model

Sequential(
  (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)
  (2): ReLU(inplace)
  (3): MaxPool2d(kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), dilation=(1, 1), ceil_mode=False)
  (4): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)
      (relu): ReLU(inplace)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)
      (relu): ReLU(inplace)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d

In [0]:
import requests
import base64
import pprint

In [0]:
with open('img_at.txt') as f:
    ACCESS_TOKEN = f.read().strip()

In [0]:
def upload_to_imgur(file_name,post_title,album_hash):
    url = 'https://api.imgur.com/3/image'
    fh = open(file_name, 'rb');
    payload = {'image': base64.b64encode(fh.read()),
              'album':album_hash,
              'type':'base64',
              'title':post_title,
              'looping':False
              }
    files = {}
    headers = {
      'Authorization': f'Bearer {ACCESS_TOKEN}'
    }
    response = requests.request('POST', url, headers = headers, data = payload, files = files, allow_redirects=False)
#     print(response.json()['data']['link'])
    return response.json()['data']['link']

In [0]:
def plot_reconstructions_single_layer(imgs,layer_name,filters,
                                      n_cols=3,
                                      cell_size=4,save_fig=False,
                                      album_hash=None):
    n_rows = ceil((len(imgs))/n_cols)

    fig,axes = plt.subplots(n_rows,n_cols, figsize=(cell_size*n_cols,cell_size*n_rows))
          
    for i,ax in enumerate(axes.flat):
        ax.grid(False)
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)

        if i>=len(filters):
            pass

        ax.set_title(f'fmap {filters[i]}')

        ax.imshow(imgs[i])
    fig.suptitle(f'ResNet34 {layer_name}', fontsize="x-large",y=1.0)
    plt.tight_layout()
    plt.subplots_adjust(top=0.88)
    save_name = layer_name.lower().replace(' ','_')
    if save_fig:
        plt.savefig(f'resnet34_{save_name}_fmaps_{"_".join([str(f) for f in filters])}.png')
        link = upload_to_imgur(f'resnet34_{save_name}_fmaps_{"_".join([str(f) for f in filters])}.png',
                        f'resnet34_{save_name}_fmaps_{"_".join([str(f) for f in filters])}',album_hash)
        plt.close()
        return link
    else:
        plt.show()
        return None

In [0]:
def reconstructions_single_layer(layer,layer_name,filters,
                    init_size=56, upscaling_steps=12, 
                    upscaling_factor=1.2, 
                    opt_steps=20, blur=5,
                    lr=1e-1,print_losses=False,
                    n_cols=3, cell_size=4,
                    save_fig=False,album_hash=None):
    
    imgs = []
    for i in range(len(filters)):
        imgs.append(FV.visualize(init_size,layer, filters[i], 
                                 upscaling_steps=upscaling_steps, 
                                 upscaling_factor=upscaling_factor, 
                                 opt_steps=opt_steps, blur=blur,
                                 lr=lr,print_losses=print_losses))
        
    return plot_reconstructions_single_layer(imgs,layer_name,filters,
                                      n_cols=n_cols,cell_size=cell_size,
                                      save_fig=save_fig,album_hash=album_hash)

In [18]:
url = reconstructions_single_layer(children(FV.model)[0],'Initial Conv',
                             list(range(6,12)),n_cols=3,save_fig=True,
                             album_hash='OhQ94rP')
display(HTML(f"<img src={url} />"))

In [19]:
url = reconstructions_single_layer(children(FV.model)[4][0].conv1,'Layer 1 Block 1 Conv1',
                             list(range(6,12)),n_cols=3,
                             save_fig=True,album_hash='OhQ94rP')
display(HTML(f"<img src={url} />"))

In [20]:
url = reconstructions_single_layer(children(FV.model)[4][0].conv2,
                                   'Layer 1 Block 1 Conv2',list(range(6,12)),
                                   n_cols=3,save_fig=True,album_hash='OhQ94rP')
display(HTML(f"<img src={url} />"))

In [21]:
url = reconstructions_single_layer(children(FV.model)[4][1].conv1,
                                   'Layer 1 Block 2 Conv1',
                                   list(range(6,12)),n_cols=3,
                                   save_fig=True,album_hash='OhQ94rP')
display(HTML(f"<img src={url} />"))

In [23]:
url = reconstructions_single_layer(children(FV.model)[5][0].conv2,
                                   'Layer 2 Block 1 Conv2',list(range(6,12)),
                                   n_cols=3,save_fig=True,album_hash='OhQ94rP')
display(HTML(f"<img src={url} />"))

In [24]:
url = reconstructions_single_layer(children(FV.model)[6][0].conv1,
                                   'Layer 3 Block 1 Conv1',list(range(6,12)),
                                   n_cols=3,save_fig=True,album_hash='OhQ94rP')
display(HTML(f"<img src={url} />"))

In [25]:
url = reconstructions_single_layer(children(FV.model)[6][1].conv1,'Layer 3 Block 2 Conv1',
                                   list(range(6,12)),n_cols=3,
                                   save_fig=True,album_hash='OhQ94rP')
display(HTML(f"<img src={url} />"))

In [26]:
url = reconstructions_single_layer(children(FV.model)[6][2].conv1,'Layer 3 Block 3 Conv1',
                                   list(range(6,12)),n_cols=3,
                                   save_fig=True,album_hash='OhQ94rP')
display(HTML(f"<img src={url} />"))

In [27]:
url = reconstructions_single_layer(children(FV.model)[6][3].conv1,'Layer 3 Block 4 Conv1',
                                   list(range(6,12)),n_cols=3,
                                   save_fig=True,album_hash='OhQ94rP')
display(HTML(f"<img src={url} />"))

In [28]:
url = reconstructions_single_layer(children(FV.model)[6][4].conv1,'Layer 3 Block 5 Conv1',
                                   list(range(6,12)),n_cols=3,
                                   save_fig=True,album_hash='OhQ94rP')
display(HTML(f"<img src={url} />"))

In [29]:
url = reconstructions_single_layer(children(FV.model)[6][5].conv1,'Layer 3 Block 6 Conv1',
                                   list(range(6,12)),n_cols=3,
                                   save_fig=True,album_hash='OhQ94rP')
display(HTML(f"<img src={url} />"))

In [30]:
url = reconstructions_single_layer(children(FV.model)[7][0].conv1,'Layer 4 Block 1 Conv1',
                                   list(range(6,12)),n_cols=3,
                                   save_fig=True,album_hash='OhQ94rP')
display(HTML(f"<img src={url} />"))

In [31]:
url = reconstructions_single_layer(children(FV.model)[7][1].conv1,'Layer 4 Block 2 Conv1',
                                   list(range(6,12)),n_cols=3,
                                   save_fig=True,album_hash='OhQ94rP')
display(HTML(f"<img src={url} />"))

In [32]:
url = reconstructions_single_layer(children(FV.model)[7][2].conv1,'Layer 4 Block 3 Conv1',
                                   list(range(6,12)),n_cols=3,
                                   save_fig=True,album_hash='OhQ94rP')
display(HTML(f"<img src={url} />"))

More images can be seen [here](https://imgur.com/a/OhQ94rP).