In [28]:
# evaluate model on val dataset

In [40]:
import numpy as np
import torch
import torch.backends.cudnn as cudnn
import torchvision.transforms as transforms
from munch import Munch
from tqdm import tqdm_notebook as tqdm

import datasets
import models
import utils

In [41]:
#config_path = 'pretrained/floating_kinect1_object/config.yml'
config_path = 'pretrained/kinect1_mask/config.yml'
#config_path = 'pretrained/floating_kinect1_mask_with_occlusion/config.yml'

In [42]:
with open(config_path, 'r') as f:
    cfg = Munch.fromYAML(f)

In [43]:
model = models.Model(cfg.arch)
model = torch.nn.DataParallel(model).cuda()
cudnn.benchmark = True
checkpoint = torch.load(cfg.training.resume)
model.load_state_dict(checkpoint['state_dict'])
model.eval()
print("=> loaded checkpoint '{}' (epoch {})".format(cfg.training.resume, checkpoint['epoch']))

=> loaded checkpoint 'pretrained/kinect1_mask/checkpoint_00002100.pth.tar' (epoch 2100)


In [44]:
transform = transforms.ToTensor()
val_dataset = datasets.RenderedPoseDataset(
    cfg.data.root, cfg.data.objects, cfg.data.val_subset_num, transform)
val_loader = torch.utils.data.DataLoader(
    val_dataset, batch_size=1, shuffle=True, num_workers=0, pin_memory=True)

In [45]:
def forward_batch(model, input, target, object_index):
    target = target.cuda(non_blocking=True)
    object_index = object_index.cuda(non_blocking=True)

    position, orientation = model(input, object_index)
    position_error = (target[:, :3] - position).pow(2).sum(dim=1).sqrt()
    orientation_error = 180.0 / np.pi * utils.batch_rotation_angle(target[:, 3:], orientation)

    return position_error.mean(), orientation_error.mean()

In [46]:
position_errors = utils.AverageMeter()
orientation_errors = utils.AverageMeter()
with torch.no_grad():
    for input, target, object_index in tqdm(val_loader):
        position_error, orientation_error = forward_batch(model, input, target, object_index)
        position_errors.update(position_error, input.size(0))
        orientation_errors.update(orientation_error, input.size(0))
print('overall position error: {:.2f} meters'.format(100 * position_errors.avg))
print('overall orientation error: {:.2f} degrees'.format(orientation_errors.avg))

HBox(children=(IntProgress(value=0, max=2560), HTML(value='')))


overall position error: 1.10 meters
overall orientation error: 58.35 degrees
