In [1]:
import torch
torch.cuda.is_available()

True

In [2]:
import os
import argparse

import torch
from torch.utils.data import DataLoader

In [3]:
from Proposed.models import Global_Scene_CAM_NFDecoder
from Proposed.utils import ModelTest
from dataset.nuscenes import NuscenesDataset, nuscenes_collate

In [18]:
scene_channels = 3
sampling_rate = 2
nfuture = int(3 * sampling_rate)

velocity_const = 0.5
agent_embed_dim = 128
num_candidates = 6
att_dropout = 0.1

crossmodal_attention = False

use_scene = True
scene_size = (64, 64)
ploss_type = 'map'

In [19]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device: {}".format(device))

model = Global_Scene_CAM_NFDecoder(device=device, agent_embed_dim=agent_embed_dim, nfuture=nfuture, att_dropout=att_dropout,
                    velocity_const=velocity_const, num_candidates=num_candidates, decoding_steps=nfuture, att=crossmodal_attention)

if ploss_type == 'mseloss':
    from R2P2_MA.model_utils import MSE_Ploss
    ploss_criterion = MSE_Ploss()
else:
    from R2P2_MA.model_utils import Interpolated_Ploss
    ploss_criterion = Interpolated_Ploss()

device: cuda


In [20]:
test_partition = 'val'
map_version = '2.0'
sample_stride = 1
multi_agent = 1
num_workers = 20
test_cache = "./data/nuscenes_val_cache.pkl"
batch_size = 64

In [21]:
model = model.to(device)

dataset = NuscenesDataset(
    test_partition, map_version=map_version, sampling_rate=sampling_rate,
    sample_stride=sample_stride, use_scene=use_scene, scene_size=scene_size, 
    ploss_type=ploss_type, num_workers=num_workers, 
    cache_file=test_cache, multi_agent=multi_agent)

data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, 
                         pin_memory=True,collate_fn=lambda x: nuscenes_collate(x), 
                         num_workers=1)

print(f'Test Examples: {len(dataset)}')

Test Examples: 5118


In [8]:
import os
import sys
import time
import numpy as np
import datetime

import pickle as pkl

import matplotlib.pyplot as plt
import cv2
import torch
import pdb
from tqdm import tqdm

from torch.utils.tensorboard import SummaryWriter
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.distributions.multivariate_normal import MultivariateNormal
import torch.nn.functional as F

import logging

from multiprocessing import Pool

In [22]:
scene_channels = 3
sampling_rate = 2
nfuture = int(3 * sampling_rate)

velocity_const = 0.5
agent_embed_dim = 128
num_candidates = 6
att_dropout = 0.1

crossmodal_attention = False

use_scene = True
scene_size = (64, 64)
ploss_type = 'map'

test_partition = 'val'
map_version = '2.0'
sample_stride = 1
multi_agent = 1
num_workers = 20
test_cache = "./data/nuscenes_val_cache.pkl"
batch_size = 64

In [23]:
beta = 0.1
decoding_steps = int(3 *  sampling_rate)
model_type = "Global_Scene_CAM_NFDecoder"
flow_based_decoder = True
out_dir = "./test"

test_ckpt = "test.pth.tar"


test_render = 1
test_times = 10
render = test_render

_data_dir = './data/nuscenes'
map_file = lambda scene_id: [os.path.join(_data_dir, x[0], x[1], x[2], 'map/v1.3', x[3]) + '.pkl' for x in scene_id]


checkpoint = torch.load(test_ckpt)
model.load_state_dict(checkpoint['model_state'], strict=False)

<All keys matched successfully>

In [24]:
def dac(gen_trajs, map_file):
    if '.png' in map_file:
        map_array = cv2.imread(map_file, cv2.IMREAD_COLOR)

    elif '.pkl' in map_file:
        with open(map_file, 'rb') as pnt:
            map_array = pkl.load(pnt)

    da_mask = np.any(map_array > 0, axis=-1)

    num_agents, num_candidates, decoding_timesteps = gen_trajs.shape[:3]
    dac = []

    gen_trajs = ((gen_trajs + 56) * 2).astype(np.int64)

    stay_in_da_count = [0 for i in range(num_agents)]
    for k in range(num_candidates):
        gen_trajs_k = gen_trajs[:, k]

        stay_in_da = [True for i in range(num_agents)]

        oom_mask = np.any( np.logical_or(gen_trajs_k >= 224, gen_trajs_k < 0), axis=-1 )
        diregard_mask = oom_mask.sum(axis=-1) > 2
        for t in range(decoding_timesteps):
            gen_trajs_kt = gen_trajs_k[:, t]
            oom_mask_t = oom_mask[:, t]
            x, y = gen_trajs_kt.T

            lin_xy = (x*224+y)
            lin_xy[oom_mask_t] = -1
            for i in range(num_agents):
                xi, yi = x[i], y[i]
                _lin_xy = lin_xy.tolist()
                lin_xyi = _lin_xy.pop(i)

                if diregard_mask[i]:
                    continue

                if oom_mask_t[i]:
                    continue

                if not da_mask[yi, xi] or (lin_xyi in _lin_xy):
                    stay_in_da[i] = False

        for i in range(num_agents):
            if stay_in_da[i]:
                stay_in_da_count[i] += 1

    for i in range(num_agents):
        if diregard_mask[i]:
            dac.append(0.0)
        else:
            dac.append(stay_in_da_count[i] / num_candidates)

    dac_mask = np.logical_not(diregard_mask)

    return np.array(dac), dac_mask


def dao(gen_trajs, map_file):
    if '.png' in map_file:
        map_array = cv2.imread(map_file, cv2.IMREAD_COLOR)

    elif '.pkl' in map_file:
        with open(map_file, 'rb') as pnt:
            map_array = pkl.load(pnt)

    da_mask = np.any(map_array > 0, axis=-1)

    num_agents, num_candidates, decoding_timesteps = gen_trajs.shape[:3]
    dao = [0 for i in range(num_agents)]

    occupied = [[] for i in range(num_agents)]

    gen_trajs = ((gen_trajs + 56) * 2).astype(np.int64)

    for k in range(num_candidates):
        gen_trajs_k = gen_trajs[:, k]

        oom_mask = np.any( np.logical_or(gen_trajs_k >= 224, gen_trajs_k < 0), axis=-1 )
        diregard_mask = oom_mask.sum(axis=-1) > 2

        for t in range(decoding_timesteps):
            gen_trajs_kt = gen_trajs_k[:, t]
            oom_mask_t = oom_mask[:, t]
            x, y = gen_trajs_kt.T

            lin_xy = (x*224+y)
            lin_xy[oom_mask_t] = -1
            for i in range(num_agents):
                xi, yi = x[i], y[i]
                _lin_xy = lin_xy.tolist()
                lin_xyi = _lin_xy.pop(i)

                if diregard_mask[i]:
                    continue

                if oom_mask_t[i]:
                    continue

                if lin_xyi in occupied[i]:
                    continue

                if da_mask[yi, xi] and (lin_xyi not in _lin_xy):
                    occupied[i].append(lin_xyi)
                    dao[i] += 1

    for i in range(num_agents):
        if diregard_mask[i]:
            dao[i] = 0.0
        else:
            dao[i] /= da_mask.sum()

    dao_mask = np.logical_not(diregard_mask)

    return np.array(dao), dao_mask

In [14]:
def log_determinant(sigma):
    det = sigma[:, :, 0, 0] * sigma[:, :, 1, 1] - sigma[:, :, 0, 1] ** 2
    logdet = torch.log(det + 1e-9)

    return logdet

In [89]:
def run_test():
    print('Starting model test.....')
    model.eval()  # Set model to evaluate mode.

    list_loss = []
    list_qloss = []
    list_ploss = []
    list_minade2, list_avgade2 = [], []
    list_minfde2, list_avgfde2 = [], []
    list_minade3, list_avgade3 = [], []
    list_minfde3, list_avgfde3 = [], []
    list_minmsd, list_avgmsd = [], []

    list_dao = []
    list_dac = []

    for test_time_ in range(test_times):

        epoch_loss = 0.0
        epoch_qloss = 0.0
        epoch_ploss = 0.0
        epoch_minade2, epoch_avgade2 = 0.0, 0.0
        epoch_minfde2, epoch_avgfde2 = 0.0, 0.0
        epoch_minade3, epoch_avgade3 = 0.0, 0.0
        epoch_minfde3, epoch_avgfde3 = 0.0, 0.0
        epoch_minmsd, epoch_avgmsd = 0.0, 0.0
        epoch_agents, epoch_agents2, epoch_agents3 = 0.0, 0.0, 0.0

        epoch_dao = 0.0
        epoch_dac = 0.0
        dao_agents = 0.0
        dac_agents = 0.0

        H = W = 64
        with torch.no_grad():
            if map_version == '2.0':
                coordinate_2d = np.indices((H, W))
                coordinate = np.ravel_multi_index(coordinate_2d, dims=(H, W))
                coordinate = torch.FloatTensor(coordinate)
                coordinate = coordinate.reshape((1, 1, H, W))

                coordinate_std, coordinate_mean = torch.std_mean(coordinate)
                coordinate = (coordinate - coordinate_mean) / coordinate_std

                distance_2d = coordinate_2d - np.array([(H-1)/2, (H-1)/2]).reshape((2, 1, 1))
                distance = np.sqrt((distance_2d ** 2).sum(axis=0))
                distance = torch.FloatTensor(distance)
                distance = distance.reshape((1, 1, H, W))

                distance_std, distance_mean = torch.std_mean(distance)
                distance = (distance - distance_mean) / distance_std

                coordinate = coordinate.to(device)
                distance = distance.to(device)

            c1 = -decoding_steps * np.log(2 * np.pi)



            for b, batch in enumerate(data_loader):

                scene_images, log_prior, \
                agent_masks, \
                num_src_trajs, src_trajs, src_lens, src_len_idx, \
                num_tgt_trajs, tgt_trajs, tgt_lens, tgt_len_idx, \
                tgt_two_mask, tgt_three_mask, \
                decode_start_vel, decode_start_pos, scene_id = batch

                # Detect dynamic batch size
                batch_size = scene_images.size(0)
                num_three_agents = torch.sum(tgt_three_mask)

                if map_version == '2.0':
                    coordinate_batch = coordinate.repeat(batch_size, 1, 1, 1)
                    distance_batch = distance.repeat(batch_size, 1, 1, 1)
                    scene_images = torch.cat((scene_images.to(device), coordinate_batch, distance_batch), dim=1)

                src_trajs = src_trajs.to(device)
                src_lens = src_lens.to(device)

                tgt_trajs = tgt_trajs.to(device)[tgt_three_mask]
                tgt_lens = tgt_lens.to(device)[tgt_three_mask]

                num_tgt_trajs = num_tgt_trajs.to(device)
                episode_idx = torch.arange(batch_size, device=device).repeat_interleave(num_tgt_trajs)[tgt_three_mask]

                agent_masks = agent_masks.to(device)
                agent_tgt_three_mask = torch.zeros_like(agent_masks)
                agent_masks_idx = torch.arange(len(agent_masks), device=device)[agent_masks][tgt_three_mask]
                agent_tgt_three_mask[agent_masks_idx] = True

                decode_start_vel = decode_start_vel.to(device)[agent_tgt_three_mask]
                decode_start_pos = decode_start_pos.to(device)[agent_tgt_three_mask]

                log_prior = log_prior.to(device)

                if flow_based_decoder:
                    # Normalizing Flow (q loss)
                    # z: A X Td X 2
                    # mu: A X Td X 2
                    # sigma: A X Td X 2 X 2
                    # Generate perturbation
                    perterb = torch.normal(mean=0.0, std=np.sqrt(0.001), size=tgt_trajs.shape, device=device)

                    if model_type == 'R2P2_SimpleRNN':
                        z_, mu_, sigma_, motion_encoding_ = model.infer(tgt_trajs+perterb, src_trajs, decode_start_vel, decode_start_pos)

                    elif model_type == 'R2P2_RNN':
                        z_, mu_, sigma_, motion_encoding_, scene_encoding_ = model.infer(tgt_trajs+perterb, src_trajs, episode_idx, decode_start_vel, decode_start_pos, scene_images)

                    elif model_type == 'CAM_NFDecoder':
                        z_, mu_, sigma_, motion_encoding_ = model.infer(tgt_trajs+perterb, src_trajs, src_lens, agent_tgt_three_mask, decode_start_vel, decode_start_pos, num_src_trajs)

                    elif model_type == 'Scene_CAM_NFDecoder':
                        z_, mu_, sigma_, motion_encoding_, scene_encoding_ = model.infer(tgt_trajs+perterb, src_trajs, src_lens, agent_tgt_three_mask, episode_idx, decode_start_vel, decode_start_pos, num_src_trajs, scene_images)

                    elif model_type == 'Global_Scene_CAM_NFDecoder':
                        z_, mu_, sigma_, motion_encoding_, scene_encoding_ = model.infer(tgt_trajs+perterb, src_trajs, src_lens, agent_tgt_three_mask, episode_idx, decode_start_vel, decode_start_pos, num_src_trajs, scene_images)

                    elif model_type == 'AttGlobal_Scene_CAM_NFDecoder':
                        z_, mu_, sigma_, motion_encoding_, scene_encoding_ = model.infer(tgt_trajs+perterb, src_trajs, src_lens, agent_tgt_three_mask, episode_idx, decode_start_vel, decode_start_pos, num_src_trajs, scene_images)

                    z_ = z_.reshape((num_three_agents, -1)) # A X (Td*2)
                    log_q0 = c1 - 0.5 * ((z_ ** 2).sum(dim=1))

                    logdet_sigma = log_determinant(sigma_)

                    log_qpi = log_q0 - logdet_sigma.sum(dim=1)
                    qloss = -log_qpi
                    batch_qloss = qloss.mean()

                    # Prior Loss (p loss)
                    if model_type == 'R2P2_SimpleRNN':
                        gen_trajs, z, mu, sigma = model(motion_encoding_, decode_start_vel, decode_start_pos, motion_encoded=True)

                    elif model_type == 'R2P2_RNN':
                        gen_trajs, z, mu, sigma = model(motion_encoding_, episode_idx, decode_start_vel, decode_start_pos, scene_encoding_, motion_encoded=True, scene_encoded=True)

                    elif model_type == 'CAM_NFDecoder':
                        gen_trajs, z, mu, sigma = model(motion_encoding_, src_lens, agent_tgt_three_mask, decode_start_vel, decode_start_pos, num_src_trajs, agent_encoded=True)

                    elif model_type == 'Scene_CAM_NFDecoder':
                        gen_trajs, z, mu, sigma = model(motion_encoding_, src_lens, agent_tgt_three_mask, episode_idx, decode_start_vel, decode_start_pos, num_src_trajs, scene_encoding_, agent_encoded=True, scene_encoded=True)

                    elif model_type == 'Global_Scene_CAM_NFDecoder':
                        gen_trajs, z, mu, sigma = model(motion_encoding_, src_lens, agent_tgt_three_mask, episode_idx, decode_start_vel, decode_start_pos, num_src_trajs, scene_encoding_, agent_encoded=True, scene_encoded=True)

                    elif model_type == 'AttGlobal_Scene_CAM_NFDecoder':
                        gen_trajs, z, mu, sigma = model(motion_encoding_, src_lens, agent_tgt_three_mask, episode_idx, decode_start_vel, decode_start_pos, num_src_trajs, scene_encoding_, agent_encoded=True, scene_encoded=True)

                    if beta != 0.0:
                        if ploss_type == 'mseloss':
                            ploss = ploss_criterion(gen_trajs, tgt_trajs)
                        else:
                            ploss = ploss_criterion(episode_idx, gen_trajs, log_prior, -15.0)

                    else:
                        ploss = torch.zeros(size=(1,), device=device)
                    batch_ploss = ploss.mean()
                    batch_loss = batch_qloss + beta * batch_ploss

                    epoch_ploss += batch_ploss.item() * batch_size
                    epoch_qloss += batch_qloss.item() * batch_size   

                else:

                    if 'CAM' == model_type:
                        gen_trajs = model(src_trajs, src_lens, agent_tgt_three_mask, decode_start_vel, decode_start_pos, num_src_trajs)                                                    

                    gen_trajs = gen_trajs.reshape(num_three_agents, num_candidates, decoding_steps, 2)


                rs_error3 = ((gen_trajs - tgt_trajs.unsqueeze(1)) ** 2).sum(dim=-1).sqrt_()
                rs_error2 = rs_error3[..., :int(decoding_steps*2/3)]


                diff = gen_trajs - tgt_trajs.unsqueeze(1)
                msd_error = (diff[:,:,:,0] ** 2 + diff[:,:,:,1] ** 2)

                num_agents = gen_trajs.size(0)
                num_agents2 = rs_error2.size(0)
                num_agents3 = rs_error3.size(0)

                ade2 = rs_error2.mean(-1)
                fde2 = rs_error2[..., -1]

                minade2, _ = ade2.min(dim=-1)
                avgade2 = ade2.mean(dim=-1)
                minfde2, _ = fde2.min(dim=-1)
                avgfde2 = fde2.mean(dim=-1)

                batch_minade2 = minade2.mean()
                batch_minfde2 = minfde2.mean()
                batch_avgade2 = avgade2.mean()
                batch_avgfde2 = avgfde2.mean()

                ade3 = rs_error3.mean(-1)
                fde3 = rs_error3[..., -1]


                msd = msd_error.mean(-1)
                minmsd, _ = msd.min(dim=-1)
                avgmsd = msd.mean(dim=-1)
                batch_minmsd = minmsd.mean()
                batch_avgmsd = avgmsd.mean()


                minade3, _ = ade3.min(dim=-1)
                avgade3 = ade3.mean(dim=-1)
                minfde3, _ = fde3.min(dim=-1)
                avgfde3 = fde3.mean(dim=-1)

                batch_minade3 = minade3.mean()
                batch_minfde3 = minfde3.mean()
                batch_avgade3 = avgade3.mean()
                batch_avgfde3 = avgfde3.mean()

                if flow_based_decoder is not True:
                    batch_loss = batch_minade3
                    epoch_loss += batch_loss.item()
                    batch_qloss = torch.zeros(1)
                    batch_ploss = torch.zeros(1)

                print("Working on test {:d}/{:d}, batch {:d}/{:d}... ".format(test_time_+1, test_times, b+1, len(data_loader)), end='\r')# +

                epoch_ploss += batch_ploss.item() * batch_size
                epoch_qloss += batch_qloss.item() * batch_size
                epoch_minade2 += batch_minade2.item() * num_agents2
                epoch_avgade2 += batch_avgade2.item() * num_agents2
                epoch_minfde2 += batch_minfde2.item() * num_agents2
                epoch_avgfde2 += batch_avgfde2.item() * num_agents2
                epoch_minade3 += batch_minade3.item() * num_agents3
                epoch_avgade3 += batch_avgade3.item() * num_agents3
                epoch_minfde3 += batch_minfde3.item() * num_agents3
                epoch_avgfde3 += batch_avgfde3.item() * num_agents3


                epoch_minmsd += batch_minmsd.item() * num_agents3
                epoch_avgmsd += batch_avgmsd.item() * num_agents3

                epoch_agents += num_agents
                epoch_agents2 += num_agents2
                epoch_agents3 += num_agents3

                map_files = map_file(scene_id)
                output_files = [out_dir + '/' + x[2] + '_' + x[3] + '.jpg' for x in scene_id]

                cum_num_tgt_trajs = [0] + torch.cumsum(num_tgt_trajs, dim=0).tolist()
                cum_num_src_trajs = [0] + torch.cumsum(num_src_trajs, dim=0).tolist()

                src_trajs = src_trajs.cpu().numpy()
                src_lens = src_lens.cpu().numpy()

                tgt_trajs = tgt_trajs.cpu().numpy()
                tgt_lens = tgt_lens.cpu().numpy()

                zero_ind = np.nonzero(tgt_three_mask.numpy() == 0)[0]
                zero_ind -= np.arange(len(zero_ind))

                tgt_three_mask = tgt_three_mask.numpy()
                agent_tgt_three_mask = agent_tgt_three_mask.cpu().numpy()

                gen_trajs = gen_trajs.cpu().numpy()

                src_mask = agent_tgt_three_mask

                gen_trajs = np.insert(gen_trajs, zero_ind, 0, axis=0)

                tgt_trajs = np.insert(tgt_trajs, zero_ind, 0, axis=0)
                tgt_lens = np.insert(tgt_lens, zero_ind, 0, axis=0)

                for i in range(batch_size):
                    candidate_i = gen_trajs[cum_num_tgt_trajs[i]:cum_num_tgt_trajs[i+1]]
                    tgt_traj_i = tgt_trajs[cum_num_tgt_trajs[i]:cum_num_tgt_trajs[i+1]]
                    tgt_lens_i = tgt_lens[cum_num_tgt_trajs[i]:cum_num_tgt_trajs[i+1]]

                    src_traj_i = src_trajs[cum_num_src_trajs[i]:cum_num_src_trajs[i+1]]
                    src_lens_i = src_lens[cum_num_src_trajs[i]:cum_num_src_trajs[i+1]]
                    map_file_i = map_files[i]
                    output_file_i = output_files[i]

                    candidate_i = candidate_i[tgt_three_mask[cum_num_tgt_trajs[i]:cum_num_tgt_trajs[i+1]]]
                    tgt_traj_i = tgt_traj_i[tgt_three_mask[cum_num_tgt_trajs[i]:cum_num_tgt_trajs[i+1]]]
                    tgt_lens_i = tgt_lens_i[tgt_three_mask[cum_num_tgt_trajs[i]:cum_num_tgt_trajs[i+1]]]

                    src_traj_i = src_traj_i[agent_tgt_three_mask[cum_num_src_trajs[i]:cum_num_src_trajs[i+1]]]
                    src_lens_i = src_lens_i[agent_tgt_three_mask[cum_num_src_trajs[i]:cum_num_src_trajs[i+1]]]

                    dao_i, dao_mask_i = dao(candidate_i, map_file_i)
                    dac_i, dac_mask_i = dac(candidate_i, map_file_i)

                    epoch_dao += dao_i.sum()
                    dao_agents += dao_mask_i.sum()

                    epoch_dac += dac_i.sum()
                    dac_agents += dac_mask_i.sum()
                    
                    write_img_output(candidate_i, src_traj_i, src_lens_i, tgt_traj_i, tgt_lens_i, map_file_i, 'test/img')
#                     time.sleep(1)
                    
                    

#                 write_img_output(gen_trajs, src_trajs, src_lens, tgt_trajs, tgt_lens, map_files, 'test/img')
                
                
                
                

        if flow_based_decoder:
            list_ploss.append(epoch_ploss / epoch_agents)
            list_qloss.append(epoch_qloss / epoch_agents)
            list_loss.append(epoch_qloss + beta * epoch_ploss)

        else:
            list_loss.append(epoch_loss / epoch_agents)

        # 2-Loss
        list_minade2.append(epoch_minade2 / epoch_agents2)
        list_avgade2.append(epoch_avgade2 / epoch_agents2)
        list_minfde2.append(epoch_minfde2 / epoch_agents2)
        list_avgfde2.append(epoch_avgfde2 / epoch_agents2)

        # 3-Loss
        list_minade3.append(epoch_minade3 / epoch_agents3)
        list_avgade3.append(epoch_avgade3 / epoch_agents3)
        list_minfde3.append(epoch_minfde3 / epoch_agents3)
        list_avgfde3.append(epoch_avgfde3 / epoch_agents3)

        list_minmsd.append(epoch_minmsd / epoch_agents3)
        list_avgmsd.append(epoch_avgmsd / epoch_agents3)

        list_dao.append(epoch_dao / dao_agents)
        list_dac.append(epoch_dac / dac_agents)

    if flow_based_decoder:
        test_ploss = [np.mean(list_ploss), np.std(list_ploss)]
        test_qloss = [np.mean(list_qloss), np.std(list_qloss)]
        test_loss = [np.mean(list_loss), np.std(list_loss)]

    else:
        test_ploss = [0.0, 0.0]
        test_qloss = [0.0, 0.0]
        test_loss = [np.mean(list_loss), np.std(list_loss)]

    test_minade2 = [np.mean(list_minade2), np.std(list_minade2)]
    test_avgade2 = [np.mean(list_avgade2), np.std(list_avgade2)]
    test_minfde2 = [np.mean(list_minfde2), np.std(list_minfde2)]
    test_avgfde2 = [np.mean(list_avgfde2), np.std(list_avgfde2)]

    test_minade3 = [np.mean(list_minade3), np.std(list_minade3)]
    test_avgade3 = [np.mean(list_avgade3), np.std(list_avgade3)]
    test_minfde3 = [np.mean(list_minfde3), np.std(list_minfde3)]
    test_avgfde3 = [np.mean(list_avgfde3), np.std(list_avgfde3)]

    test_minmsd = [np.mean(list_minmsd), np.std(list_minmsd)]
    test_avgmsd = [np.mean(list_avgmsd), np.std(list_avgmsd)]

    test_dao = [np.mean(list_dao), np.std(list_dao)]
    test_dac = [np.mean(list_dac), np.std(list_dac)]

    test_ades = ( test_minade2, test_avgade2, test_minade3, test_avgade3 )
    test_fdes = ( test_minfde2, test_avgfde2, test_minfde3, test_avgfde3 )

    print("--Final Performane Report--")
    print("minADE3: {:.5f}±{:.5f}, minFDE3: {:.5f}±{:.5f}".format(test_minade3[0], test_minade3[1], test_minfde3[0], test_minfde3[1]))
    print("avgADE3: {:.5f}±{:.5f}, avgFDE3: {:.5f}±{:.5f}".format(test_avgade3[0], test_avgade3[1], test_avgfde3[0], test_avgfde3[1]))
    print("DAO: {:.5f}±{:.5f}, DAC: {:.5f}±{:.5f}".format(test_dao[0] * 10000.0, test_dao[1] * 10000.0, test_dac[0], test_dac[1]))
    with open(out_dir + '/metric.pkl', 'wb') as f:
        pkl.dump({"ADEs": test_ades,
                  "FDEs": test_fdes,
                  "Qloss": test_qloss,
                  "Ploss": test_ploss, 
                  "DAO": test_dao,
                  "DAC": test_dac}, f)

In [45]:
def write_img_output_(gen_trajs, src_trajs, src_lens, tgt_trajs, tgt_lens, map_file, output_file):
    if '.png' in map_file:
        map_array = cv2.imread(map_file, cv2.IMREAD_COLOR)
        map_array = cv2.cvtColor(map_array, cv2.COLOR_BGR2RGB)

    elif '.pkl' in map_file:
        with open(map_file, 'rb') as pnt:
            map_array = pkl.load(pnt)
    
    H, W = map_array.shape[:2]
    fig = plt.figure(figsize=(float(H) / float(80), float(W) / float(80)),
                    facecolor='k', dpi=80)

    ax = plt.axes()
    ax.imshow(map_array, extent=[-56, 56, 56, -56])
    ax.set_aspect('equal')
    ax.set_xlim([-56, 56])
    ax.set_ylim([-56, 56])

    plt.gca().invert_yaxis()
    plt.gca().set_axis_off()
    plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0, 
                        hspace = 0, wspace = 0)
    plt.margins(0,0)

    num_tgt_agents, num_candidates = gen_trajs.shape[:2]
    num_src_agents = len(src_trajs)

    for k in range(num_candidates):
        gen_trajs_k = gen_trajs[:, k]

        x_pts_k = []
        y_pts_k = []
        for i in range(num_tgt_agents):
            gen_traj_ki = gen_trajs_k[i]
            tgt_len_i = tgt_lens[i]
            x_pts_k.extend(gen_traj_ki[:tgt_len_i, 0])
            y_pts_k.extend(gen_traj_ki[:tgt_len_i, 1])

        ax.scatter(x_pts_k, y_pts_k, s=0.5, marker='o', c='b')

    x_pts = []
    y_pts = []
    for i in range(num_src_agents):
            src_traj_i = src_trajs[i]
            src_len_i = src_lens[i]
            x_pts.extend(src_traj_i[:src_len_i, 0])
            y_pts.extend(src_traj_i[:src_len_i, 1])

    ax.scatter(x_pts, y_pts, s=2.0, marker='x', c='g')

    x_pts = []
    y_pts = []
    for i in range(num_tgt_agents):
            tgt_traj_i = tgt_trajs[i]
            tgt_len_i = tgt_lens[i]
            x_pts.extend(tgt_traj_i[:tgt_len_i, 0])
            y_pts.extend(tgt_traj_i[:tgt_len_i, 1])

    ax.scatter(x_pts, y_pts, s=2.0, marker='o', c='r')

    fig.canvas.draw()
    buffer = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
    buffer = buffer.reshape((H, W, 3))

    buffer = cv2.cvtColor(buffer, cv2.COLOR_RGB2BGR)
    cv2.imwrite(output_file+'file.png', buffer)
    ax.clear()
    plt.close(fig)

In [92]:
def write_img_output1(gen_trajs, src_trajs, src_lens, tgt_trajs, tgt_lens, map_file, output_file):
    if '.png' in map_file:
        map_array = cv2.imread(map_file, cv2.IMREAD_COLOR)
        map_array = cv2.cvtColor(map_array, cv2.COLOR_BGR2RGB)

    elif '.pkl' in map_file:
        with open(map_file, 'rb') as pnt:
            map_array = pkl.load(pnt)
    
    H, W = map_array.shape[:2]
#     fig, ax = plt.subplots(1, 1, facecolor='k')
    fig, ax = plt.subplots(1, 1)
    
    
    ax.set_title("Inference")
    ax.imshow(map_array, extent=[-56, 56, 56, -56])
    ax.set_aspect('equal')
    ax.set_xlim([-56, 56])
    ax.set_ylim([-56, 56])
    plt.gca().invert_yaxis()
    plt.gca().set_axis_off()
    plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0, hspace = 0, wspace = 0)
    plt.margins(0,0)

    num_tgt_agents, num_candidates = gen_trajs.shape[:2]
    num_src_agents = len(src_trajs)

    for k in range(num_candidates):
        gen_trajs_k = gen_trajs[:, k]

        x_pts_k = []
        y_pts_k = []
        for i in range(num_tgt_agents):
            gen_traj_ki = gen_trajs_k[i]
            tgt_len_i = tgt_lens[i]
            x_pts_k.extend(gen_traj_ki[:tgt_len_i, 0])
            y_pts_k.extend(gen_traj_ki[:tgt_len_i, 1])

#         ax.scatter(x_pts_k, y_pts_k, s=0.5, marker='o', c='b')
        ax.scatter(x_pts_k, y_pts_k, label='generated')

    x_pts = []
    y_pts = []
    for i in range(num_src_agents):
            src_traj_i = src_trajs[i]
            src_len_i = src_lens[i]
            x_pts.extend(src_traj_i[:src_len_i, 0])
            y_pts.extend(src_traj_i[:src_len_i, 1])

#     ax.scatter(x_pts, y_pts, s=2.0, marker='x', c='g')
    ax.scatter(x_pts, y_pts, label='source')

    x_pts = []
    y_pts = []
    for i in range(num_tgt_agents):
            tgt_traj_i = tgt_trajs[i]
            tgt_len_i = tgt_lens[i]
            x_pts.extend(tgt_traj_i[:tgt_len_i, 0])
            y_pts.extend(tgt_traj_i[:tgt_len_i, 1])

#     ax.scatter(x_pts, y_pts, s=2.0, marker='o', c='r')
    ax.scatter(x_pts, y_pts, label='target')
    
    ax.legend()

    plt.show()

In [90]:
def write_img_output2(gen_trajs, src_trajs, src_lens, tgt_trajs, tgt_lens, map_file, output_file):
    if '.png' in map_file:
        map_array = cv2.imread(map_file, cv2.IMREAD_COLOR)
        map_array = cv2.cvtColor(map_array, cv2.COLOR_BGR2RGB)

    elif '.pkl' in map_file:
        with open(map_file, 'rb') as pnt:
            map_array = pkl.load(pnt)
    
    
    def plot_candidate(title, ax, can_idx):
        ax.set_title(title)
        ax.imshow(map_array, extent=[-56, 56, 56, -56])
        ax.set_aspect('equal')
        ax.set_xlim([-56, 56])
        ax.set_ylim([-56, 56])
        
        num_tgt_agents, num_candidates = gen_trajs.shape[:2]
        num_src_agents = len(src_trajs)

        gen_trajs_k = gen_trajs[:, can_idx]
        
        x_pts_k, y_pts_k = [], []
        for i in range(num_tgt_agents):
            gen_traj_ki = gen_trajs_k[i]
            tgt_len_i = tgt_lens[i]
            x_pts_k.extend(gen_traj_ki[:tgt_len_i, 0])
            y_pts_k.extend(gen_traj_ki[:tgt_len_i, 1])
#         print("num_tgt_agents: {}".format(num_tgt_agents))
        ax.plot(x_pts_k, y_pts_k, label='generated')

        x_pts, y_pts = [], []
        for i in range(num_src_agents):
                src_traj_i = src_trajs[i]
                src_len_i = src_lens[i]
                x_pts.extend(src_traj_i[:src_len_i, 0])
                y_pts.extend(src_traj_i[:src_len_i, 1])
        print("num_src_agents: {}".format(num_src_agents))
        ax.scatter(x_pts, y_pts, label='source', alpha=0.3)

        x_pts, y_pts = [], []
        for i in range(num_tgt_agents):
                tgt_traj_i = tgt_trajs[i]
                tgt_len_i = tgt_lens[i]
                x_pts.extend(tgt_traj_i[:tgt_len_i, 0])
                y_pts.extend(tgt_traj_i[:tgt_len_i, 1])
        print("num_tgt_agents: {}".format(num_tgt_agents))
        ax.scatter(x_pts, y_pts, label='target', alpha=0.3)
        ax.legend()
        
        
    H, W = map_array.shape[:2]
    fig, ax = plt.subplots(2, 3, figsize=(18,6))
    plt.gca().invert_yaxis()
    plt.gca().set_axis_off()
    
    
    for a in range(2):
        for b in range(3):
            num = 3*a + b
            plot_candidate('Inference {}'.format(num+1), ax[a][b], num)

    plt.show()

In [139]:
def write_img_output(gen_trajs, src_trajs, src_lens, tgt_trajs, tgt_lens, map_file, output_file):
    if '.png' in map_file:
        map_array = cv2.imread(map_file, cv2.IMREAD_COLOR)
        map_array = cv2.cvtColor(map_array, cv2.COLOR_BGR2RGB)

    elif '.pkl' in map_file:
        with open(map_file, 'rb') as pnt:
            map_array = pkl.load(pnt)
    
    
    def plot_candidate(title, ax, can_idx):
        ax.set_title(title)
        ax.imshow(map_array, extent=[-56, 56, 56, -56])
        ax.set_aspect('equal')
        ax.set_xlim([-56, 56])
        ax.set_ylim([-56, 56])
        
        num_tgt_agents, num_candidates = gen_trajs.shape[:2]
        num_src_agents = len(src_trajs)

        gen_trajs_k = gen_trajs[:, can_idx]
        
        x_pts_k, y_pts_k = [], []
        for i in range(num_tgt_agents):
            gen_traj_ki = gen_trajs_k[i]
            tgt_len_i = tgt_lens[i]
            x_pts_k.extend(gen_traj_ki[:tgt_len_i, 0])
            y_pts_k.extend(gen_traj_ki[:tgt_len_i, 1])
            ax.plot(gen_traj_ki[:tgt_len_i, 0], gen_traj_ki[:tgt_len_i, 1], c='g', linewidth=3.5)

        x_pts, y_pts = [], []
        for i in range(num_src_agents):
                src_traj_i = src_trajs[i]
                src_len_i = src_lens[i]
                x_pts.extend(src_traj_i[:src_len_i, 0])
                y_pts.extend(src_traj_i[:src_len_i, 1])
                ax.plot(src_traj_i[:src_len_i, 0], src_traj_i[:src_len_i, 1], alpha=0.3, c='orange', linewidth=3.5)

        x_pts, y_pts = [], []
        for i in range(num_tgt_agents):
                tgt_traj_i = tgt_trajs[i]
                tgt_len_i = tgt_lens[i]
                x_pts.extend(tgt_traj_i[:tgt_len_i, 0])
                y_pts.extend(tgt_traj_i[:tgt_len_i, 1])
                ax.plot(tgt_traj_i[:tgt_len_i, 0], tgt_traj_i[:tgt_len_i, 1], alpha=0.3, c='r', linewidth=3.5)
        
        ax.plot([], [], c='r', alpha=0.3, label='ground-truth')
        ax.plot([], [], c='orange', alpha=0.3, label='history')
        ax.plot([], [], c='g', label='estimated')
        ax.legend()
        
        
    H, W = map_array.shape[:2]
    fig, ax = plt.subplots(2, 3, figsize=(18,12))
#     plt.gca().invert_yaxis()
#     plt.gca().set_axis_off()
    
    
    for a in range(2):
        for b in range(3):
            num = 3*a + b
            plot_candidate('Inference {}'.format(num+1), ax[a][b], num)

#     plt.show()
    
    global img_save_count
    
    fig.savefig('./test/results/scene_{}.jpg'.format(img_save_count), bbox_inches='tight', pad_inches=0.5, dpi=150)
    img_save_count += 1

In [None]:
img_save_count = 0

run_test()

Starting model test.....
Working on test 1/10, batch 1/80... 

  "See the documentation of nn.Upsample for details.".format(mode))
