In [1]:
import argparse
import os
import sys
sys.path.append('/home/gpuadmin/dev/Trajectory_Prediction/traffino')

In [2]:
import torch
import matplotlib.pyplot as plt
import numpy as np


In [3]:
plt.style.use("seaborn-dark")
### loader check ####
from data.loader_basic import data_loader
##### model check #####
from model_basic import TrajectoryGenerator
from utils import (
    int_tuple,
    relative_to_abs,
    get_dset_path,
)
from losses import(
    displacement_error,
    final_displacement_error,
    l2_loss
)

In [4]:
from typing import Tuple
class CreateArg():
    def __init__(self):
        
        self.num_samples = 20 # type=int

        # Dataset options
        self.dataset_name = 'waterloo'
        self.delim = '\t'
        self.loader_num_workers = 4 # 4 -> 1
        self.obs_len = 8
        self.pred_len = 8                   ############################### pred_len check !!!
        self.skip = 1
        # Optimization
        self.batch_size = 4                 ################################  batch_size check !!!
        self.num_iterations = 1000 # 10000
        self.num_epochs = 50 # 200                     
        # Model Options
        self.embedding_dim = 64
        self.num_layers = 1
        self.dropout = 0.0
        self.batch_norm = 0 
        self.mlp_dim = 1024
        # self.default_backbone= 'resnet18'
                                    
        # Generator Options
        self.encoder_h_dim_g = 64
        self.decoder_h_dim_g = 128
        self.noise_dim : Tuple[int] = (0, 0) # default=None # type=int_tuple
        self.noise_type = 'gaussian'
        self.noise_mix_type = 'ped'
        self.clipping_threshold_g = 0 # type=float
        self.g_learning_rate = 5e-4 # type=float 
        self.g_steps = 1

        # Pooling Options
        self.pooling_type = 'pool_net' 
        self.pool_every_timestep = 1 # type=bool_flag

        # Pool Net Option
        self.bottleneck_dim = 1024 # type=int

        # Social Pooling Options
        self.neighborhood_size = 1024 # type=float
        self.grid_size = 8 # type=int

        # Discriminator Options
        self.d_type = 'local' # type=str
        self.encoder_h_dim_d = 64 # type=int
        self.d_learning_rate = 5e-4 # type=float
        self.d_steps = 2 # type=int        
        self.clipping_threshold_d = 0 # type=float  

        # Loss Options
        self.l2_loss_weight = 0 # type=float 
        self.best_k = 1 # type=int 

        # Output
        self.output_dir = os.getcwd()+ '/output2'
        self.print_every = 5 # type=int
        self.checkpoint_every = 100 # type=int
        self.checkpoint_name = 'checkpoint_basic' 
        self.checkpoint_start_from = None
        self.restore_from_checkpoint = 1 # type=int
        self.num_samples_check = 5000 # type=int        

        # Misc
        self.use_gpu = 1 # type=int
        self.timing = 0 # type=int
        self.gpu_num = "1" # type=str   

        # Log
        self.log_dir = "./"
        self.restore_path = "/home/gpuadmin/dev/Trajectory_Prediction/traffino/output/basic_output/checkpoint_basic_with_model.pt"
        


args = CreateArg() 

In [5]:
print(f"restore_path: {args.restore_path}")
print(f"batch_size: {args.batch_size}")

restore_path: /home/gpuadmin/dev/Trajectory_Prediction/traffino/output/basic_output/checkpoint_basic_with_model.pt
batch_size: 4


In [6]:
def evaluate_helper(error, seq_start_end, model_output_traj, model_output_traj_best):
    error = torch.stack(error, dim=1)
    for (start, end) in seq_start_end:
        start = start.item()
        end = end.item()
        _error = error[start:end]
        _error = torch.sum(_error, dim=0)
        min_index = _error.min(0)[1].item()
        model_output_traj_best[:, start:end, :] = model_output_traj[min_index][
            :, start:end, :
        ]
    return model_output_traj_best


In [7]:
def get_generator(checkpoint):
    # n_units = (
    #     [args.traj_lstm_hidden_size]
    #     + [int(x) for x in args.hidden_units.strip().split(",")]
    #     + [args.graph_lstm_hidden_size]
    # )
    # n_heads = [int(x) for x in args.heads.strip().split(",")]
    generator = TrajectoryGenerator(
        obs_len=args.obs_len, # 8
        pred_len=args.pred_len, # 8
        embedding_dim=args.embedding_dim,
        encoder_h_dim=args.encoder_h_dim_g,
        decoder_h_dim=args.decoder_h_dim_g,
        mlp_dim=args.mlp_dim,
        num_layers=args.num_layers,
        noise_dim=args.noise_dim,
        bottleneck_dim=args.bottleneck_dim,
        noise_type=args.noise_type,
        noise_mix_type=args.noise_mix_type,
        pooling_type=args.pooling_type,
        pool_every_timestep=args.pool_every_timestep,
        dropout=args.dropout,
        activation = 'relu',
        batch_norm = args.batch_norm,
        neighborhood_size=args.neighborhood_size,
    )
    generator.load_state_dict(checkpoint["g_state"])
    generator.cuda()
    generator.eval()
    return generator

In [8]:
def cal_ade_fde(pred_traj_gt, pred_traj_fake):
    ade = displacement_error(pred_traj_fake, pred_traj_gt, mode="raw")
    fde = final_displacement_error(pred_traj_fake[-1], pred_traj_gt[-1], mode="raw")
    return ade, fde

In [9]:
def plot_trajectory(args, loader, generator):
    ground_truth_input = []
    all_model_output_traj = []
    ground_truth_output = []
    pic_cnt = 0
    with torch.no_grad():
        for batch in loader:
            batch = [tensor.cuda() for tensor in batch]
            # (
            #     obs_traj,
            #     pred_traj_gt,
            #     obs_traj_rel,
            #     pred_traj_gt_rel,
            #     non_linear_ped,
            #     loss_mask,
            #     seq_start_end,
            # ) = batch
            (obs_traj, pred_traj_gt, obs_traj_rel, pred_traj_gt_rel, non_linear_ped,
            loss_mask, seq_start_end) = batch
            ade = []
            ground_truth_input.append(obs_traj)
            ground_truth_output.append(pred_traj_gt)
            model_output_traj = []
            model_output_traj_best = torch.ones_like(pred_traj_gt).cuda() # 동일한 shpae의 tensro를 1로 채워줌

            for _ in range(args.num_samples):
                # pred_traj_fake_rel = generator(
                #     obs_traj_rel, obs_traj, seq_start_end, 0, 3
                # )
                ########
                pred_traj_fake_rel = generator(
                    obs_traj, obs_traj_rel, seq_start_end,
                )

                pred_traj_fake_rel = pred_traj_fake_rel[-args.pred_len :]

                pred_traj_fake = relative_to_abs(pred_traj_fake_rel, obs_traj[-1])
                model_output_traj.append(pred_traj_fake)
                ade_, fde_ = cal_ade_fde(pred_traj_gt, pred_traj_fake)
                ade.append(ade_)
            model_output_traj_best = evaluate_helper(
                ade, seq_start_end, model_output_traj, model_output_traj_best
            )
            all_model_output_traj.append(model_output_traj_best)

            for (start, end) in seq_start_end:
                plt.figure(figsize=(20,15), dpi=100)
                ground_truth_input_x_piccoor = (
                    obs_traj[:, start:end, :].cpu().numpy()[:, :, 0].T
                )
                ground_truth_input_y_piccoor = (
                    obs_traj[:, start:end, :].cpu().numpy()[:, :, 1].T
                )
                ground_truth_output_x_piccoor = (
                    pred_traj_gt[:, start:end, :].cpu().numpy()[:, :, 0].T
                )
                ground_truth_output_y_piccoor = (
                    pred_traj_gt[:, start:end, :].cpu().numpy()[:, :, 1].T
                )
                model_output_x_piccoor = (
                    model_output_traj_best[:, start:end, :].cpu().numpy()[:, :, 0].T
                )
                model_output_y_piccoor = (
                    model_output_traj_best[:, start:end, :].cpu().numpy()[:, :, 1].T
                )
                for i in range(ground_truth_output_x_piccoor.shape[0]):

                    observed_line = plt.plot(
                        ground_truth_input_x_piccoor[i, :],
                        ground_truth_input_y_piccoor[i, :],
                        "r-",
                        linewidth=4,
                        label="Observed Trajectory",
                    )[0]
                    observed_line.axes.annotate(
                        "",
                        xytext=(
                            ground_truth_input_x_piccoor[i, -2],
                            ground_truth_input_y_piccoor[i, -2],
                        ),
                        xy=(
                            ground_truth_input_x_piccoor[i, -1],
                            ground_truth_input_y_piccoor[i, -1],
                        ),
                        arrowprops=dict(
                            arrowstyle="->", color=observed_line.get_color(), lw=1
                        ),
                        size=20,
                    )
                    ground_line = plt.plot(
                        np.append(
                            ground_truth_input_x_piccoor[i, -1],
                            ground_truth_output_x_piccoor[i, :],
                        ),
                        np.append(
                            ground_truth_input_y_piccoor[i, -1],
                            ground_truth_output_y_piccoor[i, :],
                        ),
                        "b-",
                        linewidth=4,
                        label="Ground Truth",
                    )[0]
                    predict_line = plt.plot(
                        np.append(
                            ground_truth_input_x_piccoor[i, -1],
                            model_output_x_piccoor[i, :],
                        ),
                        np.append(
                            ground_truth_input_y_piccoor[i, -1],
                            model_output_y_piccoor[i, :],
                        ),
                        color="#ffff00", # 노랑
                        ls="--",
                        linewidth=4,
                        label="Predicted Trajectory",
                    )[0]

                #plt.axis("off")
                plt.savefig(
                    "./traj_fig/basic/pic_{}.png".format(pic_cnt)
                )
                plt.close()
                pic_cnt += 1

In [10]:
def main(args):
    checkpoint = torch.load(args.restore_path) # .pt 파일을 불러올 경로
    generator = get_generator(checkpoint)
    # path = get_dset_path(args.dataset_name, args.dset_type)
    path = '/home/gpuadmin/dev/Trajectory_Prediction/traffino/datasets/waterloo/test/'
    _, loader = data_loader(args, path)
    plot_trajectory(args, loader, generator)

In [12]:
main(args)

input_dim:1088
['0785_prep.txt', '0784_prep.txt']
/home/gpuadmin/dev/Trajectory_Prediction/traffino/datasets/waterloo/test/0784_prep.txt
/home/gpuadmin/dev/Trajectory_Prediction/traffino/datasets/waterloo/test/0785_prep.txt
