In [15]:
from pathlib import Path
import libsumo as ls
import traci
import pandas as pd
import platform
# from time import time
from collections import deque, defaultdict
import os
from tqdm import tqdm
from xml.etree.ElementTree import Element, tostring, SubElement
from xml.dom import minidom
import pickle

In [16]:
from model.batch_predict_action import batch_predict_v2

In [17]:
class DataStructure:

    def __init__(self, max_len: str = 20):

        self.max_len = max_len

        self.observations = deque(maxlen=max_len)
        self.actions = deque(maxlen=max_len)
        self.rewards = deque(maxlen=max_len)
        self.dones = deque(maxlen=max_len)
        self.t_steps = deque(maxlen=max_len)
        self.target_return = deque(maxlen=max_len)

    def add_observation(self, observation):
        self.observations.append(observation)

    def add_action(self, action):
        self.actions.append(action)

    def add_reward(self, reward):
        self.rewards.append(reward)

    def add_done(self, done):
        self.dones.append(done)

    def add_t_step(self, t_step):
        self.t_steps.append(t_step)

    def add_target_return(self, target_return):
        self.target_return.append(target_return)

    def reset(self):
        self.observations.clear()
        self.actions.clear()
        self.rewards.clear()
        self.dones.clear()
        self.t_steps.clear()
        self.target_return.clear()

    def __len__(self):
        return len(self.observations)
    
    def __getitem__(self, key):
        if hasattr(self, key):
            return list(getattr(self, key))
        else:
            raise KeyError(f"Key '{key}' not found in DataStructure.")

    def __repr__(self):
        return (f"DataStructure(observations={list(self.observations)}, "
                f"actions={list(self.actions)}, rewards={list(self.rewards)}, "
                f"dones={list(self.dones)})")

In [18]:
class car_controller():

    _counter = 0 

    def __init__(self,
                 min_gap: float = 1.5,
                 step_duration: float = 0.04,
                 sim_duration_steps: int = 1800,
                 traj_length: int = 20,
                 max_episode_length: int = 4096,
                 warm_up_duration: int = 900,
                 gui: bool = False,
                 begin_time: int = 27900, 
                 auto_start = True,
                 quit_on_exit = True,
                 network_file_path = "./marios_net/athens_net_control.net.xml",
                 route_files = "./marios_net/macro_routes_peak_demand_0.8_1.0_900_5400_0.7.rou.xml",
                 time_to_teleport = 300,
                 scale = 0.7, 
                 ignore_junction_blocker = 60,
                 rerouting_period = 180,
                 routing_algorithm = "astar",
                 rerouting_threads = 4,
                 rerouting_pre_period = 0,
                 rerouting_probability = 1,
                 weights_priority_factor = 2,
                 synchronize_rerouting = True,
                 use_rerouting = False,
                 no_internal_links = True,
                 libsumo: bool = True,
                 seed: int = 0,
                 ignore_route_errors: bool = False,
                 deactivate_lane_change: bool = False,
                 activate_control: bool = True,
                 sumo_output: bool = False,
                 suffix: str = "control",
                 sumo_output_folder: str = "./sumo_results",
                 aggregation_period: int = 60,
                 save_episode_stats: bool = True):
        
        self.label = car_controller._counter
        car_controller._counter += 1

        self.op_system = platform.system()
        self.min_gap = min_gap
        self.traj_length = traj_length
        self.max_episode_length = max_episode_length
        self.step_duration = step_duration
        self.activate_control = activate_control
        self.warm_up_duration = warm_up_duration
        self.leader_max_comfortable_decel = 4.5
        self.sim_duration_steps = sim_duration_steps
        self.sumo_running = False
        self.gui = gui
        self.begin_time = begin_time
        self.auto_start = auto_start
        self.quit_on_exit = quit_on_exit
        self.network_file_path = network_file_path
        self.route_files = route_files
        self.time_to_teleport = time_to_teleport
        self.scale = scale
        self.ignore_junction_blocker = ignore_junction_blocker
        self.rerouting_period = rerouting_period
        self.routing_algorithm = routing_algorithm
        self.rerouting_threads = rerouting_threads
        self.rerouting_pre_period = rerouting_pre_period
        self.rerouting_probability = rerouting_probability
        self.weights_priority_factor = weights_priority_factor
        self.synchronize_rerouting = synchronize_rerouting
        self.use_rerouting = use_rerouting
        self.no_internal_links = no_internal_links
        self.libsumo = libsumo
        self.seed = seed
        self.ignore_route_errors = ignore_route_errors
        self.deactivate_lane_change = deactivate_lane_change
        self.sumo_output = sumo_output
        self.suffix = suffix
        self.sumo_output_folder = sumo_output_folder
        self.aggregation_period = aggregation_period
        self.save_episode_stats = save_episode_stats

        if self.sumo_output:
            os.makedirs(self.sumo_output_folder, exist_ok=True)


    def run_simulation(self):

        if self.sumo_output:
            self.edge_data_output_file_path = os.path.join(self.sumo_output_folder, f"edge_data_output_{self.suffix}.xml")
            self.edge_data_add_file_path = os.path.join(self.sumo_output_folder, f"edge_data_{self.suffix}.add.xml")

            self.create_edge_data_output_add_xml()

        self.create_configuration()

        self.start_sumo()

        # Warm_up
        for _ in tqdm(range(int(self.warm_up_duration//self.step_duration)), desc = "Warm up steps:"):
            self.step_sim()

        self.edge_id_list = self.sumo.edge.getIDList()

        # Simulation
        self.simulate()

        self.close_sumo()

        if self.sumo_output:
            if os.path.exists(self.edge_data_add_file_path):
                os.remove(self.edge_data_add_file_path)

        if self.save_episode_stats: 
            with open(os.path.join(self.sumo_output_folder, f"episode_stats_{self.suffix}.pkl"), "wb") as f:
                pickle.dump(self.episode_stats, f)


    def state_neighbour(self,
                        vehicle_ids: list[str]):
        
        RADIUS_NEIGHBOUR = 100
        
        positions = {follower_id: self.sumo.vehicle.getPosition(follower_id)
                     for follower_id in vehicle_ids}
        
        speeds = {follower_id: self.sumo.vehicle.getSpeed(follower_id) for follower_id in vehicle_ids}

        df_vehicles = pd.DataFrame.from_dict(positions, orient="index", columns=["x", "y"])
        df_vehicles["speed"] = pd.Series(speeds)

        df_cross = df_vehicles.reset_index().merge(
            df_vehicles.reset_index(), how="cross", suffixes=("_ego", "_nbr")
        )

        df_cross["distance"] = (
            (df_cross["x_ego"] - df_cross["x_nbr"]) ** 2
            + (df_cross["y_ego"] - df_cross["y_nbr"]) ** 2
        ) ** 0.5

        df_neighbors = df_cross[df_cross["distance"] <= RADIUS_NEIGHBOUR]

        stats = (
            df_neighbors.groupby("index_ego")
            .agg(
                count_radius=("index_nbr", "count"),
                mean_speed_radius=("speed_nbr", "mean"),
                std_speed_radius=("speed_nbr", "std"),
                mean_dist_radius=("distance", "mean"),
                std_dist_radius=("distance", "std"),
            )
            .fillna(0)
        )

        return stats.to_dict("index")


    def state_follower(self,
                       vehicle_ids: list[str]):

        """State has the vehicle ids that we actually control as keys"""

        state = {}

        for follower_id in vehicle_ids:

            follower_speed = self.sumo.vehicle.getSpeed(follower_id)
            upcoming_tls = self.sumo.vehicle.getNextTLS(follower_id)

            green_light = (
                upcoming_tls[0][3] in {"G", "g", "y"} if len(upcoming_tls) != 0
                else False
            )

            follower_accel = self.sumo.vehicle.getAcceleration(follower_id)
            follower_edge = self.sumo.vehicle.getRoadID(follower_id)

            leader = self.sumo.vehicle.getLeader(follower_id, 200.0)
            if leader is None:
                continue

            leader_id, gap_minus_min_gap = leader
            leader_edge = self.sumo.vehicle.getRoadID(leader_id)
            gap = gap_minus_min_gap + self.min_gap
            leader_speed = self.sumo.vehicle.getSpeed(leader_id)

            kraus_follow_speed = self.sumo.vehicle.getFollowSpeed(
                follower_id,
                follower_speed,
                gap + self.min_gap,
                leader_speed,
                self.leader_max_comfortable_decel,
            )
            leader_in_followers_subregion = follower_edge == leader_edge
            local_reward = self.get_local_reward(follower_speed, kraus_follow_speed) * (
                green_light or leader_in_followers_subregion
            )

            state[follower_id] = {
                "gap": gap,
                "follower_speed": follower_speed,
                "leader_speed": leader_speed,
                "green_light": int(green_light),
                "local_reward": local_reward,
                "follower_accel": follower_accel,
            }

        return state
    

    def get_flow_reward(self,
        prev_edge_vehicle_count: dict[str, int],
        curr_edge_vehicle_count: dict[str, int],
    ) -> dict[str, int]:
        
        flow_rewards = {}
        
        for edge_id, _ in curr_edge_vehicle_count.items():
            diff = curr_edge_vehicle_count[edge_id] - prev_edge_vehicle_count[edge_id]
            flow_rewards[edge_id] = 1 if diff < 0 else 0

        return flow_rewards
    

    def get_local_reward(self,
                         follower_speed, v_safe):

        SIGMA = 0.2

        if follower_speed <= (1 - SIGMA) * v_safe:
            return -1
        if follower_speed <= v_safe:
            return (1 / SIGMA / v_safe) * follower_speed - 1 / SIGMA
        if follower_speed < (1 + SIGMA) * v_safe:
            return -(1 / SIGMA / v_safe) * follower_speed + 1 / SIGMA
        return -1
    

    def simulate(self):
        
        self.episode_stats = {"ep_rew_total": 0,
                              "ep_rew_local": 0,
                              "ep_rew_global": 0,
                              "ep_len": 0,
                              "num_ep": 0}

        # for every edge: keep how many vehicles it has
        prev_edge_vehicle_count = {
            edge_id: self.sumo.edge.getLastStepVehicleNumber(edge_id)
            for edge_id in self.edge_id_list
        }

        self.lane_deact_veh_id_list = set()
        self.prev_edge_dict = dict()
        self.trajectories = dict()
        self.step_counter = dict()
        self.stat_recorder = defaultdict(lambda: defaultdict(int))

        for _ in tqdm(range(int(self.sim_duration_steps//self.step_duration)), desc = "Simulation steps:"):

            self.step_sim()     

            # all vehicles currently in the simulation
            self.vehicle_id_list = list(self.sumo.vehicle.getIDList())
            
            for follower_id in self.vehicle_id_list:
                if follower_id not in self.lane_deact_veh_id_list:
                    self.lane_deact_veh_id_list.add(follower_id)
                    if self.deactivate_lane_change:
                        self.sumo.vehicle.setLaneChangeMode(follower_id, 0b001000000000) # TEST BEHAVIOR
                    else:
                        self.sumo.vehicle.setParameter(follower_id, "laneChangeModel.lcSpeedGain", str(0))

            if not self.activate_control and not self.save_episode_stats:
                continue

            # vicinity state for all vehicles
            veh_state_neighbour = self.state_neighbour(self.vehicle_id_list)

            # follower state for all vehicles, keys are candidate vehicles to control
            veh_state_follower = self.state_follower(self.vehicle_id_list)

            # flow_rewards for all edges
            curr_edge_vehicle_count = {
                edge_id: self.sumo.edge.getLastStepVehicleNumber(edge_id)
                for edge_id in self.edge_id_list
            }
            
            flow_rewards = self.get_flow_reward(prev_edge_vehicle_count, curr_edge_vehicle_count)

            # construct a step
            batch = []
            
            # if "warm_up_6380" in self.vehicle_id_list and "warm_up_6380" not in veh_state_follower.keys():
            #     print("Without follower.")

            
            for vehicle_id in veh_state_follower.keys():

                if vehicle_id in self.step_counter.keys() and \
                      self.step_counter[vehicle_id] >= self.max_episode_length:
                    self.trajectories[vehicle_id].reset()
                    self.step_counter[vehicle_id] = 0    
                    continue

                current_edge_of_vehicle = self.sumo.vehicle.getRoadID(vehicle_id)

                end_episode_flag = False

                if vehicle_id in self.prev_edge_dict.keys():
                    if self.prev_edge_dict[vehicle_id] != current_edge_of_vehicle:
                        end_episode_flag = True
                        self.prev_edge_dict[vehicle_id] = current_edge_of_vehicle
                else:
                    self.prev_edge_dict[vehicle_id] = current_edge_of_vehicle

                if vehicle_id not in self.trajectories.keys():
                    self.trajectories[vehicle_id] = DataStructure(self.traj_length)
                    self.step_counter[vehicle_id] = 0

                if end_episode_flag:
                    self.trajectories[vehicle_id].reset()
                    self.step_counter[vehicle_id] = 0   
                    self.stat_recorder.pop(vehicle_id)   
                    self.update_stats(self.stat_recorder[vehicle_id])       

                # step state
                step_gap = veh_state_follower[vehicle_id]["gap"]
                step_follower_speed = veh_state_follower[vehicle_id]["follower_speed"]
                step_leader_speed = veh_state_follower[vehicle_id]["leader_speed"]
                step_count_radius = veh_state_neighbour[vehicle_id]["count_radius"]
                step_mean_speed_radius = veh_state_neighbour[vehicle_id]["mean_speed_radius"]
                step_std_speed_radius = veh_state_neighbour[vehicle_id]["std_speed_radius"]
                step_mean_dist_radius = veh_state_neighbour[vehicle_id]["mean_dist_radius"]
                step_std_dist_radius = veh_state_neighbour[vehicle_id]["std_dist_radius"]
                step_green_light = veh_state_follower[vehicle_id]["green_light"]

                # step reward
                step_local_reward = veh_state_follower[vehicle_id]["local_reward"]
                step_global_reward = step_green_light * flow_rewards[current_edge_of_vehicle]
                step_reward = step_local_reward + step_global_reward

                self.stat_recorder[vehicle_id]["rewards"] += step_reward
                self.stat_recorder[vehicle_id]["local_rewards"] += step_local_reward
                self.stat_recorder[vehicle_id]["global_rewards"] += step_global_reward
                self.stat_recorder[vehicle_id]["ep_len"] += 1

                # step action
                step_action = veh_state_follower[vehicle_id]["follower_accel"]

                # trajectories
                self.trajectories[vehicle_id].add_observation(
                    [
                        step_gap,
                        step_follower_speed,
                        step_leader_speed,
                        step_count_radius,
                        step_mean_speed_radius,
                        step_std_speed_radius,
                        step_mean_dist_radius,
                        step_std_dist_radius,
                        step_green_light,
                    ]
                )
                self.trajectories[vehicle_id].add_action([step_action])
                self.trajectories[vehicle_id].add_done(False)
                self.trajectories[vehicle_id].add_reward(step_reward)

                if self.step_counter[vehicle_id] == 0:
                    self.trajectories[vehicle_id].add_target_return(0)
                else:
                    self.trajectories[vehicle_id].add_target_return(self.trajectories[vehicle_id].target_return[-1] - step_reward)

                self.trajectories[vehicle_id].add_t_step(self.step_counter[vehicle_id])
                
                # the sauce
                if len(self.trajectories[vehicle_id]) == 20:
                    batch.append(vehicle_id)

                self.step_counter[vehicle_id] += 1
                    
            # batch predict and control
            if len(batch) and self.activate_control:
                batch_preds = batch_predict_v2([self.trajectories[vehicle_id] for vehicle_id in batch], 0)
                actions_batch = {vehicle_id: pa[0] for vehicle_id, pa in zip(batch, batch_preds)}

                # control predicted actions
                for vehicle_id, act in actions_batch.items():
                        self.sumo.vehicle.setAcceleration(vehicle_id, act, self.step_duration)
        
    
    def update_stats(self, dict_):

        self.episode_stats["ep_rew_total"] += dict_["rewards"]
        self.episode_stats["ep_rew_local"] += dict_["local_rewards"]
        self.episode_stats["ep_rew_global"] += dict_["global_rewards"]
        self.episode_stats["ep_len"] += dict_["ep_len"]
        self.episode_stats["num_ep"] += 1


    def update(self):
        
        for veh_id in self.sumo.simulation.getArrivedIDList():
            self.lane_deact_veh_id_list.remove(veh_id)
            self.prev_edge_dict.pop(veh_id)
            self.trajectories.pop(veh_id)
            self.step_counter.pop(veh_id)
            try:
                self.stat_recorder.pop(veh_id)
            except:
                pass


    def prettify(self, elem):

        rough_string = tostring(elem, 'utf-8')
        reparsed = minidom.parseString(rough_string).toprettyxml(indent="  ")
        
        return reparsed


    def create_edge_data_output_add_xml(self):
        
        root = Element("additional")

        SubElement(root, "edgeData", 
                   id = f"edge_data_mfd",
                   file = os.path.abspath(self.edge_data_output_file_path),
                   period = str(self.aggregation_period),
                   excludeEmpty = "defaults")
        
        my_xml = self.prettify(root)
        myfile = open(self.edge_data_add_file_path, "w")
        myfile.write(my_xml)
        myfile.close()


    def step_sim(self):
        self.sumo.simulationStep()


    def start_sumo(self):
        
        if self.libsumo:
            ls.start(self.sumo_cmd)
            self.sumo = ls
            self.sumo_running = True
        else:
            traci.start(self.sumo_cmd, label = self.label)
            self.sumo = traci.getConnection(self.label)
            self.sumo_running = True

    
    def close_sumo(self):

        if self.sumo_running:
            if self.libsumo:
                ls.close()
                self.sumo_running = False
            else:
                traci.switch(self.label)
                traci.close(False)
                self.sumo_running = False
 

    def create_configuration(self):

        self.sumo_cmd = [self.create_sumobinary()]

        if self.auto_start:
            self.sumo_cmd += ["-S"]

        if self.quit_on_exit:
            self.sumo_cmd += ["-Q"]

        self.sumo_cmd += ["-n", self.network_file_path]

        if isinstance(self.route_files, str):
            self.sumo_cmd += ["-r", self.route_files]
        else:
            route_file_str = ""
            for route_file in self.route_files:
                route_file_str += route_file + ", "
            self.sumo_cmd += ["-r", route_file_str[:-2]]
                        
        self.sumo_cmd += ["-b", str(self.begin_time), 
                          "--time-to-teleport", str(self.time_to_teleport), 
                          '--no-warnings', 'True',
                          "--no-step-log", 'True',
                          "--scale", str(self.scale),
                          "--collision.action", "none",
                          "--ignore-junction-blocker", str(self.ignore_junction_blocker),
                          "--step-length", str(self.step_duration)]
        
        if self.use_rerouting:
            self.sumo_cmd += ["--device.rerouting.period", str(self.rerouting_period),
                              "--routing-algorithm", self.routing_algorithm,
                              "--device.rerouting.threads", str(self.rerouting_threads),
                              "--device.rerouting.pre-period", str(self.rerouting_pre_period),
                              "--device.rerouting.probability", str(self.rerouting_probability),
                              "--weights.priority-factor", str(self.weights_priority_factor)]
                        
        if self.sumo_output:
            self.sumo_cmd += ["--tripinfo-output.write-unfinished", "true",
                              "--tripinfo-output.write-undeparted", "true",
                              "--statistic-output", os.path.join(self.sumo_output_folder, f"statistics_{self.suffix}.xml"),
                              "--duration-log.statistics", "true"]
            
            self.sumo_cmd += ["-a", self.edge_data_add_file_path]

        if self.synchronize_rerouting:
            self.sumo_cmd += ["--device.rerouting.synchronize", "true"]

        if self.ignore_route_errors:
            self.sumo_cmd += ["--ignore-route-errors", "true"]
            
        if self.no_internal_links:
            self.sumo_cmd += ["--no-internal-links", "true"]

        if self.seed is None:
            self.sumo_cmd += ["--random"]
        else:
            self.sumo_cmd += ["--seed", str(self.seed)]


    def create_sumobinary(self):
        
        if self.op_system == "Windows":
            sumo_location = Path('C:\Program Files (x86)') / 'Eclipse' / 'Sumo' / 'bin'
        else:
            sumo_location = Path('/usr/bin/')

        if self.gui: 
            sumoBinary = sumo_location / 'sumo-gui.exe'
        else:
            sumoBinary = sumo_location / 'sumo.exe'

        return str(sumoBinary)

In [19]:
env = car_controller(gui = False,
                     libsumo = True,
                     warm_up_duration = 900,
                     step_duration = 1,
                     sumo_output = True,
                     suffix = "control",
                     activate_control = True,
                     sim_duration_steps = 1800)

In [20]:
env.run_simulation()

Loading net-file from './marios_net/athens_net_control.net.xml' ... done (88ms).
Loading additional-files from './sumo_results/edge_data_control.add.xml' ... done (1ms).
Loading done.
Simulation version 1.20.0 started via libsumo with time: 27900.00.


Warm up steps:: 100%|██████████| 900/900 [00:03<00:00, 276.65it/s] 
Simulation steps::   7%|▋         | 123/1800 [00:56<14:37,  1.91it/s]