In [1]:
import sys
sys.path.append('..')
import numpy as np
import pandas as pd
from logging import error, warning
from gsw.GSWriter import GSWriter
from constants import CONSTANTS as CONST


# InteractionDatasetLoader
## Task breakdown
- load data from the dataset


In [2]:
dataset_config = {}
dataset_config["base_dir"] = CONST.DATASET_BASE_DIR
dataset_config["version"] = CONST.DATASET_VERSION
dataset_config["recording_map"] = CONST.RECORDING_MAP

In [3]:
class InteractionDataset():
   

    def __init__(self, config, train=True):
        self.config = config
        self.train = train
        
        if self.train:
            self.base_dir = config["base_dir"]
            self.recording_map = config["recording_map"]
            self.csv_dir_list = [self.get_track_directory(recording) for ind, recording in self.recording_map.items()]
            self.osm_dir_list = [self.get_map_directory(recording) for ind, recording in self.recording_map.items()]
            
        else:
            self.base_dir = config["base_dir"]
            self.recording_map = config["recording_map"]
            self.csv_dir_list = [self.get_track_directory(recording) for ind, recording in self.recording_map.items()]
            self.osm_dir_list = [self.get_map_directory(recording) for ind, recording in self.recording_map.items()]
    
    def get_track_directory(self, recording):
        return self.base_dir+'train/'+recording+'_train.csv' if self.train else self.base_dir+'val/'+recording+'_val.csv'
        
    def get_map_directory(self, recording):
        return self.base_dir+'maps/'+recording+'.osm'
            
    def get_recording_df(self, recording):
        assert recording in self.recording_map.values()
        return pd.read_csv(self.get_track_directory(recording))
    
    def get_scenario_df_of_recording(self, scenario_id, recording_df):
        scenario_id_list = recording_df['case_id'].unique()
        assert scenario_id in scenario_id_list
        return recording_df.loc[recording_df['case_id'] == scenario_id]

In [4]:
class ScenarioData:
    def __init__(self, scenario_df) -> None:
        self.scenario_df = scenario_df
        self.agent_num = self.get_num_of_agent()
        self.duration = self.get_total_time()
        self.vehicle_df_dict = self.get_vehicle_df_dict()
        self.inclusive_timestamps = self.get_inclusive_timestamps()
        
    # get the number of agent in the scenario dataframe
    def get_num_of_agent(self):
        return max(self.scenario_df['track_id'].unique())

    # get the starting and ending timestamps that includes all the agent in the scenario dataframe
    def get_inclusive_timestamps(self):
        min_timestamp = 0
        max_timestamp = self.duration
        for vid, vehicle_df in self.vehicle_df_dict.items():
            vehicle_min_timestamp = min(vehicle_df['timestamp_ms'])
            vehicle_max_timestamp = max(vehicle_df['timestamp_ms'])
            if vehicle_min_timestamp > min_timestamp:
                min_timestamp = vehicle_min_timestamp
            if vehicle_max_timestamp < max_timestamp:
                max_timestamp = vehicle_max_timestamp
        return [min_timestamp, max_timestamp]

    # get the total time of the scenario
    def get_total_time(self):
        return max(self.scenario_df['timestamp_ms'])

    # get the starting and ending timestamps that includes all the agent in the scenario dataframe
    def get_vehicle_df_dict(self):
        vehicle_df_dict = {}
        for i in range(1,self.agent_num+1):
            vehicle_df_dict[i] = self.scenario_df.loc[self.scenario_df['track_id'] == i]
        return vehicle_df_dict

    def construct_gs_file(self, file_name, start_sim_time_ms):
        # const variable. Distance is measured in meters.
        origin = [0,0]
#         origin_icon_position = [0,0]
#         globalconfig_icon_position = [1,1]
        scenario_name = 'default scenario name'
        collision = True
        gsw = GSWriter()
        try:
            if start_sim_time_ms < self.inclusive_timestamps[0]:
                raise Exception
        except Exception:
            warning('One or more vehicles are out of recording range under the starting simulation time.')
        
        try: 
            if start_sim_time_ms >= self.inclusive_timestamps[1]:
                raise Exception
        except Exception:
            error('Starting simulation time is out of recording range.')
            return None
        
        ending_sim_time_ms = self.inclusive_timestamps[1]
        timeout_ms = self.inclusive_timestamps[1] - start_sim_time_ms

        gsw.addGlobalConfig(CONST.GC_ICON_LAT, # global configuration icon position
                            CONST.GC_ICON_LON, 
                            scenario_name,
                            CONST.MAP_DIR, # lanelet map directory
                            CONST.COLLISION, # collision
                            timeout_ms/1000.0)
        for vid, vehicle_df in self.vehicle_df_dict.items():
            ######################################
            # Format of the following state: 
            # List[case_id, track_id, frame_id, timestamp_ms, agent_type, x, y, vx, vy, psi_rad, length, width]
            ######################################
            starting_state = vehicle_df.loc[vehicle_df['timestamp_ms'] == start_sim_time_ms].values[0].tolist()
            ending_state = vehicle_df.loc[vehicle_df['timestamp_ms'] == ending_sim_time_ms].values[0].tolist()
            x_0, y_0, x_f, y_f = starting_state[5], starting_state[6], ending_state[5], ending_state[6]
            print('x_0, y_0, x_f, y_f: ', x_0, y_0, x_f, y_f)
            vx_0, vy_0 = starting_state[7], starting_state[8]
            # the altitude is set to be the same, thus we don't need to consider using it.
            lat_0, lon_0, alt = gsw.m2ll(x_0, y_0)
            lat_f, lon_f, alt = gsw.m2ll(x_f, y_f)
            yaw_0_deg = -1*np.rad2deg(starting_state[9])
            gsw.addVehicle(
                vehicle_name = 'v'+str(vid),
                route_name = 'v'+str(vid)+'_route', 
                behavior_type = 'SDV', 
                starting_yaw_deg = yaw_0_deg,
                starting_vx = vx_0,
                starting_vy = vy_0,
                starting_ax = 0,
                starting_ay = 0,
                trajectory_lat = [lat_0, lat_f], 
                trajectory_lon = [lon_0, lon_f],
                icon_lat = lat_0, 
                icon_lon = lon_0,
                vehicle_id = vid, 
                behavior_tree_dir = 'drive.btree')
        
        gsw.writeOSM(file_name)

In [5]:
ID = InteractionDataset(dataset_config, train=False)
recording_df = ID.get_recording_df(CONST.RECORDING)
scenario_df = ID.get_scenario_df_of_recording(CONST.CASE_ID, recording_df)
SD = ScenarioData(scenario_df)
starting_sim_time_ms = CONST.STARTING_SIM_TIME_MS
saving_dir = CONST.RECORDING + '_CASE_' + str(CONST.CASE_ID) + '_START_TIME_' + str(starting_sim_time_ms) + '.osm'
SD.construct_gs_file(saving_dir, starting_sim_time_ms)

x_0, y_0, x_f, y_f:  912.731 1006.49 892.346 1008.363
x_0, y_0, x_f, y_f:  932.067 1004.809 916.748 1006.182
x_0, y_0, x_f, y_f:  946.288 1005.504 935.162 1005.111
x_0, y_0, x_f, y_f:  982.462 1005.698 962.737 1005.263
x_0, y_0, x_f, y_f:  1011.525 1012.516 1000.874 1005.525


In [6]:
# ID.get_scenario_df_of_recording(42, recording_df)
SD.inclusive_timestamps

[100, 2600]

In [7]:
scenario_df

Unnamed: 0,case_id,track_id,frame_id,timestamp_ms,agent_type,x,y,vx,vy,psi_rad,length,width
14606,70.0,1,1,100,car,912.731,1006.490,-7.731,0.718,3.049,4.16,1.67
14607,70.0,2,1,100,car,932.067,1004.809,-5.339,0.227,3.099,4.34,1.74
14608,70.0,3,1,100,car,946.288,1005.504,-4.251,-0.300,-3.071,3.78,1.68
14609,70.0,4,1,100,car,982.462,1005.698,-7.796,-0.027,-3.138,4.78,1.94
14610,70.0,5,1,100,car,1011.525,1012.516,-2.324,-3.905,-2.108,4.04,1.71
...,...,...,...,...,...,...,...,...,...,...,...,...
14787,70.0,5,39,3900,car,991.876,1005.292,-7.694,0.108,3.128,4.04,1.71
14788,70.0,2,40,4000,car,906.579,1007.025,-7.596,0.646,3.057,4.34,1.74
14789,70.0,3,40,4000,car,928.536,1005.195,-5.038,0.153,3.111,3.78,1.68
14790,70.0,4,40,4000,car,952.019,1004.803,-7.592,-0.283,-3.104,4.78,1.94
