In [None]:
import os
import sys
sys.path.append('../../../src')
from PIL import Image
import numpy as np
import cv2
import matplotlib.pyplot as plt

import tortto as tt
import tortto.nn as nn
import tortto.nn.functional as F

from torchvision.transforms import Compose, ToTensor, Normalize, Resize

In [None]:
dataset_path = './cifar10_data'
img_mean = (0.485, 0.456, 0.406)
img_sd = (0.229, 0.224, 0.225)

transform = Compose([
    Resize((224,224)),
    ToTensor(),
    Normalize(img_mean, img_sd),
])

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

## define a network

In [None]:
def conv3x3(in_channels, channels, stride=1):
    return nn.Conv2d(in_channels, channels, kernel_size=3, stride=stride, padding=1, bias=False)


def conv1x1(in_channels, channels, stride=1):
    return nn.Conv2d(in_channels, channels, kernel_size=1, stride=stride, bias=False)


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_channels, channels, stride=1, downsample=None):
        """
        in_channels: number of incoming channels.
        channels: number of channels in the first layer of this block
        """
        super(Bottleneck, self).__init__()
        self.conv1 = conv1x1(in_channels, channels)
        self.bn1 = nn.BatchNorm2d(channels)
        self.conv2 = conv3x3(channels, channels, stride)
        self.bn2 = nn.BatchNorm2d(channels)
        self.conv3 = conv1x1(channels, channels * self.expansion)
        self.bn3 = nn.BatchNorm2d(channels * self.expansion)

        self.downsample = downsample

    def forward(self, x):
        """
        x---conv1,bn1,relu-->conv2,bn2,relu-->conv3,bn3---relu-->
          |_____________ downsample_____________________|
        """
        shortcut = x if self.downsample is None else self.downsample(x)
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        return F.relu(out + shortcut)


class ResNet(nn.Module):
    def __init__(self, block, layers, channels):
        """
        layers: number of residual blocks in each layer
        channels: number of channels
        """
        super(ResNet, self).__init__()
        self.in_channels = channels[0]
        self.conv1 = nn.Conv2d(3, self.in_channels, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(self.in_channels)  # in_channels will increase accordingly after each _make_layer call
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        
        self.layer1 = self._make_layer(block, channels[0], layers[0])
        self.layer2 = self._make_layer(block, channels[1], layers[1], stride=2)
        self.layer3 = self._make_layer(block, channels[2], layers[2], stride=2)
        self.layer4 = self._make_layer(block, channels[3], layers[3], stride=2)


    def _make_layer(self, block, channels, blocks, stride=1):
        downsample = nn.Sequential()
        # for the first block of conv2_x, if bottleneck is used, there will be increase of channels
        # so we need downsample
        if stride != 1 or self.in_channels != channels * block.expansion:
            downsample = nn.Sequential(
                conv1x1(self.in_channels, channels * block.expansion, stride),
                nn.BatchNorm2d(channels * block.expansion)
            )

        # append the first block in the layer
        layers = [block(self.in_channels, channels, stride, downsample)]
        self.in_channels = channels * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.in_channels, channels)  # stride=1
            )
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(self.bn1(x))
        x = self.maxpool(x)
        
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        feature_maps = x.data.copy()
        x = tt.mean(x, (-1, -2), keepdims=True)
        x = tt.flatten(x, 1)
        x = self.classifier(x)
        self.cam = (self.classifier.weight.data[x.data.argmax()][None,...,None,None] * feature_maps).sum(-3)[0]
        return F.log_softmax(x, -1)

def resnet50():
    return ResNet(Bottleneck, [3, 4, 6, 3], [64, 128, 256, 512])

In [None]:
# https://github.com/jacobgil/vit-explain/blob/main/vit_explain.py
# he made a tiny mistake: heatmap from cv2 is in bgr format, so need to convert to rgb before adding to the img.
def show_mask_on_image(img, mask):
    img = np.float32(img) / 255
    heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
    heatmap = np.float32(heatmap) / 255
    cam = heatmap + np.float32(img)
    cam = cam / np.max(cam)
    return np.uint8(255 * cam)

In [None]:
net = resnet50()
net.classifier = nn.Linear(2048,10)
net = net.cuda()

checkpoint = tt.load('checkpoint_014.npy')
net.load_state_dict(checkpoint['model'])

# uncomment to plot class activation map (CAM)

In [None]:
# fns=os.listdir('images')
# fig,axs=plt.subplots(len(fns),4,figsize=(12,43))

# axs[0,0].set_title('Original',fontsize=15)
# axs[0,1].set_title('Class Activation Map',fontsize=15)
# axs[0,2].set_title('Overlap',fontsize=15)
# axs[0,3].set_title('Prediction',fontsize=15)
# i=0
# for fn in fns:
#     fn='images/'+fn
#     original=Image.open(fn).convert('RGB')
#     data=tt.tensor(transform(original)[None].numpy()).cuda()

#     net.eval()
#     with tt.no_grad():
#         outputs=net(data)
#         predicted=classes[outputs.argmax(-1).item()]
#     prob=outputs.exp().detach().cpu().numpy()
#     mask=net.cam.get()
#     mask=cv2.resize(mask, (224, 224))
#     mask-=mask.min()
#     mask/=mask.max()
#     data=data.cpu().numpy()[0]
#     for t, m, s in zip(data, img_mean, img_sd):
#         t*=s
#         t+=m
#     img=show_mask_on_image(255*(data.transpose(1,2,0)), mask)

#     axs[i,0].imshow(original)
#     axs[i,0].axis('off')
    
#     axs[i,1].imshow(mask, cmap='jet')
#     axs[i,1].axis('off')

#     axs[i,2].imshow(img, cmap='jet')
#     axs[i,2].axis('off')

#     axs[i,3].bar(classes, prob[0])
#     axs[i,3].tick_params('x', labelrotation=45)
#     axs[i,3].set_ylim([0, 1.1])
#     axs[i,3].set_aspect(10)
#     i+=1
# fig.tight_layout()
# plt.show()