In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision

from torch.utils.data import Dataset, DataLoader
from torchvision.utils import save_image
from torchvision import transforms

import argparse
import numpy as np
import pandas as pd
import random
import scipy.io
import sys
import tqdm
import yaml

from addict import Dict
from itertools import zip_longest
from PIL import Image, ImageFilter
from tensorboardX import SummaryWriter

from models.SegNet import SegNetBasic
from dataset import PartAffordanceDataset, CenterCrop, ToTensor, Normalize


In [2]:
def reverse_normalize(x, mean=[0.2191, 0.2349, 0.3598], std=[0.1243, 0.1171, 0.0748]):
    x[:, 0, :, :] = x[:, 0, :, :] * std[0] + mean[0]
    x[:, 1, :, :] = x[:, 1, :, :] * std[1] + mean[1]
    x[:, 2, :, :] = x[:, 2, :, :] * std[2] + mean[2]
    return x

In [3]:
# assign the colors to each class
colors = torch.tensor([[0, 0, 0],         # class 0 'background'  black
                       [255, 0, 0],       # class 1 'grasp'       red
                       [255, 255, 0],     # class 2 'cut'         yellow
                       [0, 255, 0],       # class 3 'scoop'       green
                       [0, 255, 255],     # class 4 'contain'     sky blue
                       [0, 0, 255],       # class 5 'pound'       blue
                       [255, 0, 255],     # class 6 'support'     purple
                       [255, 255, 255]    # class 7 'wrap grasp'  white
                      ])

In [4]:
# convert class prediction to the mask
def class_to_mask(cls):
    
    mask = colors[cls].transpose(1, 2).transpose(1, 3)
    
    return mask

In [5]:
def predict(model, sample, device='cpu'):
    model.eval()
    
    x, y0, y1 = sample['image'], sample['label'], sample['label1']
    batch_len, _, H, W = x.shape
        
    task0 = torch.zeros((batch_len, 2, H, W))
    task0[:, 0] = 1
    task1 = torch.zeros((batch_len, 2, H, W))
    task1[:, 1] = 1

    x = x.to(device)
    
    task0 = task0.to(device)
    task1 = task1.to(device)
    
    h0 = model(x, task0)
    h1 = model(x, task1)

    with torch.no_grad():
        _, y0_pred = h0.max(1)
        _, y1_pred = h1.max(1)
    
    true_mask0 = class_to_mask(y0)
    pred_mask0 = class_to_mask(y0_pred)
    true_mask1 = class_to_mask(y1)
    pred_mask1 = class_to_mask(y1_pred)
    
    x = reverse_normalize(x)
    
    save_image(x, 'result/orig_image.jpg')
    save_image(true_mask0, 'result/true_masks_task0.jpg')
    save_image(pred_mask0, 'result/pred_masks_task0.jpg')
    save_image(true_mask1, 'result/true_masks_task1.jpg')
    save_image(pred_mask1, 'result/pred_masks_task1.jpg')

In [6]:
model = SegNetBasic(3, 2, 4)
model.load_state_dict(torch.load('./result/best_mean_iou_model.prm'))

In [13]:
CONFIG = Dict(yaml.safe_load(open('./result/config.yaml')))


""" DataLoader """
test_data = PartAffordanceDataset(CONFIG.test_data,
                                config=CONFIG,
                                transform=transforms.Compose([
                                    CenterCrop(CONFIG),
                                    ToTensor(),
                                    Normalize()
                                ]))

test_loader = DataLoader(test_data, batch_size=8, shuffle=True, num_workers=0)

In [14]:
for sample in test_loader:
    predict(model, sample)
    break