Describe your code using comments.

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [2]:
root = '/home/vilab/ssd1tb/hj_ME455/Trajectory_Prediction'
import sys
sys.path.append(root)

import argparse
import gc
import logging
import os
import sys
import time
import numpy as np

from collections import defaultdict
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim

from trajectories import data_loader
from utils import relative_to_abs, get_dset_path

In [3]:
from easydict import EasyDict as edict
cfg = edict()

cfg.dataset_name = 'zara1'
cfg.delim = '\t'
cfg.loader_num_workers = 4

cfg.obs_len = 8
cfg.pred_len = 8
cfg.skip = 1

cfg.batch_size = 16
cfg.num_epochs = 200
cfg.learning_rate = 5e-4

cfg.embedding_dim = 64
cfg.h_dim = 64
cfg.num_layers = 1
cfg.mlp_dim = 1024
cfg.dropout = 0

cfg.best_k = 1

In [4]:
train_path = get_dset_path(cfg.dataset_name, 'train')
val_path = get_dset_path(cfg.dataset_name, 'val')

In [5]:
train_dset, train_loader = data_loader(cfg, train_path)
_, val_loader = data_loader(cfg, val_path)

### Model for trajectory prediction

In [17]:

class Encoder(nn.Module):
    def __init__(
        self, embedding_dim=64, h_dim=64, num_layers=1, type='lstm'):
        super(Encoder, self).__init__()

        self.h_dim = h_dim
        self.embedding_dim = embedding_dim
        self.num_layers = num_layers

        self.encoder = nn.LSTM(embedding_dim, h_dim, num_layers)

        self.en_mlp = nn.Linear(2, embedding_dim)
        self.de_mlp = nn.Linear(h_dim, 2)

    def make_state(self, batch):
        return (
            torch.zeros(self.num_layers, batch, self.h_dim).cuda(),
            torch.zeros(self.num_layers, batch, self.h_dim).cuda()
        )

    def forward(self, obs_traj):
        batch = obs_traj.size(1)
        encoder_input =  self.en_mlp(obs_traj)
        encoder_input = encoder_input.view(-1, batch, self.embedding_dim)
        
        state = self.make_state(batch)
        encoder_output, state = self.encoder(encoder_input, state)
        
        pred = self.de_mlp(encoder_output.reshape(-1, self.h_dim))
        pred = pred.view(-1, batch, 2)
        
        return encoder_output, state
    

class Decoder(nn.Module):
    def __init__(
        self, seq_len, embedding_dim=64, h_dim=128, num_layers=1, type='lstm'):
        super(Decoder, self).__init__()

        self.seq_len = seq_len
        self.h_dim = h_dim
        self.embedding_dim = embedding_dim
        self.model_type = type
        
        self.decoder = nn.LSTM(embedding_dim, h_dim, num_layers)

        self.en_mlp = nn.Linear(2, embedding_dim)
        self.de_mlp = nn.Linear(h_dim, 2)

    def forward(self, last_pos, last_pos_rel, state):
        batch = last_pos_rel.size(0)
        pred_traj_fake_rel = []
        decoder_input = self.en_mlp(last_pos_rel)
        decoder_input = decoder_input.view(1, batch, self.embedding_dim)

        for _ in range(self.seq_len):
            decoder_output, state = self.decoder(decoder_input, state)
            pos_rel = self.de_mlp(decoder_output.view(-1, self.h_dim))
            curr_pos = last_pos + pos_rel

            decoder_input = self.en_mlp(pos_rel).view(1, batch, self.embedding_dim)
            pred_traj_fake_rel.append(pos_rel)
            last_pos = curr_pos

        pred_traj_fake_rel = torch.stack(pred_traj_fake_rel, dim=0)
        return pred_traj_fake_rel, state

In [18]:
class MyModel(nn.Module):
    def __init__(self, cfg, type):
        super(MyModel, self).__init__()
        self.encoder = Encoder(embedding_dim=cfg.embedding_dim, h_dim=cfg.h_dim, num_layers=cfg.num_layers, type='lstm')
        self.decoder = Decoder(seq_len=cfg.pred_len-1, embedding_dim=cfg.embedding_dim, h_dim=cfg.h_dim, num_layers=cfg.num_layers, type='lstm')

    
    def forward(self, obs_traj, obs_traj_rel):
        last_pos = obs_traj[-1]
        last_pos_rel = obs_traj_rel[-1]

        encoder_output, state = self.encoder(obs_traj)
        pred_traj_fake_rel, state = self.decoder(last_pos, last_pos_rel, state)
        
        pred_traj_rel = torch.unsqueeze(obs_traj_rel[-1], 0) 
        pred_traj_fake_rel = torch.cat([pred_traj_rel, pred_traj_fake_rel], dim=0)
        
        return pred_traj_fake_rel

### Pre-defined functions

In [19]:
def cal_ade(fake, gt):
    ade = gt.permute(1, 0, 2) - fake.permute(1, 0, 2)
    ade = ade**2
    ade = torch.sqrt(ade.sum(dim=2)).sum(dim=1)
    return torch.sum(ade)


def cal_fde(fake, gt):
    fde = gt - fake
    fde = fde**2
    fde = torch.sqrt(fde.sum(dim=1))
    return torch.sum(fde)

In [20]:
def loss_fn(pred_traj_fake_rel, pred_traj_gt_rel, loss_mask):
    fake = loss_mask.unsqueeze(2)*pred_traj_fake_rel.permute(1, 0, 2)
    gt =  loss_mask.unsqueeze(2)*pred_traj_gt_rel.permute(1, 0, 2)
    
    l2_loss = torch.nn.MSELoss()
    
    
    return l2_loss(fake, gt)

### Build the train lodaer and validation loader for training the model.

While validation we will calculate ade and fde using pre-defined functions above.
To do this, we will use "relative_to_abs" function to change realtive position to absolute position.

In [21]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

model_lstm = MyModel(cfg, type='lstm')

model_lstm.to(device)

optimizer = torch.optim.Adam(model_lstm.parameters(), lr=0.001)

os.makedirs(os.path.join(root, 'ckpt_lstm'), exist_ok=True)
save_path = os.path.join(root, 'ckpt_lstm/ckpt_lstm_')

cuda:0


In [22]:
for i in tqdm(range(cfg.num_epochs)):
    train_loss_value = 0
    total_traj = 0
    ade, fde = 0, 0
    
    model_lstm.train()
    
    for batch in train_loader:
        batch = [tensor.cuda() for tensor in batch]
        (obs_traj, pred_traj_gt, obs_traj_rel, pred_traj_gt_rel, _, loss_mask, seq_start_end) = batch
        loss_mask = loss_mask[:, cfg.obs_len:]
        
        #############################
        ### fill the train loader ###
        #############################
        optimizer.zero_grad()

        pred_traj_fake_rel = model_lstm(obs_traj, obs_traj_rel)  # Forward pass
        loss = loss_fn(pred_traj_fake_rel, pred_traj_gt_rel, loss_mask)  # Calculate loss
        loss.backward()  # Backpropagation
        optimizer.step()  # Update weights

        train_loss_value += loss.item()
        
    train_loss_value /= len(train_loader)
        
    model_lstm.eval()
    
    with torch.no_grad():
        for batch in val_loader:
            batch = [tensor.cuda() for tensor in batch]
            (obs_traj, pred_traj_gt, obs_traj_rel, pred_traj_gt_rel, _, loss_mask, seq_start_end) = batch
            loss_mask = loss_mask[:, cfg.obs_len:]
            
            ########################################################################
            #### fill the validation loader ########################################
            # hint: use "relative_to_abs" function to validate your result #########
            ########################################################################
            pred_traj_fake_rel = model_lstm(obs_traj, obs_traj_rel)
            pred_traj_fake = relative_to_abs(pred_traj_fake_rel, obs_traj[-1])

            ade += cal_ade(pred_traj_fake, pred_traj_gt)
            fde += cal_fde(pred_traj_fake, pred_traj_gt)

            total_traj += pred_traj_gt.size(1)
        

        
    ade = ade / (total_traj * cfg.pred_len)  # Normalize ADE by total trajectories and sequence length
    fde = fde / total_traj  # Normalize FDE by total trajectories
    
    torch.save({
        'epoch': i,
        'model_state_dict': model_lstm.state_dict(),
        'optimizer_state_dict': optimizer.state_dict()
        }, f'{save_path}{i:03d}.pt')

    if (i + 1) % 10 == 0:
        e = i + 1
        print(f'epoch: {e}, train_loss: {train_loss_value}, val_ade: {ade}, val_fde: {fde}')

  5%|▌         | 10/200 [00:44<14:00,  4.43s/it]

epoch: 10, train_loss: 0.007599214105673972, val_ade: 0.3013277053833008, val_fde: 0.3802464008331299


 10%|█         | 20/200 [01:29<13:27,  4.49s/it]

epoch: 20, train_loss: 0.007311215764738223, val_ade: 0.29600876569747925, val_fde: 0.38363611698150635


 15%|█▌        | 30/200 [02:14<12:36,  4.45s/it]

epoch: 30, train_loss: 0.007052944387590303, val_ade: 0.290357381105423, val_fde: 0.3792865574359894


 20%|██        | 40/200 [02:59<12:00,  4.50s/it]

epoch: 40, train_loss: 0.006971574380178247, val_ade: 0.3048304617404938, val_fde: 0.3884195387363434


 25%|██▌       | 50/200 [03:43<11:18,  4.52s/it]

epoch: 50, train_loss: 0.006827848668612495, val_ade: 0.2833017110824585, val_fde: 0.3732741177082062


 30%|███       | 60/200 [04:29<10:43,  4.60s/it]

epoch: 60, train_loss: 0.006664948988452378, val_ade: 0.2913244962692261, val_fde: 0.3790520429611206


 35%|███▌      | 70/200 [05:15<09:54,  4.58s/it]

epoch: 70, train_loss: 0.006510966578663806, val_ade: 0.2861684560775757, val_fde: 0.3795682489871979


 40%|████      | 80/200 [05:59<09:00,  4.50s/it]

epoch: 80, train_loss: 0.006375434652753395, val_ade: 0.29341068863868713, val_fde: 0.38223540782928467


 45%|████▌     | 90/200 [06:44<08:15,  4.50s/it]

epoch: 90, train_loss: 0.006287088625948634, val_ade: 0.30363062024116516, val_fde: 0.3876529932022095


 50%|█████     | 100/200 [07:29<07:30,  4.51s/it]

epoch: 100, train_loss: 0.006137465301779865, val_ade: 0.3019447326660156, val_fde: 0.3845038115978241


 55%|█████▌    | 110/200 [08:14<06:43,  4.48s/it]

epoch: 110, train_loss: 0.005936741359124082, val_ade: 0.2914031744003296, val_fde: 0.3811647593975067


 60%|██████    | 120/200 [08:59<06:07,  4.59s/it]

epoch: 120, train_loss: 0.005915057374241613, val_ade: 0.29874396324157715, val_fde: 0.38651323318481445


 65%|██████▌   | 130/200 [09:44<05:08,  4.41s/it]

epoch: 130, train_loss: 0.005823771969398715, val_ade: 0.2954752743244171, val_fde: 0.3857441544532776


 70%|███████   | 140/200 [10:28<04:29,  4.49s/it]

epoch: 140, train_loss: 0.005733573716960451, val_ade: 0.2985149025917053, val_fde: 0.3891157805919647


 75%|███████▌  | 150/200 [11:13<03:45,  4.52s/it]

epoch: 150, train_loss: 0.005609668992437731, val_ade: 0.29998156428337097, val_fde: 0.3895070552825928


 80%|████████  | 160/200 [11:59<03:02,  4.57s/it]

epoch: 160, train_loss: 0.005534043140191386, val_ade: 0.30115070939064026, val_fde: 0.3919069468975067


 85%|████████▌ | 170/200 [12:44<02:14,  4.50s/it]

epoch: 170, train_loss: 0.0053989481905574455, val_ade: 0.29978519678115845, val_fde: 0.3932998478412628


 90%|█████████ | 180/200 [13:29<01:29,  4.48s/it]

epoch: 180, train_loss: 0.00537403701865197, val_ade: 0.31395450234413147, val_fde: 0.39907920360565186


 95%|█████████▌| 190/200 [14:14<00:45,  4.56s/it]

epoch: 190, train_loss: 0.005283455152240907, val_ade: 0.3062886595726013, val_fde: 0.3977021276950836


100%|██████████| 200/200 [14:59<00:00,  4.50s/it]

epoch: 200, train_loss: 0.005184947397317407, val_ade: 0.3047943413257599, val_fde: 0.3929680287837982





### Visualize the prediction results with your trajectory prediction model and its checkpoint

In [29]:
model_lstm = MyModel(cfg, type='lstm')  # Ensure cfg is properly defined
checkpoint = torch.load(os.path.join(root, 'ckpt_lstm/ckpt_lstm_199.pt'))  # Load the last checkpoint
model_lstm.load_state_dict(checkpoint['model_state_dict'])
model_lstm.eval().to(device)

import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
import random

r = len(val_loader)
rr = random.randrange(0,r)

batch_cnt = 0   
for batch in val_loader:
    if batch_cnt == rr:
        batch = [tensor.cuda() for tensor in batch]
        (obs_traj, pred_traj_gt, obs_traj_rel, pred_traj_gt_rel, _, loss_mask, seq_start_end) = batch
        
        loss_mask = loss_mask[:, cfg.obs_len:]
        
        ################## fill the validation loader ####################
        ## hint: use "relative_to_abs" function to validate your result ##
        pred_traj_fake_rel = model_lstm(obs_traj, obs_traj_rel)
        pred_traj_fake_lstm = relative_to_abs(pred_traj_fake_rel, obs_traj[-1])
        ##################################################################
        
        for i in range(seq_start_end.size(0)):
            startpoint = seq_start_end[i][0]
            endpoint = seq_start_end[i][1]
            ped_num = endpoint - startpoint
            if endpoint - startpoint < 3:
                break
            
        obs_x, obs_y = obs_traj[:,startpoint:endpoint,0].detach().cpu().numpy(), obs_traj[:,startpoint:endpoint,1].detach().cpu().numpy()
        gt_x, gt_y = pred_traj_gt[:,startpoint:endpoint,0].detach().cpu().numpy(), pred_traj_gt[:,startpoint:endpoint,1].detach().cpu().numpy()
        pred_x_lstm, pred_y_lstm = pred_traj_fake_lstm[:,startpoint:endpoint,0].detach().cpu().numpy(), pred_traj_fake_lstm[:,startpoint:endpoint,1].detach().cpu().numpy()

        fig = plt.figure()
        fig.set_size_inches(7, 7)
        
        plt.xlim([-15, 15])      
        plt.ylim([-15, 15])
        
        list_x, list_y = [], []

        for i in range(8):
            for j in range(ped_num):
                list_x.append(obs_x[i][j])
                list_y.append(obs_y[i][j])

        line, = plt.plot(list_x, list_y, 'b.', label='obersed')
        
        list_x_fake, list_y_fake = [], []
        line_fake, = plt.plot([], [], 'ro', label='predicted')
        
        list_x_gt, list_y_gt = [], []
        line_gt, = plt.plot([], [], 'g*', label = 'gt')
        
        plt.title('LSTM Traj')
        plt.legend()
        
        def update(i):
            for j in range(ped_num):
                list_x_fake.append(pred_x_lstm[i][j])
                list_y_fake.append(pred_y_lstm[i][j])
                list_x_gt.append(gt_x[i][j])
                list_y_gt.append(gt_y[i][j])
            line_fake.set_data(list_x_fake, list_y_fake)
            line_gt.set_data(list_x_gt, list_y_gt)
            return line_fake, line_gt,
        

        anim = FuncAnimation(fig, update, frames=[0,1,2,3,4,5,6,7],interval=1000)
        anim.save(os.path.join(root, "traj_lstm.gif"), fps=1)
        plt.close()

        break
    batch_cnt += 1
            

MovieWriter ffmpeg unavailable; using Pillow instead.
