In [1]:
%reload_ext autoreload
%autoreload 2

import metadrive
from metadrive.envs.base_env import BaseEnv
from metadrive.constants import DEFAULT_AGENT
from metadrive.utils import Config
from metadrive.manager.scenario_map_manager import ScenarioMapManager
from metadrive.manager.base_manager import BaseManager
from metadrive.scenario.scenario_description import ScenarioDescription
from metadrive.manager.scenario_light_manager import ScenarioLightManager
from metadrive.manager.scenario_traffic_manager import ScenarioTrafficManager
from metadrive.component.vehicle_navigation_module.trajectory_navigation import TrajectoryNavigation
from metadrive.scenario import ScenarioDescription as SD
from metadrive.policy.replay_policy import ReplayEgoCarPolicy


In [12]:
import utils.scenario_converter as scenario_converter
import utils.waymo_loader as waymo_loader

In [13]:
import os
import pickle

path = "~/data/waymo/training_20s.tfrecord-00000-of-01000"
path = os.path.expanduser(path)

print("Loading scenarios from Waymo dataset...")
raw_scenarios = waymo_loader.extract_scenarios_file(path)
print(f"Loaded {len(raw_scenarios)} scenarios")

Loading scenarios from Waymo dataset...


  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


Loaded 61 scenarios


In [31]:
data = scenario_converter.convert_scenario(raw_scenarios[2])

In [16]:
class InMemoryScenarioDataManager(BaseManager):
    def __init__(self, scenario):
        super().__init__()
        self.current_scenario = scenario
        self.current_scenario_length = scenario[SD.LENGTH]
    
    def get_scenario(self, i, should_copy=False):
        return self.current_scenario

class ReplayEnv(BaseEnv):
    @classmethod
    def default_config(cls) -> Config:
        config = super().default_config()
        config.update({
            "start_seed": 0,
            "num_scenarios": 1,
            "start_scenario_index": 0,
            "no_map": False,
            "store_map": False,
            "need_lane_localization": True,
            "vehicle_config": dict(
                lidar=dict(num_lasers=120, distance=50),
                lane_line_detector=dict(num_lasers=0, distance=50),
                side_detector=dict(num_lasers=12, distance=50),
                show_dest_mark=True,
                navigation_module=TrajectoryNavigation,
            ),
            "max_lateral_dist": 4,
            # whether or not to base vehicle class purely on size or whether to evenly sample from all vehicle classes
            "even_sample_vehicle_class": False,
            # do show traffic lights
            "no_light": False,
            "skip_missing_light": False,
            "static_traffic_object": True,
            "no_static_vehicles": False,
            # if true, then any vehicle that is overlapping with another vehicle will be filtered 
            "filter_overlapping_car": False,
            # whether to use the default vehicle model
            "default_vehicle_in_traffic":False,
            "reactive_traffic": False,
        })
        return config

    def __init__(self, config, scenario):
        super().__init__(config)
        self.scenario = scenario

    def done_function(self, vehicle_id: str):
        return False, {}

    def cost_function(self, vehicle_id: str):
        return 0, {}
    
    def reward_function(self, vehicle_id: str):
        return 0, {}

    def _get_observations(self):
        return {DEFAULT_AGENT: self.get_single_observation()}
    
    def setup_engine(self):
        self.engine.register_manager("agent_manager", self.agent_manager)
        self.engine.register_manager("map_manager", ScenarioMapManager())
        self.engine.register_manager("scenario_traffic_manager", ScenarioTrafficManager())
        self.engine.register_manager("scenario_light_manager", ScenarioLightManager()) 
        self.engine.register_manager("data_manager", InMemoryScenarioDataManager(self.scenario))


In [34]:
env = ReplayEnv(config={"use_render": True, "manual_control": False, "agent_policy": ReplayEgoCarPolicy}, scenario=ScenarioDescription(data))

In [35]:
env.reset()
for i in range(1000):
    env.step([0, 1])

Known pipe types:
  glxGraphicsPipe
(1 aux display modules not yet loaded.)
show_dest_mark and show_line_to_dest are not supported in TrajectoryNavigation


97
{'64': neighbor_lanes(lane=<metadrive.component.lane.scenario_lane.ScenarioLane object at 0x7fe168fd41d0>, entry_lanes=None, exit_lanes=None, left_lanes=None, right_lanes=None), '65': neighbor_lanes(lane=<metadrive.component.lane.scenario_lane.ScenarioLane object at 0x7fe17122cb50>, entry_lanes=None, exit_lanes=None, left_lanes=None, right_lanes=None), '66': neighbor_lanes(lane=<metadrive.component.lane.scenario_lane.ScenarioLane object at 0x7fe17122d250>, entry_lanes=None, exit_lanes=None, left_lanes=None, right_lanes=None), '67': neighbor_lanes(lane=<metadrive.component.lane.scenario_lane.ScenarioLane object at 0x7fe17122db50>, entry_lanes=None, exit_lanes=None, left_lanes=None, right_lanes=None), '68': neighbor_lanes(lane=<metadrive.component.lane.scenario_lane.ScenarioLane object at 0x7fe17122e1d0>, entry_lanes=None, exit_lanes=None, left_lanes=None, right_lanes=None), '69': neighbor_lanes(lane=<metadrive.component.lane.scenario_lane.ScenarioLane object at 0x7fe1699e3450>, entry

SystemExit: 

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)


In [36]:
env.close()