In [4]:
import torch
from torch.utils.data import Dataset, DataLoader
import os, os.path 
import numpy 
import pickle
from glob import glob
from tqdm import tqdm

"""Change to the data folder"""
# new_path = "./new_train/"
new_path = "./proj/new_train.nosync/"
test_path = "./proj/new_val_in.nosync/"
# number of sequences in each dataset
# train:205942  val:3200 test: 36272 
# sequences sampled at 10HZ rate

### Create a dataset class 

In [6]:
class ArgoverseDataset(Dataset):
    """Dataset class for Argoverse"""
    def __init__(self, data_path: str, transform=None):
        super(ArgoverseDataset, self).__init__()
        self.data_path = data_path
        self.transform = transform
#         print("the path is ", os.path.join(self.data_path, '*'))
        self.pkl_list = glob(os.path.join(self.data_path, '*'))
#         print(len(self.pkl_list))
#         print(self.pkl_list[0])
        self.pkl_list.sort()
        
    def __len__(self):
        return len(self.pkl_list)

    def __getitem__(self, idx):

        pkl_path = self.pkl_list[idx]
        with open(pkl_path, 'rb') as f:
            data = pickle.load(f)
            
        if self.transform:
            data = self.transform(data)

        return data


# intialize a dataset
val_dataset  = ArgoverseDataset(data_path=valid_path)
train_dataset = ArgoverseDataset(data_path=train_path)
# print("the path is ", os.path.join(new_path, '*'))

### Create a loader to enable batch processing

In [7]:
batch_sz = 4

def my_collate(batch):
    """ collate lists of samples into batches, create [ batch_sz x agent_sz x seq_len x feature] """

    inp = [numpy.dstack([scene['p_in'], scene['v_in']]) for scene in batch]
    out = [numpy.dstack([scene['p_out'], scene['v_out']]) for scene in batch]
    inp = torch.LongTensor(inp)
    out = torch.LongTensor(out)
    return [inp, out]

def collate_xy(batch):
    agent_id = batch[0]['agent_id']
    index = numpy.where(batch[0]["track_id"] == batch[0]["agent_id"])[0][0]
    in_x = batch[0]['p_in'][index,:,0]
    in_y = batch[0]['p_in'][index,:,1]
    return [in_x, in_y]

def collate_xy_out(batch):
    agent_id = batch[0]['agent_id']
    index = numpy.where(batch[0]["track_id"] == batch[0]["agent_id"])[0][0]
    in_x = batch[0]['p_out'][index,:,0]
    in_y = batch[0]['p_out'][index,:,1]
    return [in_x, in_y]

def collate_xy_out(batch):
    agent_id = batch[0]['agent_id']
    index = numpy.where(batch[0]["track_id"] == batch[0]["agent_id"])[0][0]
    in_x = batch[0]['p_out'][index,:,0]
    in_y = batch[0]['p_out'][index,:,1]
    return [in_x, in_y]

def collate_v(batch):
    agent_id = batch[0]['agent_id']
    index = numpy.where(batch[0]["track_id"] == batch[0]["agent_id"])[0][0]
    in_x = batch[0]['v_in'][index,:,0]
    in_y = batch[0]['v_in'][index,:,1]
    
    out_x = batch[0]['v_out'][index,:,0]
    out_y = batch[0]['v_out'][index,:,1]
    
    return [in_x, in_y, out_x, out_y]
    
val_loader = DataLoader(train_dataset,batch_size=1, shuffle = False, collate_fn=collate_xy, num_workers=0)
val_loader1 = DataLoader(train_dataset,batch_size=1, shuffle = False, collate_fn=collate_xy_out, num_workers=0)
val_loader2 = DataLoader(train_dataset,batch_size=1, shuffle = False, collate_fn=collate_v, num_workers=0)

### Visualize the batch of sequences

In [9]:
import matplotlib.pyplot as plt
import random

agent_id = 0

def show_sample_batch(sample_batch, agent_id):
    """visualize the trajectory for a batch of samples with a randon agent"""
    inp, out = sample_batch
    batch_sz = inp.size(0)
    agent_sz = inp.size(1)
    
    fig, axs = plt.subplots(1,batch_sz, figsize=(15, 3), facecolor='w', edgecolor='k')
    fig.subplots_adjust(hspace = .5, wspace=.001)
    axs = axs.ravel()   
    for i in range(batch_sz):
        axs[i].xaxis.set_ticks([])
        axs[i].yaxis.set_ticks([])
        print(inp[i, agent_id,:,0], inp[i, agent_id,:,1])
        print(out[i, agent_id,:,0], out[i, agent_id,:,1])
        # first two feature dimensions are (x,y) positions
        axs[i].scatter(inp[i, agent_id,:,0], inp[i, agent_id,:,1])
        axs[i].scatter(out[i, agent_id,:,0], out[i, agent_id,:,1])

        
for i_batch, sample_batch in enumerate(val_loader):
    inp, out = sample_batch
    """TODO:
      Deep learning model
      training routine
    """
    print(inp.s)
    show_sample_batch(sample_batch, agent_id)
    break

In [None]:
import tqdm as tqdm

inp_x = numpy.empty(0)
inp_y = numpy.empty(0)

for i_batch, sample_batch in enumerate(tqdm(val_loader)):    
    in_x, in_y = sample_batch
    inp_x = numpy.append(inp_x, in_x)
    inp_y = numpy.append(inp_y, in_y)

In [None]:
from matplotlib import pyplot as plt

In [None]:
n, bins, patches = plt.hist(x=inp_x, bins='auto', color='#0504aa',
                            alpha=0.7, rwidth=0.85)

In [None]:
n1, bins1, patches1 = plt.hist(x=inp_y, bins='auto', color='#0504aa',
                            alpha=0.7, rwidth=0.85)

In [None]:
from matplotlib import colors

In [None]:
fig = plt.subplots(figsize =(10, 7))
plt.hist2d(inp_x, inp_y, norm = colors.LogNorm(), cmap = "Greens")
plt.title("Input positions")
plt.show()

In [None]:
outp_x = numpy.empty(0)
outp_y = numpy.empty(0)

for i_batch, sample_batch in enumerate(tqdm(val_loader)):    
    out_x, out_y = sample_batch
    outp_x = numpy.append(outp_x, out_x)
    outp_y = numpy.append(outp_y, out_y)

In [None]:
v_x = numpy.empty(0)
v_y = numpy.empty(0)

for i_batch, sample_batch in enumerate(tqdm(val_loader2)):    
    in_x, in_y, out_x, out_y = sample_batch
    v_x = numpy.append(v_x, in_x)
    v_x = numpy.append(v_x, out_x)
    v_y = numpy.append(v_y, in_y)
    v_y = numpy.append(v_y, out_y)

In [None]:
from matplotlib import pyplot as plt
from matplotlib import colors


In [None]:
n, bins, patches = plt.hist(x=inp_x, bins='auto', color='#0504aa',
                            alpha=0.7, rwidth=0.85)

In [None]:
n1, bins1, patches1 = plt.hist(x=inp_y, bins='auto', color='#0504aa',
                            alpha=0.7, rwidth=0.85)

In [None]:
fig = plt.subplots(figsize =(10, 7))
plt.hist2d(inp_x, inp_y, norm = colors.LogNorm(), cmap = "Greens")
plt.title("Input positions")
plt.show()