In [4]:
import torch
import time
import gc
import math
from scipy import ndimage
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import cv2
import pickle

from skimage import data, img_as_float
from skimage.measure import compare_ssim as ssim

import sys
sys.path.append('../code')
from python.commons import full_inference_e2e, inc_inference_e2e, adaptive_drilldown, generate_heatmap
from python.imagenet_classes import class_names
from python.vgg16 import VGG16

image_file_path = "../code/python/dog_resized.jpg"
interested_logit_index = 208

In [5]:
patch_size = 16
stride = 1

In [6]:
# torch.cuda.synchronize()
prev_time = time.time()
x = full_inference_e2e(VGG16, image_file_path, patch_size, stride, interested_logit_index, batch_size=128)
torch.cuda.synchronize()
full_inference_time = time.time() - prev_time
print("Full Inference Time: " + str(full_inference_time))

  orig_image = Variable(loader(orig_image).unsqueeze(0), volatile=True)


RuntimeError: cuda runtime error (2) : out of memory at /pytorch/aten/src/THC/generic/THCStorage.cu:58

In [None]:
orig_hm = generate_heatmap(image_file_path, x, show=True, label=class_names[interested_logit_index])

output = open('temp', 'w')
pickle.dump(orig_hm, output)

In [None]:
gc.collect()
torch.cuda.empty_cache()

In [None]:
def inc_inference(beta, patch_size=4, stride=1, adaptive=False):
    torch.cuda.synchronize()
    if not adaptive:
        x = inc_inference_e2e(VGG16, image_file_path, patch_size, stride, interested_logit_index,
                                  batch_size=1024, beta=beta)
    else:
        x = adaptive_drilldown(VGG16, image_file_path, patch_size, stride, interested_logit_index,
                                    batch_size=1024, beta=beta, percentile=20)
    
    torch.cuda.synchronize()

    return x

### Patch growth thresholding

In [None]:
orig_hm = pickle.load(open('temp', 'r'))

times = []
score = []

for beta in [1.0, 0.7, 0.5, 0.4, 0.35, 0.3, 0.25, 0.2]:
    prev_time = time.time()
    x = inc_inference(beta, patch_size=patch_size, stride=stride)
    inc_inference_time = time.time()-prev_time
    times.append(inc_inference_time)
    
    label = "BETA: " + str(beta) +" Inference Time: " + str(inc_inference_time)
    print(label)
    hm = generate_heatmap(image_file_path, x, show=True, label=class_names[interested_logit_index])
    score.append(ssim(orig_hm, hm, data_range=255, multichannel=True))

In [None]:
plt.plot(times, score, marker='o')
plt.grid()
plt.xlabel('runtime (s)')
plt.ylabel('SSIM')

In [None]:
orig_hm = pickle.load(open('temp', 'r'))

times = []
score = []

for beta in [1.0, 0.7, 0.5, 0.4, 0.35, 0.3, 0.25, 0.2]:
    prev_time = time.time()
    x = inc_inference(beta, patch_size=patch_size, stride=stride, adaptive=True)
    inc_inference_time = time.time()-prev_time
    times.append(inc_inference_time)
    
    label = "BETA: " + str(beta) +" Inference Time: " + str(inc_inference_time)
    print(label)
    hm = generate_heatmap(image_file_path, x, show=True, label=class_names[interested_logit_index])
    score.append(ssim(orig_hm, hm, data_range=255, multichannel=True))

In [None]:
plt.plot(times, score, marker='o')
plt.grid()
plt.xlabel('runtime (s)')
plt.ylabel('SSIM')