## Model Definition

In [1]:
import torch
import torchvision
import torch.nn as nn


class TransformerResnet18(nn.Module):
    def __init__(self, num_classes, pretrained=False):
        super(TransformerResnet18, self).__init__()
        resnet18_model = torchvision.models.resnet18(pretrained=pretrained)
        self.resnet18_conv = nn.Sequential(*list(resnet18_model.children())[:-2])
        self.self_attention = nn.TransformerEncoderLayer(d_model=resnet18_model.fc.in_features, nhead=8)
        self.avg_pool = nn.AdaptiveAvgPool2d(output_size=1)
        self.linear = nn.Linear(in_features=resnet18_model.fc.in_features, out_features=num_classes, bias=True)
        self.linear.bias.data.fill_(0)

    def forward(self, x: torch.Tensor):
        x = self.resnet18_conv(x)
        B, C, H, W = x.shape
        x = x.reshape(B, C, H * W).permute(2, 0, 1).contiguous()
        x = self.self_attention(x)
        x = x.permute(1, 2, 0).contiguous().reshape(B, C, H, W)
        x = self.avg_pool(x)
        x = torch.flatten(x, start_dim=1)
        x = self.linear(x)
        return x


class Resnet18(nn.Module):
    def __init__(self, num_classes, pretrained=False):
        super(Resnet18, self).__init__()
        resnet18_model = torchvision.models.resnet18(pretrained=pretrained)
        self.resnet18_conv = nn.Sequential(*list(resnet18_model.children())[:-2])
        self.avg_pool = nn.AdaptiveAvgPool2d(output_size=1)
        self.linear = nn.Linear(in_features=resnet18_model.fc.in_features, out_features=num_classes, bias=True)
        self.linear.bias.data.fill_(0)

    def forward(self, x: torch.Tensor):
        x = self.resnet18_conv(x)
        x = self.avg_pool(x)
        x = torch.flatten(x, start_dim=1)
        x = self.linear(x)
        return x


## Load Checkpoint

In [2]:
def create_instance(config):
    module = config['module']
    class_ = config['class']
    config_kwargs = config.get(class_, {})
    for key, value in config_kwargs.items():
        if isinstance(value, str):
            config_kwargs[key] = eval(value)
    if module:
        return getattr(import_module(module), class_)(**config_kwargs)
    else:
        return eval(class_)(**config_kwargs)

In [124]:
from copy import deepcopy

resnet18_config = {
    'module': None,
    'class': 'CAM',
    'CAM': {
        'arch_config': {'module': None, 'class': 'Resnet18', 'Resnet18': {'num_classes': 2,}},
        'classes': ['hole', 'none'],
        'image_size': (224, 224),
        'weight_path': "'weights/hole_detection/resnet18/2103191429/best_model_94_loss=-0.1392.pt'",
        'device': "'cpu'",
    }
}

trasformer_resnet18_config = deepcopy(resnet18_config)
trasformer_resnet18_config['CAM']['arch_config'] = {'module': None, 'class': 'TransformerResnet18', 'TransformerResnet18': {'num_classes': 2}}
trasformer_resnet18_config['CAM']['weight_path'] = "'weights/hole_detection/transformer_resnet18/2103191437/best_model_83_loss=-0.1259.pt'"

In [125]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F


class CAM(nn.Module):
    def __init__(self, arch_config, classes, weight_path, image_size, device):
        super(CAM, self).__init__()
        self.device = device
        self.model = create_instance(arch_config)
        self.model.load_state_dict(torch.load(weight_path, map_location='cpu'))
        self.model.to(self.device)
        self.model.eval()
        self.classes = classes
        self.image_size = image_size

        self.activation = None

    def _register_forward_hook(self, target_layer):
        assert target_layer in list(self.model._modules.keys()), 'target_layer must be in list of modules of model'

        def hook_feature(module, input, output):
            self.activation = output.data

        def hook_attention_feature(module, input, output):
            HW, B, C = output.shape
            H, W = int(math.sqrt(HW)), int(math.sqrt(HW))
            self.activation = output.reshape(H, H, B, C).permute(2, 3, 0, 1).contiguous()

        if target_layer == 'self_attention':
            self.model._modules.get(target_layer).register_forward_hook(hook_attention_feature)
        else:
            self.model._modules.get(target_layer).register_forward_hook(hook_feature)

    def _get_softmax_weights(self, linear_layer=None):
        if linear_layer:
            assert linear_layer in list(self.model._modules.keys()), 'classifier_layer must be in list of modules of model'
            softmax_weights = dict(self.model.named_parameters())[linear_layer].data
        else:
            softmax_weights = list(self.model.parameters())[-2].data
        return softmax_weights

    def preprocess(self, image):
        sample = cv2.resize(image, dsize=self.image_size)
        sample = torch.from_numpy(sample).to(self.device).to(torch.float)
        sample = sample.unsqueeze(dim=0).permute(0, 3, 1, 2)
        sample = (sample - sample.mean()) / sample.std()
        return sample

    def process(self, sample):
        with torch.no_grad():
            preds = self.model(sample)
        return preds

    def postprocess(self, preds):
        pred = preds.softmax(dim=1).squeeze(dim=0)
        class_name = self.classes[pred.argmax().item()]
        class_score = pred[pred.argmax()].item()
        return class_name, class_score

    def class_activation_map(self, sample, map_size, target_layer, linear_layer=None):
        self._register_forward_hook(target_layer=target_layer)
        with torch.no_grad():
            preds = self.model(sample)

        class_idx = preds.softmax(dim=1).squeeze(dim=0).argmax().item()
        softmax_weights = self._get_softmax_weights(linear_layer=linear_layer)
        softmax_weight = softmax_weights[class_idx, :]
        activation_map = self.activation
        
        _, C, H, W = activation_map.shape
        activation_map = activation_map.reshape(C, H * W)
        saliency_map = torch.matmul(softmax_weight, activation_map)
        saliency_map = saliency_map.reshape(H, W)
        saliency_map = F.relu(saliency_map)

        saliency_map_min, saliency_map_max = saliency_map.min(), saliency_map.max()
        saliency_map = (saliency_map - saliency_map_min).div(saliency_map_max - saliency_map_min).data

        saliency_map = (saliency_map * 255).to(torch.uint8).cpu().detach().numpy()
        saliency_map = cv2.resize(saliency_map, dsize=map_size)
        saliency_map = cv2.applyColorMap(saliency_map, cv2.COLORMAP_JET)

        return saliency_map

    def forward(self, image, target_layer='resnet18_conv', linear_layer=None):
        sample = self.preprocess(image)
        preds = self.process(sample)
        class_name, class_score = self.postprocess(preds)
        heatmap = self.class_activation_map(sample, image.shape[1::-1], target_layer, linear_layer)
        heatmap = (heatmap * 0.5 + image * 0.5).astype(np.uint8)
        return class_name, class_score, heatmap

In [126]:
resnet18_detector = create_instance(resnet18_config)
transformer_resnet18_detector = create_instance(trasformer_resnet18_config)

In [102]:
image = cv2.imread('../../ID_CARD/hole_classification/dataset_update/test/hole/CMQD_A/2_001136846_GTTT-page2.jpg')

In [1]:
# print(resnet18_detector(image))
# print(transformer_resnet18_detector(image))

## Visualize Feature Map

In [127]:
import random
from pathlib import Path

image_paths = list(Path('../../ID_CARD/hole_classification/dataset_update/test/hole/').glob('**/*.*'))
random.shuffle(image_paths)
print(len(image_paths))

922


In [None]:
# image = cv2.imread('../../ID_CARD/hole_classification/dataset_update/test/hole/CMQD_A/2_08029910_GTTT-page0.jpg')
max_height = 500

for idx, image_path in enumerate(image_paths):
    if idx == 10:
        break

    print(image_path)
    print('**' * 20)
    image = cv2.imread(str(image_path))
    
    class_name, class_score, heat_map = resnet18_detector(image, target_layer='resnet18_conv')
    print(class_name, class_score)
    cv2.imshow('resnet18_conv', heat_map)
    cv2.waitKey()
    cv2.destroyAllWindows()

    class_name, class_score, heat_map = transformer_resnet18_detector(image, target_layer='resnet18_conv')   
    print(class_name, class_score)
    cv2.imshow('transformer_resnet18_conv', heat_map)
    cv2.waitKey()
    cv2.destroyAllWindows()

    class_name, class_score, heat_map = transformer_resnet18_detector(image, target_layer='self_attention')
    print(class_name, class_score)
    cv2.imshow('self_attention', heat_map)
    cv2.waitKey()
    cv2.destroyAllWindows()

../../ID_CARD/hole_classification/dataset_update/test/hole/CMQD_D/2_15040569_GTTT-page0.jpg
****************************************
resnet18_conv torch.Size([1, 512, 7, 7])
none 0.9639642834663391
resnet18_conv torch.Size([1, 512, 7, 7])
none 0.9714662432670593
self_attention torch.Size([1, 512, 7, 7])
none 0.9714662432670593
../../ID_CARD/hole_classification/dataset_update/test/hole/CMQD_A_BACK/2_41A123052972_GTTT-page1.jpg
****************************************
resnet18_conv torch.Size([1, 512, 7, 7])
hole 0.9997718930244446
resnet18_conv torch.Size([1, 512, 7, 7])
hole 0.9997168183326721
self_attention torch.Size([1, 512, 7, 7])
hole 0.9997168183326721


## Class Activation Map

In [11]:
import cv2

def preprocess(image, image_size=(224, 224), device='cpu'):
    sample = cv2.resize(image, dsize=image_size)
    sample = torch.from_numpy(sample).to(device).to(torch.float)
    sample = sample.unsqueeze(dim=0).permute(0, 3, 1, 2)
    sample = (sample - sample.mean()) / sample.std()

    return image, sample

In [9]:
weight_path = 'weights/hole_detection/transformer_resnet18/2103191437/best_model_83_loss=-0.1259.pt'
device = 'cpu'
num_classes = 2

model = TransformerResnet18(num_classes)
model.load_state_dict(torch.load(weight_path, map_location='cpu'))
model.to(device)
model = model.eval()

In [24]:
def get_feature_map(model, sample, last_layer_name='self_attention'):
    # without transformer
#     feature_map = model._modules[last_layer_name](sample)  # [1, 512, 7, 7]
    
    # with transformer
    feature_map = model._modules['resnet18_conv'](sample)
    B, C, H, W = feature_map.shape
    feature_map = feature_map.reshape(B, C, H * W).permute(2, 0, 1).contiguous()
    feature_map = model._modules[last_layer_name](feature_map)  # [49, 1, 512]
    feature_map = feature_map.permute(1, 2, 0).contiguous().reshape(B, C, H, W)  # [1, 512, 7, 7]

    return feature_map

In [25]:
image = cv2.imread('../../ID_CARD/hole_classification/dataset_update/test/hole/CMQD_A/2_05003827_GTTT-page0.jpg')

In [26]:
image, sample = preprocess(image)
print(sample.shape)

torch.Size([1, 3, 224, 224])


In [27]:
feature_map = get_feature_map(model, sample)
print(feature_map.shape)

torch.Size([1, 512, 7, 7])


In [141]:
# get weight of fully connected layers
fc_classes_weights = list(model.parameters())[-2]
print(fc_classes_weights.shape)  # [n_classes, feature_map_channels]

torch.Size([2, 512])


In [142]:
# class prediction
with torch.no_grad():
    preds = model(sample)

pred = preds.softmax(dim=1).squeeze(dim=0)
class_idx = pred.argmax().item()
class_name = classes[class_idx]
class_score = pred[class_idx].item()
print(class_name, class_score)

hole 0.9997307658195496


In [143]:
# get weight of predicted class
fc_class_weights = fc_classes_weights[class_idx]

In [144]:
B, C, H, W = feature_map.shape
feature_map = feature_map.reshape(C, H * W)
CAM = torch.matmul(fc_class_weights, feature_map)
CAM = CAM.reshape(H, W)
print(CAM.shape)

torch.Size([7, 7])


In [145]:
CAM = (CAM - CAM.min()) / (CAM.max() - CAM.min())
CAM = (CAM * 255).to(torch.uint8).cpu().detach().numpy()
print(CAM)

[[235 197 204 221 238 225 243]
 [240 175 195 205 230 210 233]
 [232 202 216 235 246 255 253]
 [170  77 121 171 241 236 248]
 [123  55  79 151 213 200 225]
 [ 81   0   8  83 175 171 216]
 [109  42  50 132 185 179 207]]


In [146]:
CAM = cv2.resize(CAM, dsize=(image.shape[1], image.shape[0]))
print(CAM.shape)

(492, 733)


In [147]:
CAM_heatmap = cv2.applyColorMap(CAM, cv2.COLORMAP_JET)

In [148]:
# cv2.imshow('heatmap', CAM_heatmap)
# cv2.waitKey()
# cv2.destroyAllWindows()

In [149]:
heatmap = (CAM_heatmap * 0.5 + image * 0.5).astype(np.uint8)
cv2.imshow('heatmap', heatmap)
cv2.waitKey()
cv2.destroyAllWindows()