In [None]:
%matplotlib inline

import sys, os
from PIL import Image
from matplotlib import pyplot as plt
import torch
import torch.nn as nn
from torchvision import transforms
from torch.autograd import Variable
import models.unet_normals as unet
import numpy as np
from data_loader import Dataset,Options
from skimage.transform import resize
from skimage import img_as_uint


###################### Loss fuction - Cosine #############################
'''
@input: The 2 vectors whose cosine loss is to be calculated
The dimensions of the matrices are expected to be (batchSize, 3, imsize, imsize). 

@return: 
elementwise_mean: will return the sum of all losses divided by num of elements
none: The loss will be calculated to be of size (batchSize, imsize, imsize) containing cosine loss of each pixel
'''
def loss_fn_cosine(input_vec, target_vec, reduction='elementwise_mean'):
    cos = nn.CosineSimilarity(dim=1, eps=1e-6)
    loss_val = 1.0 - cos(input_vec, target_vec)
    if (reduction=='elementwise_mean'):
        return torch.mean(loss_val)
    elif (reduction=='none'):
        return loss_val
    else:
        raise Exception('Warning! The reduction is invalid. Please use \'elementwise_mean\' or \'none\''.format())

###################### Loss fuction - Avg Angle Calc #############################
'''
@input: The 2 vectors whose cosine loss is to be calculated
The dimensions of the matrices are expected to be (batchSize, 3, imsize, imsize). 

@return: 
elementwise_mean: will return the sum of all losses divided by num of elements
none: The loss will be calculated to be of size (batchSize, imsize, imsize) containing cosine loss of each pixel
'''
def loss_fn_radians(input_vec, target_vec, reduction='elementwise_mean'):
    cos = nn.CosineSimilarity(dim=1, eps=1e-6)
    loss_cos = cos(input_vec, target_vec)    
    if (reduction=='elementwise_mean'):
        return torch.acos(torch.mean(loss_cos))
    elif (reduction=='none'):
        return torch.acos(loss_cos)
    else:
        raise Exception('Warning! The reduction is invalid. Please use \'elementwise_mean\' or \'none\''.format())

    return loss_val

###################### Options #############################
class OPT():
    def __init__(self):
        self.dataroot = './data/'
        self.file_list = './data/datalist_test'
        self.batchSize = 1
        self.shuffle = False
        self.phase = 'eval'
        self.num_epochs = 1000
        self.imsize = 224
        self.num_classes = int(3)
        self.gpu = '0'
        self.logs_path = 'logs/exp10'

opt = OPT()
dataloader = Dataset(opt)

# checkpoint_path = opt.logs_path + '/checkpoints/checkpoint-epoch_500.pth'
checkpoint_path = 'logs_dl_playground/exp19/checkpoints/checkpoint-epoch_1000.pth'
show_plots = True #True #False
save_images = False

device = torch.device("cuda:"+ opt.gpu if torch.cuda.is_available() else "cpu")
# Select Loss Func
loss_fn = loss_fn_cosine


##### Load Model #####
model = unet.Unet(num_classes=opt.num_classes)
model.load_state_dict(torch.load(checkpoint_path))
model = model.to(device)
model.eval()

##### Load Data #####
running_loss = 0.0
for i in range(int(dataloader.size()/opt.batchSize)): # BatchSize = 1
    inputs, labels =  dataloader.get_batch() # Get tensor of image and camera normal
    
    inputs_orig = inputs.clone().squeeze(0)
    labels_orig = labels.clone().squeeze(0).numpy()
    
    inputs = inputs.to(device)
    labels = labels.to(device)

    normal_vectors = model(inputs)
    normal_vectors_norm = nn.functional.normalize(normal_vectors, p=2, dim=1)
    loss = loss_fn_cosine(normal_vectors_norm, labels, reduction='elementwise_mean')
    running_loss += loss.item()
    
    loss_np = np.array([loss.item()], dtype=np.float32)
    loss_rad = np.arccos(1-loss_np)
    loss_deg = loss_rad * (180/np.pi)

    print('Loss for img%03d is %0.4f = %.2f deg'%(i, loss.item(), loss_deg.item())) # round(loss.item(),6)
    
    ### Create Plots ##
    if (show_plots):
        # Input RGB Image
        inv_normalize = transforms.Normalize( mean=[-0.485/0.229, -0.456/0.224, -0.406/0.255],
                                              std=[1/0.229, 1/0.224, 1/0.255] )
        rgb_img = inv_normalize(inputs_orig)
        rgb_img = torch.clamp(rgb_img, min=0.0, max=1.0) #inv_norm isn't perfect. Some values out of range.
        rgb_img = transforms.ToPILImage(mode='RGB')(rgb_img)
        rgb_img = np.array(rgb_img)
   
        # Predicted Normals
        output_norm = normal_vectors_norm.cpu()
        output_norm = output_norm.squeeze(0).detach().numpy()
        camera_normal_rgb = dataloader.normals_to_rgb_with_negatives(output_norm)
        camera_normal_rgb = np.transpose(camera_normal_rgb, (1,2,0))

        # Ground Truth Normals
        truth_normal = labels_orig.transpose(1,2,0)
        truth_normal = dataloader.normals_to_rgb_with_negatives(truth_normal)

        fig = plt.figure(figsize=(12,12))
        ax0 = plt.subplot(131)
        ax1 = plt.subplot(132)
        ax2 = plt.subplot(133)
        ax0.imshow(rgb_img)
        ax0.set_title('Source RGB Image') # subplot 211 title
        ax1.imshow(camera_normal_rgb)
        ax1.set_title('Predicted Normals')
        ax2.imshow(truth_normal)
        ax2.set_title('Ground Truth Normals')
        plt.show()
        plt.close('all')
        
    ### Save Images ###
    if (save_images):
        # Orig image
        plt.imsave('data/results/test-results/%09d-rgb.png'%(i), rgb_img)

        # Predicted Normals
        plt.imsave('data/results/test-results/%09d-normals.png'%(i), camera_normal_rgb)

        # Ground Truth Normals
        plt.imsave('data/results/test-results/%09d-normals-groundtruth.png'%(i), truth_normal_resized)


    

avg_loss = running_loss / (dataloader.size()/opt.batchSize) # BatchSize = 1
print('Avg Loss of Test Set is: %0.4f = %03.2f deg'%(avg_loss, loss_deg.item()))
    