In [106]:
import torch
from torch.utils.data import Dataset, DataLoader
import os, os.path 
import numpy as np
import pickle
from glob import glob
import pandas as pd
import matplotlib.pyplot as plt

# Load training data

In [107]:
# train_path = "./train/train"
train_path = "./val_in/val_in"
training_samples = []
# The glob module finds all the pathnames matching a specified pattern
train_pkl_lst = glob(os.path.join(train_path, '*'))
for i in range(np.array(train_pkl_lst).shape[0]):
    with open(train_pkl_lst[i], 'rb') as f:
        training_samples.append(pickle.load(f))

In [108]:
training_samples = np.array(training_samples)
print(training_samples.shape)

(3200,)


In [109]:
print(training_samples[0].keys())

dict_keys(['city', 'lane', 'lane_norm', 'scene_idx', 'agent_id', 'car_mask', 'p_in', 'v_in', 'track_id'])


In [112]:
cities_arr = []
for i in range(training_samples.shape[0]):
    cities_arr.append(training_samples[i]['city'])
cities_df = pd.DataFrame(cities_arr)
print(cities_df[0].unique())

['PIT' 'MIA']


In [113]:
# city can be either 'PIT' or 'MIA'
# replace with binary representation
# PIT is 0, MIA is 1
for i in range(training_samples.shape[0]):
    if training_samples[i]['city'] == 'PIT':
        training_samples[i]['city'] = 0
    elif training_samples[i]['city'] == 'MIA':
        training_samples[i]['city'] = 1

In [114]:
cities_arr = []
for i in range(training_samples.shape[0]):
    cities_arr.append(training_samples[i]['city'])
cities_df = pd.DataFrame(cities_arr)
print(cities_df[0].unique())

[0 1]


In [115]:
# scene index gives us no information so it can be dropped
scene_idx_arr = []
for i in range(training_samples.shape[0]):
    scene_idx_arr.append(training_samples[i]['scene_idx'])
scene_idx_df = pd.DataFrame(scene_idx_arr)
print(scene_idx_df[0].unique().shape)

(3200,)


In [116]:
# drop scene_idx
for i in range(training_samples.shape[0]):
    del training_samples[i]['scene_idx']

In [117]:
print(training_samples[100]['agent_id'])

00000000-0000-0000-0000-000000013270


In [118]:
# get unique values of Agent ID
agent_id_arr = []
for i in range(training_samples.shape[0]):
    agent_id_arr.append(training_samples[i]['agent_id'])
agent_id_df = pd.DataFrame(agent_id_arr)
print(agent_id_df[0].unique().shape)
print(agent_id_df[0][0])

(2979,)
00000000-0000-0000-0000-000000010910


In [119]:
# can replace last with only the last 8 values
agent_id_arr = []
for i in range(training_samples.shape[0]):
    training_samples[i]['agent_id'] = training_samples[i]['agent_id'][-8:]
    agent_id_arr.append(training_samples[i]['agent_id'])
agent_id_df = pd.DataFrame(agent_id_arr)
print(agent_id_df[0].unique().shape)
print(agent_id_df[0][0])

(2979,)
00010910


In [120]:
# for some reason the first track ID is always all 0s, lets remove that
track_id_arr = []
for i in range(training_samples.shape[0]):
    training_samples[i]['track_id'] = np.delete(training_samples[i]['track_id'], 0)
    track_id_arr.append(training_samples[i]['track_id'])
track_id_df = pd.DataFrame(track_id_arr)
print(track_id_df[0].unique().shape)

(2302,)


In [121]:
# we only need the last 8 values of track ID since the rest are zeros
track_id_arr = []
for i in range(training_samples.shape[0]):
    for j in range(training_samples[i]['track_id'].shape[0]):
        track_id_arr.append(training_samples[i]['track_id'][j])
track_id_df = pd.DataFrame(track_id_arr)
print(track_id_df[0].unique().shape)
print(track_id_df[0][0])

(16863,)
00000000-0000-0000-0000-000000010910


In [122]:
# lets actually remove them
track_id_arr = []
for i in range(training_samples.shape[0]):
    for j in range(training_samples[i]['track_id'].shape[0]):
        training_samples[i]['track_id'][j] = training_samples[i]['track_id'][j][-8:]
        track_id_arr.append(training_samples[i]['track_id'][j])
track_id_df = pd.DataFrame(track_id_arr)
print(track_id_df[0].unique().shape)
print(track_id_df[0][0])

(16863,)
00010910


In [123]:
print(training_samples[0].keys())

dict_keys(['city', 'lane', 'lane_norm', 'agent_id', 'car_mask', 'p_in', 'v_in', 'track_id'])


In [124]:
print(training_samples[0]['lane'].shape)

(90, 2)


In [125]:
# get only tracked agent info and normalized vector
train_point_arr_x = []
train_point_arr_y = []
for scene_idx, data in enumerate(training_samples):
    idx_track = 0
    for i, j in enumerate(training_samples[scene_idx]['track_id']):
        if training_samples[scene_idx]['agent_id']==j:
            idx_track=i
            break

    tracked_positions = np.array(training_samples[scene_idx]['p_in'][idx_track]-training_samples[scene_idx]['p_in'][idx_track][0])
    tracked_velocities = np.array(training_samples[scene_idx]['v_in'][idx_track])
    tracked_velocity_norms = np.linalg.norm(training_samples[scene_idx]['v_in'][idx_track], axis=1)
    
    train_point = np.append(np.append(tracked_positions.flatten(), tracked_velocities.flatten()), tracked_velocity_norms)
    train_point_arr_x.append(train_point)
    
    # train_point_arr_y.append(np.array(training_samples[scene_idx]['p_out'][idx_track]).flatten())
    
train_point_arr_x = np.array(train_point_arr_x)
train_point_arr_y = np.array(train_point_arr_y)
print(train_point_arr_x.shape)
print(train_point_arr_y.shape)

(3200, 95)
(0,)


In [126]:
# import pickle
# pickle.dump(train_point_arr_x, open("train_point_arr_x.p", "wb"))
# pickle.dump(train_point_arr_y, open("train_point_arr_y.p", "wb"))

In [127]:
import pickle
pickle.dump(train_point_arr_x, open("val_point_arr_x.p", "wb"))

In [128]:
print(train_point_arr_x[0])

[ 0.          0.          0.90710449  0.75946045  1.72692871  1.44744873
  2.54711914  2.13500977  3.36853027  2.8258667   4.27111816  3.57946777
  5.08874512  4.26391602  5.91101074  4.95184326  6.80944824  5.70318604
  7.62976074  6.38909912  8.44958496  7.07421875  9.26794434  7.75793457
 10.08642578  8.44445801 10.98864746  9.19610596 11.80871582  9.88146973
 12.62683105 10.56567383 13.52819824 11.32122803 14.35021973 12.00579834
 15.16748047 12.68920898  8.52231884  7.13423395  9.07077694  7.59421206
  8.19851398  6.88010645  8.20141506  6.87594032  8.21442795  6.9084959
  9.02537918  7.53567648  8.17660046  6.84430504  8.22331619  6.87958813
  8.98385525  7.51321602  8.20329666  6.85914564  8.19821739  6.85116768
  8.18357944  6.83722019  8.18468094  6.86532688  9.0217762   7.51629257
  8.20175266  6.85392761  8.18101501  6.84207535  9.01346779  7.5554204
  8.2196722   6.84594965  8.17323303  6.83389473 11.11427967 11.83009091
 10.70287327 10.70241862 10.73331925 11.75771616 10.6

In [129]:
print(train_point_arr_x[1])

[  0.           0.          -1.01525879  -0.89697266  -2.1776123
  -1.9152832   -3.19616699  -2.77685547  -4.28979492  -3.80169678
  -5.37182617  -4.68011475  -6.40429688  -5.52868652  -7.43225098
  -6.38427734  -8.53405762  -7.36395264  -9.66455078  -8.30609131
 -10.71777344  -9.22680664 -11.81213379 -10.1361084  -12.80065918
 -10.93200684 -13.921875   -11.87414551 -15.02600098 -12.76947021
 -16.05786133 -13.64532471 -17.11706543 -14.53503418 -18.25720215
 -15.4831543  -19.34973145 -16.33978271 -10.03384399  -8.87351036
 -10.15268517  -8.9695797  -11.62311935 -10.18312263 -10.18486691
  -8.61578369 -10.93676281 -10.24828243 -10.82073402  -8.78427601
 -10.32417583  -8.48564148 -10.27919674  -8.55607796 -11.01836777
  -9.79698658 -11.30561352  -9.42138863 -10.53173733  -9.20701981
 -10.94385624  -9.09283924  -9.88493347  -7.95918941 -11.21180439
  -9.42130184 -11.04222202  -8.95307255 -10.31856155  -8.75902939
 -10.59195518  -8.89668083 -11.40059471  -9.48131561 -10.92600346
  -8.566026