In [None]:
import sys
sys.path.append('./models/')
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import time
import os
from data_loader import Dataset
import models.unet_normals as unet
from PIL import Image 
import imageio
from torchvision import transforms, utils
import cv2
%matplotlib inline 


class OPT():
    def __init__(self):
        self.dataroot = './data/'
        self.file_list = './data/datalist'
        self.weights_path = './data/results/weights'
        self.batchSize = 1
        self.shuffle = False
        self.phase = 'eval'
        self.num_epochs = 500
        self.imsize = (288,512)
        self.num_classes = int(3)
        self.gpu = '0'
        self.logs_path = 'logs/exp2'
        self.use_pretrained = True

opt = OPT()


###################### Options #############################
DIR_RESULTS = 'data/results/exp0/'
DIR_WEIGHTS = 'data/results/exp0/weights_png/'
DIR_WEIGHTS_VIZ = 'data/results/exp0/weights_png/rgb-visualizations/'

phase = opt.phase
device = torch.device("cuda:"+ opt.gpu if torch.cuda.is_available() else "cpu")


###################### DataLoader #############################
dataloader = Dataset(opt)


###################### ModelBuilder #############################
model = unet.Unet(num_classes=opt.num_classes)
criterion = nn.CrossEntropyLoss(reduction='sum').to(device)

# Load weights from checkpoint
if (opt.use_pretrained == True):
    checkpoint_path = 'logs/exp0/checkpoints/checkpoint-epoch_215.pth'
    model.load_state_dict(torch.load(checkpoint_path))
    
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, weight_decay=0.0001)
exp_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)


def label_to_rgb(label):
    '''Output RGB visualizations of the labels (outlines)
    Assumes labels have int values and max number of classes = 3
    
    Args:
        label (numpy.ndarray): Shape (height, width). Each pixel contains an int with value of class that it belongs to.
    
    Returns:
        numpy.ndarray: Shape (height, width, 3): RGB representation of the labels
    '''
    rgbArray = np.zeros((label.shape[0], label.shape[1], 3), dtype=np.uint8)
    rgbArray[:, :, 0][label == 0] = 255
    rgbArray[:, :, 1][label == 1] = 255
    rgbArray[:, :, 2][label == 2] = 255
    
    return rgbArray
    

# Run evaluation of the occlusion boundary model

In [None]:

model.to(device)
model.eval()

for i in range(int(dataloader.size()/opt.batchSize)):
    # Get data
    inputs, labels =  dataloader.get_batch()
    
    # Forward pass of the mini-batch
    inputs = inputs.to(device)
    labels = labels.to(device)
    
    # Forward Prop
    optimizer.zero_grad()
    torch.set_grad_enabled(True)
    logits = model(inputs)
    
    # calculating occlusion weights
    logits_softmax = nn.Softmax(dim=1)(logits).detach().cpu().numpy().astype(np.float32)
    file_arr = logits_softmax[0] # select the first img
    weight = (1-file_arr[1,:,:])
    x = np.power(weight,3)
    x = np.multiply(x,1000)
    final_weight = x.astype(np.uint16)
    # Increase the min and max values by small amount epsilon so that it doesn't cause problem in depth2depth optimization code.
    eps = 1
    final_weight[final_weight==0] += eps
    final_weight[final_weight==1000] -= eps
    
    # model predictions absolute - each pixel classified into a class
    predictions = torch.max(logits, 1)[1].detach().cpu().numpy()
    predictions = predictions[0]  # select the first img
    predictions_color = label_to_rgb(predictions)
    
    #original rgb image
    rgb_img = inputs.detach().cpu()
    rgb_img = rgb_img[0] # select the first img
    inv_normalize = transforms.Normalize(
            mean=[-0.5/0.5, -0.5/0.5, -0.5/0.5],
            std=[1/0.5, 1/0.5, 1/0.5]
        )
    rgb_img = inv_normalize(rgb_img)
    rgb_img = rgb_img.numpy()
    rgb_img = np.transpose(rgb_img, (1,2,0))
    rgb_img = (rgb_img * 255).astype(np.uint8)
    
    #label
    labels = labels.detach().cpu().squeeze().numpy()
    labels_color = label_to_rgb(labels)
    
    # Save output results
    final_weight_color =  np.array((1-file_arr[1,:,:])*255, dtype = np.uint8)
    final_weight_color = np.expand_dims(final_weight_color, axis=2)
    final_weight_color = cv2.applyColorMap(final_weight_color, cv2.COLORMAP_OCEAN)
    final_weight_color = cv2.cvtColor(final_weight_color, cv2.COLOR_BGR2RGB)
    result = np.concatenate([rgb_img, predictions_color, labels_color, final_weight_color], axis=1)
    imageio.imwrite(os.path.join(DIR_RESULTS, '%04d-results.png' % (i)), result)
    
    # Display Results
    print('Image %09d' % (i))
    fig = plt.figure()
    plt.imshow(result)
    plt.show()
    
    # Save weights file
    array_buffer = final_weight.tobytes()
    img = Image.new("I", final_weight.T.shape)
    img.frombytes(array_buffer, 'raw', 'I;16')
    img.save(os.path.join(DIR_WEIGHTS, '%09d-occlusion-weight.png' % (i)))
    
    # Save weights rgb representation
    imageio.imwrite(os.path.join(DIR_WEIGHTS_VIZ, '%09d-occlusion-weight-rgb.png' % (i)), final_weight_color)
    
    

    