In [3]:
import torch
import torch.nn as nn
from torchvision import models
from torch.hub import load_state_dict_from_url

from PIL import Image
import cv2
import numpy as np
from matplotlib import pyplot as plt

from torchvision import transforms
from torchsummary import summary

In [4]:
class FullyConvolutionalResnet18(models.ResNet):
    def __init__(self, num_classes=1000, pretrained=False, **kwargs):

        # Start with standard resnet18 defined here 
        super().__init__(block = models.resnet.BasicBlock, layers = [2, 2, 2, 2], num_classes = num_classes, **kwargs)
        if pretrained:
            state_dict = load_state_dict_from_url( models.resnet.model_urls["resnet18"], progress=True)
            self.load_state_dict(state_dict)

        # Replace AdaptiveAvgPool2d with standard AvgPool2d 
        self.avgpool = nn.AvgPool2d((7, 7))

        # Convert the original fc layer to a convolutional layer.  
        self.last_conv = torch.nn.Conv2d( in_channels = self.fc.in_features, out_channels = num_classes, kernel_size = 1)
        self.last_conv.weight.data.copy_( self.fc.weight.data.view ( *self.fc.weight.data.shape, 1, 1))
        self.last_conv.bias.data.copy_ (self.fc.bias.data)

    # Reimplementing forward pass. 
    def _forward_impl(self, x):
        # Standard forward for resnet18
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.avgpool(x)

        # Notice, there is no forward pass 
        # through the original fully connected layer. 
        # Instead, we forward pass through the last conv layer
        x = self.last_conv(x)
        return x

In [6]:
models.ResNet

torchvision.models.resnet.ResNet

In [7]:
with open('imagenet_classes.txt') as f:
    labels = [line.strip() for line in f.readlines()]

In [9]:
labels[:10]

['tench, Tinca tinca',
 'goldfish, Carassius auratus',
 'great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias',
 'tiger shark, Galeocerdo cuvieri',
 'hammerhead, hammerhead shark',
 'electric ray, crampfish, numbfish, torpedo',
 'stingray',
 'cock',
 'hen',
 'ostrich, Struthio camelus']

In [20]:
# read image
original_image = cv2.imread('/data/file/img/camel.jpg')

In [21]:
original_image.shape

(725, 1920, 3)

In [22]:
# convert original image to RGB format
image = cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB)

In [23]:
image.shape

(725, 1920, 3)

In [24]:
# transform input image
# convert to Tensor
# subtract mean
# divide by standard deviation
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.225],
        std=[0.229, 0.224, 0.225]
    )
])
image = transform(image)
image = image.unsqueeze(0)

In [25]:
model = FullyConvolutionalResnet18(pretrained=True).eval()

Downloading: "https://download.pytorch.org/models/resnet18-5c106cde.pth" to /home/roczhang/.cache/torch/hub/checkpoints/resnet18-5c106cde.pth
100%|██████████| 44.7M/44.7M [00:30<00:00, 1.54MB/s]


In [27]:
with torch.no_grad():
    preds = model(image)
    preds = torch.softmax(preds, dim=1)

    print('Response map shape:', preds.shape)

    pred, class_idx = torch.max(preds, dim=1)
    print(class_idx)

    row_max, row_idx = torch.max(pred, dim=1)
    col_max, col_idx = torch.max(row_max, dim=1)
    predicted_class = class_idx[0, row_idx[0, col_idx], col_idx]

    print('Predicted Class:', labels[predicted_class], predicted_class)

Response map shape: torch.Size([1, 1000, 3, 8])
tensor([[[978, 557, 557, 557, 557, 557, 354, 682],
         [978, 970, 980, 977, 858, 970, 354, 461],
         [141, 143, 977, 977, 977, 977, 354, 354]]])
Predicted Class: Arabian camel, dromedary, Camelus dromedarius tensor([354])


In [28]:
score_map = preds[0, predicted_class, :, :].numpy()

In [31]:
score_map = score_map[0]

In [32]:
score_map = cv2.resize(score_map, (original_image.shape[1], original_image.shape[0]))

In [34]:
score_map.shape

(725, 1920)

In [35]:
_, score_map_for_contours = cv2.threshold(score_map, 0.25, 1, type=cv2.THRESH_BINARY)
score_map_for_contours = score_map_for_contours.astype(np.uint8).copy()

In [37]:
contours, _ = cv2.findContours(score_map_for_contours, mode=cv2.RETR_EXTERNAL, method=cv2.CHAIN_APPROX_SIMPLE)

In [38]:
rect = cv2.boundingRect(contours[0])

In [39]:
rect

(1384, 186, 536, 539)

In [40]:
score_map = score_map - np.min(score_map[:])
score_map = score_map / np.max(score_map[:])

In [41]:
score_map = cv2.cvtColor(score_map, cv2.COLOR_GRAY2BGR)
masked_image = (original_image * score_map).astype(np.uint8)
cv2.rectangle(masked_image, rect[:2], (rect[0] + rect[2], rect[1] + rect[3]), (0, 0, 255), 2)

array([[[  0,   0,   0],
        [  0,   0,   0],
        [  0,   0,   0],
        ...,
        [  2,   2,   1],
        [  2,   2,   1],
        [  2,   2,   1]],

       [[  0,   0,   0],
        [  0,   0,   0],
        [  0,   0,   0],
        ...,
        [  2,   2,   1],
        [  2,   2,   1],
        [  2,   2,   1]],

       [[  0,   0,   0],
        [  0,   0,   0],
        [  0,   0,   0],
        ...,
        [  2,   2,   1],
        [  2,   2,   1],
        [  2,   2,   1]],

       ...,

       [[  0,   0,   0],
        [  0,   0,   0],
        [  0,   0,   0],
        ...,
        [ 46,  61,  79],
        [ 68,  83, 102],
        [  0,   0, 255]],

       [[  0,   0,   0],
        [  0,   0,   0],
        [  0,   0,   0],
        ...,
        [ 55,  70,  88],
        [ 31,  46,  65],
        [  0,   0, 255]],

       [[  0,   0,   0],
        [  0,   0,   0],
        [  0,   0,   0],
        ...,
        [  0,   0, 255],
        [  0,   0, 255],
        [  0,   0, 255]]

In [42]:
cv2.imshow("Original Image", original_image)
cv2.imshow("scaled_score_map", score_map)
cv2.imshow("activations_and_bbox", masked_image)
cv2.waitKey(0)
cv2.destroyAllWindows()
