In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
import os, os.path 
import numpy
import pickle
from glob import glob
from typing import Any, Dict, List, Tuple, Union
import pandas as pd

"""Change to the data folder"""
new_path = "./new_train/new_train"
test_path = './new_val_in/new_val_in'
# number of sequences in each dataset
# train:205942  val:3200 test: 36272 
# sequences sampled at 10HZ rate

### Create a dataset class 

In [2]:
class ArgoverseDataset(Dataset):
    """Dataset class for Argoverse"""
    def __init__(self, data_path: str, transform=None):
        super(ArgoverseDataset, self).__init__()
        self.data_path = data_path
        self.transform = transform

        self.pkl_list = glob(os.path.join(self.data_path, '*'))
        self.pkl_list.sort()
        
    def __len__(self):
        return len(self.pkl_list)

    def __getitem__(self, idx):

        pkl_path = self.pkl_list[idx]
        with open(pkl_path, 'rb') as f:
            data = pickle.load(f)
            
        if self.transform:
            data = self.transform(data)

        return data


# intialize a dataset
val_dataset  = ArgoverseDataset(data_path=new_path)
test_dataset = ArgoverseDataset(data_path=test_path)

### Create a loader to enable batch processing

In [3]:
batch_sz = 4

def train_collate(batch):
    """ collate lists of samples into batches, create [ batch_sz x agent_sz x seq_len x feature] """
    inp = [numpy.dstack([scene['p_in'][scene['track_id'][:,0,0]==scene['agent_id'],:,:]]) for scene in batch]
    out = [numpy.dstack([scene['p_out'][scene['track_id'][:,0,0]==scene['agent_id'],:,:]]) for scene in batch]
    inp = torch.Tensor(inp)
    out = torch.Tensor(out)
    return [inp, out]

def test_collate(batch):
    """ collate lists of samples into batches, create [ batch_sz x agent_sz x seq_len x feature] """
    inp = [numpy.dstack([scene['p_in'][scene['track_id'][:,0,0]==scene['agent_id'],:,:]]) for scene in batch]
    inp = torch.Tensor(inp)
    idx = [numpy.dstack([scene['scene_idx']]) for scene in batch]
    return inp, idx
    
val_loader = DataLoader(val_dataset,batch_size=batch_sz, shuffle = True, collate_fn=train_collate, num_workers=0)

test_loader = DataLoader(test_dataset,batch_size=batch_sz, shuffle = True, collate_fn=test_collate, num_workers=0)

In [4]:
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

class EncoderRNN(nn.Module):
    """referenced from official Argoverse forecasting code: https://github.com/jagjeet-singh/argoverse-forecasting"""
    
    def __init__(self,
                 input_size = 2,
                 embedding_size = 8,
                 hidden_size = 16):
        
        super(EncoderRNN, self).__init__()
        
        self.hidden_size = hidden_size
        self.linear = nn.Linear(input_size, embedding_size)
        self.lstm = nn.LSTMCell(embedding_size, hidden_size)

    def forward(self, x, hidden):
        embedded = F.relu(self.linear(x))
        hidden = self.lstm(embedded, hidden)
        return hidden


class DecoderRNN(nn.Module):
    """Decoder Network."""
    """referenced from official Argoverse forecasting code: https://github.com/jagjeet-singh/argoverse-forecasting"""
    def __init__(self, embedding_size=8, hidden_size=16, output_size=2):
        super(DecoderRNN, self).__init__()
        self.hidden_size = hidden_size

        self.linear1 = nn.Linear(output_size, embedding_size)
        self.lstm = nn.LSTMCell(embedding_size, hidden_size)
        self.linear2 = nn.Linear(hidden_size, output_size)

    def forward(self, x, hidden):
        embedded = F.relu(self.linear1(x))
        hidden = self.lstm(embedded, hidden)
        output = self.linear2(hidden[0])
        return output, hidden


In [5]:
from tqdm import tqdm_notebook as tqdm

def train(encoder, decoder, device, train_loader, encoder_optimizer, decoder_optimizer, epoch, log_interval=10000):    
    """referenced from official Argoverse forecasting code: https://github.com/jagjeet-singh/argoverse-forecasting"""
    
    iterator = tqdm(train_loader, total=int(len(train_loader)))
    counter = 0
    criterion = nn.MSELoss()
    
    for i_batch, sample_batch in enumerate(train_loader):
        
        inp, out = sample_batch
        # preprocessing more ????
        inp = inp[:,0,:,:]
        out = out[:,0,:,:]
        
        #inp - inp[0] for all in whaetver
        x_offset = []
        y_offset = []
        for i in range(inp.shape[0]):
            x_offset.append(inp[i][0][0].detach().clone())
            y_offset.append(inp[i][0][1].detach().clone())
    
        for j in range(inp.shape[0]):
            for i in range(inp.shape[1]):
                inp[j][i][0] = inp[j][i][0] - x_offset[j]
                inp[j][i][1] = inp[j][i][1] - y_offset[j]

        #outoput whatever
        for j in range(out.shape[0]):
            for i in range(out.shape[1]):
                out[j][i][0] = out[j][i][0] - x_offset[j]
                out[j][i][1] = out[j][i][1] - y_offset[j]
        
        _input, target = inp.to(device), out.to(device)
        
        encoder.train()
        decoder.train()
        
        encoder_optimizer.zero_grad()
        decoder_optimizer.zero_grad()
        
        
        #encoder 
        batch_size = _input.shape[0]
        input_length = _input.shape[1]
        output_length = target.shape[1]
        feature_len = _input.shape[2]
        input_shape = _input.shape[2]
        
        encoder_hidden = (torch.zeros(batch_size, encoder.module.hidden_size).to(device), 
                          torch.zeros(batch_size, encoder.module.hidden_size).to(device))
        
        loss = 0
        
        # Encode observed trajectory
        for ei in range(input_length):
            encoder_input = _input[:, ei, :]
            encoder_hidden = encoder(encoder_input, encoder_hidden)

        # Initialize decoder input with last coordinate in encoder
        decoder_input = encoder_input[:, :2]

        # Initialize decoder hidden state as encoder hidden state
        decoder_hidden = encoder_hidden

        decoder_outputs = torch.zeros(target.shape).to(device)

        # Decode hidden state in future trajectory
        for di in range(30):
            decoder_output, decoder_hidden = decoder(decoder_input,
                                                     decoder_hidden)
            decoder_outputs[:, di, :] = decoder_output

            # Update loss
            loss += torch.sqrt(criterion(decoder_output[:, :2], target[:, di, :2]))

            # Use own predictions as inputs at next step
            decoder_input = decoder_output

        # Get average loss for pred_len
        loss = loss / 30

        # Backpropagate
        loss.backward()
        encoder_optimizer.step()
        decoder_optimizer.step()
        
#         output = model(data)
#         loss = MSELoss(output, target)
        counter += 1
        iterator.set_postfix(loss=(loss.item()*_input.size(0) / (counter * train_loader.batch_size)))

In [6]:
device = "cuda"
encoder = EncoderRNN(input_size=2)
decoder = DecoderRNN(output_size=2)

encoder = nn.DataParallel(encoder)
decoder = nn.DataParallel(decoder)

encoder.to(device)
decoder.to(device)

encoder_optimizer = torch.optim.Adam(encoder.parameters())
decoder_optimizer = torch.optim.Adam(decoder.parameters())

num_epoch = 4

for epoch in range(1, num_epoch + 1):
        train(encoder, decoder, device, val_loader, encoder_optimizer, decoder_optimizer, epoch)
#         predict(model, device, test_loader)

HBox(children=(IntProgress(value=0, max=51486), HTML(value='')))

HBox(children=(IntProgress(value=0, max=51486), HTML(value='')))

HBox(children=(IntProgress(value=0, max=51486), HTML(value='')))

HBox(children=(IntProgress(value=0, max=51486), HTML(value='')))

In [7]:
def infer_absolute(
        test_loader: torch.utils.data.DataLoader,
        encoder: EncoderRNN,
        decoder: DecoderRNN,
#         start_idx: int,
#         forecasted_save_dir: str,
#         model_utils: ModelUtils,
):
    """Infer function for non-map LSTM baselines and save the forecasted trajectories.
    
    referenced from official Argoverse forecasting code: https://github.com/jagjeet-singh/argoverse-forecasting
    
    Args:
        test_loader: DataLoader for the test set
        encoder: Encoder network instance
        decoder: Decoder network instance
        start_idx: start index for the current joblib batch
        forecasted_save_dir: Directory where forecasted trajectories are to be saved
        model_utils: ModelUtils instance

    """
    
    forecasted_trajectories = {}

    for i, (_input, idx) in enumerate(test_loader):
        
        _input = _input[:,0,:,:]
        
        #inp - inp[0] for all in whaetver
        x_offset = []
        y_offset = []
        for i in range(_input.shape[0]):
            x_offset.append(_input[i][0][0].detach().clone())
            y_offset.append(_input[i][0][1].detach().clone())
    
        for j in range(_input.shape[0]):
            for i in range(_input.shape[1]):
                _input[j][i][0] = _input[j][i][0] - x_offset[j]
                _input[j][i][1] = _input[j][i][1] - y_offset[j]

        _input = _input.to(device)

        # Set to eval mode
        encoder.eval()
        decoder.eval()

        # Encoder
        batch_size = _input.shape[0]
        input_length = _input.shape[1]
        input_shape = _input.shape[2]

        # Initialize encoder hidden state
        encoder_hidden = (torch.zeros(batch_size, encoder.module.hidden_size).to(device), 
                          torch.zeros(batch_size, encoder.module.hidden_size).to(device))
       
        # Encode observed trajectory
        for ei in range(input_length):
            encoder_input = _input[:, ei, :]
            encoder_hidden = encoder(encoder_input, encoder_hidden)

        # Initialize decoder input with last coordinate in encoder
        decoder_input = encoder_input[:, :2]

        # Initialize decoder hidden state as encoder hidden state
        decoder_hidden = encoder_hidden

        decoder_outputs = torch.zeros(
            (batch_size, 30, 2)).to(device)

        # Decode hidden state in future trajectory
        for di in range(30):
            decoder_output, decoder_hidden = decoder(decoder_input,
                                                     decoder_hidden)
            decoder_outputs[:, di, :] = decoder_output

            # Use own predictions as inputs at next step
            decoder_input = decoder_output

        for i in range(30):
            for j in range(4):
                decoder_outputs[j,i,0] = decoder_outputs[j,i,0] + x_offset[j]
                decoder_outputs[j,i,1] = decoder_outputs[j,i,1] + y_offset[j]
            
                if (idx[j][0][0][0] in forecasted_trajectories):
                    forecasted_trajectories[idx[j][0][0][0]].append(decoder_outputs[j,i,:].tolist())
                else:
                    forecasted_trajectories[idx[j][0][0][0]] = [decoder_outputs[j,i,:].tolist()]
                
    return(forecasted_trajectories)

In [8]:
output = infer_absolute(test_loader, encoder, decoder)

In [9]:
import pandas as pd
df = pd.DataFrame.from_dict(output, orient='index')
df

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,20,21,22,23,24,25,26,27,28,29
34566,"[2093.34130859375, 657.5591430664062]","[2092.33935546875, 656.6316528320312]","[2091.45361328125, 655.8548583984375]","[2090.546630859375, 655.0439453125]","[2089.654296875, 654.2310791015625]","[2088.7626953125, 653.4061279296875]","[2087.8623046875, 652.566650390625]","[2086.947021484375, 651.7138061523438]","[2086.0166015625, 650.8519287109375]","[2085.07470703125, 649.9873657226562]",...,"[2075.044189453125, 641.23486328125]","[2074.16552734375, 640.49462890625]","[2073.29296875, 639.7614135742188]","[2072.430419921875, 639.0377807617188]","[2071.580810546875, 638.3272094726562]","[2070.748046875, 637.6337890625]","[2069.934814453125, 636.9617919921875]","[2069.142578125, 636.3153076171875]","[2068.371337890625, 635.6976928710938]","[2067.61962890625, 635.1113891601562]"
36575,"[723.7435913085938, 1006.6578369140625]","[723.7767944335938, 1008.6141357421875]","[723.634033203125, 1010.0662841796875]","[723.7188720703125, 1011.821533203125]","[723.7057495117188, 1013.5205078125]","[723.7213134765625, 1015.1975708007812]","[723.725830078125, 1016.8411254882812]","[723.7177124023438, 1018.458740234375]","[723.6910400390625, 1020.0557861328125]","[723.6441040039062, 1021.6373291015625]",...,"[723.1075439453125, 1037.866943359375]","[723.6769409179688, 1040.212646484375]","[723.7183227539062, 1042.115234375]","[723.7128295898438, 1043.9173583984375]","[723.478515625, 1045.4654541015625]","[723.1060791015625, 1046.8658447265625]","[722.6753540039062, 1048.2139892578125]","[722.2733154296875, 1049.5953369140625]","[721.976318359375, 1051.072265625]","[721.784423828125, 1052.6080322265625]"
22274,"[2094.327880859375, 657.2356567382812]","[2093.201171875, 656.1349487304688]","[2092.071533203125, 655.0584106445312]","[2090.9482421875, 653.9935913085938]","[2089.851806640625, 652.9503173828125]","[2088.76220703125, 651.9138793945312]","[2087.66357421875, 650.875732421875]","[2086.547607421875, 649.8341674804688]","[2085.416015625, 648.7933959960938]","[2084.2763671875, 647.760498046875]",...,"[2072.33349609375, 637.2857055664062]","[2071.29736328125, 636.3683471679688]","[2070.28271484375, 635.4668579101562]","[2069.29345703125, 634.5877075195312]","[2068.330322265625, 633.7359619140625]","[2067.39013671875, 632.9151611328125]","[2066.466552734375, 632.126953125]","[2065.55126953125, 631.3720703125]","[2064.635498046875, 630.6513671875]","[2063.711669921875, 629.9669189453125]"
6405,"[585.6062622070312, 1362.032958984375]","[585.6181030273438, 1363.5887451171875]","[585.6260986328125, 1364.97509765625]","[585.7677001953125, 1366.68701171875]","[585.8218383789062, 1368.357666015625]","[585.916015625, 1370.00732421875]","[586.0078125, 1371.6146240234375]","[586.0916748046875, 1373.1793212890625]","[586.1583251953125, 1374.703857421875]","[586.2030639648438, 1376.19287109375]",...,"[586.1631469726562, 1392.0687255859375]","[586.5567626953125, 1394.3302001953125]","[586.6168823242188, 1396.29150390625]","[586.7073364257812, 1398.1845703125]","[586.5968017578125, 1399.78955078125]","[586.3278198242188, 1401.1915283203125]","[585.9295654296875, 1402.4658203125]","[585.442138671875, 1403.6785888671875]","[584.9178466796875, 1404.8919677734375]","[584.4242553710938, 1406.171142578125]"
813,"[575.1336059570312, 1290.397216796875]","[575.1258544921875, 1290.3404541015625]","[574.9265747070312, 1289.855712890625]","[574.9689331054688, 1289.7510986328125]","[574.8139038085938, 1289.3507080078125]","[574.8192138671875, 1289.189208984375]","[574.7109985351562, 1288.8482666015625]","[574.6863403320312, 1288.626953125]","[574.6110229492188, 1288.3138427734375]","[574.5712280273438, 1288.04248046875]",...,"[574.1673583984375, 1284.0631103515625]","[574.148681640625, 1283.64404296875]","[574.1340942382812, 1283.2191162109375]","[574.1239624023438, 1282.789306640625]","[574.1187744140625, 1282.355712890625]","[574.118896484375, 1281.91943359375]","[574.1246948242188, 1281.4814453125]","[574.1365356445312, 1281.0428466796875]","[574.1544799804688, 1280.604736328125]","[574.1787109375, 1280.1678466796875]"
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
18030,"[2021.4365234375, 598.9205932617188]","[2019.8895263671875, 597.52978515625]","[2018.34033203125, 596.1505126953125]","[2016.8795166015625, 594.89990234375]","[2015.489013671875, 593.7333374023438]","[2014.1177978515625, 592.5946044921875]","[2012.732177734375, 591.4527587890625]","[2011.3236083984375, 590.2996826171875]","[2009.901123046875, 589.1400146484375]","[2008.4803466796875, 587.9808959960938]",...,"[1993.6568603515625, 575.3488159179688]","[1992.4512939453125, 574.3662109375]","[1991.258056640625, 573.416015625]","[1990.059326171875, 572.4747314453125]","[1988.8431396484375, 571.528076171875]","[1987.611083984375, 570.5847778320312]","[1986.3828125, 569.6757202148438]","[1985.196533203125, 568.8363647460938]","[1984.101806640625, 568.0867309570312]","[1983.1456298828125, 567.4263916015625]"
22315,"[211.1432647705078, 1526.822265625]","[211.7019805908203, 1527.2237548828125]","[212.59107971191406, 1526.7708740234375]","[213.23483276367188, 1526.381591796875]","[213.90614318847656, 1526.0540771484375]","[214.65943908691406, 1525.76904296875]","[215.49435424804688, 1525.5518798828125]","[216.40188598632812, 1525.404541015625]","[217.3624267578125, 1525.3238525390625]","[218.3538818359375, 1525.3033447265625]",...,"[228.39361572265625, 1526.9686279296875]","[229.14096069335938, 1527.149169921875]","[229.85671997070312, 1527.3184814453125]","[230.54188537597656, 1527.4754638671875]","[231.19786071777344, 1527.6195068359375]","[231.82632446289062, 1527.7503662109375]","[232.42913818359375, 1527.8685302734375]","[233.00827026367188, 1527.9744873046875]","[233.56564331054688, 1528.069091796875]","[234.10324096679688, 1528.153564453125]"
9384,"[413.079345703125, 1515.998779296875]","[413.3428649902344, 1516.902587890625]","[413.1959228515625, 1516.857177734375]","[413.4525451660156, 1517.6826171875]","[413.4298095703125, 1517.8790283203125]","[413.6017761230469, 1518.502685546875]","[413.6462097167969, 1518.889892578125]","[413.77020263671875, 1519.41259765625]","[413.8553466796875, 1519.8702392578125]","[413.9598388671875, 1520.353271484375]",...,"[414.6883850097656, 1525.09619140625]","[414.7353515625, 1525.536376953125]","[414.7856140136719, 1525.9970703125]","[414.8403625488281, 1526.483642578125]","[414.89990234375, 1527.0008544921875]","[414.9632568359375, 1527.55224609375]","[415.02734375, 1528.139892578125]","[415.08642578125, 1528.7628173828125]","[415.13134765625, 1529.41748046875]","[415.1494445800781, 1530.096923828125]"
27646,"[2150.760986328125, 701.836181640625]","[2149.608154296875, 700.7591552734375]","[2148.47900390625, 699.7073974609375]","[2147.359130859375, 698.6690673828125]","[2146.26806640625, 697.6537475585938]","[2145.185302734375, 696.6458740234375]","[2144.09375, 695.6359252929688]","[2142.984130859375, 694.620849609375]","[2141.856201171875, 693.603515625]","[2140.71728515625, 692.5902709960938]",...,"[2128.676513671875, 682.1731567382812]","[2127.636962890625, 681.2733154296875]","[2126.62158203125, 680.3975830078125]","[2125.63427734375, 679.5535888671875]","[2124.676025390625, 678.7474975585938]","[2123.744140625, 677.9827880859375]","[2122.8330078125, 677.260009765625]","[2121.93408203125, 676.5770263671875]","[2121.038818359375, 675.929931640625]","[2120.1396484375, 675.314208984375]"


In [10]:
for i in range(30):
    df[['v{}'.format((i*2)+1), 'v{}'.format((i*2)+2)]] = pd.DataFrame(df.get(i).tolist(), index=df.index)

In [11]:
import numpy as np
dropped_cols = list(np.arange(30))
df2 = df.drop(dropped_cols, axis=1)
df2.index.name = 'ID'
df2

Unnamed: 0_level_0,v1,v2,v3,v4,v5,v6,v7,v8,v9,v10,...,v51,v52,v53,v54,v55,v56,v57,v58,v59,v60
ID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
34566,2093.341309,657.559143,2092.339355,656.631653,2091.453613,655.854858,2090.546631,655.043945,2089.654297,654.231079,...,2070.748047,637.633789,2069.934814,636.961792,2069.142578,636.315308,2068.371338,635.697693,2067.619629,635.111389
36575,723.743591,1006.657837,723.776794,1008.614136,723.634033,1010.066284,723.718872,1011.821533,723.705750,1013.520508,...,723.106079,1046.865845,722.675354,1048.213989,722.273315,1049.595337,721.976318,1051.072266,721.784424,1052.608032
22274,2094.327881,657.235657,2093.201172,656.134949,2092.071533,655.058411,2090.948242,653.993591,2089.851807,652.950317,...,2067.390137,632.915161,2066.466553,632.126953,2065.551270,631.372070,2064.635498,630.651367,2063.711670,629.966919
6405,585.606262,1362.032959,585.618103,1363.588745,585.626099,1364.975098,585.767700,1366.687012,585.821838,1368.357666,...,586.327820,1401.191528,585.929565,1402.465820,585.442139,1403.678589,584.917847,1404.891968,584.424255,1406.171143
813,575.133606,1290.397217,575.125854,1290.340454,574.926575,1289.855713,574.968933,1289.751099,574.813904,1289.350708,...,574.118896,1281.919434,574.124695,1281.481445,574.136536,1281.042847,574.154480,1280.604736,574.178711,1280.167847
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
18030,2021.436523,598.920593,2019.889526,597.529785,2018.340332,596.150513,2016.879517,594.899902,2015.489014,593.733337,...,1987.611084,570.584778,1986.382812,569.675720,1985.196533,568.836365,1984.101807,568.086731,1983.145630,567.426392
22315,211.143265,1526.822266,211.701981,1527.223755,212.591080,1526.770874,213.234833,1526.381592,213.906143,1526.054077,...,231.826324,1527.750366,232.429138,1527.868530,233.008270,1527.974487,233.565643,1528.069092,234.103241,1528.153564
9384,413.079346,1515.998779,413.342865,1516.902588,413.195923,1516.857178,413.452545,1517.682617,413.429810,1517.879028,...,414.963257,1527.552246,415.027344,1528.139893,415.086426,1528.762817,415.131348,1529.417480,415.149445,1530.096924
27646,2150.760986,701.836182,2149.608154,700.759155,2148.479004,699.707397,2147.359131,698.669067,2146.268066,697.653748,...,2123.744141,677.982788,2122.833008,677.260010,2121.934082,676.577026,2121.038818,675.929932,2120.139648,675.314209


In [12]:
df2.to_csv("outputs4ep.csv", index=True)