In [None]:
import json
import os
import numpy as np
import PIL
from PIL import Image
import matplotlib.pyplot as plt

import torch
from torchvision.models import resnet18
import torchvision.transforms as T
import torch.nn.functional as F

# reading lime_image
from lime import lime_image

from skimage.segmentation import mark_boundaries 

In [None]:
# Set random seed for reproducibility.
np.random.seed(0)
torch.manual_seed(0) 

device ="cuda:0" if torch.cuda.is_available() else "cpu"

In [None]:
# loading the model
resnet = resnet18(pretrained=True)
resnet = resnet.eval().to(device)

# reading imagenet classes
idx2label, cls2label, cls2idx = [], {}, {}
with open(os.path.join("../data","imagenet_class_index.json"), 'r') as read_file:
    class_idx = json.load(read_file)
    idx2label = [class_idx[str(k)][1] for k in range(len(class_idx))]
    cls2label = {class_idx[str(k)][0]: class_idx[str(k)][1] for k in range(len(class_idx))}
    cls2idx = {class_idx[str(k)][0]: k for k in range(len(class_idx))}

In [None]:
# resize and take the center part of image to what our model expects
def pil_to_torch(img):
    transf = T.Compose([
        T.Resize((256, 256)),
        T.CenterCrop(224),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406],
                    std=[0.229, 0.224, 0.225]) 
    ])        
    # unsqeeze converts single image to batch of 1
    return transf(img).unsqueeze(0)

def pil_transform(img): 
    transf = T.Compose([
        T.Resize((256, 256)),
        T.CenterCrop(224)
    ])    

    return transf(img)

In [None]:
def cnn_predict(images): 
    transf = T.Compose([
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406],
                    std=[0.229, 0.224, 0.225])
        ])
    
    batch = torch.stack(tuple(transf(img) for img in images), dim=0)

    logits = resnet(batch.to(device)).cpu()
    probs = F.softmax(logits, dim=1)
    return probs.detach().cpu().numpy()

In [None]:
img_file_name = "puppy_kitten.jpg"
img_pil_0 = Image.open(os.path.join("../data",img_file_name)).convert('RGB')

_ = plt.imshow(img_pil_0)

In [None]:
img0 = pil_to_torch(img_pil_0)

logits = resnet(img0.to(device))
probs = F.softmax(logits, dim=1).cpu()
probs5 = probs.topk(5)
tuple((p,c, idx2label[c]) for p, c in zip(probs5[0][0].detach().numpy(), probs5[1][0].detach().numpy()))

In [None]:
explainer = lime_image.LimeImageExplainer()
explanation = explainer.explain_instance(np.array(pil_transform(img_pil_0)), 
                                         cnn_predict, # classification function
                                         top_labels=2, 
                                         hide_color=0, 
                                         num_samples=1000) # number of images that will be sent to classification function

In [None]:
fig, axes = plt.subplots(1,3, figsize=(20, 10))

axes[0].imshow(pil_transform(img_pil_0))
axes[0].axis('off')
axes[0].set_title("Original Image")

temp, mask = explanation.get_image_and_mask(explanation.top_labels[0], 
                                            positive_only=True, negative_only=False, 
                                            num_features=5, hide_rest=True)

img_boundry = mark_boundaries(temp/255.0, mask)
axes[1].imshow(img_boundry)
axes[1].set_title("Positive mask")
axes[1].axis('off')

temp, mask = explanation.get_image_and_mask(explanation.top_labels[0], 
                                            positive_only=False, negative_only=True, 
                                            num_features=5, hide_rest=True)

img_boundry = mark_boundaries(temp/255.0, mask)
axes[2].imshow(img_boundry)
axes[2].set_title("Negative mask")
axes[2].axis('off')

plt.show()

In [None]:
img_file_name = "fruit.jpg"
img_pil_1 = Image.open(os.path.join("../data",img_file_name)).convert('RGB')

img1 = pil_to_torch(img_pil_1)
logits = resnet(img1.to(device)).cpu()
probs = F.softmax(logits, dim=1)
probs5 = probs.topk(5)

In [None]:
_ = plt.imshow(img_pil_1)
labels = tuple((p,c, idx2label[c]) for p, c in zip(probs5[0][0].detach().numpy(), probs5[1][0].detach().numpy()))
labels

In [None]:
explainer = lime_image.LimeImageExplainer()
explanation = explainer.explain_instance(np.array(pil_transform(img_pil_1)), 
                                         cnn_predict, # classification function
                                         top_labels=5, 
                                         hide_color=0, 
                                         num_samples=1000) # number of images that will be sent to classification function


In [None]:
label = 3
print("Looking at:", labels[label][-1])

fig, axes = plt.subplots(1,3, figsize=(20, 10))
axes[0].imshow(pil_transform(img_pil_1))
axes[0].axis('off')
axes[0].set_title("Original Image")

temp, mask = explanation.get_image_and_mask(explanation.top_labels[label], positive_only=True, 
                                            negative_only=False, num_features=3, hide_rest=True)
img_boundry = mark_boundaries(temp/255.0, mask)
axes[1].imshow(img_boundry)
axes[1].set_title("Positive mask")
axes[1].axis('off')

temp, mask = explanation.get_image_and_mask(explanation.top_labels[label], positive_only=False, 
                                            negative_only=True, num_features=10, hide_rest=True)

img_boundry = mark_boundaries(temp/255.0, mask)
axes[2].imshow(img_boundry)
axes[2].set_title("Negative mask")
axes[2].axis('off')

plt.show()