In [1]:
from os.path import isdir, join
import pickle
import numpy as np
import matplotlib.pyplot as plt
# import rosbag
import glob
from sklearn.model_selection import train_test_split, KFold
import os
import pdb
import math
import torch
from datetime import datetime
from torch import nn
from torch.utils.data import Dataset, DataLoader
from datetime import datetime
from os.path import join
from os import makedirs
import time
device = "cuda:0" if torch.cuda.is_available() else "cpu"

In [2]:
dataset_filepath = 'short_prediction_data.pt'
dataset = torch.load(dataset_filepath)

xobs_train, xpred_train, yintention_train, xobs_test, xpred_test, yintention_test = \
dataset["xobs_train"], dataset["xpred_train"], dataset["yintention_train"], dataset["xobs_test"], \
dataset["xpred_test"], dataset["yintention_test"]
obs_seq_len, pred_seq_len = dataset["obs_seq_len"], dataset["pred_seq_len"]


In [3]:
class Arguments:
    def __init__(
        self,
        batch_size=32,
        obs_seq_len=4,
        pred_seq_len=6,
        embedding_size=32,
        hidden_size=32,
        num_layers=1,
        dropout=0.,
        lr=1e-3,
        num_epochs=100,
        clip_grad=10.,
        device="cuda:0",
        checkpoint_dir="checkpoints",
        ):
        self.batch_size = batch_size
        self.obs_seq_len = obs_seq_len
        self.pred_seq_len = pred_seq_len
        self.embedding_size = embedding_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.dropout = dropout
        self.lr = lr
        self.num_epochs = num_epochs
        self.clip_grad = clip_grad
        self.device = device
        self.checkpoint_dir=checkpoint_dir

args = Arguments(
    obs_seq_len=obs_seq_len,
    pred_seq_len=pred_seq_len,
    lr=5e-3,
    device=device,
)

In [4]:
class TrajectoriesDataset(Dataset):
    def __init__(
        self,
        xobs,
        xpred,
        yintention,
        obs_seq_len=4,
        pred_seq_len=6,
        ):
        super(TrajectoriesDataset, self).__init__()
        assert xobs.shape[0]==xpred.shape[0]==yintention.shape[0]
        assert xobs.shape[1]==obs_seq_len and xpred.shape[1]==pred_seq_len
        self.obs_seq_len = obs_seq_len
        self.pred_seq_len = pred_seq_len
        self.seq_len = self.obs_seq_len + self.pred_seq_len
        self.xobs = xobs
        self.xpred = xpred
        self.yintention = yintention
        self.num_seq = self.xobs.shape[0]

    def __len__(self):
        return self.num_seq


    def __getitem__(self, index):
        out = [
            self.xobs[index],
            self.xpred[index],
            self.yintention[index],
        ]
        return out

In [5]:
kf = KFold(n_splits=5, random_state=0, shuffle=True)
for train_index, validation_index in kf.split(xobs_train):
    xobs_train_k, xobs_val_k = xobs_train[train_index], xobs_train[validation_index]
    xpred_train_k, xpred_val_k = xpred_train[train_index], xpred_train[validation_index] 
    yintention_train_k, yintention_val_k = yintention_train[train_index], yintention_train[validation_index]
    dataset_train = TrajectoriesDataset(
        xobs_train_k,
        xpred_train_k,
        yintention_train_k,
        obs_seq_len=obs_seq_len,
        pred_seq_len=pred_seq_len,
    )
    dataset_val = TrajectoriesDataset(
        xobs_val_k,
        xpred_val_k,
        yintention_val_k,
        obs_seq_len=obs_seq_len,
        pred_seq_len=pred_seq_len,
    )
    loader_train = DataLoader(
        dataset_train,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=4,  
    )
    loader_val = DataLoader(
        dataset_val,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=4,  
    )
    break

In [6]:
class IntentionLstm(nn.Module):
    def __init__(
        self,
        embedding_size=32,
        hidden_size=32,
        num_layers=1,
        dropout=0.,
        bidirectional=False,
        obs_seq_len=4,
        pred_seq_len=6,
    ):
        super(IntentionLstm, self).__init__()
        self.lstm = nn.LSTM(
            input_size=2*embedding_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True,
            dropout=dropout,
            bidirectional=bidirectional,
            )
        self.spatial_embedding = nn.Linear(2, embedding_size)
        self.intention_embedding = nn.Linear(1, embedding_size)
        if bidirectional:
            self.directions = 2
        else:
            self.directions = 1
        self.hidden_to_pos = nn.Linear(self.directions*num_layers*hidden_size, 2)
        self.embedding_size = embedding_size
        self.obs_seq_len = obs_seq_len
        self.pred_seq_len = pred_seq_len
    
    def forward(self, b_xobs, b_yintention, device="cuda:0"):
        
        """
        Forward function.
        inputs:
            - b_xobs: batch of observation. (batch, obs_seq_len, 2)
            - b_yintention: batch of intention. (batch, 1)
        outputs:
            - b_xpred: (batch, pred_seq_len, 2)
            
        """
        batch_size, _, _ = b_xobs.shape
        b_xobs = self.spatial_embedding(b_xobs) # (batch, obs_seq_len, embedding_size)
        b_yintention = self.intention_embedding(b_yintention.unsqueeze(-1)) # (batch, 1, embedding_size)
        b_yintention_obs = b_yintention*torch.ones(batch_size, self.obs_seq_len, self.embedding_size).to(device)
        b_obs = torch.cat((b_xobs, b_yintention_obs),dim=2) # (batch, obs_seq_len, 2*embedding_size)
        _, (ht, ct) = self.lstm(b_obs) # (D∗num_layers, batch, hidden_size)
        b_xpred = []
        b_xpred_tt = self.hidden_to_pos(ht.permute(1,0,2).reshape(batch_size,1,-1)) # (batch, 1, 2)
        b_xpred.append(b_xpred_tt)
        for tt in range(1, self.pred_seq_len):
            b_pred_tt = torch.cat((self.spatial_embedding(b_xpred_tt), b_yintention), dim=2)
            _, (ht, ct) = self.lstm(b_pred_tt, (ht, ct))
            b_xpred_tt = self.hidden_to_pos(ht.permute(1,0,2).reshape(batch_size,1,-1)) # (batch, 1, 2)
            b_xpred.append(b_xpred_tt)
        b_xpred = torch.cat(b_xpred, dim=1)
        return b_xpred


In [7]:
def train(args, loader_train, loader_val):
    model = IntentionLstm(
        embedding_size=args.embedding_size,
        hidden_size=args.hidden_size,
        num_layers=args.num_layers,
        dropout=args.dropout,
        bidirectional=False,
        obs_seq_len=args.obs_seq_len,
        pred_seq_len=args.pred_seq_len,
    ).to(args.device)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    now = datetime.now()
    curr_checkpoint_dir = join(args.checkpoint_dir, now.strftime("%y%m%d_%H%M%S"))
    if not isdir(curr_checkpoint_dir):
        makedirs(curr_checkpoint_dir)
    with open(join(curr_checkpoint_dir, 'args.pickle'), 'wb') as f:
        pickle.dump(args, f)
    print('EPOCHS: ', args.num_epochs)
    train_loss_task, train_aoe_task, train_foe_task = [], [], []
    val_loss_task, val_aoe_task, val_foe_task = [], [], []
    
    for epoch in range(1, args.num_epochs+1):
        epoch_start_time = time.time()
        model.train()
        train_loss_epoch, train_aoe_epoch, train_foe_epoch = [], [], []
        for batch_idx, batch in enumerate(loader_train):
            batch = [b.to(args.device) for b in batch]
            b_xobs, b_xpred_gt, b_yintention = batch
            optimizer.zero_grad()
            b_xpred = model(b_xobs, b_yintention, device=args.device) # (batch, pred_seq_len, 2)
            loss = ((b_xpred-b_xpred_gt)**2.).mean()
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad)
            optimizer.step()
            offset_error = (((b_xpred-b_xpred_gt)**2.).sum(2))**0.5 # (batch, pred_seq_len)
            aoe, foe = offset_error.mean(-1).mean(), offset_error[-1].mean()
            train_loss_epoch.append(loss.detach().to('cpu').item())
            train_aoe_epoch.append(aoe.detach().to('cpu').item())
            train_foe_epoch.append(foe.detach().to('cpu').item())
        train_loss_epoch = sum(train_loss_epoch)/len(train_loss_epoch)
        train_aoe_epoch = sum(train_aoe_epoch)/len(train_aoe_epoch)
        train_foe_epoch = sum(train_foe_epoch)/len(train_foe_epoch)
        val_loss_epoch, val_aoe_epoch, val_foe_epoch = inference(args, loader_val, model)
        if epoch % 10 == 0:
#             print('Epoch: {0} | train loss: {1:.2f} | val loss: {2:.2f} | train aoe: {3:.2f} | val aoe: {4:.2f} | train foe: {5:.2f} | val foe: {6:.2f} | period: {7:.2f} sec'\
#                 .format(epoch, train_loss_epoch, val_loss_epoch,\
#                 train_aoe_epoch, val_aoe_epoch,\
#                 train_foe_epoch, val_foe_epoch,\
#                 time.time()-epoch_start_time)) 
            print('Epoch: {0} | train aoe: {1:.2f} | val aoe: {2:.2f} | train foe: {3:.2f} | val foe: {4:.2f} | period: {5:.2f} sec'\
                .format(epoch, \
                train_aoe_epoch, val_aoe_epoch,\
                train_foe_epoch, val_foe_epoch,\
                time.time()-epoch_start_time)) 
            model_filename = join(curr_checkpoint_dir, 'epoch_'+str(epoch)+'.pt')
            torch.save({'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'train_loss': train_loss_epoch,
                    'val_loss': val_loss_epoch,
                    'train_aoe': train_aoe_epoch,
                    'val_aoe': val_aoe_epoch, 
                    'train_foe': train_foe_epoch,
                    'val_foe': val_foe_epoch,   
                    }, model_filename)
            # print('epoch_'+str(epoch)+'.pt is saved.')
        train_loss_task.append(train_loss_epoch)
        train_aoe_task.append(train_aoe_epoch)
        train_foe_task.append(train_foe_epoch)
        val_loss_task.append(val_loss_epoch)
        val_aoe_task.append(val_aoe_epoch)
        val_foe_task.append(val_foe_epoch)
    hist = {}
    hist['train_loss'], hist['val_loss'] = train_loss_task, val_loss_task
    hist['train_aoe'], hist['val_aoe'] = train_aoe_task, val_aoe_task
    hist['train_foe'], hist['val_foe'] = train_foe_task, val_foe_task
    with open(join(curr_checkpoint_dir, 'train_hist.pickle'), 'wb') as f:
        pickle.dump(hist, f)
        print(join(curr_checkpoint_dir, 'train_hist.pickle')+' is saved.')
    return


def inference(args, loader, model):
    with torch.no_grad():
        model.eval()
        loss_epoch, aoe_epoch, foe_epoch = [], [], []
        for batch_idx, batch in enumerate(loader):
            batch = [b.to(args.device) for b in batch]
            b_xobs, b_xpred_gt, b_yintention = batch
            b_xpred = model(b_xobs, b_yintention, device=args.device) # (batch, pred_seq_len, 2)
            loss = ((b_xpred-b_xpred_gt)**2.).mean()
            offset_error = (((b_xpred-b_xpred_gt)**2.).sum(2))**0.5 # (batch, pred_seq_len)
            aoe, foe = offset_error.mean(-1).mean(), offset_error[-1].mean()
            loss_epoch.append(loss.detach().to('cpu').item())
            aoe_epoch.append(aoe.detach().to('cpu').item())
            foe_epoch.append(foe.detach().to('cpu').item())
        loss_epoch = sum(loss_epoch)/len(loss_epoch)
        aoe_epoch = sum(aoe_epoch)/len(aoe_epoch)
        foe_epoch = sum(foe_epoch)/len(foe_epoch)
    return loss_epoch, aoe_epoch, foe_epoch

In [8]:
train(args, loader_train, loader_val)

EPOCHS:  100
Epoch: 10 | train aoe: 297.71 | val aoe: 294.10 | train foe: 305.80 | val foe: 300.37 | period: 0.32 sec
Epoch: 20 | train aoe: 252.74 | val aoe: 249.53 | train foe: 254.81 | val foe: 256.25 | period: 0.32 sec
Epoch: 30 | train aoe: 209.17 | val aoe: 206.42 | train foe: 214.65 | val foe: 213.68 | period: 0.32 sec
Epoch: 40 | train aoe: 167.14 | val aoe: 165.08 | train foe: 168.63 | val foe: 173.00 | period: 0.32 sec
Epoch: 50 | train aoe: 127.65 | val aoe: 126.36 | train foe: 120.79 | val foe: 135.07 | period: 0.32 sec
Epoch: 60 | train aoe: 92.19 | val aoe: 92.32 | train foe: 88.53 | val foe: 102.04 | period: 0.31 sec
Epoch: 70 | train aoe: 66.50 | val aoe: 68.79 | train foe: 61.66 | val foe: 79.27 | period: 0.31 sec
Epoch: 80 | train aoe: 58.05 | val aoe: 60.65 | train foe: 66.23 | val foe: 64.72 | period: 0.31 sec
Epoch: 90 | train aoe: 58.01 | val aoe: 60.39 | train foe: 70.35 | val foe: 61.76 | period: 0.32 sec
Epoch: 100 | train aoe: 58.03 | val aoe: 60.40 | train fo