In [31]:
import torch
import torchvision
import torchvision.transforms as T
import random
import numpy as np
from scipy.ndimage.filters import gaussian_filter1d
import matplotlib.pyplot as plt
#from cs231n.image_utils import SQUEEZENET_MEAN, SQUEEZENET_STD
from PIL import Image
import pandas as pd
from skimage import io, transform, color
import os
import math
import torchvision.models as models
%matplotlib inline
plt.rcParams['figure.figsize'] = (10.0, 8.0) # set default size of plots
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'

In [100]:
path = '../data/'
size = 193
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")
dtype = torch.float32

In [14]:
means = [0.49439774802337344,0.5177020917650417,0.5013496452715945]
stds = [0.16087996922691195,0.15714445907773483,0.1605951051365687]

In [17]:
[1/std for std in stds]

[6.21581421730357, 6.363571492554687, 6.226839847638]

In [18]:
def preprocess(img, size=size,means=means,stds=stds):
    transform = T.Compose([
        T.Resize(size),
        T.ToTensor(),
        T.Normalize(mean=means,
                    std=stds),
        T.Lambda(lambda x: x[None]),
    ])
    return transform(img)

def deprocess(img, should_rescale=True):
    transform = T.Compose([
        T.Lambda(lambda x: x[0]),
        T.Normalize(mean=[0, 0, 0], std=[1/std for std in stds]),
        T.Normalize(mean=-means, std=[1, 1, 1]),
        T.Lambda(rescale) if should_rescale else T.Lambda(lambda x: x),
        T.ToPILImage(),
    ])
    return transform(img)

def rescale(x):
    low, high = x.min(), x.max()
    x_rescaled = (x - low) / (high - low)
    return x_rescaled
    
def blur_image(X, sigma=1):
    X_np = X.cpu().clone().numpy()
    X_np = gaussian_filter1d(X_np, sigma, axis=2)
    X_np = gaussian_filter1d(X_np, sigma, axis=3)
    X.copy_(torch.Tensor(X_np).type_as(X))
    return X

# load saved_model

In [6]:
model = torch.load('classification_checkpoint.pth')['model']

In [7]:
for param in model.parameters():
    param.requires_grad = False

# load_data

In [21]:
validation_labels = pd.read_csv('../data_description/A1_A2_C1_filtered_validation_v2.csv')
validation_image_names = {2:validation_labels['image_name_2'],8:validation_labels['image_name_8'],5:validation_labels['image_name_5']}
validation_y = validation_labels['has_cell_13']

In [22]:
validation_image_names

{2: 0      well_A2/well3362_day02_well.png
 1      well_A1/well2413_day02_well.png
 2      well_C1/well2126_day02_well.png
 3      well_C1/well0543_day02_well.png
 4      well_A1/well0184_day02_well.png
 5      well_C1/well2965_day02_well.png
 6      well_A2/well0203_day02_well.png
 7      well_A1/well4708_day02_well.png
 8      well_A1/well1842_day02_well.png
 9      well_A2/well2752_day02_well.png
 10     well_A1/well1484_day02_well.png
 11     well_C1/well2141_day02_well.png
 12     well_C1/well3198_day02_well.png
 13     well_A1/well0968_day02_well.png
 14     well_A1/well1172_day02_well.png
 15     well_A1/well1745_day02_well.png
 16     well_A1/well2637_day02_well.png
 17     well_A1/well2436_day02_well.png
 18     well_A1/well2514_day02_well.png
 19     well_C1/well0010_day02_well.png
 20     well_C1/well1652_day02_well.png
 21     well_A2/well0123_day02_well.png
 22     well_C1/well2712_day02_well.png
 23     well_A1/well0957_day02_well.png
 24     well_C1/well0946_day02_well.p

In [84]:
def getXimage(index):
    all_images_list = []
    for day,img_names in validation_image_names.items():
            #print(day, "   ", index)
            
        img_name = img_names[index]
        img_loc = os.path.join(path, img_name)
        image = io.imread(img_loc)
#         mean, sd = mean_sd_dict[day]
#         image = np.true_divide(color.rgb2gray(image) - mean, sd)
        image = color.rgb2gray(image)
        all_images_list.append(image)
    images = np.array(all_images_list)
    #images = np.reshape(images, (1,2,0))
    #return torch.from_numpy(images).float()
    return images
def getY(index):
    Y = validation_y[index]
    #return torch.from_numpy(np.asarray(validation_y[index], dtype=float)).float()
    return Y

In [83]:
def getXimage(index):
    all_images_list = []
    for day,img_names in validation_image_names.items():
            #print(day, "   ", index)
            
        img_name = img_names[index]
        img_loc = os.path.join(path, img_name)
        image = io.imread(img_loc)
#         mean, sd = mean_sd_dict[day]
#         image = np.true_divide(color.rgb2gray(image) - mean, sd)
        image = color.rgb2gray(image)
        all_images_list.append(image)
    images = np.array(all_images_list)
    #images = np.reshape(images, (1,2,0))
    return images

In [80]:
X = getXimage(0)
y = getY(0)

In [85]:
X = []
Y = []
for i in range(5):
    X.append(getXimage(i))
    Y.append(getY(i))

In [90]:
X = torch.from_numpy(np.array(X)).float()

In [95]:
#Y = torch.from_numpy(np.array(Y)).float()

In [107]:
validation_y[:5]

0    1
1    1
2    0
3    0
4    1
Name: has_cell_13, dtype: int64

In [109]:
y = torch.LongTensor(validation_y[:5])

In [76]:
torch.cat(X).shape

torch.Size([15, 193, 193])

In [68]:
y.shape

torch.Size([])

In [59]:
X.dtype

dtype('float64')

In [56]:
preprocess(Image.fromarray(X))

TypeError: Cannot handle this data type

# Saliency Maps

In [8]:
# Example of using gather to select one entry from each row in PyTorch
def gather_example():
    N, C = 4, 5
    s = torch.randn(N, C)
    y = torch.LongTensor([1, 2, 1, 3])
    print(s)
    print(y)
    print(s.gather(1, y.view(-1, 1)).squeeze())
gather_example()

tensor([[ 0.1067,  1.0481,  1.1629,  0.7058,  1.3044],
        [-0.3272, -0.2228,  1.1700,  1.9523, -0.5114],
        [ 0.7604, -0.1452,  0.2288, -0.4644,  1.1930],
        [ 0.8360, -0.3876, -0.2788,  1.5241, -1.7641]])
tensor([1, 2, 1, 3])
tensor([ 1.0481,  1.1700, -0.1452,  1.5241])


In [9]:
    N, C = 4, 5
    s = torch.randn(N, C)
    y = torch.LongTensor([1, 2, 1, 3])

In [12]:
y.view(-1,1).squeeze()

tensor([1, 2, 1, 3])

In [114]:
def compute_saliency_maps(X, y, model):
    """
    Compute a class saliency map using the model for images X and labels y.

    Input:
    - X: Input images; Tensor of shape (N, 3, H, W)
    - y: Labels for X; LongTensor of shape (N,)
    - model: A pretrained CNN that will be used to compute the saliency map.

    Returns:
    - saliency: A Tensor of shape (N, H, W) giving the saliency maps for the input
    images.
    """
    # Make sure the model is in "test" mode
    model.eval()
    X = X.to(device=device, dtype=dtype) 
    y = y.to(device=device)
    # Make input tensor require gradient
    X.requires_grad_()
    
    saliency = None
    N = y.shape[0]
    
    scores = model(X)
    correct_scores = scores.gather(1,y.view(-1,1)).squeeze()
    correct_scores..to(device=device).backward(torch.ones(N))
    
    saliency = X.grad.data.abs()
    saliency,_=torch.max(saliency,dim=1)
    saliency = saliency.squeeze()

    return saliency

SyntaxError: invalid syntax (<ipython-input-114-91807e6936a7>, line 26)

In [113]:
X_tensor = X
y_tensor = y
saliency = compute_saliency_maps(X_tensor, y_tensor, model)

RuntimeError: invalid gradient at index 0 - expected type torch.cuda.FloatTensor but got torch.FloatTensor