In [21]:
import torch
import torch.nn as nn
from copy import deepcopy
from collections import OrderedDict

class XaiBase(nn.Module):
    def __init__(self, model):
        super(XaiBase, self).__init__()
        """
        need to define hook function at each method
        - f_hook
        - b_hook
        """
        self.model = deepcopy(model)
        self._reset_maps()
        self.handlers = list()
        
        
    def _reset_maps(self):
        self.maps = OrderedDict()
        
    def _reset_handlers(self):
        for handle in self.handlers:
            handle.remove()
        self.handlers = []
                
    def _register(self):
        """
        need to define hook functions to use
        
        def f_hook(self, *x):
            '''
            m: module name
            i: forward input
            o: forward output
            '''
            m, i, o = x
        
        def b_hook(self, *x):
            '''
            m: module name
            i: gradient input
            o: gradient output
            '''
            m, i, o = x
        """
        for layer in self.model.layers:
            handle1 = layer.register_forward_hook(self.f_hook)
            handle2 = layer.register_backward_hook(self.b_hook)
            self.handlers.append(handle1)
            self.handlers.append(handle2)    
                
    def save_maps(self, layer_name, x):
        self.maps[layer_name] = x

In [12]:
class Test(XaiBase):
    def __init__(self, model):
        super(Test, self).__init__(model)

In [19]:
class DeconvNet(XaiBase):
    def __init__(self, model):
        super(DeconvNet, self).__init__(model)
        
        self.deconvs = 
        
        
    def deconv_make_layers(self):
        self.deconv_module_len = 0
        layers = []
        conv_end = [i for i, l in enumerate(self.model.convs) if str(l) == "Reshape()"][0]
        # {0: 999, 3: 0}
        conv_bias_pos = {}
        conv_locs = [i for i, l in enumerate(model.layers[:conv_end]) if isinstance(l, nn.Conv2d)]
        for idx, i in enumerate(conv_locs):
            if idx == 0:
                conv_bias_pos[i] = 999
            else:
                conv_bias_pos[i] = conv_locs[idx-1]

        for idx, layer in enumerate(model.layers[:conv_end]):
            if isinstance(layer, nn.Conv2d):
                temp_layer = nn.ConvTranspose2d(layer.out_channels,
                                                layer.in_channels,
                                                layer.kernel_size, 
                                                layer.stride, 
                                                layer.padding,
                                                layer.output_padding,
                                                layer.groups, 
                                                False,  # bias
                                                layer.dilation,
                                                layer.padding_mode)
                temp_layer.weight.data = layer.weight.data
                if conv_bias_pos[idx] < 999:
                    temp_layer.bias = model.layers[:conv_end][conv_bias_pos[idx]].bias
                layers.append(temp_layer)
            elif isinstance(layer, nn.MaxPool2d):
                temp_layer = nn.MaxUnpool2d(layer.kernel_size,
                                            layer.stride,
                                            layer.padding)
                layers.append(temp_layer)
                self.deconv_module_len += 1
            else:
                layers.append(layer)
        layers = nn.Sequential(*reversed(layers))

SyntaxError: invalid syntax (<ipython-input-19-383e3b6bdfbd>, line 1)

In [None]:
class DeconvNet(nn.Module):
    """DeconvNet"""
    def __init__(self, model):
        super(DeconvNet, self).__init__()
        # deconv
        self.activation_func = model.activation_func
        self.model_type = model.model_type
        self.activation_type = model.activation_type
        
        self.layers = self.deconv_make_layers(model)
        
        self.activation_maps = OrderedDict()
        
    def deconv_make_layers(self, model):
        self.deconv_module_len = 0
        layers = []
        conv_end = [i for i, l in enumerate(model.layers) if str(l) == "Reshape()"][0]
        # {0: 999, 3: 0}
        conv_bias_pos = {}
        conv_locs = [i for i, l in enumerate(model.layers[:conv_end]) if isinstance(l, nn.Conv2d)]
        for idx, i in enumerate(conv_locs):
            if idx == 0:
                conv_bias_pos[i] = 999
            else:
                conv_bias_pos[i] = conv_locs[idx-1]

        for idx, layer in enumerate(model.layers[:conv_end]):
            if isinstance(layer, nn.Conv2d):
                temp_layer = nn.ConvTranspose2d(layer.out_channels,
                                                layer.in_channels,
                                                layer.kernel_size, 
                                                layer.stride, 
                                                layer.padding,
                                                layer.output_padding,
                                                layer.groups, 
                                                False,  # bias
                                                layer.dilation,
                                                layer.padding_mode)
                temp_layer.weight.data = layer.weight.data
                if conv_bias_pos[idx] < 999:
                    temp_layer.bias = model.layers[:conv_end][conv_bias_pos[idx]].bias
                layers.append(temp_layer)
            elif isinstance(layer, nn.MaxPool2d):
                temp_layer = nn.MaxUnpool2d(layer.kernel_size,
                                            layer.stride,
                                            layer.padding)
                layers.append(temp_layer)
                self.deconv_module_len += 1
            else:
                layers.append(layer)
        layers = nn.Sequential(*reversed(layers))
        
        deconv_locs = [i for i, l in enumerate(layers) if isinstance(l, nn.ConvTranspose2d)]
        self.conv_end = conv_end
        # {2: 2, 1: 5}
        self.deconv_locs = {(len(deconv_locs) - j):i for j, i in enumerate(deconv_locs)}
        return layers
    
    def save_activation_maps(self, layer, typ, idx, x):
        if isinstance(layer, typ):
            layer_name = f"({idx}) {str(layer).split('(')[0]}(in:{layer.in_channels}, out:{layer.out_channels})"
            self.activation_maps[layer_name] = x
    
    def deconv(self, x, switches, deconv_layer_num=None, store=False):
        """
        deconv_layer_num: 
            deconv from which module(m =  "MaxPool > activation > Conv2d") 
            numbering from the original cnn conv module(n = "Conv2d > activation > MaxPool")
            ex) deconv_layers = [m1, m2, m3, m4, m5]
                if deconv_layer_num = 4, will goes from m4 to m5
                * cnn_layer = [n5, n4, n3, n2, n1] (n.T = m, n5 is the first layer of cnn)
                
        x: should match module input size
        switches: from MNISTmodel forward method "forward_switches"
        store: if True, save activation maps
        """
        assert (deconv_layer_num <= self.deconv_module_len) or (deconv_layer_num==None), \
            "`deconv_layer_num` should <= `self.deconv_module_len` or == None"
        if deconv_layer_num == None: deconv_layer_num = 1
        deconvfrom = self.deconv_locs[deconv_layer_num]
        deconvlayers = self.layers[-(deconvfrom+1):]
        unpool_locs = {idx:(len(deconvlayers)-1 - idx) for idx, l in enumerate(deconvlayers) if isinstance(l, nn.MaxUnpool2d)}
        
        for idx, layer in enumerate(deconvlayers):
            if isinstance(layer, nn.MaxUnpool2d):
                x = layer(x, switches[unpool_locs[idx]])
            else:
                x = layer(x)
                if store:
                    self.save_activation_maps(layer, nn.ConvTranspose2d, idx, x)
        return x

In [None]:
__author__ = "simonjisu"
# models/relavance

from pathlib import Path
import sys
sys.path.append(str(Path(__file__).absolute().parent.parent))
from reshape import Reshape

import torch
import torch.nn as nn
from collections import OrderedDict
from .layers import relLinear, relConv2d, relMaxPool2d, relReLU

class LRP(nn.Module):
    """LRP"""
    def __init__(self, model):
        super(LRP, self).__init__()
        # lrp
        self.activation_func = model.activation_func
        self.model_type = model.model_type
        self.activation_type = model.activation_type
        
        self.layers = self.lrp_make_layers(model)
        
        self.activation_maps = OrderedDict()
        
    def lrp_make_layers(self, model):
        layers = []
        mapping_dict = {nn.Linear: relLinear, nn.Conv2d: relConv2d, nn.MaxPool2d: relMaxPool2d, 
                        nn.ReLU: relReLU}
        for layer in model.layers:
            if isinstance(layer, Reshape):
                layers.append(layer)
            else:
                layers.append(mapping_dict[layer.__class__](layer))
                
        return nn.Sequential(*layers)
    
    def forward(self, x):
        """
        lrp method
        must run forward first to save input and output at each layer
        """
        for layer in self.layers:
            x = layer(x)
        return x
    
    def save_activation_maps(self, layer, typ, idx, x):
        if isinstance(layer, typ):
            layer_name = f"({idx}) {str(layer).split('(')[0]}"
            self.activation_maps[layer_name] = x
    
    def get_attribution(self, x, target=None, store=False, use_rho=False):
        """
        store: if True, save activation maps
        """
        o = self.forward(x).detach()
        r = o * torch.zeros_like(o).scatter(1, o.argmax(1, keepdim=True), 1)
        for idx, layer in enumerate(self.layers[::-1]):
            r = layer.relprop(r, use_rho)
            if store:
                self.save_activation_maps(layer, relConv2d, idx, r)
        return r.detach()

In [None]:
import torch
import torch.nn as nn


class VanillaGrad(nn.Module):
    def __init__(self, model):
        super(VanillaGrad, self).__init__()
        
        # vanilla saliency
        self.activation_func = model.activation_func
        self.model_type = model.model_type
        self.activation_type = model.activation_type

        self.model = model.cpu()
        self.model.eval()
        
    def get_attribution(self, x, target):
        """vanilla gradient"""
        x.requires_grad_(requires_grad=True)
        self.model.zero_grad()
        o = self.model(x)
        grad_outputs = torch.zeros_like(o).scatter(1, target.unsqueeze(1), 1).detach()
        o.backward(gradient=grad_outputs)
        x.requires_grad_(requires_grad=False)
        
        return x.grad.clone()


class GradInput(nn.Module):
    def __init__(self, model):
        super(GradInput, self).__init__()
        
        # vanilla saliency
        self.activation_func = model.activation_func
        self.model_type = model.model_type
        self.activation_type = model.activation_type

        self.model = model.cpu()
        self.model.eval()
        
    def get_attribution(self, x, target):
        """vanilla gradient*input"""
        x.requires_grad_(requires_grad=True)
        self.model.zero_grad()
        o = self.model(x)
        grad_outputs = torch.zeros_like(o).scatter(1, target.unsqueeze(1), 1).detach()
        o.backward(gradient=grad_outputs)
        x.requires_grad_(requires_grad=False)
        
        return x.grad.clone() * x