In [1]:
from __future__ import print_function, division

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
from skimage.measure import compare_ssim as ssim
import matplotlib.pyplot as plt
import time
import os
import copy
import h5py
import sys
import torch
from PIL import Image
import gc

sys.path.append('../../../code')

from python.finetune_commons import show_images, ft_train_model, visualize_model
from python.commons import load_dict_from_hdf5, save_dict_to_hdf5, inc_inference_e2e, full_inference_e2e, adaptive_drilldown, generate_heatmap
from python.vgg16 import VGG16
from python.resnet18 import ResNet18

%matplotlib inline

  from ._conv import register_converters as _register_converters


In [2]:
image_size = 224

def inc_inference(model, image_file_path, beta, patch_size=4, stride=1,
                  adaptive=False, weights_data=None):
    if gpu:
        torch.cuda.synchronize()
    
    if not adaptive:
        with torch.no_grad():
            x = inc_inference_e2e(model, image_file_path, patch_size, stride,
                                  batch_size=128, beta=beta, gpu=gpu, version='v1',
                                  weights_data=weights_data, n_labels=2, c=0.0)
    
    if gpu:
        torch.cuda.synchronize()

    return x

loader = transforms.Compose([transforms.Resize([image_size, image_size]), transforms.ToTensor()])

In [3]:
gpu = True
image_size = 224

file_paths = [
    '../../../data/oct/DRUSEN_NORMAL/validation/DRUSEN/DRUSEN-5333808-1.jpeg',
#     '../../../data/oct/DRUSEN_NORMAL/validation/DRUSEN/DRUSEN-1020679-5.jpeg',
#     '../../../data/oct/DRUSEN_NORMAL/validation/DRUSEN/DRUSEN-1112835-10.jpeg',
#     '../../../data/oct/DRUSEN_NORMAL/validation/DRUSEN/DRUSEN-1130960-76.jpeg',
#     '../../../data/oct/DRUSEN_NORMAL/validation/DRUSEN/DRUSEN-1169820-1.jpeg',
#     '../../../data/oct/DRUSEN_NORMAL/validation/DRUSEN/DRUSEN-1146923-28.jpeg',
#     '../../../data/oct/DRUSEN_NORMAL/validation/NORMAL/NORMAL-1007507-1.jpeg',
#     '../../../data/oct/DRUSEN_NORMAL/validation/NORMAL/NORMAL-1007507-1.jpeg',
#     '../../../data/oct/DRUSEN_NORMAL/validation/NORMAL/NORMAL-1007507-1.jpeg',
#     '../../../data/oct/DRUSEN_NORMAL/validation/NORMAL/NORMAL-1001772-3.jpeg'    
]

In [4]:
patch_size = 32
stride = 4

taus = [1.0, 0.9, 0.8, 0.7, 0.6, 0.5]

plt.figure(figsize=(9,5))

i = 1

for model,model_name,weight_file in zip([VGG16, ResNet18], ['VGG16', 'ResNet18', 'ResNet18'],
    ['../../../code/python/vgg16_weights_ptch.h5', '../../../code/python/resnet18_weights_ptch.h5',
        '../../../code/python/resnet18_weights_ptch.h5']):
    
    weights_data = load_dict_from_hdf5(weight_file, gpu=gpu)
    temp_weights_data = load_dict_from_hdf5('../../../exps/oct_drusen_'+model_name.lower()+'_ptch.h5', gpu=gpu)
    
    if model_name == 'VGG16':
        weights_data['fc8_W:0'] = temp_weights_data['fc8_W:0']
        weights_data['fc8_b:0'] = temp_weights_data['fc8_b:0']
    else:
        weights_data['fc:w'] = temp_weights_data['fc:w']
        weights_data['fc:b'] = temp_weights_data['fc:b']

    x_vals = []
    y_ssim = []
    y_time = []
    
    for file_path in file_paths:
        
        x = full_inference_e2e(model, file_path, patch_size, stride, batch_size=128,
                               gpu=gpu, weights_data=weights_data)
        orig_hm = generate_heatmap(file_path, x, show=False, label="")
        
        for beta in taus:
            prev_time = time.time()
            x = inc_inference(model, file_path, beta, patch_size=patch_size, stride=stride,
                             weights_data=weights_data)
            inc_inference_time = time.time()-prev_time
            hm = generate_heatmap(file_path, x, show=True, label="")
            
            ssim_value = ssim(orig_hm, hm, data_range=255, multichannel=True, win_size=3)
            x_vals.append(beta)
            y_ssim.append(ssim_value)
            y_time.append(inc_inference_time)
            
        gc.collect()
        torch.cuda.empty_cache()
            

    ax = plt.subplot(2,3,i)
    plt.scatter(x_vals, y_ssim)
    ax.set_title(model_name)
    
    plt.grid()
    plt.xticks(taus, taus)

    if i == 1:
        plt.ylabel('SSIM')
        
    ax = plt.subplot(2,3,i+3)
    plt.scatter(x_vals, y_time)
    
    plt.grid()
    plt.xlabel(r'Projective Field Threshold $(\tau)$')
    plt.xticks(taus, taus)

    if i == 1:
        plt.ylabel('Time (s)')


    i = i + 1
    
plt.subplots_adjust(top=0.8)
plt.savefig('../images/ssim_tau.pdf', bbox_inches='tight')
plt.show()    

RuntimeError: cuda runtime error (77) : an illegal memory access was encountered at /pytorch/aten/src/THC/generic/THCTensorCopy.c:20

<Figure size 648x360 with 0 Axes>