# Setup

In [1]:
# Packages
import torch
import torchvision
from dataclasses import dataclass
import numpy as np
from sklearn.metrics import top_k_accuracy_score
from tqdm import tqdm
import math
import os
import matplotlib.pyplot as plt

In [2]:
print("Torch version:", torch.__version__)
print("CUDA Available:", torch.cuda.is_available())
print("# GPUS:", torch.cuda.device_count())
# for idx in range(torch.cuda.device_count()):
#     print(idx, torch.cuda.get_device_name(idx))

Torch version: 1.10.0
CUDA Available: False
# GPUS: 0


In [3]:
@dataclass
class Configuration:
    cuda_device: str = "cuda:1"
    imagenet_root: str = "../datasets/ilsvrc2012/"
    imagenetv2_root: str = "../datasets/imagenetv2-matched-frequency-format-val/"
    val_batch_size: int = 512
    val_loader_num_workers: int = 4

        
configuration = Configuration()
configuration

Configuration(cuda_device='cuda:1', imagenet_root='../datasets/ilsvrc2012/', imagenetv2_root='../datasets/imagenetv2-matched-frequency-format-val/', val_batch_size=512, val_loader_num_workers=4)

# Dataset

In [4]:
"""
https://pytorch.org/vision/stable/models.html#classification
For res50,
interpolation: bilinear
input size: 224
crop ratio: 0.85 (original -> 224/0.85 = 256, 224)
"""
val_transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize(
        int(math.floor(224 / 0.85)),
        interpolation=torchvision.transforms.functional.InterpolationMode.BILINEAR
    ),
    torchvision.transforms.CenterCrop(224),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(
        mean=(0.485, 0.456, 0.406),
        std=(0.229, 0.224, 0.225),
    ),
])

In [None]:
"""
ImageNet (https://image-net.org)
Download (after signup) from https://image-net.org/challenges/LSVRC/2012/2012-downloads.php
- Development kit (Task 1 & 2). 2.5MB. ILSVRC2012_devkit_t12.tar.gz  
- Validation images (all tasks). 6.3GB. MD5: 29b22e2961454d5413ddabcf34fc5622 ILSVRC2012_img_val.tar
"""
imagenet_training_dataset = torchvision.datasets.ImageNet(
    root=configuration.imagenet_root, split="train", transform=val_transform,
)
imagenet_validation_dataset = torchvision.datasets.ImageNet(
    root=configuration.imagenet_root, split="val", transform=val_transform,
)

imagenet_training_loader = torch.utils.data.DataLoader(
    imagenet_training_dataset,
    batch_size=configuration.val_batch_size,
    shuffle=False,
    num_workers=configuration.val_loader_num_workers,
)
imagenet_validation_loader = torch.utils.data.DataLoader(
    imagenet_validation_dataset,
    batch_size=configuration.val_batch_size,
    shuffle=False,
    num_workers=configuration.val_loader_num_workers,
)

In [None]:
"""
ImageNetV2 Matched Freq. (https://github.com/modestyachts/ImageNetV2)
Download from http://imagenetv2public.s3-website-us-west-2.amazonaws.com/
"""
IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp')

class ImageNetV2(torchvision.datasets.folder.DatasetFolder):
    """
    Modified from https://pytorch.org/vision/stable/_modules/torchvision/datasets/folder.html#ImageFolder
    """
    def __init__(self):
        super().__init__(
            root=configuration.imagenetv2_root,
            loader=torchvision.datasets.folder.default_loader,
            extensions=IMG_EXTENSIONS,
            transform=val_transform,
        )
        self.imgs = self.samples
        
    def find_classes(self, directory):
        """
        By default, torchvision.datasets.folder.ImageFolder will sort the folder in the str type.
        for example: "0", "1", "10", therefore we need a custom class_to_idx implementation here
        """ 
        classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())
        if not classes:
            raise FileNotFoundError(f"Couldn't find any class folder in {directory}.")
        
        class_to_idx = {cls_name: int(cls_name) for cls_name in classes}
        return classes, class_to_idx
        

imagenetv2_dataset = ImageNetV2()

imagenetv2_validation_loader = torch.utils.data.DataLoader(
    imagenetv2_dataset,
    batch_size=configuration.val_batch_size,
    shuffle=False,
    num_workers=configuration.val_loader_num_workers,
)

# Validate and store the predictions

In [None]:
def evaluate(model, dataloader):
    model = model.to(configuration.cuda_device).eval()
    # Clean up the GPU cache
    torch.cuda.empty_cache()
    torch.backends.cudnn.benchmark = True

    with torch.no_grad():
        y_labels = []
        y_probas = []
        for image, label in tqdm(dataloader):
            image = image.to(configuration.cuda_device)
            
            with torch.cuda.amp.autocast():
                batch_logits = model(image)
            
            batch_probas = torch.nn.functional.softmax(batch_logits, dim=-1).cpu().numpy()
            batch_labels = label.cpu().numpy()
            y_probas.append(batch_probas)
            y_labels.append(batch_labels)

        y_labels = np.concatenate(y_labels)
        y_probas = np.concatenate(y_probas)

    print(f"Top-1: {100 * top_k_accuracy_score(y_true=y_labels, y_score=y_probas, k=1): .3f}")
    print(f"Top-5: {100 * top_k_accuracy_score(y_true=y_labels, y_score=y_probas, k=5): .3f}")
    
    return y_probas, y_labels

In [None]:
pretrained_resnet50 = torchvision.models.resnet50(pretrained=True)

In [None]:
# """
# https://pytorch.org/vision/stable/models.html#classification
# Top-1: 76.130 Top-5: 92.862
# From the ImagenetV2 paper:
# Top-1: 76.1 Top-5 92.9
# """

probas, labels = evaluate(pretrained_resnet50, imagenet_validation_loader)
np.save("./resnet50_imagenet_valid_probabilities.npy", probas)
np.save("./resnet50_imagenet_valid_labels.npy", labels)

In [None]:
# """
# From the ImagenetV2 paper:
# Top-1: 63.3 Top-5 84.7
# """

probas, labels = evaluate(pretrained_resnet50, imagenetv2_validation_loader)
np.save("./resnet50_imagenetv2_valid_probabilities.npy", probas)
np.save("./resnet50_imagenetv2_valid_labels.npy", labels)

In [None]:
# probas, labels = evaluate(pretrained_resnet50, imagenet_training_loader)
# np.save("./resnet50_imagenet_train_probabilities.npy", probas)
# np.save("./resnet50_imagenet_train_labels.npy", labels)

# Build the map from the label to its confusing/competing classes

In [None]:
# Load the pretrained HLTM model
from apply_hltm import apply_hltm
import collections

clusterify_resnet50 = apply_hltm(cut_level=0, json_path="./ResNet50.json")
cluster_cls_map = collections.defaultdict(set)
for cls in range(0, 1000):
    cluster_id = clusterify_resnet50.paths[cls][clusterify_resnet50.cut_depth]
    cluster_cls_map[cluster_id].add(cls)

def get_competing_classes(cls):
    cluster_id = clusterify_resnet50.paths[cls][clusterify_resnet50.cut_depth]
    return list(cluster_cls_map[cluster_id])

print(get_competing_classes(0))
print(get_competing_classes(389))
print(get_competing_classes(402))
print(get_competing_classes(889))

# Compute the GradCam for the training images

In [None]:
from tqdm import tqdm
from torchray.attribution.grad_cam import grad_cam
from plt_wox import imsc

def get_activation_heat_map(
    model, probs, labels, dataset, batch_size, output_dir, layer_use="layer4",

):
    probs = torch.Tensor(probs)
    model.to(configuration.cuda_device)
    
    # Keep track of the original index
    for processing_cls in tqdm(range(730, 1000)):
        original_indices = []
        all_saliency = []
        # Clean up the GPU cache
        torch.cuda.empty_cache()
        torch.backends.cudnn.benchmark = True

        # Get the competing classes
        poss = get_competing_classes(processing_cls)

        # Processing the images with the same predicted class
        _indice = np.where(labels == processing_cls)[0]
        print(f"Processing {len(_indice)} images with predicted class: {processing_cls}.")
        original_indices += list(_indice)
        
        _probs = probs[_indice, :]
        
        # Fetch the coressponding images from the pytorch dataset
        images = []
        for idx in _indice:
            images.append(dataset[idx][0])
        images = torch.stack(images).to(configuration.cuda_device)

        # Compute the gradcam
        grad = torch.zeros_like(_probs)
        poss_p = _probs[:, poss]
        grad[:, poss] = poss_p / poss_p.sum()
        grad = grad.to(configuration.cuda_device)
        
        # Batch processing
        start = 0
        end = batch_size
        while start < len(_indice):
            with torch.cuda.amp.autocast():
                saliency = grad_cam(
                    model, images[start: end, :], grad[start: end, :],
                    saliency_layer=layer_use, resize = True,
                )
            saliency = saliency.detach().cpu().numpy().reshape(-1, 224, 224)
            saliency = np.clip(saliency, a_min=0, a_max=None)
            all_saliency.append(saliency)
            start += batch_size
            end += batch_size
        
        all_saliency = np.concatenate(all_saliency)
        original_indices = np.array(original_indices)
        np.save(f"{output_dir}/class_{processing_cls}_original_indices.npy", original_indices)
        np.save(f"{output_dir}/class_{processing_cls}_saliency.npy", all_saliency)
        print(f"Save results to {output_dir}")
        print(all_saliency.shape, original_indices.shape)


In [None]:
# Load the predictions from the baseline pretrained model
train_probas = np.load("./resnet50_imagenet_train_probabilities.npy")
train_labels = np.load("./resnet50_imagenet_train_labels.npy")
get_activation_heat_map(
    pretrained_resnet50,
    train_probas, train_labels, imagenet_training_loader.dataset,
    output_dir = "./competiting_classes_gradcam",
    batch_size = configuration.val_batch_size,
)

In [None]:

for cls in [0, 486, 889, 402, 546]:
    print("Class", cls)
    saliency = np.load(f"./competiting_classes_gradcam/class_{cls}_saliency.npy")
    original_indices = np.load(f"./competiting_classes_gradcam/class_{cls}_original_indices.npy")
    
    for idx in [0, 100, 500, 1000, 1200]:
        img=imsc(imagenet_training_loader.dataset[original_indices[idx]][0])
        plt.imshow(img)
        plt.imshow(saliency[idx], cmap='jet', alpha=0.6)
        plt.axis('off')
        plt.show()