In [1]:
import argparse
import os.path as osp
import os

import math

from PIL import Image
import requests

import mmcv
import torch
from mmcv import Config, DictAction
from mmcv.runner import init_dist, load_checkpoint
from mmcv.parallel import MMDataParallel, collate, scatter

from mmdet.datasets import build_dataloader, build_dataset
from mmdet.models import build_detector
from mmdet.core import tensor2imgs
import matplotlib.pyplot as plt
import numpy as np

%config InlineBackend.figure_format = 'retina'

import ipywidgets as widgets
from IPython.display import display, clear_output


In [None]:
class AttentionVisualizer:
    def __init__(self, model, dataset, url2idx):
        self.model = model
        self.dataset = dataset
        self.url2idx = url2idx

        self.url = ""
        self.cur_url = None
        self.pil_img = None
        self.tensor_img = None
        self.data = None

        self.attn_weights = None

        self.setup_widgets()
    

    def setup_widgets(self):
        self.sliders = [
            widgets.Text(
                value='http://images.cocodataset.org/val2017/000000084031.jpg',
                placeholder='Type something',
                description='URL (ENTER):',
                disabled=False,
                continuous_update=False,
                layout=widgets.Layout(width='100%')
            ),
            widgets.FloatSlider(min=0, max=0.99,
                        step=0.02, description='X coordinate', value=0.72,
                        continuous_update=False,
                        layout=widgets.Layout(width='50%')
                        ),
            widgets.FloatSlider(min=0, max=0.99,
                        step=0.02, description='Y coordinate', value=0.40,
                        continuous_update=False,
                        layout=widgets.Layout(width='50%')),
            widgets.Checkbox(
              value=True,
              description='Direction of self attention',
              disabled=False,
              indent=False,
              layout=widgets.Layout(width='50%'),
          ),
            widgets.Checkbox(
              value=True,
              description='Show red dot in attention',
              disabled=False,
              indent=False,
              layout=widgets.Layout(width='50%'),
          ),
            widgets.Button(
                description='save',
                disabled=False,
                button_style='', # 'success', 'info', 'warning', 'danger' or ''
                tooltip='save',
                icon='check')
        ]
        self.o = widgets.Output()

    def compute_features(self, data):
        model = self.model
        hidden_outputs = {}
        def attn_mask_hook(name):
            def hook(self, input, output):
                x = input[0]
                # Assume `reduction = 1`, then `inter_channels = C`
                # or `inter_channels = C` when `mode="gaussian"`

                # NonLocal1d x: [N, C, H]
                # NonLocal2d x: [N, C, H, W]
                # NonLocal3d x: [N, C, T, H, W]
                n = x.size(0)

                # NonLocal1d g_x: [N, H, C]
                # NonLocal2d g_x: [N, HxW, C]
                # NonLocal3d g_x: [N, TxHxW, C]
                g_x = self.g(x).view(n, self.inter_channels, -1)
                g_x = g_x.permute(0, 2, 1)

                # NonLocal1d theta_x: [N, H, C], phi_x: [N, C, H]
                # NonLocal2d theta_x: [N, HxW, C], phi_x: [N, C, HxW]
                # NonLocal3d theta_x: [N, TxHxW, C], phi_x: [N, C, TxHxW]
                if self.mode == 'gaussian':
                    theta_x = x.view(n, self.in_channels, -1)
                    theta_x = theta_x.permute(0, 2, 1)
                    if self.sub_sample:
                        phi_x = self.phi(x).view(n, self.in_channels, -1)
                    else:
                        phi_x = x.view(n, self.in_channels, -1)
                elif self.mode == 'concatenation':
                    theta_x = self.theta(x).view(n, self.inter_channels, -1, 1)
                    phi_x = self.phi(x).view(n, self.inter_channels, 1, -1)
                else:
                    theta_x = self.theta(x).view(n, self.inter_channels, -1)
                    theta_x = theta_x.permute(0, 2, 1)
                    phi_x = self.phi(x).view(n, self.inter_channels, -1)

                pairwise_func = getattr(self, self.mode)
                # NonLocal1d pairwise_weight: [N, H, H]
                # NonLocal2d pairwise_weight: [N, HxW, HxW]
                # NonLocal3d pairwise_weight: [N, TxHxW, TxHxW]
                pairwise_weight = pairwise_func(theta_x, phi_x)
                
                hidden_outputs[name] = pairwise_weight.view(n, x.size(2), x.size(3), x.size(2), x.size(3))

            return hook


        hooks = []
        for module_name, module in model.module.named_modules():
            if 'NonLocal' in str(module.__class__):
                module.register_forward_hook(attn_mask_hook(module_name))
                print(f'{module_name} is registered')
        # propagate through the model
        with torch.no_grad():
            model(return_loss=False, rescale=True, **data)

        for hook in hooks:
            hook.remove()

        self.attn_weights = list(hidden_outputs.values())[0].detach().cpu().numpy()[0]
    
    def compute_on_image(self, url):
        if url != self.url:
            self.url = url
            data = self.dataset[self.url2idx[url]]
            data = collate([data], samples_per_gpu=1)
            device = next(self.model.parameters()).device  # model device
            # scatter to specified GPU
            data = scatter(data, [device])[0]
            self.data = data
            self.compute_features(data)
    
    def update_chart(self, change):
        with self.o:
            clear_output()

            # j and i are the x and y coordinates of where to look at
            # sattn_dir is which direction to consider in the self-attention matrix
            # sattn_dot displays a red dot or not in the self-attention map
            url, j, i, sattn_dir, sattn_dot = [s.value for s in self.sliders[:-1]]

            fig, axs = plt.subplots(ncols=2, nrows=1, figsize=(9, 4))
            self.compute_on_image(url)
            data = self.data
            img_tensor = data['img'][0]
            img_meta = data['img_metas'][0][0]
            img = tensor2imgs(img_tensor, **img_meta['img_norm_cfg'])[0]
            h, w, _ = img_meta['img_shape']
            img_show = img[:h, :w, :]
            img_show = mmcv.bgr2rgb(img_show)
            ori_h, ori_w = img_meta['ori_shape'][:-1]
            img_show = mmcv.imresize(img_show, (ori_w, ori_h))

            # convert reference point to absolute coordinates
            j = int(j * w)
            i = int(i * h)
            
             # how much was the original image upsampled before feeding it to the model
            scale = ori_h / h

            # compute the downsampling factor for the model
            # it should be 32 for standard DETR and 16 for DC5
            sattn = self.attn_weights
            print('sattn shape:', sattn.shape)
            fact = 2 ** round(math.log2(w / sattn.shape[-1]))

            # round the position at the downsampling factor
            x = ((j // fact) + 0.5) * fact
            y = ((i // fact) + 0.5) * fact

            axs[0].imshow(img_show)
            axs[0].axis('off')
            axs[0].add_patch(plt.Circle((x * scale, y * scale), fact // 2, color='r'))

            idx = (i // fact, j // fact)
            
            if sattn_dir:
                sattn_map = sattn[idx[0], idx[1], ...]
            else:
                sattn_map = sattn[..., idx[0], idx[1]]
            print('sattn_map sum:', sattn_map.sum())
            print('sattn map shape:', sattn_map.shape)
            attn_map_copy = sattn_map.copy()
            sattn_map[attn_map_copy < (np.percentile(attn_map_copy, 80))] = attn_map_copy.min()
            # sattn_map[attn_map_copy > (np.percentile(attn_map_copy, 95))] = np.percentile(attn_map_copy, 95)
            sattn_map = mmcv.imresize_like(sattn_map, img_show)
            axs[1].imshow(sattn_map, cmap='jet', interpolation='bilinear')
            # axs[1].imshow(sattn_map, cmap='cividis', interpolation='nearest')
            if sattn_dot:
                axs[1].add_patch(plt.Circle((x * scale, y * scale), fact // 2, color='r'))
            axs[1].axis('off')
            axs[1].set_title(f'self-attention{(i, j)}')

            plt.show()
            fig = plt.gcf()
            DPI = fig.get_dpi()
            fig.set_size_inches(ori_w/float(DPI), ori_h/float(DPI))
            plt.clf()
            plt.imshow(img_show)
            plt.imshow(sattn_map, cmap='jet', interpolation='bilinear', alpha=0.3)
            # axs.add_patch(plt.Circle((x * scale, y * scale), fact // 2, color='r'))
            plt.plot([x * scale], [y * scale], marker='s', color='r')
            plt.title(img_meta['ori_filename'])
            plt.show()
            def on_button_clicked(b):
                plt.clf()
                plt.imshow(img_show)
                plt.imshow(sattn_map, cmap='jet', interpolation='bilinear', alpha=0.3)
                # axs.add_patch(plt.Circle((x * scale, y * scale), fact // 2, color='r'))
                plt.plot([x * scale], [y * scale], marker='s', color='r')
                dst_dir = 'nl_vis'
                mmcv.mkdir_or_exist(dst_dir)
                img_id, ext = osp.splitext(img_meta['ori_filename'])
                filename = osp.join(dst_dir, f'{img_id}_{j}-{i}{ext}')
                print(f"saving {filename}")
                plt.savefig(filename)
            self.sliders[-1]._click_handlers.callbacks = []
            self.sliders[-1].on_click(on_button_clicked)
        
    def run(self):
      for s in self.sliders:
          s.observe(self.update_chart, 'value')
      self.update_chart(None)
      url, x, y, d, sattn_d, clicked = self.sliders
      res = widgets.VBox(
      [
          url,
          widgets.HBox([x, y]),
          widgets.HBox([d, sattn_d]),
          clicked,
          self.o
      ])
      return res

In [2]:
config_file = 'configs/nonlocal/mask_rcnn_r50_fpn_nl_c4.4_1x_coco.py'
checkpoint_file = 'checkpoints/mask_rcnn_r50_fpn_nl_c4.4_1x_coco-8cea600c.pth'


In [4]:
cfg = Config.fromfile(config_file)
# set cudnn_benchmark
if cfg.get('cudnn_benchmark', False):
    torch.backends.cudnn.benchmark = True
cfg.model.pretrained = None
if cfg.model.get('neck'):
    if cfg.model.neck.get('rfp_backbone'):
        if cfg.model.neck.rfp_backbone.get('pretrained'):
            cfg.model.neck.rfp_backbone.pretrained = None
cfg.data.test.test_mode = True

# init distributed env first, since logger depends on the dist info.
distributed = False

# build the dataloader
# TODO: support multiple images per gpu (only minor changes are needed)
dataset = build_dataset(cfg.data.test)
url2idx = {}
for idx, info in enumerate(dataset.data_infos):
    url2idx[info['coco_url']] = idx

# build the model and load checkpoint
model = build_detector(cfg.model, train_cfg=None, test_cfg=cfg.test_cfg)
model = MMDataParallel(model, device_ids=[0])
checkpoint = load_checkpoint(model, checkpoint_file, map_location='cpu')
w = AttentionVisualizer(model, dataset, url2idx)
w.run()

loading annotations into memory...
Done (t=0.53s)
creating index...
index created!
