In [1]:
import argparse
import os.path as osp
import os
from collections import defaultdict

import mmcv
import torch
from mmcv import Config, DictAction
from mmcv.parallel import MMDataParallel
from mmcv.runner import init_dist, load_checkpoint

from mmcls.datasets import build_dataloader, build_dataset
from mmcls.models import build_classifier
from mmcv import tensor2imgs
from mmcv.cnn.bricks import ContextBlock
import matplotlib.pyplot as plt
import numpy as np


In [2]:
config_file = 'configs/imagenet/resnet50_gc-r4_batch256.py'
checkpoint_file = 'resnet50_gc-r4.pth'


In [None]:
hidden_outputs = {}


def activation_hook(name):
    def hook(module, input, output):
        # [N, C, 1, 1]
        x = input[0]
        context = module.spatial_pool(x)

        assert module.channel_add_conv is not None
        # [N, C, 1, 1]
        channel_add_term = module.channel_add_conv(context)
        hidden_outputs[name] = channel_add_term.squeeze(-1).squeeze(-1)

    return hook


def register_activation_hook(model):
    for module_name, module in model.module.named_modules():
        if isinstance(module, ContextBlock):
            module.register_forward_hook(activation_hook(module_name))
            print(f'{module_name} is registered')



activations = dict()


def single_gpu_vis(model,
                   data_loader,
                   show=False,
                   out_dir=None):
    model.eval()
    register_activation_hook(model)

    dataset = data_loader.dataset
    dataset_length = len(dataset)
    prog_bar = mmcv.ProgressBar(len(dataset))
    for i, data in enumerate(data_loader):
        batch_size = data['img'].size(0)
        with torch.no_grad():
            model(return_loss=True, **data)

        gt_label = data['gt_label'].cuda()

        for name in hidden_outputs:
            hidden_output = hidden_outputs[name].view(batch_size, -1)
            if name not in activations:
                activations[name] = hidden_output.new_zeros(
                    1000, hidden_output.shape[-1])/data_length
            activations[name].scatter_add_(0, gt_label.unsqueeze(
                1).expand_as(hidden_output), hidden_output)
            activations[name] = activations[name]

        hidden_outputs.clear()

        for _ in range(batch_size):
            prog_bar.update()

In [None]:
cfg = Config.fromfile(config_file)
# set cudnn_benchmark
if cfg.get('cudnn_benchmark', False):
    torch.backends.cudnn.benchmark = True
cfg.model.pretrained = None
cfg.data.test.test_mode = True

# init distributed env first, since logger depends on the dist info.
distributed = False

# build the dataloader
dataset = build_dataset(cfg.data.test)
data_loader = build_dataloader(
    dataset,
    samples_per_gpu=cfg.data.samples_per_gpu,
    workers_per_gpu=cfg.data.workers_per_gpu,
    dist=distributed,
    shuffle=False,
    round_up=False)


# build the model and load checkpoint
model = build_classifier(cfg.model)
checkpoint = load_checkpoint(model, checkpoint_file, map_location='cpu')
# old versions did not save class info in checkpoints, this walkaround is
# for backward compatibility
if 'CLASSES' in checkpoint['meta']:
    model.CLASSES = checkpoint['meta']['CLASSES']
else:
    model.CLASSES = dataset.CLASSES

assert not distributed
model = MMDataParallel(model, device_ids=[0])


In [None]:
single_gpu_vis(model, data_loader, True)

In [19]:
activations = mmcv.load(osp.join('output/r50_gc-r4_c3', 'activations.pkl'))
mapping = mmcv.load('imagenet_class_index.json')
vis_indices = [1, 254, 726, 972]
step = 8
num_samples = 32
for name in activations:
    plt.figure(figsize=(10, 5))
    labels = []
    for vis_index in vis_indices:
        class_name = mapping[str(vis_index)][1]
        
        step = activations[name].shape[1] // num_samples
        x = np.arange(activations[name].shape[1])[::step]
        y = activations[name][vis_index][::step]
        plt.plot(x, y)
#         plt.plot(x, np.exp(y)/sum(np.exp(y)))
#         plt.plot(x, np.exp(y)/(np.exp(y)+1))
        labels.append(class_name)
    class_name = 'all'
    x = np.arange(activations[name].shape[1])[::step]
    y = activations[name].mean(0)[::step]
    axes = plt.gca()
#     axes.set_ylim([0, 1])
#     plt.plot(x, np.exp(y)/(np.exp(y)+1))
#     plt.plot(x, np.exp(y)/sum(np.exp(y)))
    plt.plot(x, y)
    labels.append(class_name)

    plt.legend(labels, ncol=1, loc='upper right',
               columnspacing=2.0, labelspacing=1,
               handletextpad=0.5, handlelength=1.5,
           fancybox=True, shadow=True, fontsize=15)
#     plt.title(name)
    name_splits = name.split('.')
    plt.title('c'+str(int(name_splits[1][-1])+1)+'.'+'.'.join(name_splits[2:]), fontsize=25)
    plt.ylabel('Context Amplitude', fontsize=20, labelpad=15)
    plt.xlabel('Channel Index', fontsize=20, labelpad=15)
    mmcv.mkdir_or_exist('output/r50_gc-r4_c3/vis_act/')
#     plt.show()
    plt.savefig(f'output/r50_gc-r4_c3/vis_act/{name}.png', bbox_inches='tight')
    plt.clf()

<Figure size 720x360 with 0 Axes>

<Figure size 720x360 with 0 Axes>

<Figure size 720x360 with 0 Axes>

<Figure size 720x360 with 0 Axes>

<Figure size 720x360 with 0 Axes>

<Figure size 720x360 with 0 Axes>

<Figure size 720x360 with 0 Axes>

<Figure size 720x360 with 0 Axes>

<Figure size 720x360 with 0 Axes>

<Figure size 720x360 with 0 Axes>

<Figure size 720x360 with 0 Axes>

<Figure size 720x360 with 0 Axes>

<Figure size 720x360 with 0 Axes>

<Figure size 720x360 with 0 Axes>

<Figure size 720x360 with 0 Axes>

<Figure size 720x360 with 0 Axes>