In [1]:
import math

from PIL import Image
import requests
import matplotlib.pyplot as plt

import ipywidgets as widgets
from IPython.display import display, clear_output

import torch
from torch import nn
from torchvision.models import resnet50
import torchvision.transforms as T

In [2]:
model = torch.hub.load('facebookresearch/detr', 'detr_resnet50', pretrained=True)

Using cache found in C:\Users\JungSungYeon/.cache\torch\hub\facebookresearch_detr_main


In [3]:
model.backbone[-2]

Backbone(
  (body): IntermediateLayerGetter(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): FrozenBatchNorm2d()
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): FrozenBatchNorm2d()
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): FrozenBatchNorm2d()
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): FrozenBatchNorm2d()
        (relu): ReLU(inplace=True)
        (downsample): Sequential(
          (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): FrozenBatchNorm2d()
        )
      )
      (1): Bottleneck(
        (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
  

In [4]:
class attention_detr(nn.Module):
    def __init__(self, model, batch_size):
        super().__init__()
        self.model = model
        self.batch_size = batch_size
    
    def attention_score(self, x_list, batch_size):
        conv_features, enc_attn_weights, dec_attn_weights = [], [], []

        hooks = [model.backbone[-2].register_forward_hook(lambda self, input, output: conv_features.append(output)),
                model.transformer.encoder.layers[-1].self_attn.register_forward_hook(lambda self, input, output: enc_attn_weights.append(output[1])),
                model.transformer.decoder.layers[-1].multihead_attn.register_forward_hook(lambda self, input, output: dec_attn_weights.append(output[1]))]
        
        for i in range(batch_size):
            self.model(x_list[i])

        for hook in hooks:
            hook.remove()

        first_weight = enc_attn_weights.pop(0)
        first_weight = first_weight.unsqueeze(0)
        for weights in enc_attn_weights:
            weights = weights.unsqueeze(0)
            first_weight = torch.cat((first_weight, weights))
        
        enc_attn_weights = first_weight

        return conv_features, enc_attn_weights, dec_attn_weights
    
    def forward(self, x):
        return self.attention_score(x, self.batch_size)


In [5]:
attention = attention_detr(model, 4)

In [30]:
x = torch.randn((4,1,3,1024,1024))
output = attention(x)

In [31]:
output[1].shape

torch.Size([4, 1, 1024, 1024])

In [11]:
x = torch.einsum('bfchw -> bcfhw', x)
x.shape

torch.Size([4, 3, 16, 400, 400])