In [None]:
# UTILS
from __future__ import print_function, division
from torch.utils.data import Dataset, DataLoader
import scipy.io as scp
import numpy as np
import h5py
import torch

# Dataset class for the NGSIM dataset
class ngsimDataset(Dataset):
    def __init__(self, mat_file, t_h=30, t_f=10, d_s=2, enc_size = 64, grid_size = (13,3)):
        self.D = scp.loadmat(mat_file)['traj']
        self.T = scp.loadmat(mat_file)['tracks']
        self.t_h = t_h
        self.t_f = t_f
        self.d_s = d_s
        self.enc_size = enc_size
        self.grid_size = grid_size

    def __len__(self):
        return len(self.D)

    def __getitem__(self, idx):

        dsId = self.D[idx, 0].astype(int)
        vehId = self.D[idx, 1].astype(int)
        t = self.D[idx, 2]
        grid = self.D[idx,8:]
        neighbors = []
        hist = self.getHistory(vehId,t,vehId,dsId)
        fut = self.getFuture(vehId,t,dsId)
        for i in grid:
            neighbors.append(self.getHistory(i.astype(int), t,vehId,dsId))
        lon_enc = np.zeros([2])
        lon_enc[int(self.D[idx, 7] - 1)] = 1
        lat_enc = np.zeros([3])
        lat_enc[int(self.D[idx, 6] - 1)] = 1
        return hist,fut,neighbors,lat_enc,lon_enc, vehId, t, dsId

    def getHistory(self,vehId,t,refVehId,dsId):
        if vehId == 0:
            return np.empty([0,2])
        else:
            if self.T.shape[1]<=vehId-1:
                return np.empty([0,2])
            refTrack = self.T[dsId-1][refVehId-1].transpose()
            vehTrack = self.T[dsId-1][vehId-1].transpose()
            refPos = refTrack[np.where(refTrack[:,0]==t)][0,1:3]
            if vehTrack.size==0 or np.argwhere(vehTrack[:, 0] == t).size==0:
                 return np.empty([0,2])
            else:
                stpt = np.maximum(0, np.argwhere(vehTrack[:, 0] == t).item() - self.t_h)
                enpt = np.argwhere(vehTrack[:, 0] == t).item() + 1
                hist = vehTrack[stpt:enpt:self.d_s,1:3]-refPos
            if len(hist) < self.t_h//self.d_s + 1:
                return np.empty([0,2])
            return hist

    def getFuture(self, vehId, t,dsId):
        vehTrack = self.T[dsId-1][vehId-1].transpose()
        refPos = vehTrack[np.where(vehTrack[:, 0] == t)][0, 1:3]
        stpt = np.argwhere(vehTrack[:, 0] == t).item() + self.d_s
        enpt = np.minimum(len(vehTrack), np.argwhere(vehTrack[:, 0] == t).item() + self.t_f + 1)
        fut = vehTrack[stpt:enpt:self.d_s,1:3]-refPos
        return fut

    def collate_fn(self, samples):
        nbr_batch_size = 0
        for _,_,nbrs,_,_,_,_,_ in samples:
            nbr_batch_size += sum([len(nbrs[i])!=0 for i in range(len(nbrs))])
        maxlen = self.t_h//self.d_s + 1
        if nbr_batch_size != 0:
            nbrs_batch = torch.zeros(maxlen,nbr_batch_size,2)
            pos = [0, 0]
            mask_batch = torch.zeros(len(samples), self.grid_size[1],self.grid_size[0],self.enc_size)
            mask_batch = mask_batch.byte()
            hist_batch = torch.zeros(maxlen,len(samples),2)
            fut_batch = torch.zeros(self.t_f//self.d_s,len(samples),2)
            op_mask_batch = torch.zeros(self.t_f//self.d_s,len(samples),2)
            lat_enc_batch = torch.zeros(len(samples),3)
            lon_enc_batch = torch.zeros(len(samples), 2)
            count = 0
            veh_ID = []
            time = []
            dsID = []
            for sampleId,(hist, fut, nbrs, lat_enc, lon_enc, vehId, t, ds) in enumerate(samples):
                hist_batch[0:len(hist),sampleId,0] = torch.from_numpy(hist[:, 0])
                hist_batch[0:len(hist), sampleId, 1] = torch.from_numpy(hist[:, 1])
                fut_batch[0:len(fut), sampleId, 0] = torch.from_numpy(fut[:, 0])
                fut_batch[0:len(fut), sampleId, 1] = torch.from_numpy(fut[:, 1])
                op_mask_batch[0:len(fut),sampleId,:] = 1
                lat_enc_batch[sampleId,:] = torch.from_numpy(lat_enc)
                lon_enc_batch[sampleId, :] = torch.from_numpy(lon_enc)
                veh_ID.append(vehId)
                time.append(t)
                dsID.append(ds)
                for id,nbr in enumerate(nbrs):
                    if len(nbr)!=0:
                        nbrs_batch[0:len(nbr),count,0] = torch.from_numpy(nbr[:, 0])
                        nbrs_batch[0:len(nbr), count, 1] = torch.from_numpy(nbr[:, 1])
                        pos[0] = id % self.grid_size[0]
                        pos[1] = id // self.grid_size[0]
                        mask_batch[sampleId,pos[1],pos[0],:] = torch.ones(self.enc_size).byte()
                        count+=1
            return hist_batch, nbrs_batch, mask_batch, lat_enc_batch, lon_enc_batch, fut_batch, op_mask_batch, veh_ID, time, dsID
        else:
            return [-1], -1, -1, -1, -1, -1, -1, -1, -1, -1

def outputActivation(x):
    muX = x[:,:,0:1]
    muY = x[:,:,1:2]

    out = torch.cat([muX, muY],dim=2)
    return out

def maskedMSE(y_pred, y_gt, mask):
    acc = torch.zeros_like(mask)
    muX = y_pred[:,:,0]
    muY = y_pred[:,:,1]
    x = y_gt[:,:, 0]
    y = y_gt[:,:, 1]
    out = torch.pow(x-muX, 2) + torch.pow(y-muY, 2)
    acc[:,:,0] = out
    acc[:,:,1] = out
    acc = acc*mask
    lossVal = torch.sum(acc)/torch.sum(mask)
    return lossVal

def maskedMSETest(y_pred, y_gt, mask):
    acc = torch.zeros_like(mask)
    muX = y_pred[:, :, 0]
    muY = y_pred[:, :, 1]
    x = y_gt[:, :, 0]
    y = y_gt[:, :, 1]
    out = torch.pow(x - muX, 2) + torch.pow(y - muY, 2)
    acc[:, :, 0] = out
    acc[:, :, 1] = out
    acc = acc * mask
    lossVal = torch.sum(acc[:,:,0],dim=1)
    counts = torch.sum(mask[:,:,0],dim=1)
    return lossVal, counts

In [None]:
# MODEL
from __future__ import division
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F

class highwayNet(nn.Module):

    ## Initialization
    def __init__(self,args):
        super(highwayNet, self).__init__()
        self.args = args
        self.use_cuda = args['use_cuda']
        self.train_flag = args['train_flag']
        self.encoder_size = args['encoder_size']
        self.decoder_size = args['decoder_size']
        self.in_length = args['in_length']
        self.out_length = args['out_length']
        self.grid_size = args['grid_size']
        self.input_embedding_size = args['input_embedding_size']
        # Input embedding layer
        self.ip_emb = torch.nn.Linear(2,self.input_embedding_size)
        # Encoder LSTM
        self.enc_lstm1 = torch.nn.LSTM(self.input_embedding_size,self.encoder_size,1)
        # Encoder LSTM
        self.enc_lstm2 = torch.nn.LSTM(self.input_embedding_size,self.encoder_size,1)
        self.spatial_embedding = nn.Linear(5, self.encoder_size)
        self.tanh = nn.Tanh()
        self.pre4att = nn.Sequential(
            nn.Linear(self.encoder_size, 1),
        )
        self.dec_lstm = torch.nn.LSTM(self.encoder_size, self.decoder_size)
        # Output layers:
        self.op = torch.nn.Linear(self.decoder_size,2) # 2-dimension (x, y)
        # Activations:
        self.leaky_relu = torch.nn.LeakyReLU(0.1)
        self.relu = torch.nn.ReLU()
        self.softmax = torch.nn.Softmax(dim=1)

    def attention(self, lstm_out_weight, lstm_out):
        alpha = F.softmax(lstm_out_weight, 1)
        lstm_out = lstm_out.permute(0, 2, 1)
        new_hidden_state = torch.bmm(lstm_out, alpha).squeeze(2)
        new_hidden_state = F.relu(new_hidden_state)
        return new_hidden_state, alpha

    ## Forward Pass
    def forward(self,hist,nbrs,masks,lat_enc,lon_enc):
        lstm_out,(hist_enc,_) = self.enc_lstm1(self.leaky_relu(self.ip_emb(hist)))
        lstm_out = lstm_out.permute(1, 0, 2)
        lstm_weight = self.pre4att(self.tanh(lstm_out))
        new_hidden, soft_attn_weights = self.attention(lstm_weight, lstm_out)
        new_hidden = new_hidden.unsqueeze(2)
        nbrs_out, (nbrs_enc,_) = self.enc_lstm1(self.leaky_relu(self.ip_emb(nbrs)))
        nbrs_out = nbrs_out.permute(1, 0, 2)
        nbrs_lstm_weight = self.pre4att(self.tanh(nbrs_out))
        new_nbrs_hidden, soft_nbrs_attn_weights = self.attention(nbrs_lstm_weight, nbrs_out)
        nbrs_enc = new_nbrs_hidden

        soc_enc = torch.zeros_like(masks).float()
        masks_tem = masks.permute(0, 3, 2, 1)
        soc_enc = soc_enc.permute(0,3,2,1)
        soc_enc = soc_enc.contiguous().view(soc_enc.shape[0], soc_enc.shape[1], -1)
        new_hs = torch.cat((soc_enc, new_hidden), 2)
        new_hs_per = new_hs.permute(0, 2, 1)

        # second attention
        weight = self.pre4att(self.tanh(new_hs_per))
        new_hidden_ha, soft_attn_weights_ha = self.attention(weight, new_hs_per)
        ## Concatenate encodings:
        enc = new_hidden_ha
        fut_pred = self.decode(enc)
        return fut_pred, soft_attn_weights, soft_nbrs_attn_weights, soft_attn_weights_ha

    def decode(self,enc):
        enc = enc.repeat(self.out_length, 1, 1)
        h_dec, _ = self.dec_lstm(enc)
        h_dec = h_dec.permute(1, 0, 2)
        fut_pred = self.op(h_dec)
        fut_pred = fut_pred.permute(1, 0, 2)
        fut_pred = outputActivation(fut_pred)
        return fut_pred

    def decode_by_step(self,enc):
        pre_traj = []
        decoder_input = enc
        for _ in range(self.out_length):
            decoder_input = decoder_input.unsqueeze(0)
            h_dec, _ = self.dec_lstm(decoder_input)
            h_for_pred = h_dec.squeeze()
            fut_pred = self.op(h_for_pred)
            pre_traj.append(fut_pred.view(fut_pred.size()[0], -1))
            embedding_input = fut_pred
            decoder_input = self.spatial_embedding(embedding_input)
        pre_traj = torch.stack(pre_traj, dim=0)
        pre_traj = outputActivation(pre_traj)
        return pre_traj

In [5]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
from __future__ import print_function
import torch
from torch.utils.data import DataLoader
import time
import math
import datetime

if __name__ == '__main__':
    args = {}
    args['use_cuda'] = True
    args['encoder_size'] = 64 # lstm encoder hidden state size, adjustable
    args['decoder_size'] = 128 # lstm decoder hidden state size, adjustable
    args['in_length'] = 16
    args['out_length'] = 5
    args['grid_size'] = (13,3)

    args['input_embedding_size'] = 32 # input dimension for lstm encoder, adjustable

    args['train_flag'] = True
    start_time = datetime.datetime.now()
    net = highwayNet(args)
    if args['use_cuda']:
        net = net.cuda()
    trainEpochs = 1
    optimizer = torch.optim.Adam(net.parameters())
    batch_size = 128
    crossEnt = torch.nn.BCELoss()
    trSet = ngsimDataset('/content/drive/MyDrive/Trajectory_Datasets/TrainSet.mat')
    valSet = ngsimDataset('/content/drive/MyDrive/Trajectory_Datasets/ValSet.mat')
    trDataloader = DataLoader(trSet,batch_size=batch_size,shuffle=True,num_workers=2,collate_fn=trSet.collate_fn)
    valDataloader = DataLoader(valSet,batch_size=batch_size,shuffle=True,num_workers=2,collate_fn=valSet.collate_fn)
    train_loss = []
    val_loss = []
    prev_val_loss = math.inf

    for epoch_num in range(trainEpochs):
        net.train_flag = True
        avg_tr_loss = 0
        avg_tr_time = 0
        avg_lat_acc = 0
        avg_lon_acc = 0
        for i, data in enumerate(trDataloader):
            st_time = time.time()
            hist, nbrs, mask, lat_enc, lon_enc, fut, op_mask, vehid, t, ds = data
            if args['use_cuda']:
                hist = hist.cuda()
                nbrs = nbrs.cuda()
                mask = mask.cuda()
                lat_enc = lat_enc.cuda()
                lon_enc = lon_enc.cuda()
                fut = fut.cuda()
                op_mask = op_mask.cuda()
            fut_pred, weight_ts_center, weight_ts_nbr, weight_ha = net(hist, nbrs, mask, lat_enc, lon_enc)
            l = maskedMSE(fut_pred, fut, op_mask)
            optimizer.zero_grad()
            l.backward()
            a = torch.nn.utils.clip_grad_norm_(net.parameters(), 10)
            optimizer.step()
            batch_time = time.time()-st_time
            avg_tr_loss += l.item()
            avg_tr_time += batch_time
            if i%100 == 99:
                eta = avg_tr_time/100*(len(trSet)/batch_size-i)
                print("Epoch no:",epoch_num+1,"| Epoch progress(%):",format(i/(len(trSet)/batch_size)*100,'0.2f'), "| Avg train loss:",format(avg_tr_loss/100,'0.4f'),"| Acc:",format(avg_lat_acc,'0.4f'),format(avg_lon_acc,'0.4f'), "| Validation loss prev epoch",format(prev_val_loss,'0.4f'), "| ETA(s):",int(eta))
                train_loss.append(avg_tr_loss/100)
                avg_tr_loss = 0
                avg_lat_acc = 0
                avg_lon_acc = 0
                avg_tr_time = 0
        net.train_flag = False
        print("Epoch",epoch_num+1,'complete. Calculating validation loss...')
        avg_val_loss = 0
        avg_val_lat_acc = 0
        avg_val_lon_acc = 0
        val_batch_count = 0
        total_points = 0

        for i, data  in enumerate(valDataloader):
            st_time = time.time()
            hist, nbrs, mask, lat_enc, lon_enc, fut, op_mask, vehid, t, ds = data
            if args['use_cuda']:
                hist = hist.cuda()
                nbrs = nbrs.cuda()
                mask = mask.cuda()
                lat_enc = lat_enc.cuda()
                lon_enc = lon_enc.cuda()
                fut = fut.cuda()
                op_mask = op_mask.cuda()
            fut_pred, weight_ts_center, weight_ts_nbr, weight_ha = net(hist, nbrs, mask, lat_enc, lon_enc)
            l = maskedMSE(fut_pred, fut, op_mask)
            avg_val_loss += l.item()
            val_batch_count += 1

        print(avg_val_loss/val_batch_count)
        print('Validation loss :',format(avg_val_loss/val_batch_count,'0.4f'))
        val_loss.append(avg_val_loss/val_batch_count)
        prev_val_loss = avg_val_loss/val_batch_count
    end_time = datetime.datetime.now()
    print('Total training time: ', end_time-start_time)

In [None]:
torch.save(net.state_dict(), '/LSTM.tar')

In [None]:
from __future__ import print_function
import torch
from torch.utils.data import DataLoader
import time
import pandas as pd
import numpy as np

args = {}
args['use_cuda'] = True
args['encoder_size'] = 64
args['decoder_size'] = 128
args['in_length'] = 16
args['out_length'] = 5
args['grid_size'] = (13,3)
args['input_embedding_size'] = 32
args['train_flag'] = False


metric = 'rmse'
net = highwayNet(args)
net.load_state_dict(torch.load('/LSTM.tar'))
if args['use_cuda']:
    net = net.cuda()
tsSet = ngsimDataset('/content/drive/MyDrive/Trajectory_Datasets/TestSet.mat')
tsDataloader = DataLoader(tsSet,batch_size=128,shuffle=True,num_workers=2,collate_fn=tsSet.collate_fn)
lossVals = torch.zeros(5).cuda()
counts = torch.zeros(5).cuda()
lossVal = 0
count = 0
vehid = []
pred_x = []
pred_y = []
T = []
dsID = []
ts_cen = []
ts_nbr = []
wt_ha = []

for i, data in enumerate(tsDataloader):
    st_time = time.time()
    hist, nbrs, mask, lat_enc, lon_enc, fut, op_mask, veh_id, t, ds = data
    if not isinstance(hist, list):
        vehid.append(veh_id)
        T.append(t)
        dsID.append(ds)
        if args['use_cuda']:
            hist = hist.cuda()
            nbrs = nbrs.cuda()
            mask = mask.cuda()
            lat_enc = lat_enc.cuda()
            lon_enc = lon_enc.cuda()
            fut = fut.cuda()
            op_mask = op_mask.cuda()
        fut_pred, weight_ts_center, weight_ts_nbr, weight_ha= net(hist, nbrs, mask, lat_enc, lon_enc)
        l, c = maskedMSETest(fut_pred, fut, op_mask)
        fut_pred_x = fut_pred[:,:,0].detach()
        fut_pred_x = fut_pred_x.cpu().numpy()
        fut_pred_y = fut_pred[:,:,1].detach()
        fut_pred_y = fut_pred_y.cpu().numpy()
        pred_x.append(fut_pred_x)
        pred_y.append(fut_pred_y)
        ts_cen.append(weight_ts_center[:, :, 0].detach().cpu().numpy())
        ts_nbr.append(weight_ts_nbr[:, :, 0].detach().cpu().numpy())
        wt_ha.append(weight_ha[:, :, 0].detach().cpu().numpy())
        lossVal +=l.detach()
        count += c.detach()

print ('lossVal is:', lossVal)
print(torch.pow(lossVal / count,0.5)*0.3048)