In [1]:
from __future__ import print_function
import copy
import os.path as osp
import click
import cv2
import matplotlib.cm as cm
import numpy as np
import torch
import torch.nn.functional as F
from torchvision import models, transforms,datasets
import PIL
from torch.utils.data import Dataset, DataLoader
import random
import glob
from grad_cam import (
    BackPropagation,
    Deconvnet,
    GradCAM,
    GuidedBackPropagation,
)
import torch.nn as nn

In [2]:
#Dataset class for loading image and label
class CUBDataset(Dataset):
    def __init__(self, image_paths,labels,transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform
        self.load_image_from_paths()
        
    def load_image_from_paths(self):
        self.images = []
        for i in self.image_paths:
            img = PIL.Image.open(i)
            if len(img.getbands()) ==1 :
                img = img.convert("RGB")
            self.images.append(img.resize((224,224)))
            
    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]
        if self.transform:
            image = self.transform(image)
        return image, label

In [3]:
def get_device(cuda):
    cuda = cuda and torch.cuda.is_available()
    device = torch.device("cuda" if cuda else "cpu")
    if cuda:
        current_device = torch.cuda.current_device()
        print("Device:", torch.cuda.get_device_name(current_device))
    else:
        print("Device: CPU")
    return device

def load_images(image_paths):
    images = []
    raw_images = []
    print("Images:")
    for i, image_path in enumerate(image_paths):
        print("\t#{}: {}".format(i, image_path))
        image, raw_image = preprocess(image_path)
        images.append(image)
        raw_images.append(raw_image)
    return images, raw_images

In [4]:
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
test_transform = transforms.Compose([transforms.ToTensor(),
                                 transforms.Normalize(mean, std)])
def preprocess(image_path):
    raw_image = PIL.Image.open(image_path).resize((224,224))
    if len(raw_image.getbands()) ==1 :
        raw_image = raw_image.convert("RGB")       
    image = test_transform(raw_image.copy())
    return image, raw_image

def save_gradient(filename, gradient):
    gradient = gradient.cpu().numpy().transpose(1, 2, 0)
    gradient -= gradient.min()
    gradient /= gradient.max()
    gradient *= 255.0
    cv2.imwrite(filename, np.uint8(gradient))

In [5]:
def save_gradcam(filename, gcam, raw_image, paper_cmap=False):
    gcam = gcam.cpu().numpy()
    cmap = cm.jet_r(gcam)[..., :3] * 255.0
    if paper_cmap:
        alpha = gcam[..., None]
        gcam = alpha * cmap + (1 - alpha) * raw_image
    else:
        gcam = (cmap.astype(np.float64) + raw_image.astype(np.float64)) / 2
    cv2.imwrite(filename, np.uint8(gcam))

def save_sensitivity(filename, maps):
    maps = maps.cpu().numpy()
    scale = max(maps[maps > 0].max(), -maps[maps <= 0].min())
    maps = maps / scale * 0.5
    maps += 0.5
    maps = cm.bwr_r(maps)[..., :3]
    maps = np.uint8(maps * 255.0)
    maps = cv2.resize(maps, (224, 224), interpolation=cv2.INTER_NEAREST)
    cv2.imwrite(filename, maps)

In [8]:
!mkdir embedding_results

In [9]:
np.random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed(42)
torch.backends.cudnn.deterministic = False

In [10]:
# print(models.resnet18())
class ResNetFeatrueExtractor50(nn.Module):
    def __init__(self, pretrained = True):
        super(ResNetFeatrueExtractor50, self).__init__()
        self.model = models.resnet50(pretrained=pretrained)
        self.model.fc = nn.Linear(2048, config['embedding_dim'])
        
    def forward(self, x):
        x = self.model(x)
        return x

In [11]:
device = torch.device('cuda')
output_dir = 'embedding_results'
topk = 1
#load model weights finetuned using EPSHN triplet loss on CUB_200_2011 Dataset with all 200 classes
model = torch.load('../models/cub_triplet_loss_epshn_resnet50_sgd_aug_200.pth',map_location='cuda')
model = model.to(device)
model.eval()

ResNetFeatrueExtractor50(
  (model): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequent

In [12]:
#load class name dict
with open('../CUB_200_2011/classes.txt','r') as f:
    classes = f.readlines()
classes = [i.replace('\n','') for i in classes]
classes = [i.split(' ')[1] for i in classes]
class_dict = {k:v for k,v in zip(classes,range(200))}

## Create Proxy Embeddings for each class

In [13]:
proxy_class = dict() #store class proxies against class names
for folder_path,i in class_dict.items():#iterate over class names
    image_paths = []
    labels = []
    folder_images = glob.glob('../CUB_200_2011/images/'+'/'+str(folder_path)+'/*')
    image_paths.extend(folder_images)
    labels.extend([i]*len(folder_images))
    all_class_embeddings = []
    train_dataset  = CUBDataset(image_paths,labels,test_transform)# load dataset for this class
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True)#create dataloader for this class
    for batch_idx, (data, labels) in enumerate(train_loader):
        data, labels = data.to(device), labels.to(device)
        embeddings = model(data)#get embeddings 
        all_class_embeddings.extend(embeddings.detach().cpu().numpy())
    
    all_class_embeddings = torch.as_tensor(np.asarray(all_class_embeddings))
    class_weight = torch.nn.functional.normalize(torch.unsqueeze(torch.mean(all_class_embeddings,axis=0),dim=0),
                                                 p=2.0, dim=-1).to(device)#equation 11 & 12
    proxy_class[str(folder_path)] = {'proxy':class_weight[0],'image_paths':image_paths,'labels':labels}#save class proxy
    #break#consider only 1 class for now

In [14]:
class_name = '188.Pileated_Woodpecker'
proxy_embedding = proxy_class[class_name]['proxy']

In [48]:
# Images
images, raw_images = load_images([random.choice(
    glob.glob(os.path.join('../CUB_200_2011/images/',class_name+'/*')))])
images = torch.stack(images).to(device)

Images:
	#0: ../CUB_200_2011/images/188.Pileated_Woodpecker/Pileated_Woodpecker_0034_180419.jpg


In [15]:
image_path = '../CUB_200_2011/images/188.Pileated_Woodpecker/Pileated_Woodpecker_0034_180419.jpg'
images, raw_images = load_images([image_path])
images = torch.stack(images).to(device)

Images:
	#0: ../CUB_200_2011/images/188.Pileated_Woodpecker/Pileated_Woodpecker_0034_180419.jpg


In [16]:
target_layers = ["model.layer1", "model.layer2", "model.layer3", "model.layer4"]

In [17]:
target_layer = target_layers[-1]

## Vanilla BackPropagation

In [18]:

"""
Common usage:
1. Wrap your model with visualization classes defined in grad_cam.py
2. Run forward() with images
3. Run backward() with a list of specific classes
4. Run generate() to export results
"""

# =========================================================================
print("Vanilla Backpropagation:")

bp = BackPropagation(model=model,proxy_embeddings=proxy_embedding)
_= bp.forward(images)  # sorted

for i in range(topk):
    bp.backward()
    gradients = bp.generate()

    # Save results as image files
    for j in range(len(images)):
        #print("\t#{}: {} ({:.5f})".format(j, classes[ids[j, i]], probs[j, i]))

        save_gradient(
            filename=osp.join(
                output_dir,
                "{}-vanilla_200.png".format(j),
            ),
            gradient=gradients[j],
        )

# Remove all the hook function in the "model"
bp.remove_hook()

Vanilla Backpropagation:


## Deconvnet

In [19]:
# =========================================================================
print("Deconvolution:")

deconv = Deconvnet(model=model,proxy_embeddings=proxy_embedding)
_ = deconv.forward(images)

for i in range(topk):
    deconv.backward()
    gradients = deconv.generate()

    for j in range(len(images)):

        save_gradient(
            filename=osp.join(
                output_dir,
                "{}-deconvnet_200.png".format(j),
            ),
            gradient=gradients[j],
        )

deconv.remove_hook()

Deconvolution:




## Grad-CAM | Guided Backpropagation | Guided Grad-CAM

In [20]:
# =========================================================================
print("Grad-CAM/Guided Backpropagation/Guided Grad-CAM:")

gcam = GradCAM(model=model,proxy_embeddings=proxy_embedding)
_ = gcam.forward(images)

gbp = GuidedBackPropagation(model=model,proxy_embeddings=proxy_embedding)
_ = gbp.forward(images)

for i in range(topk):
    # Guided Backpropagation
    gbp.backward()
    gradients = gbp.generate()

    # Grad-CAM
    gcam.backward()
    regions = gcam.generate(target_layer=target_layer)

    for j in range(len(images)):

        # Guided Backpropagation
        save_gradient(
            filename=osp.join(
                output_dir,
                "{}-guided_200.png".format(j),
            ),
            gradient=gradients[j],
        )

        # Grad-CAM
        save_gradcam(
            filename=osp.join(
                output_dir,
                "{}-gradcam-{}_200.png".format(
                    j,target_layer
                ),
            ),
            gcam=regions[j, 0],
            raw_image=raw_images[j],
            paper_cmap=True
        )

        # Guided Grad-CAM
        save_gradient(
            filename=osp.join(
                output_dir,
                "{}-guided-gradcam-{}_200.png".format(
                    j,target_layer
                ),
            ),
            gradient=torch.mul(regions, gradients)[j],
        )
gcam.remove_hook()
gbp.remove_hook()

Grad-CAM/Guided Backpropagation/Guided Grad-CAM:


## GradCam visualization layer wise

In [39]:
#target_layer = target_layers[-1]
# =========================================================================
print("Grad-CAM/Guided Backpropagation/Guided Grad-CAM:")

gcam = GradCAM(model=model,proxy_embeddings=proxy_embedding)
_ = gcam.forward(images)

gbp = GuidedBackPropagation(model=model,proxy_embeddings=proxy_embedding)
_ = gbp.forward(images)


# Guided Backpropagation
gbp.backward()
gradients = gbp.generate()

# Grad-CAM
gcam.backward()
for target_layer in target_layers: 
    regions = gcam.generate(target_layer=target_layer)
    for j in range(len(images)):
        # Grad-CAM
        save_gradcam(
            filename=osp.join(
                output_dir,
                "{}-gradcam-{}_180.png".format(
                    j,target_layer
                ),
            ),
            gcam=regions[0, 0],
            raw_image=raw_images[0],
            paper_cmap=True
        )
gcam.remove_hook()
gbp.remove_hook()

Grad-CAM/Guided Backpropagation/Guided Grad-CAM:


