In [None]:
%matplotlib inline

import sys, os
from PIL import Image
from matplotlib import pyplot as plt
import torch
import torch.nn as nn
from torchvision import transforms
from torch.autograd import Variable
import models.unet_normals as unet
import numpy as np
from data_loader import Dataset
from skimage.transform import resize
from skimage import img_as_uint

class OPT():
    def __init__(self):
        self.dataroot = './data/'
        self.file_list = './data/datalist'
        self.batchSize = 1
        self.shuffle = True
        self.phase = 'eval'
        self.num_epochs = 1
        self.imsize = 224
        self.num_classes = int(2)
        self.gpu = '0'
        self.logs_path = 'logs/exp2'

opt = OPT()
dataloader = Dataset(opt)

device = torch.device("cuda:"+ opt.gpu if torch.cuda.is_available() else "cpu")



checkpoint_path = opt.logs_path + '/checkpoints/checkpoint-epoch_760.pth'


for i in range(0, 100):
#     # Open and Transform Img
#     img_preprocessed = np.load('data/test/rgb-imgs-preprocessed/%09d-rgb.npy'%(i))
#     img_tensor = torch.from_numpy(img_preprocessed)
#     img = img_tensor.unsqueeze(0)

#     # Send img to device
#     img = Variable(img.to(device))

    # Get data
    inputs, labels =  dataloader.get_batch()
    rgb_img = inputs.squeeze(0)
    inputs = inputs.to(device)
    labels = labels.to(device)

    # Load Model
    fcn = unet.Unet(num_classes=opt.num_classes)
    fcn.load_state_dict(torch.load(checkpoint_path))
    fcn.to(device)
    fcn.eval()

    # Inference
    logits = fcn(inputs)
    logits_norm = nn.functional.normalize(logits, p=2, dim=1)
    output = logits_norm.squeeze(0)
    output = output.data.cpu().numpy()
        
    
    
    ### Create Plots ##
    show_plots = True
    if (show_plots):
        # Orig image
#         rgb_img = np.load('data/test/rgb-imgs-preprocessed/%09d-rgb.npy'%(i))
        
        inv_normalize = transforms.Normalize(
            mean=[-0.485/0.229, -0.456/0.224, -0.406/0.255],
            std=[1/0.229, 1/0.224, 1/0.255]
        )
        rgb_img = inv_normalize(rgb_img)
        rgb_img = rgb_img.numpy()
        rgb_img = np.transpose(rgb_img, (1,2,0))

        # Predicted Edges
        output_edges = output[1,:,:]
        seg_viz = np.zeros(output_edges.shape, dtype=np.uint8)
        seg_viz[output_edges>0] = 255

        

        fig = plt.figure(figsize=(12,12))
        ax0 = plt.subplot(121)
        ax1 = plt.subplot(122)
        ax0.imshow(rgb_img)
        ax0.set_title('Source RGB Image') # subplot 211 title
        ax1.imshow(seg_viz)
        ax1.set_title('Predicted Normals')
        
        plt.show()
        
        fig.savefig('data/results/%09d-results.png'%(i), dpi=fig.dpi)
        plt.close('all')



    ### Save Images ###
    save_images = False
    if (save_images):
        
        # Orig image
        plt.imsave('data/results/test-results/%09d-rgb.png'%(i), rgb_img)

        # Predicted Normals
        plt.imsave('data/results/test-results/%09d-normals.png'%(i), camera_normal_rgb)

        
    