In [None]:
import os
import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
import PIL
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils, models

from ds import *
from networks import *
from utils import *

## Get data loaders

In [None]:
train_tfm = transforms.Compose([
        transforms.ToPILImage(),
        transforms.RandomRotation(degrees=15),
        transforms.ColorJitter(),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor()
])
train_dset = CxVAE_Dset(
    csv_file='../../../Datasets/chest_xray_pneumonia/train_labels.csv', 
    root_dir='../../../Datasets/chest_xray_pneumonia/images_224x256/',
    tfm=train_tfm
)
val_dset = CxVAE_Dset(
    csv_file='../../../Datasets/chest_xray_pneumonia/val_labels.csv', 
    root_dir='../../../Datasets/chest_xray_pneumonia/images_224x256/'
)
test_dset = CxVAE_Dset(
    csv_file='../../../Datasets/chest_xray_pneumonia/test_labels.csv', 
    root_dir='../../../Datasets/chest_xray_pneumonia/images_224x256/'
)

train_loader = DataLoader(train_dset, batch_size=4, shuffle=True, num_workers=16, pin_memory=False)
val_loader = DataLoader(val_dset, batch_size=4, shuffle=False, num_workers=16, pin_memory=False)
test_loader = DataLoader(test_dset, batch_size=4, shuffle=False, num_workers=16, pin_memory=False)

## Define model and pass to the training loop

In [None]:
Net = models.vgg16(pretrained=True, progress=False)
print(Net)

# Freeze training for all layers
for param in Net.parameters():
    param.require_grad = False

# Newly created modules have require_grad=True by default
num_features = Net.classifier[6].in_features
fc_new = torch.nn.Linear(num_features, 2)
Net.classifier[6] = fc_new
print(Net)

NetAE = AutoEncoder(3, 3, 8, 4, 32, 16)
NetAE.load_state_dict(torch.load('../ckpt/AE_pneu_best.pth'))

In [None]:
train_classifier_w_AE_loop(
    train_loader,
    val_loader,
    NetAE,
    Net,
    n_epochs=6,
    init_lr=1e-6,
    eval_every = 2,
    dtype = torch.cuda.FloatTensor,
    device='cuda',
    ckpt_path = '../ckpt/VGG16_w_AE_pneu'
)

In [None]:
Net.load_state_dict(torch.load('../ckpt/VGG16_w_AE_pneu_best.pth'))
eval_classifier_w_AE_loop(
    test_loader,
    NetAE,
    Net,
    dtype = torch.cuda.FloatTensor,
    device='cuda',
)

In [None]:
!pip install captum

In [None]:
from captum.attr import IntegratedGradients
from captum.attr import Saliency
from captum.attr import DeepLift
from captum.attr import NoiseTunnel
from captum.attr import visualization as viz
from captum.attr import GuidedGradCam
from captum.attr import InputXGradient

import scipy
from scipy.ndimage import gaussian_filter

import cv2

In [None]:
Net.eval()
def attribute_image_features(algorithm, input, label, **kwargs):
    Net.zero_grad()
    tensor_attributions = algorithm.attribute(
        input,
        target=label,
        **kwargs
    )
    return tensor_attributions

In [None]:
def overlay_saliency_map(i1_rec, smap, alpha=0.5):
    smap = (smap-np.min(smap))/(np.max(smap) - np.min(smap))
    mask = smap > 0.75*np.max(smap)
    smap = mask*smap
    smap = np.mean(smap, axis=2)
    smap = gaussian_filter(smap, sigma=10)
    smap = (smap-np.min(smap))/(np.max(smap) - np.min(smap))
    smap = np.uint8(smap*255)
    smap = cv2.applyColorMap(smap, colormap=cv2.COLORMAP_PLASMA)
    alpha = 0.5
    i1_rec = np.uint8(i1_rec*255)
    smap = cv2.addWeighted(i1_rec, alpha, smap, 1-alpha, 0)
    return smap


def show_classifier_w_AE_interp(
    eval_loader,
    NetAE,
    Net,
    dtype = torch.cuda.FloatTensor,
    device='cuda',
    n_show = 5,
):
    NetAE.to(device)
    NetAE.type(dtype)
    # Freeze training for all layers
    for param in NetAE.parameters():
        param.require_grad = False
    NetAE.eval()
    
    Net.to(device)
    Net.type(dtype)
    Net.eval()
    tot_err = 0
    tot_samples = 0
    for idx, (xin, yout) in enumerate(eval_loader):
        if idx>= n_show:
            break
        xin, yout = xin.to(device), yout.to(device)
        xin_rec, _ = NetAE(xin)
        output = F.log_softmax(Net(xin_rec))
        predictions = output.argmax(dim=1, keepdim=True).squeeze()
        n_batch = xin.shape[0]
        for j in range(n_batch):
            i1 = xin[j].data.cpu().transpose(0,2).transpose(0,1).numpy()
            i1_rec = xin_rec[j].data.cpu().transpose(0,2).transpose(0,1).clip(0,1).numpy()
            
            saliency = Saliency(Net)
            grads = saliency.attribute(xin_rec[j].unsqueeze(0), target=yout[j].unsqueeze(0))
            grads = np.transpose(grads.squeeze().data.cpu().detach().numpy(), (1, 2, 0))
            
            ig = IntegratedGradients(Net)
            attr_ig, delta = attribute_image_features(ig, xin_rec[j].unsqueeze(0), yout[j], baselines=xin_rec[j].unsqueeze(0) * 0, return_convergence_delta=True)
            attr_ig = np.transpose(attr_ig.squeeze().cpu().detach().numpy(), (1, 2, 0))
            
            ggcam = GuidedGradCam(Net, Net.features[10])
            attr_ggcam = attribute_image_features(ggcam, xin_rec[j].unsqueeze(0), yout[j])
            attr_ggcam = np.transpose(attr_ggcam.squeeze().cpu().detach().numpy(), (1, 2, 0))
            
            input_x_gradient = InputXGradient(Net)
            attr_ixg = attribute_image_features(input_x_gradient, xin_rec[j].unsqueeze(0), yout[j])
            attr_ixg = np.transpose(attr_ixg.squeeze().data.cpu().detach().numpy(), (1, 2, 0))
            
            grads = overlay_saliency_map(i1_rec, grads, alpha=0.5)
            attr_ggcam = overlay_saliency_map(i1_rec, attr_ggcam, alpha=0.5)
            attr_ig = overlay_saliency_map(i1_rec, attr_ig, alpha=0.5)
            attr_ixg = overlay_saliency_map(i1_rec, attr_ixg, alpha=0.5)
            
            print('GT: {}, predicted: {}'.format(yout[j], predictions[j]))
            
            plt.figure(figsize=(24,8))
            
            plt.subplot(1,6,1)
            plt.imshow(i1)
            plt.axis('off')
            plt.subplot(1,6,2)
            plt.imshow(i1_rec)
            plt.axis('off')
            
            plt.subplot(1,6,3)
            ax = plt.gca()
            im = ax.imshow(grads)
            plt.axis('off')
            divider = make_axes_locatable(ax)
            cax = divider.append_axes("bottom", size="5%", pad=0.05)
            plt.colorbar(im, cax=cax, orientation="horizontal")
            
            plt.subplot(1,6,4)
            ax = plt.gca()
            im = ax.imshow(attr_ggcam)
            plt.axis('off')
            divider = make_axes_locatable(ax)
            cax = divider.append_axes("bottom", size="5%", pad=0.05)
            plt.colorbar(im, cax=cax, orientation="horizontal")
            
            plt.subplot(1,6,5)
            ax = plt.gca()
            im = ax.imshow(attr_ig)
            plt.axis('off')
            divider = make_axes_locatable(ax)
            cax = divider.append_axes("bottom", size="5%", pad=0.05)
            plt.colorbar(im, cax=cax, orientation="horizontal")
            
            plt.subplot(1,6,6)
            ax = plt.gca()
            im = ax.imshow(attr_ixg)
            plt.axis('off')
            divider = make_axes_locatable(ax)
            cax = divider.append_axes("bottom", size="5%", pad=0.05)
            plt.colorbar(im, cax=cax, orientation="horizontal")
            
            
            plt.show()

In [None]:
show_classifier_w_AE_interp(
    test_loader,
    NetAE,
    Net,
    dtype = torch.cuda.FloatTensor,
    device='cuda',
    n_show = 5,
)