In [1]:
### test for train ###
# 230902
# 

In [2]:
import sys
sys.path.append('/home/gpuadmin/dev/Trajectory_Prediction/traffino')

import argparse
import gc
import logging
import os
import sys
import time

from collections import defaultdict

import torch
import torch.nn as nn
import torch.optim as optim

from data.loader_basic import data_loader # basic train을 위한 loader

In [3]:
from losses import gan_g_loss, gan_d_loss, l2_loss
from losses import displacement_error, final_displacement_error

In [4]:
from traffino.model_basic import TrajectoryGenerator, TrajectoryDiscriminator
from traffino.utils import int_tuple, bool_flag, get_total_norm
from traffino.utils import relative_to_abs, get_dset_path

In [5]:
torch.backends.cudnn.benchmark = True # cudnn의 benchmark를 통해 최적 backend 연산을 찾는 flag를 true로 하겠음

In [6]:
parser = argparse.ArgumentParser()
FORMAT = '[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT, stream=sys.stdout)
logger = logging.getLogger(__name__)

In [7]:
from typing import Tuple
class CreateArg():
    def __init__(self):
        # Dataset options
        self.dataset_name = 'waterloo'
        self.delim = '\t'
        self.loader_num_workers = 4
        self.obs_len = 8 # 8 timestep
        self.pred_len = 8   
        self.skip = 1
        # Optimization
        self.batch_size = 8                    ##### batch_size test !!!
        self.num_iterations = 1000 # 10000
        self.num_epochs = 50 # 200                     
        # Model Options
        self.embedding_dim = 64
        self.num_layers = 1
        self.dropout = 0.0
        self.batch_norm = 0 
        self.mlp_dim = 1024
        # self.default_backbone= 'resnet18'
                                    
        # Generator Options
        self.encoder_h_dim_g = 64
        self.decoder_h_dim_g = 128
        self.noise_dim : Tuple[int] = (0, 0) # default=None # type=int_tuple
        self.noise_type = 'gaussian'
        self.noise_mix_type = 'ped'
        self.clipping_threshold_g = 0 # type=float
        self.g_learning_rate = 5e-4 # type=float 
        self.g_steps = 1

        # Pooling Options
        self.pooling_type = 'pool_net' 
        self.pool_every_timestep = 1 # type=bool_flag

        # Pool Net Option
        self.bottleneck_dim = 1024 # type=int

        # Social Pooling Options
        self.neighborhood_size = 1024 # type=float
        self.grid_size = 8 # type=int

        # Discriminator Options
        self.d_type = 'local' # type=str
        self.encoder_h_dim_d = 64 # type=int
        self.d_learning_rate = 5e-4 # type=float
        self.d_steps = 2 # type=int        
        self.clipping_threshold_d = 0 # type=float  

        # Loss Options
        self.l2_loss_weight = 0 # type=float 
        self.best_k = 1 # type=int 

        # Output
        self.output_dir = os.getcwd() + '/output/basic_output'
        self.print_every = 5 # type=int
        self.checkpoint_every = 100 # type=int
        self.checkpoint_name = 'checkpoint_basic' 
        self.checkpoint_start_from = None
        self.restore_from_checkpoint = 1 # type=int
        self.num_samples_check = 5000 # type=int        

        # Misc
        self.use_gpu = 1 # type=int
        self.timing = 0 # type=int
        self.gpu_num = "1" # type=str   


args = CreateArg() 

In [8]:
print(args.batch_size)
print(args.num_iterations)
print(args.checkpoint_name)
print(args.output_dir)
print(args.restore_from_checkpoint )

8
1000
checkpoint
/home/gpuadmin/dev/Trajectory_Prediction/traffino/output/basic_output
1


In [9]:
def init_weights(m):
    classname = m.__class__.__name__
    if classname.find('Linear') != -1:
        nn.init.kaiming_normal_(m.weight)


def get_dtypes(args):
    long_dtype = torch.LongTensor
    float_dtype = torch.FloatTensor
    if args.use_gpu == 1:
    # use_gpu == 1:
        long_dtype = torch.cuda.LongTensor
        float_dtype = torch.cuda.FloatTensor
    return long_dtype, float_dtype

In [10]:
long_dtype, float_dtype = get_dtypes(args)

In [11]:
train_path = '/home/gpuadmin/dev/Trajectory_Prediction/traffino/datasets/waterloo/train/'

# val_path = get_dset_path(args.dataset_name, 'val')

val_path = '/home/gpuadmin/dev/Trajectory_Prediction/traffino/datasets/waterloo/val/'

In [12]:
logger.info("Initializing train dataset")
train_dset, train_loader = data_loader(args, train_path) # train_dset은 TrajectoryDataset, train_loader는 DataLoader (batch 단위로 변경)

logger.info("Initializing val dataset")
_, val_loader = data_loader(args, val_path)

[INFO: 2824878509.py:    1]: Initializing train dataset
['0771_prep.txt', '0776_prep.txt', '0775_prep.txt', '0778_prep.txt', '0777_prep.txt', '0770_prep.txt', '0769_prep.txt', '0779_prep.txt']
/home/gpuadmin/dev/Trajectory_Prediction/traffino/datasets/waterloo/train/0769_prep.txt
/home/gpuadmin/dev/Trajectory_Prediction/traffino/datasets/waterloo/train/0770_prep.txt
/home/gpuadmin/dev/Trajectory_Prediction/traffino/datasets/waterloo/train/0771_prep.txt
/home/gpuadmin/dev/Trajectory_Prediction/traffino/datasets/waterloo/train/0775_prep.txt
/home/gpuadmin/dev/Trajectory_Prediction/traffino/datasets/waterloo/train/0776_prep.txt
/home/gpuadmin/dev/Trajectory_Prediction/traffino/datasets/waterloo/train/0777_prep.txt
/home/gpuadmin/dev/Trajectory_Prediction/traffino/datasets/waterloo/train/0778_prep.txt
/home/gpuadmin/dev/Trajectory_Prediction/traffino/datasets/waterloo/train/0779_prep.txt
[INFO: 2824878509.py:    4]: Initializing val dataset
['0780_prep.txt', '0782_prep.txt', '0783_prep.txt

In [13]:
len(train_dset) # self.num_seq (727) --> iterations_per_epoch 계산하기 위해 사용

5385

In [14]:
print(train_dset.num_seq)

5385


In [15]:
iterations_per_epoch = len(train_dset) / args.batch_size / args.d_steps
if args.num_epochs:
    args.num_iterations = int(iterations_per_epoch * args.num_epochs)

logger.info(
    'There are {} iterations per epoch'.format(iterations_per_epoch)
)

[INFO: 2378675245.py:    6]: There are 336.5625 iterations per epoch


In [16]:
generator = TrajectoryGenerator(
    obs_len=args.obs_len,
    pred_len=args.pred_len,
    embedding_dim=args.embedding_dim,
    encoder_h_dim=args.encoder_h_dim_g,
    decoder_h_dim=args.decoder_h_dim_g,
    mlp_dim=args.mlp_dim,
    num_layers=args.num_layers,
    noise_dim=args.noise_dim,
    bottleneck_dim=args.bottleneck_dim,
    noise_type=args.noise_type,
    noise_mix_type=args.noise_mix_type,
    pooling_type=args.pooling_type,
    pool_every_timestep=args.pool_every_timestep,
    dropout=args.dropout,
    activation = 'relu',
    batch_norm = args.batch_norm,
    neighborhood_size=args.neighborhood_size,
    # default_backbone = args.default_backbone
    # grid_size=args.grid_size
    )

input_dim:1088


In [17]:
generator.apply(init_weights)
generator.type(float_dtype).train()
logger.info('Here is the generator:')
logger.info(generator)

[INFO: 1295372059.py:    3]: Here is the generator:
[INFO: 1295372059.py:    4]: TrajectoryGenerator(
  (encoder): TrajEncoder(
    (encoder): LSTM(64, 64)
    (spatial_embedding): Linear(in_features=2, out_features=64, bias=True)
  )
  (decoder): Decoder(
    (decoder): LSTM(64, 128)
    (pool_net): PoolHiddenNet(
      (spatial_embedding): Linear(in_features=2, out_features=64, bias=True)
      (mlp_pre_pool): Sequential(
        (0): Linear(in_features=192, out_features=512, bias=True)
        (1): ReLU()
        (2): Linear(in_features=512, out_features=1024, bias=True)
        (3): ReLU()
      )
    )
    (mlp): Sequential(
      (0): Linear(in_features=1152, out_features=1024, bias=True)
      (1): ReLU()
      (2): Linear(in_features=1024, out_features=128, bias=True)
      (3): ReLU()
    )
    (spatial_embedding): Linear(in_features=2, out_features=64, bias=True)
    (hidden2pos): Linear(in_features=128, out_features=2, bias=True)
  )
  (pool_net): PoolHiddenNet(
    (spatial

In [18]:
discriminator = TrajectoryDiscriminator(
    obs_len=args.obs_len,
    pred_len=args.pred_len,
    embedding_dim=args.embedding_dim,
    h_dim=args.encoder_h_dim_d,
    mlp_dim=args.mlp_dim,
    num_layers=args.num_layers,
    dropout=args.dropout,
    batch_norm=args.batch_norm,
    d_type=args.d_type)

In [19]:
discriminator.apply(init_weights)
discriminator.type(float_dtype).train()
logger.info('Here is the discriminator:')
logger.info(discriminator)

[INFO: 2576673517.py:    3]: Here is the discriminator:
[INFO: 2576673517.py:    4]: TrajectoryDiscriminator(
  (encoder): TrajEncoder(
    (encoder): LSTM(64, 64)
    (spatial_embedding): Linear(in_features=2, out_features=64, bias=True)
  )
  (real_classifier): Sequential(
    (0): Linear(in_features=64, out_features=1024, bias=True)
    (1): ReLU()
    (2): Linear(in_features=1024, out_features=1, bias=True)
    (3): ReLU()
  )
)


In [20]:
g_loss_fn = gan_g_loss
d_loss_fn = gan_d_loss

optimizer_g = optim.Adam(generator.parameters(), lr=args.g_learning_rate)
optimizer_d = optim.Adam(
    discriminator.parameters(), lr=args.d_learning_rate
)

In [21]:
args.restore_from_checkpoint = 1 #default : 1

In [22]:
# Maybe restore from checkpoint
restore_path = None
if args.checkpoint_start_from is not None:
    restore_path = args.checkpoint_start_from
elif args.restore_from_checkpoint == 1:
    restore_path = os.path.join(args.output_dir,              # basic
                                '%s_basic_with_model.pt' % args.checkpoint_name)

In [23]:
print(restore_path) #checkpoint 불러오는 path
print(os.path.isfile(restore_path)) #checkpoint 불러오는 path 파일이 있는지 출력

/home/gpuadmin/dev/Trajectory_Prediction/traffino/output/basic_output/checkpoint_with_model.pt
False


In [24]:
if restore_path is not None and os.path.isfile(restore_path): # restore 할 파일이 있으면
    logger.info('Restoring from checkpoint {}'.format(restore_path))
    checkpoint = torch.load(restore_path)
    generator.load_state_dict(checkpoint['g_state'])
    discriminator.load_state_dict(checkpoint['d_state'])
    optimizer_g.load_state_dict(checkpoint['g_optim_state'])
    optimizer_d.load_state_dict(checkpoint['d_optim_state'])
    t = checkpoint['counters']['t']
    epoch = checkpoint['counters']['epoch']
    checkpoint['restore_ts'].append(t)
    
else:
    # Starting from scratch, so initialize checkpoint data structure
    t, epoch = 0, 0
    checkpoint = {
        'args': args.__dict__,
        'G_losses': defaultdict(list),
        'D_losses': defaultdict(list),
        'losses_ts': [],
        'metrics_val': defaultdict(list),
        'metrics_train': defaultdict(list),
        'sample_ts': [],
        'restore_ts': [],
        'norm_g': [],
        'norm_d': [],
        'counters': {
            't': None,
            'epoch': None,
        },
        'g_state': None,
        'g_optim_state': None,
        'd_state': None,
        'd_optim_state': None,
        'g_best_state': None,
        'd_best_state': None,
        'best_t': None,
        'g_best_nl_state': None,
        'd_best_state_nl': None,
        'best_t_nl': None,
    }

In [28]:
t0 = None
while t < args.num_iterations: # 10000
    gc.collect()
    d_steps_left = args.d_steps # 2
    g_steps_left = args.g_steps # 1
    epoch += 1
    logger.info('Starting epoch {}'.format(epoch))
    for batch in train_loader: # batch
        if args.timing == 1: # default = 0
            torch.cuda.synchronize()
            t1 = time.time()

        # Decide whether to use the batch for stepping on discriminator or
        # generator; an iteration consists of args.d_steps steps on the
        # discriminator followed by args.g_steps steps on the generator.
        
        if d_steps_left > 0:
            step_type = 'd'
            losses_d = discriminator_step(args, batch, generator,
                                            discriminator, d_loss_fn,
                                            optimizer_d)
            checkpoint['norm_d'].append(
                get_total_norm(discriminator.parameters()))
            d_steps_left -= 1
            
        elif g_steps_left > 0:
            step_type = 'g'
            losses_g = generator_step(args, batch, generator,
                                        discriminator, g_loss_fn, # g_loss_fn = gan_g_loss
                                        optimizer_g)
            
            checkpoint['norm_g'].append(
                get_total_norm(generator.parameters())
            )
            g_steps_left -= 1

        if args.timing == 1:
            torch.cuda.synchronize()
            t2 = time.time()
            logger.info('{} step took {}'.format(step_type, t2 - t1))

        # Skip the rest if we are not at the end of an iteration
        if d_steps_left > 0 or g_steps_left > 0:
            continue

        if args.timing == 1:
            if t0 is not None:
                logger.info('Interation {} took {}'.format(
                    t - 1, time.time() - t0
                ))
            t0 = time.time()

        # Maybe save loss
        if t % args.print_every == 0:
            logger.info('t = {} / {}'.format(t + 1, args.num_iterations))
            for k, v in sorted(losses_d.items()):
                logger.info('  [D] {}: {:.3f}'.format(k, v))
                checkpoint['D_losses'][k].append(v)
            for k, v in sorted(losses_g.items()):
                logger.info('  [G] {}: {:.3f}'.format(k, v))
                checkpoint['G_losses'][k].append(v)
            checkpoint['losses_ts'].append(t)

        # Maybe save a checkpoint
        if t > 0 and t % args.checkpoint_every == 0:
            checkpoint['counters']['t'] = t
            checkpoint['counters']['epoch'] = epoch
            checkpoint['sample_ts'].append(t)

            # Check stats on the validation set
            logger.info('Checking stats on val ...')
            metrics_val = check_accuracy(
                args, val_loader, generator, discriminator, 
                d_loss_fn, limit=False
            )
            logger.info('Checking stats on train ...')
            metrics_train = check_accuracy(
                args, train_loader, generator, discriminator,
                d_loss_fn, limit=True
            )

            for k, v in sorted(metrics_val.items()):
                logger.info('  [val] {}: {:.3f}'.format(k, v))
                checkpoint['metrics_val'][k].append(v)
                
            for k, v in sorted(metrics_train.items()):
                logger.info('  [train] {}: {:.3f}'.format(k, v))
                checkpoint['metrics_train'][k].append(v)

            min_ade = min(checkpoint['metrics_val']['ade'])
            min_ade_nl = min(checkpoint['metrics_val']['ade_nl'])

            if metrics_val['ade'] == min_ade:
                logger.info('New low for avg_disp_error')
                checkpoint['best_t'] = t
                checkpoint['g_best_state'] = generator.state_dict()
                checkpoint['d_best_state'] = discriminator.state_dict()

            if metrics_val['ade_nl'] == min_ade_nl:
                logger.info('New low for avg_disp_error_nl')
                checkpoint['best_t_nl'] = t
                checkpoint['g_best_nl_state'] = generator.state_dict()
                checkpoint['d_best_nl_state'] = discriminator.state_dict()

            # Save another checkpoint with model weights and
            # optimizer state
            checkpoint['g_state'] = generator.state_dict()
            checkpoint['g_optim_state'] = optimizer_g.state_dict()
            checkpoint['d_state'] = discriminator.state_dict()
            checkpoint['d_optim_state'] = optimizer_d.state_dict()
            
            checkpoint_path = os.path.join(
                args.output_dir, '%s_with_model.pt' % args.checkpoint_name
            )
            
            logger.info('Saving checkpoint to {}'.format(checkpoint_path))
            torch.save(checkpoint, checkpoint_path)
            logger.info('Done.')

            # Save a checkpoint with no model weights by making a shallow
            # copy of the checkpoint excluding some items
            
            checkpoint_path = os.path.join(
                args.output_dir, '%s_no_model.pt' % args.checkpoint_name)
            
            logger.info('Saving checkpoint to {}'.format(checkpoint_path))
            
            key_blacklist = [
                'g_state', 'd_state', 'g_best_state', 'g_best_nl_state',
                'g_optim_state', 'd_optim_state', 'd_best_state',
                'd_best_nl_state'
            ]
            
            small_checkpoint = {}
            for k, v in checkpoint.items():
                if k not in key_blacklist:
                    small_checkpoint[k] = v
            torch.save(small_checkpoint, checkpoint_path)
            logger.info('Done.')

        t += 1
        d_steps_left = args.d_steps
        
        g_steps_left = args.g_steps
        if t >= args.num_iterations:
            break

torch.cuda.empty_cache()

[INFO: 2281761580.py:    7]: Starting epoch 1
[INFO: 2281761580.py:   55]: t = 1 / 16828
[INFO: 2281761580.py:   57]:   [D] D_data_loss: 1.448
[INFO: 2281761580.py:   57]:   [D] D_total_loss: 1.448
[INFO: 2281761580.py:   60]:   [G] G_discriminator_loss: 0.616
[INFO: 2281761580.py:   60]:   [G] G_total_loss: 0.616
[INFO: 2281761580.py:   55]: t = 6 / 16828
[INFO: 2281761580.py:   57]:   [D] D_data_loss: 1.386
[INFO: 2281761580.py:   57]:   [D] D_total_loss: 1.386
[INFO: 2281761580.py:   60]:   [G] G_discriminator_loss: 0.693
[INFO: 2281761580.py:   60]:   [G] G_total_loss: 0.693
[INFO: 2281761580.py:   55]: t = 11 / 16828
[INFO: 2281761580.py:   57]:   [D] D_data_loss: 1.386
[INFO: 2281761580.py:   57]:   [D] D_total_loss: 1.386
[INFO: 2281761580.py:   60]:   [G] G_discriminator_loss: 0.693
[INFO: 2281761580.py:   60]:   [G] G_total_loss: 0.693
[INFO: 2281761580.py:   55]: t = 16 / 16828
[INFO: 2281761580.py:   57]:   [D] D_data_loss: 1.386
[INFO: 2281761580.py:   57]:   [D] D_total_lo

In [25]:
def generator_step(
    args, batch, generator, discriminator, g_loss_fn, optimizer_g
):
    batch = [tensor.cuda() for tensor in batch]
    (obs_traj, pred_traj_gt, obs_traj_rel, pred_traj_gt_rel, non_linear_ped,
     loss_mask, seq_start_end) = batch
    
    losses = {}
    loss = torch.zeros(1).to(pred_traj_gt)
    g_l2_loss_rel = []

    loss_mask = loss_mask[:, args.obs_len:]

    for _ in range(args.best_k):
        generator_out = generator(obs_traj, obs_traj_rel, seq_start_end)

        pred_traj_fake_rel = generator_out
        pred_traj_fake = relative_to_abs(pred_traj_fake_rel, obs_traj[-1])

        if args.l2_loss_weight > 0:
            g_l2_loss_rel.append(args.l2_loss_weight * l2_loss(
                pred_traj_fake_rel,
                pred_traj_gt_rel,
                loss_mask,
                mode='raw'))

    g_l2_loss_sum_rel = torch.zeros(1).to(pred_traj_gt)
    if args.l2_loss_weight > 0:
        g_l2_loss_rel = torch.stack(g_l2_loss_rel, dim=1)
        for start, end in seq_start_end.data:
            _g_l2_loss_rel = g_l2_loss_rel[start:end]
            _g_l2_loss_rel = torch.sum(_g_l2_loss_rel, dim=0)
            _g_l2_loss_rel = torch.min(_g_l2_loss_rel) / torch.sum(
                loss_mask[start:end])
            g_l2_loss_sum_rel += _g_l2_loss_rel
        losses['G_l2_loss_rel'] = g_l2_loss_sum_rel.item()
        loss += g_l2_loss_sum_rel

    traj_fake = torch.cat([obs_traj, pred_traj_fake], dim=0)
    traj_fake_rel = torch.cat([obs_traj_rel, pred_traj_fake_rel], dim=0)

    scores_fake = discriminator(traj_fake, traj_fake_rel, seq_start_end)
    discriminator_loss = g_loss_fn(scores_fake)

    loss += discriminator_loss
    losses['G_discriminator_loss'] = discriminator_loss.item()
    losses['G_total_loss'] = loss.item()

    optimizer_g.zero_grad()
    loss.backward()
    if args.clipping_threshold_g > 0:
        nn.utils.clip_grad_norm_(
            generator.parameters(), args.clipping_threshold_g
        )
    optimizer_g.step()

    return losses


In [26]:
def discriminator_step(
    args, batch, generator, discriminator, d_loss_fn, optimizer_d
):
    batch = [tensor.cuda() for tensor in batch]
    (obs_traj, pred_traj_gt, obs_traj_rel, pred_traj_gt_rel, non_linear_ped,
     loss_mask, seq_start_end) = batch
    losses = {}
    loss = torch.zeros(1).to(pred_traj_gt)

    generator_out = generator(obs_traj, obs_traj_rel, seq_start_end)
    
    pred_traj_fake_rel = generator_out
    pred_traj_fake = relative_to_abs(pred_traj_fake_rel, obs_traj[-1])

    traj_real = torch.cat([obs_traj, pred_traj_gt], dim=0) # Sizes of tensors must match except in dimension 0. Expected size 2 but got size 4 for tensor number 1 in the list.
    traj_real_rel = torch.cat([obs_traj_rel, pred_traj_gt_rel], dim=0)
    traj_fake = torch.cat([obs_traj, pred_traj_fake], dim=0)
    traj_fake_rel = torch.cat([obs_traj_rel, pred_traj_fake_rel], dim=0)

    scores_fake = discriminator(traj_fake, traj_fake_rel, seq_start_end)
    scores_real = discriminator(traj_real, traj_real_rel, seq_start_end)

    # Compute loss with optional gradient penalty
    data_loss = d_loss_fn(scores_real, scores_fake)
    losses['D_data_loss'] = data_loss.item()
    loss += data_loss
    losses['D_total_loss'] = loss.item()

    optimizer_d.zero_grad()
    loss.backward()
    if args.clipping_threshold_d > 0:
        nn.utils.clip_grad_norm_(discriminator.parameters(),
                                 args.clipping_threshold_d)
    optimizer_d.step()

    return losses

In [27]:
def check_accuracy(
    args, loader, generator, discriminator, d_loss_fn, limit=False
):
    d_losses = []
    metrics = {}
    g_l2_losses_abs, g_l2_losses_rel = ([],) * 2
    disp_error, disp_error_l, disp_error_nl = ([],) * 3
    f_disp_error, f_disp_error_l, f_disp_error_nl = ([],) * 3
    total_traj, total_traj_l, total_traj_nl = 0, 0, 0
    loss_mask_sum = 0
    generator.eval()
    with torch.no_grad():
        for batch in loader:
            batch = [tensor.cuda() for tensor in batch]
            (obs_traj, pred_traj_gt, obs_traj_rel, pred_traj_gt_rel,
             non_linear_ped, loss_mask, seq_start_end) = batch


            linear_ped = 1 - non_linear_ped
            loss_mask = loss_mask[:, args.obs_len:]

            pred_traj_fake_rel = generator(
                obs_traj, obs_traj_rel, seq_start_end
            )
            pred_traj_fake = relative_to_abs(pred_traj_fake_rel, obs_traj[-1])

            g_l2_loss_abs, g_l2_loss_rel = cal_l2_losses(
                pred_traj_gt, pred_traj_gt_rel, pred_traj_fake,
                pred_traj_fake_rel, loss_mask
            )
            ade, ade_l, ade_nl = cal_ade(
                pred_traj_gt, pred_traj_fake, linear_ped, non_linear_ped
            )

            fde, fde_l, fde_nl = cal_fde(
                pred_traj_gt, pred_traj_fake, linear_ped, non_linear_ped
            )

            traj_real = torch.cat([obs_traj, pred_traj_gt], dim=0)
            traj_real_rel = torch.cat([obs_traj_rel, pred_traj_gt_rel], dim=0)
            traj_fake = torch.cat([obs_traj, pred_traj_fake], dim=0)
            traj_fake_rel = torch.cat([obs_traj_rel, pred_traj_fake_rel], dim=0)

            scores_fake = discriminator(traj_fake, traj_fake_rel, seq_start_end)
            scores_real = discriminator(traj_real, traj_real_rel, seq_start_end)

            d_loss = d_loss_fn(scores_real, scores_fake)
            d_losses.append(d_loss.item())

            g_l2_losses_abs.append(g_l2_loss_abs.item())
            g_l2_losses_rel.append(g_l2_loss_rel.item())
            disp_error.append(ade.item())
            disp_error_l.append(ade_l.item())
            disp_error_nl.append(ade_nl.item())
            f_disp_error.append(fde.item())
            f_disp_error_l.append(fde_l.item())
            f_disp_error_nl.append(fde_nl.item())

            loss_mask_sum += torch.numel(loss_mask.data)
            total_traj += pred_traj_gt.size(1)
            total_traj_l += torch.sum(linear_ped).item()
            total_traj_nl += torch.sum(non_linear_ped).item()
            if limit and total_traj >= args.num_samples_check:
                break

    metrics['d_loss'] = sum(d_losses) / len(d_losses)
    metrics['g_l2_loss_abs'] = sum(g_l2_losses_abs) / loss_mask_sum
    metrics['g_l2_loss_rel'] = sum(g_l2_losses_rel) / loss_mask_sum

    metrics['ade'] = sum(disp_error) / (total_traj * args.pred_len)
    metrics['fde'] = sum(f_disp_error) / total_traj
    if total_traj_l != 0:
        metrics['ade_l'] = sum(disp_error_l) / (total_traj_l * args.pred_len)
        metrics['fde_l'] = sum(f_disp_error_l) / total_traj_l
    else:
        metrics['ade_l'] = 0
        metrics['fde_l'] = 0
    if total_traj_nl != 0:
        metrics['ade_nl'] = sum(disp_error_nl) / (
            total_traj_nl * args.pred_len)
        metrics['fde_nl'] = sum(f_disp_error_nl) / total_traj_nl
    else:
        metrics['ade_nl'] = 0
        metrics['fde_nl'] = 0

    generator.train()
    return metrics


def cal_l2_losses(
    pred_traj_gt, pred_traj_gt_rel, pred_traj_fake, pred_traj_fake_rel,
    loss_mask
):
    g_l2_loss_abs = l2_loss(
        pred_traj_fake, pred_traj_gt, loss_mask, mode='sum'
    )
    g_l2_loss_rel = l2_loss(
        pred_traj_fake_rel, pred_traj_gt_rel, loss_mask, mode='sum'
    )
    return g_l2_loss_abs, g_l2_loss_rel


def cal_ade(pred_traj_gt, pred_traj_fake, linear_ped, non_linear_ped):
    ade = displacement_error(pred_traj_fake, pred_traj_gt)
    ade_l = displacement_error(pred_traj_fake, pred_traj_gt, linear_ped)
    ade_nl = displacement_error(pred_traj_fake, pred_traj_gt, non_linear_ped)
    return ade, ade_l, ade_nl


def cal_fde(
    pred_traj_gt, pred_traj_fake, linear_ped, non_linear_ped
):
    fde = final_displacement_error(pred_traj_fake[-1], pred_traj_gt[-1])
    fde_l = final_displacement_error(
        pred_traj_fake[-1], pred_traj_gt[-1], linear_ped
    )
    fde_nl = final_displacement_error(
        pred_traj_fake[-1], pred_traj_gt[-1], non_linear_ped
    )
    return fde, fde_l, fde_nl