In [None]:
# Visualizing policy using method from https://arxiv.org/abs/1711.00138
# GitHub repo: https://github.com/greydanus/visualize_atari
# Imports
import os
import sys
import pdb
import json
import torch
import random
import argparse
import torchvision
import tensorboardX
import torch.optim as optim
import torchvision.utils as vutils
import matplotlib.pyplot as plt

"""
TODO: Set the absolute repository path.
"""
repo_path = "/scratch/cluster/srama/Research/LearningToLookAround/SidekickPolicyLearning"
sys.path.append(os.path.join(repo_path, 'misc/'))

from utils import *
from models import *
from torch.optim import lr_scheduler
from tensorboardX import SummaryWriter
from scipy.ndimage.filters import gaussian_filter

In [None]:
# Refer https://github.com/greydanus/visualize_atari/blob/master/saliency.py

def occlude(I, mask):
    num_channels = I.shape[2]
    I_out = np.zeros(I.shape)
    for chn in range(num_channels):
        I_out[:, :, chn] = I[:, :, chn]*(1-mask) + gaussian_filter(I[:, :, chn], sigma=5)*mask 
    return I_out
    
def get_mask(center, size, r):
    y,x = np.ogrid[-center[0]:size[0]-center[0], -center[1]:size[1]-center[1]]
    keep = x*x + y*y <= 1
    mask = np.zeros(size) ; mask[keep] = 1 # select a circle of pixels
    mask = gaussian_filter(mask, sigma=r) # blur the circle of pixels. this is a 2D Gaussian for r=r^2=1
    return mask/mask.max()

def saliency_on_image(saliency, image, fudge_factor, channels=[2], sigma=0):
    # Refer https://github.com/greydanus/visualize_atari/blob/master/saliency.py
    pmax = saliency.max()
    S = saliency
    S = S if sigma == 0 else gaussian_filter(S, sigma=sigma)
    S -= S.min()
    S = S.clip(0, 1)
    S = 255.0*fudge_factor*pmax*S/(S.max() + 1e-7)
    I = np.copy(image.astype('float32'))
    if image.shape[2] == 1:
        I[:, :, 0] += S.astype('float32')
    else:
        for channel in channels:
            I[:, :, channel] += S.astype('float32')
    I = I.clip(0, 255.0).astype('uint8')
    return I

def get_belief_saliency(belief, others, agent, iscuda, lr=1e-4, weight_decay=1e-1, iters=200):
    """
    belief: current hidden state of the aggregator
    others: dictionary of elevation('elev'), azimuth('azim'), time('time'), proprioception change('pro')
    NOTE: Assuming batch size = 1
          Ensure that the agent's parameters have requires_grad=False
    """
    M = agent.M
    N = agent.N
    C = agent.C
    d_belief_orig = Variable(torch.randn(1, 256)*1e-3, requires_grad=True)
    optimizer = optim.SGD([d_belief_orig], lr=lr, weight_decay=weight_decay, momentum=0.9)
    
    act_input_1 = torch.cat([belief.view(1, -1), others['pro'][:, :2]], dim=1) 
    
    if 'elev' in others:
        xe = others['elev']
        act_input_1 = torch.cat([act_input_1, xe], dim=1)
    if 'azim' in others:
        xa = others['azim']
        act_input_1 = torch.cat([act_input_1, xa], dim=1)
    if 'time' in others:
        xt = others['time']
        act_input_1 = torch.cat([act_input_1, xt], dim=1)
    
    activations_1 = agent.policy.act(act_input_1)
    # Solve for d_belief that maximises change in policy activations
    for i in range(iters):
        if iscuda:
            d_belief = d_belief_orig.cuda()
        else:
            d_belief = d_belief_orig
            
        act_input_2 = torch.cat([d_belief + belief.view(1, -1), others['pro'][:, :2]], dim=1)
        
        if 'elev' in others:
            xe = others['elev']
            act_input_2 = torch.cat([act_input_2, xe], dim=1)
        if 'azim' in others:
            xa = others['azim']
            act_input_2 = torch.cat([act_input_2, xa], dim=1)
        if 'time' in others:
            xt = others['time']
            act_input_2 = torch.cat([act_input_2, xt], dim=1)

        activations_2 = agent.policy.act(act_input_2)
        loss = -((activations_1 - activations_2)**2).sum() + weight_decay*(d_belief**2).sum()
        loss.backward()
        #pdb.set_trace()
        optimizer.step()
        if torch.norm(d_belief_orig.data) >= 0.75 * torch.norm(belief.data):
            break
    diff = -((activations_1 - activations_2)**2).sum()
    diff_probs = -((F.softmax(activations_1, dim=1) - F.softmax(activations_2, dim=1))**2).sum()
    
    # Get the saliency
    decoded_2 = agent.policy.decode(F.normalize(d_belief + belief.view(1, -1), p=1, dim=1))
    decoded_2 = decoded_2.view(1, N, M, C, 32, 32)
    decoded_1 = agent.policy.decode(F.normalize(belief.view(1, -1), p=1, dim=1))
    decoded_1 = decoded_1.view(1, N, M, C, 32, 32)
    saliency = np.sum(0.5*((decoded_1 - decoded_2).data.cpu().numpy())**2, axis=3)
    
    return saliency, diff.data[0]

In [None]:
"""
TODO: Specify the path to model file to be visualized
"""
model_path = os.path.join(repo_path, 'models/sun360/ltla.net')

loaded_state = torch.load(model_path)
opts = loaded_state['opts']

# Unset all the unnecessary variables
opts.critic_full_obs = False
opts.act_full_obs = False
opts.batch_size = 50
opts.trajectories_type = 'utility_maps'
opts.expert_trajectories = False
opts.expert_rewards = False

if opts.dataset == 0:
    opts.h5_path = os.path.join(repo_path, 'data/sun360/sun360_processed.h5')
    opts.utility_h5_path = os.path.join(repo_path, 'scores/sun360/ours-demo-scores.h5')
else:
    opts.h5_path = os.path.join(repo_path, 'data/modelnet_hard/modelnet30_processed.h5')    
    opts.utility_h5_path = os.path.join(repo_path, 'scores/modelnet_hard/ours-demo-scores.h5')
    
if not hasattr(opts, 'h5_path_unseen'):
    if opts.dataset == 0:
        opts.h5_path_unseen = ''
    else:
        opts.h5_path_unseen = os.path.join(repo_path, 'data/modelnet_hard/modelnet10_processed.h5')
        
if opts.expert_trajectories:
    opts.T_sup = 3

if opts.expert_trajectories:
    agent = AgentSupervised(opts)
else:
    agent = Agent(opts)
    
agent.policy.load_state_dict(loaded_state['state_dict'])
agent.policy.eval()

if opts.expert_rewards:
    from DataLoader import DataLoaderExpert as DataLoader
elif opts.expert_trajectories:
    from DataLoader import DataLoaderExpertPolicy as DataLoader
else:
    from DataLoader import DataLoaderSimple as DataLoader

loader = DataLoader(opts)

set_random_seeds(opts.seed)

if opts.dataset == 0:
    opts.num_channels = 3
    if opts.mean_subtract:
        # R, G, B means and stds
        opts.mean = [119.16, 107.68, 95.12]
        opts.std = [61.88, 61.72, 67.24]
    else:
        opts.mean = [0, 0, 0]
        opts.std = [1, 1, 1]
elif opts.dataset == 1:
    opts.num_channels = 1
    if opts.mean_subtract:
        # R, G, B means and stds
        opts.mean = [193.0162338615919]
        opts.std = [37.716024486312811]
    else:
        opts.mean = [0]
        opts.std = [1]
else:
    raise ValueError('Dataset %d does not exist!'%(opts.dataset))
    
# Avoid computing gradients for the agent in backward
for parameter in agent.policy.parameters():
    parameter.requires_grad = False

In [None]:
saliency_belief_all = []
images_all = []
pano_all = []
decoded_all = []
decoded_saliency_all = []
decoded_all_raw = []

num_iters = 6
for iters in range(num_iters):
    if opts.expert_rewards:
        pano, _, depleted = loader.next_batch_test()
        pano_maps = None
        pano_rewards = None
    elif opts.expert_trajectories:
        pano, _, _, depleted = loader.next_batch_test()
        pano_rewards = None 
        pano_maps = None
    else:
        pano, _, depleted = loader.next_batch_test()
        pano_rewards = None
        pano_maps = None
    
    batch_size = opts.batch_size
    start_idx = get_starts(opts.N, opts.M, batch_size, opts.start_view)
    state_object = State(pano, pano_rewards, start_idx, opts)
    hidden = None
    images_all.append([])
    decoded_all.append([])
    saliency_belief_all.append([])
    decoded_all_raw.append([])
    visited_idxes = []
    decoded_saliency_all.append([])    
    
    pano_all.append([])
    
    for t in range(opts.T):
        # Get the original image's action distribution
        im_np, pro = state_object.get_view() # im_np - BxCx32x32
        im, pro = preprocess(im_np, pro)
        
        # Store the visited locations
        visited_idxes.append(state_object.idx)
        
        C = im.shape[1]
        W = im.shape[2]
        H = im.shape[3]
        
        # Get the panoramas with all views filled in till the current time step
        pano_full_view = np.copy(state_object.views)
        if opts.num_channels == 1:
            pano_full_view_3chn = np.zeros((batch_size, opts.N, opts.M, 3, 32, 32))
            for c in range(3):
                pano_full_view_3chn[:, :, :, c, :, :] = pano_full_view[:, :, :, 0, :, :]
            pano_full_view = pano_full_view_3chn

        for i in range(batch_size):
            for t_prime in range(t+1):
                
                pano_full_view[i, visited_idxes[t_prime][i][0], visited_idxes[t_prime][i][1], 0, :3, :] = 255
                pano_full_view[i, visited_idxes[t_prime][i][0], visited_idxes[t_prime][i][1], 0, :, :3] = 255
                pano_full_view[i, visited_idxes[t_prime][i][0], visited_idxes[t_prime][i][1], 0, -3:, :] = 255
                pano_full_view[i, visited_idxes[t_prime][i][0], visited_idxes[t_prime][i][1], 0, :, -3:] = 255

                pano_full_view[i, visited_idxes[t_prime][i][0], visited_idxes[t_prime][i][1], 1:, :3, :] = 0
                pano_full_view[i, visited_idxes[t_prime][i][0], visited_idxes[t_prime][i][1], 1:, :, :3] = 0
                pano_full_view[i, visited_idxes[t_prime][i][0], visited_idxes[t_prime][i][1], 1:, -3:, :] = 0
                pano_full_view[i, visited_idxes[t_prime][i][0], visited_idxes[t_prime][i][1], 1:, :, -3:] = 0

        pano_full_view = pano_full_view.transpose((0, 3, 1, 4, 2, 5)).reshape(\
                                    batch_size, 1, 3, opts.N*32, opts.M*32)
        pano_all[iters].append(pano_full_view)
        
        policy_input = {'im': im, 'pro': pro}
        # Assuming only actOnElev
        policy_input['elev'] = torch.Tensor([[state_object.idx[i][0]] for i in range(batch_size)])
        policy_input['time'] = torch.Tensor([[t] for i in range(batch_size)])
        if opts.iscuda:
            for var in policy_input:
                policy_input[var] = policy_input[var].cuda()
        
        for var in policy_input:
            policy_input[var] = Variable(policy_input[var])
        
        probs, decoded, hidden_new, value = agent.policy.forward(policy_input, hidden)

        # Create the decoded image with saliency modified view filled in
        decoded_images = decoded.data.cpu().numpy()*255.0
        # Add the means
        for chn in range(len(opts.mean)):
            decoded_images[:, :, :, chn] += opts.mean[chn] 

        decoded_images_shifted = np.zeros(decoded_images.shape)
        decoded_images_shifted_wo_sal = np.zeros(decoded_images.shape)

        for i in range(batch_size):
            decoded_images_shifted[i] = np.copy(np.roll(decoded_images[i], state_object.start_idx[i][1], axis=1))
            decoded_images_shifted_wo_sal[i] = np.copy(np.roll(decoded_images[i], state_object.start_idx[i][1], axis=1))
        if opts.num_channels == 1:
            decoded_images_shifted_3chn = np.zeros((batch_size, opts.N, opts.M, 3, 32, 32))
            decoded_images_shifted_wo_sal_3chn = np.zeros((batch_size, opts.N, opts.M, 3, 32, 32))
            for c in range(3):
                decoded_images_shifted_3chn[:, :, :, c, :, :] = np.copy(decoded_images_shifted[:, :, :, 0, :, :])
                decoded_images_shifted_wo_sal_3chn[:, :, :, c, :, :] = np.copy(decoded_images_shifted_wo_sal[:, :, :, 0, :, :])
            decoded_images_shifted = decoded_images_shifted_3chn
            decoded_images_shifted_wo_sal = decoded_images_shifted_wo_sal_3chn
            
        for t_view, visited_view in enumerate(visited_idxes):
            for i in range(batch_size):                
                curr_view = np.copy(state_object.views[i, visited_view[i][0], visited_view[i][1]])
                curr_view = np.transpose(curr_view, (1, 2, 0))
                if opts.num_channels == 1:
                    curr_view_3chn = np.zeros((32, 32, 3))
                    for c in range(3):
                        curr_view_3chn[:, :, c] = np.copy(curr_view[: ,:, 0])
                    curr_view = curr_view_3chn
                
                # Modify the current view to have some borders (for easy visualization)
                if t_view == t:
                    curr_view[:3, :, 0] = 255
                    curr_view[-3:, :, 0] = 255
                    curr_view[:, :3, 0] = 255
                    curr_view[:, -3:, 0] = 255
                    curr_view[:3, :, 1:] = 0 
                    curr_view[-3:, :, 1:] = 0
                    curr_view[:, :3, 1:] = 0
                    curr_view[:, -3:, 1:] = 0

                decoded_images_shifted[i, visited_view[i][0], visited_view[i][1]] = \
                    np.transpose(curr_view, (2, 0, 1))
                    
                decoded_images_shifted_wo_sal[i, visited_view[i][0], visited_view[i][1]] = \
                    np.transpose(curr_view, (2, 0, 1))
        
        # Convert BxNxMxCx32x32 -> BxCx(32*N)x(32*M)
        decoded_image_full_view = decoded_images_shifted.transpose((0, 3, 1, 4, 2, 5)).reshape(\
                                            batch_size, 1, 3, opts.N*32, opts.M*32)
        decoded_image_full_view_wo_sal = decoded_images_shifted_wo_sal.transpose((0, 3, 1, 4, 2, 5)).reshape(\
                                            batch_size, 1, 3, opts.N*32, opts.M*32)
        
        decoded_all[iters].append(decoded_image_full_view)
        decoded_all_raw[iters].append(decoded_image_full_view_wo_sal)
        
        # ============ Get the belief state saliency ============ 
        saliency_belief = []
        for i in range(batch_size):
            policy_input_curr = {}
            # Get only current batch element
            for k, v in policy_input.iteritems():
                policy_input_curr[k] = v[i:i+1]
            saliency_belief_curr, _ = get_belief_saliency(hidden_new[0][:, i:i+1], policy_input_curr, agent, opts.iscuda)
            # roll saliency_belief_curr to reflect the unknown azimuth
            saliency_belief_curr = np.roll(saliency_belief_curr, state_object.start_idx[i][1], axis=2)
            saliency_belief_curr = saliency_belief_curr.transpose((0, 1, 3, 2, 4)).reshape(1, opts.N*32, opts.M*32)
            saliency_belief.append(saliency_belief_curr)
        
        saliency_belief = np.concatenate(saliency_belief, axis=0)
        saliency_belief_all[iters].append(saliency_belief)
                        
        # Act greedily
        _, act = probs.max(dim=1)
        act = act.data.view(-1, 1)
        # Rotate the view
        _ = state_object.rotate(act[:, 0])
        
        # Set hidden
        hidden = hidden_new
    
    # Normalize the saliency values
    for i in range(batch_size):        
        saliency_max = 0
        for t in range(opts.T):
            saliency_max = saliency_belief_all[iters][t][i].max()
            saliency_belief_all[iters][t][i] /= (saliency_max + 1e-9)
    
    # Impose saliency on decoded views
    for t in range(opts.T):

        decoded_saliency_full_view = np.zeros((batch_size, 1, 3, opts.N*32, opts.M*32))
        for i in range(batch_size):
            curr_saliency = saliency_belief_all[iters][t][i]
            curr_view = decoded_all_raw[iters][t][i, 0].transpose(1, 2, 0)
            decoded_saliency_full_view[i, 0] = saliency_on_image(curr_saliency, curr_view, 1, channels=[0, 2], sigma=2).transpose(2, 0, 1)
        
        decoded_saliency_all[iters].append(decoded_saliency_full_view)    

In [None]:
# Write outputs to tensorboard
save_path = os.path.join(repo_path, 'visualizations/sun360/')
writer = SummaryWriter(log_dir=save_path)
images_count = 0
for iters in range(num_iters):
    batch_size = decoded_all[iters][0].shape[0]
    for i in range(batch_size):
        images_count += 1
        outputs_all = []
        #display_image = np.transpose(pano_all[iters][i], (0, 2, 3, 1)).astype('float32')
        #display_image = (display_image - display_image.min())/(display_image.max()-display_image.min())
        #outputs_all.append(display_image)
        for t in range(opts.T):
            display_image = np.transpose(decoded_all[iters][t][i], (0, 2, 3, 1))
            display_image = (display_image - display_image.min())/(display_image.max()-display_image.min())
            outputs_all.append(display_image)
            
        outputs_all = np.concatenate(outputs_all, axis=0).transpose((0, 3, 1, 2))
        x = vutils.make_grid(torch.Tensor(outputs_all), padding=5, normalize=True, scale_each=True, nrow=opts.T+1, pad_value=1.0)
        writer.add_image('#%d Decoded Viewgrid'%(images_count), x, 0)
        
        outputs_all = []
        for t in range(opts.T):
            display_image = np.transpose(pano_all[iters][t][i], (0, 2, 3, 1)).astype('float32')
            display_image = (display_image - display_image.min())/(display_image.max()-display_image.min())
            outputs_all.append(display_image)
        
        outputs_all = np.concatenate(outputs_all, axis=0).transpose((0, 3, 1, 2))
        x = vutils.make_grid(torch.Tensor(outputs_all), padding=5, normalize=False, scale_each=False, nrow=opts.T+1, pad_value=1.0)
        writer.add_image('#%d GT Viewgrid'%(images_count), x, 0)

In [None]:
# Visualize belief saliency in numpy
for iters in range(1):
    batch_size = decoded_saliency_all[iters][0].shape[0]
    for i in range(batch_size):
        for t in range(opts.T):
            display_image = np.transpose(decoded_saliency_all[iters][t][i][0], (1, 2, 0)).astype('float32')
            display_image = (display_image - display_image.min())/(display_image.max()-display_image.min())
            fig = plt.figure(figsize=(12, 16))
            if display_image.shape[2] == 1:
                plt.imshow(display_image[:, :, 0], cmap='gray')
            else:
                plt.imshow(display_image)
            plt.show()

In [None]:
# Write outputs to tensorboard
images_count = 0
for iters in range(num_iters):
    batch_size = decoded_saliency_all[iters][0].shape[0]
    for i in range(batch_size):
        images_count += 1
        outputs_all = []
        #display_image = pano_all[iters][i].astype('float32')
        #display_image = (display_image - display_image.min())/(display_image.max()-display_image.min())
        #outputs_all.append(display_image)
        for t in range(opts.T):
            display_image = decoded_saliency_all[iters][t][i].astype('float32')
            display_image = (display_image - display_image.min())/(display_image.max()-display_image.min())
            outputs_all.append(display_image)
        
        outputs_all = np.concatenate(outputs_all, axis=0)
        x = vutils.make_grid(torch.Tensor(outputs_all), padding=5, normalize=True, scale_each=True, nrow=opts.T+1, pad_value=1.0)
        writer.add_image('#%d Belief Saliency'%(images_count), x, 0)

writer.close()