load model

In [None]:
import torch
from torch import nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import numpy as np
import os

import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, precision_recall_fscore_support
import scipy.ndimage
from scipy import misc
from glob import glob
from scipy import stats
from sklearn.preprocessing import LabelEncoder, StandardScaler
import skimage
import imageio
import seaborn as sns
from PIL import Image
import glob
import matplotlib.pyplot as plt
import matplotlib
from sklearn.model_selection import StratifiedShuffleSplit

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# importing metadata
data_dir = ""
metadata = pd.read_csv(data_dir + 'HAM10000_metadata.csv')

le = LabelEncoder()
le.fit(metadata['dx'])
LabelEncoder()
print("Classes:", list(le.classes_))

metadata['label'] = le.transform(metadata["dx"])
metadata.sample(10)

dest_dir = ""

label = [ 'akiec', 'bcc','bkl','df','mel', 'nv',  'vasc']
classes = [ 'actinic keratoses', 'basal cell carcinoma', 'benign keratosis-like lesions',
           'dermatofibroma','melanoma', 'melanocytic nevi', 'vascular lesions']

def estimate_weights_mfb(label):
  class_weights = np.zeros_like(label, dtype=np.float)
  counts = np.zeros_like(label)
  for i, l in enumerate(label):
    counts[i] = metadata[metadata['dx'] == str(l)]['dx'].value_counts()[0]
  counts = counts.astype(np.float)
  median_freq = np.median(counts)
  for i, label in enumerate(label):
    class_weights[i] = median_freq/counts[i]
  return class_weights

classweight = estimate_weights_mfb(label)
for i in range(len(label)):
    print(label[i],":", classweight[i])

norm_mean = (0.4914, 0.4822, 0.4465)
norm_std = (0.2023, 0.1994, 0.2010)

batch_size = 10
validation_batch_size = 10

class_weights = estimate_weights_mfb(label)
class_weights = torch.FloatTensor(class_weights)

transform_train = transforms.Compose([
                    transforms.Resize((224,224)),
                    transforms.RandomHorizontalFlip(),
                    transforms.RandomRotation(degrees=60),
                    transforms.ToTensor(),
                    transforms.Normalize(norm_mean, norm_std),
                    ])

transform_test = transforms.Compose([
                    transforms.Resize((224,224)),
                    transforms.ToTensor(),
                    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
                    ])

train_indices = np.loadtxt("/train_indices.txt").astype(np.int)
val_indices = np.loadtxt("/val_indices.txt").astype(np.int)
test_indices = np.loadtxt("/test_indices.txt").astype(np.int)

SubsetRandomSampler = torch.utils.data.sampler.SubsetRandomSampler

train_samples = SubsetRandomSampler(train_indices)
val_samples = SubsetRandomSampler(val_indices)
test_samples = SubsetRandomSampler(test_indices)

dataset = torchvision.datasets.ImageFolder(root=dest_dir, transform = transform_train)
train_data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False,num_workers=1, sampler= train_samples)
validation_data_loader = torch.utils.data.DataLoader(dataset, batch_size=validation_batch_size, shuffle=False, sampler=val_samples)

dataset = torchvision.datasets.ImageFolder(root= dest_dir, transform=transform_test)
test_data_loader = torch.utils.data.DataLoader(dataset, batch_size=validation_batch_size, shuffle=False, sampler=test_samples)


## define CNN
num_classes = len(classes)

vgg = torchvision.models.vgg16(pretrained = True)

vgg.classifier[-1] = nn.Linear(in_features=4096, out_features=num_classes, bias=True)
vgg = vgg.to(device)

import torch.optim as optim

criterion = nn.CrossEntropyLoss(weight = class_weights.to(device))
optimizer = optim.Adam(vgg.parameters(), lr = 1e-6)

from sklearn.metrics import accuracy_score

def get_accuracy(predicted, labels):
    batch_len, correct= 0, 0
    batch_len = labels.size(0)
    correct = (predicted == labels).sum().item()
    return batch_len, correct

def evaluate(model, val_loader):
    losses= 0
    num_samples_total=0
    correct_total=0
    model.eval()
    for inputs, labels in val_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        out = model(inputs)
        _, predicted = torch.max(out, 1)
        loss = criterion(out, labels)
        losses += loss.item()
        b_len, corr = get_accuracy(predicted, labels)
        num_samples_total +=b_len
        correct_total +=corr
    accuracy = correct_total/num_samples_total
    losses = losses/len(val_loader)
    return losses, accuracy

vgg.load_state_dict(torch.load('/.pt'))

In [None]:
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from torchvision.transforms import ToTensor
from tqdm import tqdm

from attribution_bottleneck.utils.misc import *
from attribution_bottleneck.utils.baselines import Mean
from attribution_bottleneck.evaluate.perturber import *
from sklearn.metrics import auc
warn_baseline_prob = 0.05
tile_size = (56,56)

In [None]:
def to_np(t: torch.Tensor):
    t = t.detach()
    if t.is_cuda:
        t = t.cpu()
    return t.numpy()

def eval_np(img_t):
        """ pass the tensor through the network and return the scores as a numpyarray w/o batch dimension (1D shape)"""
        return to_np(vgg(img_t))[0]

def show_img(img, title="", place=None):
    img = to_np_img(img)
    if place is None:
        place = plt
    try:
        if len(img.shape) == 3 and img.shape[2] == 1:
            # remove single grey channel
            img = img[...,0]

        if len(img.shape) == 2:
            place.imshow(img, cmap="Greys_r")
        else:
            place.imshow(img)
    except TypeError:
        print("type error: shape is {}".format(img.shape))
        raise TypeError

    if not isinstance(place, Axes):
        place.title(title)
        plt.show()
    else:
        place.set_title(title)


## gradcam

In [None]:
target_layers = [vgg.features[29]]  ## for vgg16

save_model_results = []
for i, data in enumerate(train_data_loader):
  gc_inputs, gc_labels = data
  gc_target = []
  for lbl in gc_labels:
    gc_target.append([ClassifierOutputTarget(lbl)])

  print("###################", i)
  if (i==100): break

  for j, im in enumerate(gc_inputs):    
    with GradCAM(model=vgg,
                      target_layers=target_layers,
                      use_cuda=torch.cuda.is_available()) as cam:

      cam.batch_size = 32
      input_tensor = im.unsqueeze(0)

      grayscale_cam = cam(input_tensor=input_tensor,
                      targets=gc_target[j], 
                      aug_smooth=True,
                      eigen_smooth=True)
      
      # Here grayscale_cam has only one image in the batch]
      grayscale_cam = grayscale_cam[0, :]
      im = im.swapaxes(0,1)
      im = im.swapaxes(1,2)
      im = im.cpu().numpy()

      ## degradation ##
      img_t = im
      hmap = grayscale_cam
      
      tts = ToTensor()
      img_t = tts(img_t).unsqueeze(0)
      img_t = img_t.cuda()
      img = to_np_img(img_t)

      baseline_img = Mean().apply(img)
      baseline_t = to_img_tensor(baseline_img, device=img_t.device)

      with torch.no_grad():
        initial_out = eval_np(img_t)
      top1 = np.argmax(initial_out)
      initial_val = initial_out[top1]
      baseline_val = eval_np(baseline_t)[top1]

      perturber = PixelPerturber(img_t, baseline_img) if (tile_size is None or tile_size == (1, 1)) else GridPerturber(img_t, baseline_t, tile_size)
      idxes = perturber.get_idxes(hmap)

      max_steps = len(idxes)
      n_steps = 200
      progbar = False

      do_steps = int(max_steps * 1.0)
      parts = np.linspace(0, 1, n_steps)
      parts_int = [int(p) for p in np.round(parts*max_steps)]
      min_value = initial_val
      min_degraded_t = img_t

      parts = [0]
      perturbed_ts = [img_t]
      for step in tqdm(range(do_steps), desc="Perturbing", disable=not progbar):
        perturber.perturbe(*idxes[step])
        if step in parts_int:
            perturbed_ts.append(perturber.get_current().clone())

      perturbed_ts = torch.cat(perturbed_ts, 0)
      with torch.no_grad():
        model_results = to_np(vgg(perturbed_ts))

      model_results = model_results[:, top1]
      model_results = np.array(model_results)
      model_results = model_results/max(model_results)
      save_model_results.append(model_results)

  
np_result = np.array(save_model_results)
result = np_result.mean(axis=0)
result_normalize = (result-min(result))/(max(result) - min(result)) 

x = np.array(range(len(result)))
plt.figure(figsize=(8,8))
plt.title("gradcam" + ",AUC=" + str(round(auc(x, result_normalize)/(result_normalize[0]*len(result_normalize)),3)))
plt.ylabel('normalized model score')
plt.xlabel('level of degradation[%]')
plt.ylim([-0.05,1.05])
plt.yticks([0.0, 0.2, 0.4, 0.6, 0.8, 1.0])
plt.xticks([0, 20, 40, 60, 80, 100])
new_x = x/max(x)*100
plt.plot(new_x, result_normalize, color="gray")
plt.savefig("")
gradcam_deletion_result = result_normalize
np.savetxt('/gradcam_deletion_result.txt', gradcam_deletion_result)

plt.show()


## RISE

In [None]:
from explanations import RISE
from RISE_utils import *
from tqdm import tqdm

args = Dummy()

args.input_size = (224, 224)
args.gpu_batch = 1

explainer = RISE(vgg, args.input_size, args.gpu_batch)

def explain(idx, img, target):
    img = img.unsqueeze(0)
    with torch.no_grad():
      saliency = explainer(img.cuda()).cpu().numpy()
      sal = saliency[target]

    return sal

In [None]:
maskspath = 'masks.npy'
generate_new = True

if generate_new or not os.path.isfile(maskspath):
    explainer.generate_masks(N=3000, s=8, p1=0.1, savepath=maskspath)
else:
    explainer.load_masks(maskspath)
    print('Masks are loaded.')

Generating filters: 100%|██████████| 3000/3000 [00:22<00:00, 131.56it/s]


In [None]:
tile_size = (56, 56)
save_model_results = []
for i, data in enumerate(train_data_loader):
  inputs, labels = data

  print("###################", i)
  if (i==100): break

  for j, im in enumerate(inputs):
    hmap = explain(i, im, labels[j])
    
    im = im.swapaxes(0,1)
    im = im.swapaxes(1,2)
    im = im.cpu().numpy()

    ## degradation ##
    img_t = im
    
    tts = ToTensor()
    img_t = tts(img_t).unsqueeze(0)
    img_t = img_t.cuda()
    img = to_np_img(img_t)

    baseline_img = Mean().apply(img)
    baseline_t = to_img_tensor(baseline_img, device=img_t.device)

    with torch.no_grad():
      initial_out = eval_np(img_t)
    top1 = np.argmax(initial_out)
    initial_val = initial_out[top1]
    baseline_val = eval_np(baseline_t)[top1]

    perturber = PixelPerturber(img_t, baseline_img) if (tile_size is None or tile_size == (1, 1)) else GridPerturber(img_t, baseline_t, tile_size)
    idxes = perturber.get_idxes(hmap)

    max_steps = len(idxes)
    n_steps = 200
    progbar = False

    do_steps = int(max_steps * 1.0)
    parts = np.linspace(0, 1, n_steps)
    parts_int = [int(p) for p in np.round(parts*max_steps)]
    min_value = initial_val
    min_degraded_t = img_t

    parts = [0]
    perturbed_ts = [img_t]
    for step in tqdm(range(do_steps), desc="Perturbing", disable=not progbar):
      perturber.perturbe(*idxes[step])
      if step in parts_int:
          perturbed_ts.append(perturber.get_current().clone())

    perturbed_ts = torch.cat(perturbed_ts, 0)
    with torch.no_grad():
      model_results = to_np(vgg(perturbed_ts))

    model_results = model_results[:, top1]
    model_results = np.array(model_results)
    model_results = model_results/max(model_results)
    save_model_results.append(model_results)

  
np_result = np.array(save_model_results)
result = np_result.mean(axis=0)
result_normalize = (result-min(result))/(max(result) - min(result)) 

x = np.array(range(len(result)))
plt.figure(figsize=(8,8))
plt.title("RISE" + ",AUC=" + str(round(auc(x, result_normalize)/(result_normalize[0]*len(result_normalize)),3)))
plt.ylabel('normalized model score')
plt.xlabel('level of degradation[%]')
plt.ylim([-0.05,1.05])
plt.yticks([0.0, 0.2, 0.4, 0.6, 0.8, 1.0])
plt.xticks([0, 20, 40, 60, 80, 100])
new_x = x/max(x)*100
plt.plot(new_x, result_normalize, color="gray")  
plt.savefig("")
rise_deletion_result = result_normalize
np.savetxt('/rise_deletion_result.txt', rise_deletion_result)

plt.show()

## Extremal perturbation

In [None]:
from torchray.attribution.extremal_perturbation import extremal_perturbation, contrastive_reward
from torchray.benchmark import get_example_data, plot_example
from torchray.utils import get_device, imsc

In [None]:
tile_size = (56, 56)
save_model_results = []
for i, data in enumerate(train_data_loader):
  inputs, labels = data

  print("###################", i)
  if (i==100): break

  for j, im in enumerate(inputs):
    im = im.unsqueeze(0)
    im = im.to(device)
    target = labels[j].tolist()

    masks_1, _ = extremal_perturbation(
      vgg, im, target,
      reward_func=contrastive_reward,
      debug=False,
      areas=[0.07],
    )

    mask_img = masks_1.squeeze()
    hmap = mask_img.cpu().numpy()

    ## degradation ##
    img_t = im
    
    img_t = img_t.cuda()
    img = to_np_img(img_t)

    baseline_img = Mean().apply(img)
    baseline_t = to_img_tensor(baseline_img, device=img_t.device)

    with torch.no_grad():
      initial_out = eval_np(img_t)
    top1 = np.argmax(initial_out)
    initial_val = initial_out[top1]
    baseline_val = eval_np(baseline_t)[top1]

    perturber = PixelPerturber(img_t, baseline_img) if (tile_size is None or tile_size == (1, 1)) else GridPerturber(img_t, baseline_t, tile_size)
    idxes = perturber.get_idxes(hmap)

    max_steps = len(idxes)
    n_steps = 200
    progbar = False

    do_steps = int(max_steps * 1.0)
    parts = np.linspace(0, 1, n_steps)
    parts_int = [int(p) for p in np.round(parts*max_steps)]
    min_value = initial_val
    min_degraded_t = img_t

    parts = [0]
    perturbed_ts = [img_t]
    for step in tqdm(range(do_steps), desc="Perturbing", disable=not progbar):
      perturber.perturbe(*idxes[step])
      if step in parts_int:
          perturbed_ts.append(perturber.get_current().clone())

    perturbed_ts = torch.cat(perturbed_ts, 0)
    with torch.no_grad():
      model_results = to_np(vgg(perturbed_ts))

    model_results = model_results[:, top1]
    model_results = np.array(model_results)
    model_results = model_results/max(model_results)
    save_model_results.append(model_results)

  
np_result = np.array(save_model_results)
result = np_result.mean(axis=0)
result_normalize = (result-min(result))/(max(result) - min(result)) 

x = np.array(range(len(result)))
plt.figure(figsize=(8,8))
plt.title("Extremal Perturbation" + ",AUC=" + str(round(auc(x, result_normalize)/(result_normalize[0]*len(result_normalize)),3)))
plt.ylabel('normalized model score')
plt.xlabel('level of degradation[%]')
plt.ylim([-0.05,1.05])
plt.yticks([0.0, 0.2, 0.4, 0.6, 0.8, 1.0])
plt.xticks([0, 20, 40, 60, 80, 100])
new_x = x/max(x)*100
plt.plot(new_x, result_normalize, color="gray")  
plt.savefig("")

plt.show()

In [None]:
ext_deletion_result = result_normalize
np.savetxt('/content/drive/MyDrive/HAM/ext_deletion_result.txt', ext_deletion_result)

## scoreCAM

In [None]:
from utils import *
from cam.scorecam import *

In [None]:
tile_size = (56,56)

save_model_results = []
for i, data in enumerate(train_data_loader):
  inputs, labels = data
  print("###################", i)
  if (i==100): break

  for j, im in enumerate(inputs):
    vgg_model_dict = dict(type='vgg16', arch=vgg, layer_name='features_29',input_size=(224, 224))
    vgg_scorecam = ScoreCAM(vgg_model_dict)
    input_tensor = im.unsqueeze(0)
    input_tensor = input_tensor.cuda()

    scorecam_map = vgg_scorecam(input_tensor)
    scorecam_map = scorecam_map.cpu().squeeze()

    im = im.swapaxes(0,1)
    im = im.swapaxes(1,2)
    im = im.cpu().numpy()

    ## degradation ##
    img_t = im
    hmap = scorecam_map
    
    tts = ToTensor()
    img_t = tts(img_t).unsqueeze(0)
    img_t = img_t.cuda()
    img = to_np_img(img_t)

    baseline_img = Mean().apply(img)
    baseline_t = to_img_tensor(baseline_img, device=img_t.device)

    with torch.no_grad():
      initial_out = eval_np(img_t)
    top1 = np.argmax(initial_out)
    initial_val = initial_out[top1]
    baseline_val = eval_np(baseline_t)[top1]

    perturber = PixelPerturber(img_t, baseline_img) if (tile_size is None or tile_size == (1, 1)) else GridPerturber(img_t, baseline_t, tile_size)
    idxes = perturber.get_idxes(hmap)

    max_steps = len(idxes)
    n_steps = 200
    progbar = False

    do_steps = int(max_steps * 1.0)
    parts = np.linspace(0, 1, n_steps)
    parts_int = [int(p) for p in np.round(parts*max_steps)]
    min_value = initial_val
    min_degraded_t = img_t

    parts = [0]
    perturbed_ts = [img_t]
    for step in tqdm(range(do_steps), desc="Perturbing", disable=not progbar):
      perturber.perturbe(*idxes[step])
      if step in parts_int:
          perturbed_ts.append(perturber.get_current().clone())

    perturbed_ts = torch.cat(perturbed_ts, 0)
    with torch.no_grad():
      model_results = to_np(vgg(perturbed_ts))

    model_results = model_results[:, top1]
    model_results = np.array(model_results)
    model_results = model_results/max(model_results)
    save_model_results.append(model_results)

  
np_result = np.array(save_model_results)
result = np_result.mean(axis=0)
result_normalize = (result-min(result))/(max(result) - min(result)) 

x = np.array(range(len(result)))
plt.title("ScoreCAM" + ",AUC=" + str(round(auc(x, result_normalize)/(result_normalize[0]*len(result_normalize)),3)))
plt.ylabel('normalized model score')
plt.xlabel('level of degradation[%]')
plt.ylim([-0.05,1.05])
plt.yticks([0.0, 0.2, 0.4, 0.6, 0.8, 1.0])
plt.xticks([0, 20, 40, 60, 80, 100])
new_x = x/max(x)*100
plt.plot(new_x, result_normalize)  
plt.savefig("")

plt.show()


In [None]:
scorecam_deletion_result = result_normalize
np.savetxt('//scorecam_deletion_result.txt', scorecam_deletion_result)

## IBA

In [None]:
import IBA
from IBA.pytorch import IBA, tensor_to_np_img
from torchvision.datasets import ImageFolder
from torchvision.transforms import Compose, CenterCrop, ToTensor, Resize, Normalize
from torch.utils.data import DataLoader

In [None]:
imagenet_dir = ""

dev = torch.device('cuda:0')

vgg.to(dev).eval()


image_size = 224
    
trainset = ImageFolder(
    os.path.join(imagenet_dir),
    transform=Compose([
        CenterCrop(256), Resize(image_size), ToTensor(), 
        Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]))

trainloader = DataLoader(trainset, batch_size=64, shuffle=True, num_workers=1)

In [None]:
iba = IBA(vgg.features[17])

iba.reset_estimate()

iba.estimate(vgg, trainloader, device=dev, n_samples=10000, progbar=True)

neuron = (12, 3, 4)
print("Neuron at position {:} has mean {:.2f} and std {:.2f}".format(
    neuron, iba.estimator.mean()[neuron],  iba.estimator.std()[neuron]))

iba.estimator.n_samples()

  0%|          | 0/10000 [00:00<?, ?it/s]

Neuron at position (12, 3, 4) has mean -6.96 and std 12.59


10048

In [None]:
tile_size = (56,56)

save_model_results = []
for i, data in enumerate(train_data_loader):
  inputs, labels = data
  print("###################", i)
  if (i==100): break

  for j, im in enumerate(inputs):
    target = labels[j]

    model_loss_closure = lambda x: -torch.log_softmax(vgg(x), 1)[:, target].mean()
    hmap = iba.analyze(im[None].to(dev), model_loss_closure) 

    im = im.swapaxes(0,1)
    im = im.swapaxes(1,2)
    im = im.cpu().numpy()

    ## degradation ##
    img_t = im
    
    tts = ToTensor()
    img_t = tts(img_t).unsqueeze(0)
    img_t = img_t.cuda()
    img = to_np_img(img_t)

    baseline_img = Mean().apply(img)
    baseline_t = to_img_tensor(baseline_img, device=img_t.device)

    with torch.no_grad():
      initial_out = eval_np(img_t)
    top1 = np.argmax(initial_out)
    initial_val = initial_out[top1]
    baseline_val = eval_np(baseline_t)[top1]

    perturber = PixelPerturber(img_t, baseline_img) if (tile_size is None or tile_size == (1, 1)) else GridPerturber(img_t, baseline_t, tile_size)
    idxes = perturber.get_idxes(hmap)

    max_steps = len(idxes)
    n_steps = 200
    progbar = False

    do_steps = int(max_steps * 1.0)
    parts = np.linspace(0, 1, n_steps)
    parts_int = [int(p) for p in np.round(parts*max_steps)]
    min_value = initial_val
    min_degraded_t = img_t

    parts = [0]
    perturbed_ts = [img_t]
    for step in tqdm(range(do_steps), desc="Perturbing", disable=not progbar):
      perturber.perturbe(*idxes[step])
      if step in parts_int:
          perturbed_ts.append(perturber.get_current().clone())

    perturbed_ts = torch.cat(perturbed_ts, 0)
    with torch.no_grad():
      model_results = to_np(vgg(perturbed_ts))

    model_results = model_results[:, top1]
    model_results = np.array(model_results)
    model_results = model_results/max(model_results)
    save_model_results.append(model_results)

  
np_result = np.array(save_model_results)
result = np_result.mean(axis=0)
result_normalize = (result-min(result))/(max(result) - min(result)) 

x = np.array(range(len(result)))
plt.figure(figsize=(8,8))
plt.title("IBA" + ",AUC=" + str(round(auc(x, result_normalize)/(result_normalize[0]*len(result_normalize)),3)))
plt.ylabel('normalized model score')
plt.xlabel('level of degradation[%]')
plt.ylim([-0.05,1.05])
plt.yticks([0.0, 0.2, 0.4, 0.6, 0.8, 1.0])
plt.xticks([0, 20, 40, 60, 80, 100])
new_x = x/max(x)*100
plt.plot(new_x, result_normalize)
plt.savefig("")

plt.show()