Describe your code using comments.

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
root = '/home/vilab/ssd1tb/hj_ME455/Trajectory_Prediction'
import sys
sys.path.append(root)

import argparse
import gc
import logging
import os
import sys
import time
import numpy as np

from collections import defaultdict
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim

from trajectories import data_loader
from utils import relative_to_abs, get_dset_path

In [None]:
from easydict import EasyDict as edict
cfg = edict()

cfg.dataset_name = 'zara1'
cfg.delim = '\t'
cfg.loader_num_workers = 4

cfg.obs_len = 8
cfg.pred_len = 8
cfg.skip = 1

cfg.batch_size = 16
cfg.num_epochs = 200
cfg.learning_rate = 5e-4

cfg.embedding_dim = 64
cfg.h_dim = 64
cfg.num_layers = 1
cfg.mlp_dim = 1024
cfg.dropout = 0

cfg.best_k = 1

In [None]:
train_path = get_dset_path(cfg.dataset_name, 'train')
val_path = get_dset_path(cfg.dataset_name, 'val')

In [None]:
train_dset, train_loader = data_loader(cfg, train_path)
_, val_loader = data_loader(cfg, val_path)

### Model for trajectory prediction

In [None]:

class Encoder(nn.Module):
    def __init__(
        self, embedding_dim=64, h_dim=64, num_layers=1, type='lstm'):
        super(Encoder, self).__init__()

        self.h_dim = h_dim
        self.embedding_dim = embedding_dim
        self.num_layers = num_layers

        self.encoder = nn.LSTM(???, ???, ???)

        self.en_mlp = nn.Linear(???, ???)
        self.de_mlp = nn.Linear(???, ???)

    def make_state(self, batch):
        return (
            torch.zeros(self.num_layers, batch, self.h_dim).cuda(),
            torch.zeros(self.num_layers, batch, self.h_dim).cuda()
        )

    def forward(self, obs_traj):
        batch = obs_traj.size(1)
        encoder_input = self.en_mlp(???)
        encoder_input = encoder_input.view(???, ???, ???)
        
        state = self.make_state(batch)
        encoder_output, state = self.encoder(???, ???)
        
        pred = self.de_mlp(encoder_output.reshape(???, ???))
        pred = pred.view(-1, batch, 2)
        
        return ???, ???
    

class Decoder(nn.Module):
    def __init__(
        self, seq_len, embedding_dim=64, h_dim=128, num_layers=1, type='lstm'):
        super(Decoder, self).__init__()

        self.seq_len = seq_len
        self.h_dim = h_dim
        self.embedding_dim = embedding_dim
        self.model_type = type
        
        self.decoder = nn.LSTM(???, ???, ???)

        self.en_mlp = nn.Linear(???, ???)
        self.de_mlp = nn.Linear(???, ???)

    def forward(self, last_pos, last_pos_rel, state):
        batch = last_pos_rel.size(0)
        pred_traj_fake_rel = []
        decoder_input = self.en_mlp(???)
        decoder_input = decoder_input.view(???, ???, ???)

        for _ in range(???):
            decoder_output, state = self.decoder(???, ???)
            pos_rel = self.de_mlp(???)
            curr_pos = ??? + ???

            ??? = self.en_mlp(pos_rel).view(???, ???, ???)
            pred_traj_fake_rel.append(???)
            last_pos = ???

        pred_traj_fake_rel = torch.stack(pred_traj_fake_rel, dim=0)
        return ???, ???

In [None]:
class MyModel(nn.Module):
    def __init__(self, cfg, type):
        super(MyModel, self).__init__()
        self.encoder = Encoder(embedding_dim=cfg.embedding_dim, h_dim=cfg.h_dim, num_layers=cfg.num_layers, type='lstm')
        self.decoder = Decoder(seq_len=cfg.pred_len-1, embedding_dim=cfg.embedding_dim, h_dim=cfg.h_dim, num_layers=cfg.num_layers, type='lstm')

    
    def forward(self, obs_traj, obs_traj_rel):
        last_pos = ???
        
        ???, ??? = self.encoder(???)
        ???, ??? = self.decoder(???, ???, ???)
        
        pred_traj_rel = torch.unsqueeze(pred_traj_rel,0) 
        pred_traj_fake_rel = torch.cat([pred_traj_rel, pred_traj_fake_rel], dim=0)
        
        return pred_traj_fake_rel

### Pre-defined functions

In [None]:
def cal_ade(fake, gt):
    ade = gt.permute(1, 0, 2) - fake.permute(1, 0, 2)
    ade = ade**2
    ade = torch.sqrt(ade.sum(dim=2)).sum(dim=1)
    return torch.sum(ade)


def cal_fde(fake, gt):
    fde = gt - fake
    fde = fde**2
    fde = torch.sqrt(fde.sum(dim=1))
    return torch.sum(fde)

In [None]:
def loss_fn(pred_traj_fake_rel, pred_traj_gt_rel, loss_mask):
    fake = loss_mask.unsqueeze(2)*pred_traj_fake_rel.permute(1, 0, 2)
    gt =  loss_mask.unsqueeze(2)*pred_traj_gt_rel.permute(1, 0, 2)
    
    l2_loss = torch.nn.MSELoss()
    
    
    return l2_loss(fake, gt)

### Build the train lodaer and validation loader for training the model.

While validation we will calculate ade and fde using pre-defined functions above.
To do this, we will use "relative_to_abs" function to change realtive position to absolute position.

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

model_lstm = ???

model_lstm.to(device)

optimizer = torch.optim.Adam(model_lstm.parameters(), lr=0.001)

os.makedirs(os.path.join(root, 'ckpt_lstm'), exist_ok=True)
save_path = os.path.join(root, 'ckpt_lstm/ckpt_lstm_')

In [None]:
for i in tqdm(range(cfg.num_epochs)):
    for batch in train_loader:
        batch = [tensor.cuda() for tensor in batch]
        (obs_traj, pred_traj_gt, obs_traj_rel, pred_traj_gt_rel, _, loss_mask, seq_start_end) = batch
        loss_mask = loss_mask[:, cfg.obs_len:]
        ### fill the train loader ###
        
        
        
        ##################################

    for batch in val_loader:
        batch = [tensor.cuda() for tensor in batch]
        (obs_traj, pred_traj_gt, obs_traj_rel, pred_traj_gt_rel, _, loss_mask, seq_start_end) = batch
        loss_mask = loss_mask[:, cfg.obs_len:]
        ### fill the validation loader ### # hint: use "relative_to_abs" function to validate your result
        
        
        
        ##################################
        
    ade = sum(ade) / (total_traj * cfg.pred_len)
    fde = sum(fde) / (total_traj)
    
    torch.save({
        'epoch': i,
        'model_state_dict': model_lstm.state_dict(),
        'optimizer_state_dict': optimizer.state_dict()
        }, f'{save_path}{i:03d}.pt')

    if (i + 1) % 10 == 0:
        e = i + 1
        print(f'epoch: {e}, train_loss: {train_loss_value}, val_ade: {ade}, val_fde: {fde}')

### Visualize the prediction results with your trajectory prediction model and its checkpoint

In [None]:
model_lstm = ?????

import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
import random

r = len(val_loader)
        
rr = random.randrange(0,r)

batch_cnt = 0   
for batch in val_loader:
    if batch_cnt == rr:
        batch = [tensor.cuda() for tensor in batch]
        (obs_traj, pred_traj_gt, obs_traj_rel, pred_traj_gt_rel, _, loss_mask, seq_start_end) = batch
        
        loss_mask = loss_mask[:, cfg.obs_len:]
        
        ### fill the validation loader ### # hint: use "relative_to_abs" function to validate your result
        
        
        pred_traj_fake_lstm = ???
        ##################################
        
        for i in range(seq_start_end.size(0)):
            startpoint = seq_start_end[i][0]
            endpoint = seq_start_end[i][1]
            ped_num = endpoint - startpoint
            if endpoint - startpoint < 3:
                break
            
        obs_x, obs_y = obs_traj[:,startpoint:endpoint,0].detach().cpu().numpy(), obs_traj[:,startpoint:endpoint,1].detach().cpu().numpy()
        gt_x, gt_y = pred_traj_gt[:,startpoint:endpoint,0].detach().cpu().numpy(), pred_traj_gt[:,startpoint:endpoint,1].detach().cpu().numpy()
        pred_x_lstm, pred_y_lstm = pred_traj_fake_lstm[:,startpoint:endpoint,0].detach().cpu().numpy(), pred_traj_fake_lstm[:,startpoint:endpoint,1].detach().cpu().numpy()

        fig = plt.figure()
        fig.set_size_inches(7, 7)
        
        plt.xlim([-15, 15])      
        plt.ylim([-15, 15])
        
        list_x, list_y = [], []

        for i in range(8):
            for j in range(ped_num):
                list_x.append(obs_x[i][j])
                list_y.append(obs_y[i][j])

        line, = plt.plot(list_x, list_y, 'b.', label='obersed')
        
        list_x_fake, list_y_fake = [], []
        line_fake, = plt.plot([], [], 'ro', label='predicted')
        
        list_x_gt, list_y_gt = [], []
        line_gt, = plt.plot([], [], 'g*', label = 'gt')
        
        plt.title('LSTM Traj')
        plt.legend()
        
        def update(i):
            for j in range(ped_num):
                list_x_fake.append(pred_x_lstm[i][j])
                list_y_fake.append(pred_y_lstm[i][j])
                list_x_gt.append(gt_x[i][j])
                list_y_gt.append(gt_y[i][j])
            line_fake.set_data(list_x_fake, list_y_fake)
            line_gt.set_data(list_x_gt, list_y_gt)
            return line_fake, line_gt,
        

        anim = FuncAnimation(fig, update, frames=[0,1,2,3,4,5,6,7],interval=1000)
        anim.save(os.path.join(root, "traj_lstm.gif"), fps=1)
        plt.close()

        break
    batch_cnt += 1
            