In [1]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import random
import numpy as np
from scipy.ndimage.filters import gaussian_filter1d
import matplotlib.pyplot as plt
from PIL import Image
from utils.dataset_processing import evaluation, grasp
from utils.dataset_processing.grasp import GraspRectangles, detect_grasps
from utils.data.jacquard_sal import JacquardSalDataset
from skimage.filters import gaussian
from scipy.ndimage.filters import gaussian_filter1d

%matplotlib inline
plt.rcParams['figure.figsize'] = (10.0, 8.0) # set default size of plots
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'

device = torch.device("cuda:0")

In [2]:
def preprocess(img, size=224):
    transform = T.Compose([
        T.Resize(size),
        T.ToTensor(),
        T.Normalize(mean=SQUEEZENET_MEAN.tolist(),
                    std=SQUEEZENET_STD.tolist()),
        T.Lambda(lambda x: x[None]),
    ])
    return transform(img)

def deprocess(img, should_rescale=True):
    transform = T.Compose([
        T.Lambda(lambda x: x[0]),
        T.Normalize(mean=[0, 0, 0], std=(1.0 / SQUEEZENET_STD).tolist()),
        T.Normalize(mean=(-SQUEEZENET_MEAN).tolist(), std=[1, 1, 1]),
        T.Lambda(rescale) if should_rescale else T.Lambda(lambda x: x),
        T.ToPILImage(),
    ])
    return transform(img)

def rescale(x):
    low, high = x.min(), x.max()
    x_rescaled = (x - low) / (high - low)
    return x_rescaled
    
def blur_image(X, sigma=1):
    X_np = X.cpu().clone().numpy()
    X_np = gaussian_filter1d(X_np, sigma, axis=2)
    X_np = gaussian_filter1d(X_np, sigma, axis=3)
    X.copy_(torch.Tensor(X_np).type_as(X))
    return X

def post_process_output(q_img, cos_img, sin_img, width_img):
    """
    Post-process the raw output of the GG-CNN, convert to numpy arrays, apply filtering.
    :param q_img: Q output of GG-CNN (as torch Tensors)
    :param cos_img: cos output of GG-CNN
    :param sin_img: sin output of GG-CNN
    :param width_img: Width output of GG-CNN
    :return: Filtered Q output, Filtered Angle output, Filtered Width output
    """
    q_img = q_img.detach().cpu().numpy().squeeze()
    ang_img = (torch.atan2(sin_img, cos_img) / 2.0).detach().cpu().numpy().squeeze()
    width_img = width_img.detach().cpu().numpy().squeeze() * 150.0

    q_img = gaussian(q_img, 2.0, preserve_range=True)
    ang_img = gaussian(ang_img, 2.0, preserve_range=True)
    width_img = gaussian(width_img, 1.0, preserve_range=True)

    return q_img, ang_img, width_img

In [3]:
def activation_map(X, target, y, model, loss='grasp'):
    model.eval()
    X.requires_grad_()
    
    lossd = model.compute_loss(X, target, y)
    
    grasploss = lossd['loss']['grasp']
    classloss = lossd['loss']['class']
    
    if loss == 'grasp':
        grasploss.backward()
    elif loss == 'class':
        classloss.backward()
    elif loss == 'combined':
        grasploss.backward(retain_graph=True)
        classloss.backward()
    else:
        print('No Loss implemented')
    
    activation = torch.abs(X.grad)
    activation, _ = torch.max(activation, 1)
    activation = activation.data
    
    return activation
    
def show_activation_maps(X, target, y, model, activation='grasp'):
    pos, cos, sin, width, pred = model(X)
    
    q_img, ang_img, width_img = post_process_output(pos, cos, sin, width)
    
    saliency = activation_map(X, target, y, model, loss=activation)
    activation = saliency.cpu().squeeze().numpy()
    
    N = X.shape[0]
    
    for i in range(N):
        plt.subplot(2, N, i + 1)
        plt.imshow(X[i].detach().cpu().numpy().transpose(1,2,0))
        plt.axis('off')
        plt.subplot(2, N, N + i + 1)
        plt.imshow(activation, cmap=plt.cm.hot)
        plt.axis('off')
        plt.gcf().set_size_inches(12, 5)
    plt.show()
    
def plot_output(img, grasp_q_img, grasp_angle_img, no_grasps=1, grasp_width_img=None):
    """
    Plot the output of a GG-CNN
    :param rgb_img: RGB Image
    :param grasp_q_img: Q output of GG-CNN
    :param grasp_angle_img: Angle output of GG-CNN
    :param no_grasps: Maximum number of grasps to plot
    :param grasp_width_img: (optional) Width output of GG-CNN
    :return:
    """
    gs = detect_grasps(grasp_q_img, grasp_angle_img, width_img=grasp_width_img, no_grasps=no_grasps)

    fig = plt.figure(figsize=(10, 10))
    ax = fig.add_subplot(2, 2, 1)
    ax.imshow(img)
    for g in gs:
        g.plot(ax)
    ax.axis('off')

    ax = fig.add_subplot(2, 2, 3)
    plot = ax.imshow(grasp_q_img, cmap='jet', vmin=0, vmax=1)
    ax.set_title('Q')
    ax.axis('off')
    plt.colorbar(plot)

    ax = fig.add_subplot(2, 2, 4)
    plot = ax.imshow(grasp_angle_img, cmap='hsv', vmin=-np.pi / 2, vmax=np.pi / 2)
    ax.set_title('Angle')
    ax.axis('off')
    plt.colorbar(plot)

    plt.show()

In [4]:
transform = transforms.Compose([transforms.ToTensor()])
topil = transforms.ToPILImage()

data = JacquardSalDataset('/media/will/research/Jacquard/data', include_depth=True, include_rgb=True, train=False, random_rotate=True, random_zoom=True, transform=transform)

(tensor([[[-0.0603, -0.0600, -0.0596,  ...,  0.0639,  0.0644,  0.0648],
          [-0.0603, -0.0600, -0.0596,  ...,  0.0639,  0.0644,  0.0648],
          [-0.0603, -0.0600, -0.0596,  ...,  0.0639,  0.0644,  0.0648],
          ...,
          [-0.0603, -0.0600, -0.0596,  ...,  0.0639,  0.0644,  0.0648],
          [-0.0603, -0.0600, -0.0596,  ...,  0.0639,  0.0644,  0.0648],
          [-0.0603, -0.0600, -0.0596,  ...,  0.0639,  0.0644,  0.0648]],
 
         [[ 0.6902,  0.6863,  0.6863,  ...,  0.6745,  0.6706,  0.6706],
          [ 0.6902,  0.6902,  0.6902,  ...,  0.6706,  0.6706,  0.6667],
          [ 0.6902,  0.6902,  0.6863,  ...,  0.6706,  0.6627,  0.6549],
          ...,
          [ 0.7176,  0.7137,  0.7137,  ...,  0.7529,  0.7529,  0.7529],
          [ 0.7216,  0.7137,  0.7137,  ...,  0.7569,  0.7569,  0.7569],
          [ 0.7255,  0.7216,  0.7216,  ...,  0.7569,  0.7569,  0.7569]],
 
         [[ 0.6902,  0.6863,  0.6863,  ...,  0.6745,  0.6706,  0.6706],
          [ 0.6902,  0.6902,