## Install the package dependencies before running this notebook

In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
import os, os.path 
import numpy 
import pickle
from glob import glob

"""
    number of trajectories in each city
    # austin --  train: 43041 test: 6325 
    # miami -- train: 55029 test:7971
    # pittsburgh -- train: 43544 test: 6361
    # dearborn -- train: 24465 test: 3671
    # washington-dc -- train: 25744 test: 3829
    # palo-alto -- train:  11993 test:1686

    trajectories sampled at 10HZ rate, input 5 seconds, output 6 seconds
    
"""

'\n    number of trajectories in each city\n    # austin --  train: 43041 test: 6325 \n    # miami -- train: 55029 test:7971\n    # pittsburgh -- train: 43544 test: 6361\n    # dearborn -- train: 24465 test: 3671\n    # washington-dc -- train: 25744 test: 3829\n    # palo-alto -- train:  11993 test:1686\n\n    trajectories sampled at 10HZ rate, input 5 seconds, output 6 seconds\n    \n'

## Create a Torch.Dataset class for the training dataset

In [2]:
from glob import glob
import pickle
import numpy as np

ROOT_PATH = "./Data/"

cities = ["austin", "miami", "pittsburgh", "dearborn", "washington-dc", "palo-alto"]
splits = ["train", "test"]

def get_city_trajectories(city="palo-alto", split="train", normalized=False):

    outputs = None
    
    if split=="train":
        f_in = ROOT_PATH + split + "/" + city + "_inputs"
        inputs = np.asarray(pickle.load(open(f_in, "rb")))
#         n = len(inputs)
#         inputs = np.asarray(inputs)[:int(n * 0.8)]
        
        f_out = ROOT_PATH + split + "/" + city + "_outputs"
        outputs = np.asarray(pickle.load(open(f_out, "rb")))
#         outputs = np.asarray(outputs)[:int(n * 0.8)]
        
    elif split == 'val':
        f_in = ROOT_PATH + 'train' + "/" + city + "_inputs"
        inputs = pickle.load(open(f_in, "rb"))
        n = len(inputs)
        inputs = np.asarray(inputs)[int(n * 0.8):]
        
        f_out = ROOT_PATH + 'train' + "/" + city + "_outputs"
        outputs = pickle.load(open(f_out, "rb"))
        outputs = np.asarray(outputs)[int(n * 0.8):]
    
    else:
        f_in = ROOT_PATH + split + "/" + city + "_inputs"
        inputs = pickle.load(open(f_in, "rb"))
        n = len(inputs)
        inputs = np.asarray(inputs)

    return inputs, outputs

class ArgoverseDataset(Dataset):
    """Dataset class for Argoverse"""
    def __init__(self, city: str, split:str, transform=None):
        super(ArgoverseDataset, self).__init__()
        self.transform = transform

        self.inputs, self.outputs = get_city_trajectories(city=city, split=split, normalized=False)

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):

        data = (self.inputs[idx], self.outputs[idx])
            
        if self.transform:
            data = self.transform(data)

        return data

## Create a DataLoader class for training

In [3]:
from torch import nn, optim

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# class Pred(nn.Module):

#     def __init__(self):
#         super().__init__()
        
#         self.encoder = nn.Sequential(
#             nn.Linear(100, 64),
#             nn.ReLU(),
#             nn.Linear(64, 64),
#             nn.ReLU(),
#             nn.Linear(64, 32),
#             nn.ReLU(),
#             nn.Linear(32, 32)
#         )
        
#         self.decoder = nn.Sequential(
#             nn.Linear(32, 64),
#             nn.ReLU(),
#             nn.Linear(64, 64),
#             nn.ReLU(),
#             nn.Linear(64, 120),
#             nn.ReLU(),
#             nn.Linear(120, 120)
#         )
        
#     def forward(self, x):
#         x = x.reshape(-1, 100).float()
#         x = self.encoder(x)
#         x = self.decoder(x)
#         x = x.reshape(-1, 60, 2)
#         return x

In [4]:
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# pred = Pred()
# pred.to(device)
# opt = optim.Adam(pred.parameters(), lr=1e-3)

## Training

In [6]:
import time

start = time.time()

batch_sz = 4  # batch size 
for city in cities:
    print(city)
    train_dataset = ArgoverseDataset(city = city, split = 'train')
    train_loader = DataLoader(train_dataset, batch_size=batch_sz)

    for epoch in range(30):

        total_loss = 0
        for i_batch, sample_batch in enumerate(train_loader):
#             if i_batch % 100 == 0:
#                 print(i_batch)
            i, o = sample_batch
            inp, out = i.to(device), o.to(device)
            preds = transformer_model(inp.float(), out.float())
            loss = ((preds - out) ** 2).sum()

            opt.zero_grad()
            loss.backward()
            opt.step()

            total_loss += loss.item()

        end = time.time()
        print('Time elapsed: {}'.format((end - start) / 60))
        print('epoch {} loss: {}'.format(epoch, total_loss / len(train_dataset)))    

austin
Time elapsed: 7.593198478221893
epoch 0 loss: 196982018.46488315
Time elapsed: 15.17181428273519
epoch 1 loss: 181544203.95763022
Time elapsed: 22.921051637331644
epoch 2 loss: 181507400.42740014
Time elapsed: 30.69908484617869
epoch 3 loss: 181507017.24751797
Time elapsed: 38.16384460528692
epoch 4 loss: 181507003.49082237
miami
Time elapsed: 48.55288451910019
epoch 0 loss: 1026971579.2091588
Time elapsed: 58.03132556676864
epoch 1 loss: 933328099.5034621
Time elapsed: 68.36286822954814
epoch 2 loss: 930720642.8687719
Time elapsed: 78.33688868284226
epoch 3 loss: 930577605.0854676
Time elapsed: 87.99915175040563
epoch 4 loss: 930565862.5793945
pittsburgh
Time elapsed: 95.661305920283
epoch 0 loss: 249492148.59736192
Time elapsed: 103.45517633358638
epoch 1 loss: 225963374.64082474
Time elapsed: 111.3959932923317
epoch 2 loss: 225698929.6957579
Time elapsed: 119.24318982362747
epoch 3 loss: 225696615.12162554
Time elapsed: 126.896784055233
epoch 4 loss: 225696543.21182418
dearbo

## Validation

In [7]:
# val_dataset = ArgoverseDataset(city = 'austin', split = 'val')
# val_loader = DataLoader(val_dataset, batch_size=batch_sz)

# val_loss = 0
# for i_batch, sample_batch in enumerate(val_loader):
#     i, o = sample_batch
#     inp, out = i.to(device), o.to(device)
#     preds = pred(inp)
#     loss = ((preds - out) ** 2).sum()

#     val_loss += loss.item()
# print('loss: {}'.format(val_loss / len(val_dataset)))

## Testing

In [8]:
import pandas as pd

df = pd.read_csv(ROOT_PATH + 'submission.csv')
int_col = df.select_dtypes(include=['int'])
for col in int_col.columns.values:
    df[col] = df[col].astype('float32')
row = 0

for city in cities:
    print(city)
    test_dataset = ArgoverseDataset(city = city, split = 'test')

    for i in range(len(test_dataset.inputs)):
        data = torch.from_numpy(test_dataset.inputs[i]).to(device)
        preds = pred(data)
        df.iloc[row, 1:121] = preds.cpu().detach().numpy().ravel()
        row += 1
        
df.to_csv(ROOT_PATH + 'submission.csv')

## Sample a batch of data and visualize 

In [9]:
import matplotlib.pyplot as plt
import random

def show_sample_batch(sample_batch, version):
    """visualize the trajectory for a batch of samples"""
    inp, out = sample_batch
    batch_sz = inp.size(0)
    agent_sz = inp.size(1)
    
    fig, axs = plt.subplots(1,batch_sz, figsize=(15, 3), facecolor='w', edgecolor='k')
    fig.subplots_adjust(hspace = .5, wspace=.001)
    if version == 'gt':
        axs.set_title('Ground Truth')
    elif version == 'pred':
        axs.set_title('Prediction')
        
    axs = axs.ravel()   
    for i in range(batch_sz):
        axs[i].xaxis.set_ticks([])
        axs[i].yaxis.set_ticks([])
        
        # first two feature dimensions are (x,y) positions
        axs[i].scatter(inp[i,:,0], inp[i,:,1])
        axs[i].scatter(out[i,:,0], out[i,:,1])

        
for i_batch, sample_batch in enumerate(train_loader):
    if i_batch % 3000 == 0:
        inp, out = sample_batch
        show_sample_batch(sample_batch)

In [14]:
# torch.save(transformer_model, './Data/saved_models/transformer_model')