In [80]:
import torch
import torch.nn as nn
from copy import deepcopy
from collections import OrderedDict

class XaiBase(nn.Module):
    def __init__(self):
        super(XaiBase, self).__init__()
        """
        need to define hook function at each method
        - f_hook
        - b_hook
        """
        self._reset_maps()
        self.handlers = list()
    
    def _reset_maps(self):
        self.maps = OrderedDict()
        
    def _save_maps(self, layer_name, x):
        self.maps[layer_name] = x    
    
    def _reset_handlers(self):
        for handle in self.handlers:
            handle.remove()
        self.handlers = []
                
    def _register(self, layers, registor_type="both"):
        """
        need to define hook functions to use
         - both: registor both forward and backward
         - f: registor forward
         - b: registor backward
        
        def f_hook(self, *x):
            '''
            m: module name
            i: forward input
            o: forward output
            '''
            m, i, o = x
        
        def b_hook(self, *x):
            '''
            m: module name
            i: gradient input
            o: gradient output
            '''
            m, i, o = x
        """
        for layer in layers:
            if registor_type == "both":
                handle1 = layer.register_forward_hook(self.f_hook)
                handle2 = layer.register_backward_hook(self.b_hook)
                self.handlers.append(handle1)
                self.handlers.append(handle2)
            elif registor_type == "f":
                handle1 = layer.register_forward_hook(self.f_hook)
                self.handlers.append(handle1)
            elif registor_type == "b":
                handle2 = layer.register_backward_hook(self.b_hook)
                self.handlers.append(handle2)

    def _return_indices(self, layers, on=True):
        """
        support for cnn layer which have `nn.MaxPool2d`,
        you can turn on/off pooling indices.
        please define a forward function to use it in your model
        '''
        # in your model
        def forward_switch(self, x):
            switches = OrderedDict()
            self.return_indices(on=True)
            for idx, layer in enumerate(self.convs):
                if isinstance(layer, nn.MaxPool2d):
                    x, indices = layer(x)
                    switches[idx] = indices
                else:
                    x = layer(x)
            self.return_indices(on=False)
            return x, switches
        '''
        """
        if on:
            for layer in layers:
                if isinstance(layer, nn.MaxPool2d):
                    layer.return_indices = True
        else:
            for layer in layers:
                if isinstance(layer, nn.MaxPool2d):
                    layer.return_indices = False  
                    
                    
class XaiModel(XaiBase):
    def __init__(self, model):
        super(XaiModel, self).__init__()
        self.model = deepcopy(model)

In [64]:
class ChannelAttention(nn.Module):
    """Channel Attention Module"""
    def __init__(self, C, H, W, ratio):
        """
        Method in [arXiv:1807.06521]
        args:
         - C: channel of input features
         - H: height of input features
         - W: width of input features
         - hid_size: hidden size of shallow network
         - ratio: reduction ratio
        """
        super(ChannelAttention, self).__init__()
        assert isinstance(2*C // ratio, int), "`2*C // ratio` must be int "
        kernel_size = (H, W)
        self.maxpool = nn.MaxPool2d(kernel_size)
        self.avgpool = nn.AvgPool2d(kernel_size)
        self.shallow_net = nn.Sequential(
            nn.Linear(2*C, 2*C // ratio),
            nn.ReLU(),
            nn.Linear(2*C // ratio, 2*C),
        )
    
    def forward(self, x):
        # (B, C, H, W) > (B, 2*C, 1, 1)
        x = torch.cat([self.maxpool(x), self.avgpool(x)], dim=1)
        # (B, 2*C) > (B, 2*C//2) > (B, 2*C)
        x = self.shallow_net(x.squeeze(-1).squeeze(-1))
        # (B, C), (B, C)
        x_max, x_avg = torch.chunk(x, 2, dim=1)
        # not using softmax in paper: something like gate function
        x = torch.sigmoid(x_max + x_avg)
        return x.unsqueeze(-1).unsqueeze(-1)

In [65]:
class SpatialAttention(nn.Module):
    """Spatial Attention Module"""
    def __init__(self, H, W, K_H=7, K_W=7, S_H=1, S_W=1):
        """
        Method in [arXiv:1807.06521]
        args:
         - H: height of input features
         - W: width of input features
         - K_H: height of kernel size
         - K_W: width of kernel size
         - S_H: stride height of conv layer
         - S_W: stride width of conv layer
        """
        super(SpatialAttention, self).__init__()
        P_H = self.cal_padding_size(H, K_H, S_H)
        P_W = self.cal_padding_size(W, K_W, S_W)
        kernel_size = (K_H, K_W)
        stride = (S_H, S_W)
        padding = (P_H, P_W)
        # same padding conv layer
        self.conv_layer = nn.Conv2d(2, 1, kernel_size, stride, padding)
    
    def cal_padding_size(self, x, K, S):
        return int((S * (x-1) + K - x) / 2)
    
    def forward(self, x):
        # (B, C, H, W) > (B, 1, H, W)
        x_max, _ = torch.max(x, dim=1, keepdim=True)
        x_avg = torch.mean(x, dim=1, keepdim=True)
        # (B, 2, H, W)
        x = torch.cat([x_max, x_avg], dim=1)
        # (B, 2, H, W) > (B, 1, H, W)
        x = self.conv_layer(x)
        # return gated features
        return torch.sigmoid(x)

In [66]:
class CBAM(nn.Module):
    """Convolution Block Attention Module"""
    def __init__(self, C, H, W, ratio, K_H=7, K_W=7, S_H=1, S_W=1):
        """
        Method in [arXiv:1807.06521]
        args:
         - C: channel of input features
         - H: height of input features
         - W: width of input features
         - ratio: reduction ratio
         - K_H: height of kernel size
         - K_W: width of kernel size
         - S_H: stride height of conv layer
         - S_W: stride width of conv layer
         
        return:
         - attentioned features, size = (B, C, H, W)
        """
        super(CBAM, self).__init__()
        self.channel_attn = ChannelAttention(C, H, W, ratio)
        self.spatial_attn = SpatialAttention(H, W, K_H, K_W, S_H, S_W)
        
    def forward(self, x, return_attn=False):
        """
        return: attentioned features, size = (B, C, H, W)
        """
        out = x
        c_attn = self.channel_attn(out)
        out = c_attn * out
        s_attn = self.spatial_attn(out)
        out = s_attn * out
        if return_attn:
            return out, (c_attn, s_attn)
        return out

In [75]:
import torch
import torch.nn as nn
from collections import OrderedDict

class CnnWithCBAM(XaiBase):
    def __init__(self, activation_type):
        """
        activation_type: "relu", "tanh", "sigmoid", "softplus"
        """
        super(CnnWithCBAM, self).__init__()
        act = {"relu": nn.ReLU, 
               "tanh": nn.Tanh, 
               "sigmoid": nn.Sigmoid, 
               "softplus": nn.Softplus}
        self.activation_func = act[activation_type]
        
        self.convs = nn.Sequential(
            nn.Conv2d(1, 32, 5),  # (B, 1, 32, 32) > (B, 32, 28, 28)
            CBAM(32, 28, 28, 16),
            self.activation_func(),
            nn.MaxPool2d(2),  # (B, 32, 28, 28) > (B, 32, 14, 14)
            nn.Conv2d(32, 64, 3),  # (B, 32, 14, 14) > (B, 64, 12, 12)
            CBAM(64, 12, 12, 16),
            self.activation_func(), 
            nn.MaxPool2d(2),  # (B, 64, 12, 12) > (B, 64, 6, 6)
        )
        self.fc = nn.Sequential(
            nn.Linear(64*6*6, 128),
            self.activation_func(),
            nn.Linear(128, 10)
        )
        
    def forward(self, x):        
        x = self.convs(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x
    
    def forward_map(self, x):
        self._reset_maps()
        for i, layer in enumerate(self.convs):
            layer_name = type(layer).__name__
            if layer_name == "CBAM":
                x, attns = layer(x, return_attn=True)
                self._save_maps(f"{i}"+layer_name, attns)
            else:
                x = layer(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

In [76]:
model = CnnWithCBAM("relu")

In [81]:
class GradCAM(XaiModel):
    def __init__(self, model):
        super(GradCAM, self).__init__(model)
    
    def f_hook(self, x):
        m, i, o = x
        return m
    
    def b_hook(self, x):
        m, i, o = x
        return m
    
    def get_attribution(self, x):
        for layer in self.model.convs:
            break

In [82]:
gradcam = GradCAM(model)

In [84]:
gradcam._register(gradcam.model.convs)

In [86]:
x = gradcam.model.convs(torch.rand(1, 1, 28, 28))

TypeError: f_hook() takes 2 positional arguments but 4 were given