In [1]:
# os.environ["KMP_DUPLICATE_LIB_OK"] = "True"

from collections import deque
import random
# random_seed = 0
# random.seed(random_seed)

import os
import sys
from pathlib import Path
if 'SUMO_HOME' in os.environ:
    tools = os.path.join(os.environ['SUMO_HOME'], 'tools')
    sys.path.append(tools)
else:
    sys.exit("Please declare the environment variable 'SUMO_HOME'")
import traci
import sumolib
from ray.rllib.env.multi_agent_env import MultiAgentEnv
import numpy as np
import pandas as pd
# from traffic_signal_Onnut import TrafficSignal

# if 'SUMO_HOME' in os.environ:
#     tools = os.path.join(os.environ['SUMO_HOME'], 'tools')
#     sys.path.append(tools)
# else:
#     sys.exit("Please declare the environment variable 'SUMO_HOME'")
from gym import spaces
import gym
gym.logger.set_level(40)

import torch
from torch import nn
# from dqn import Network


# from buffer import Buffer
# from madqn import maDQN
# from env_Onnut import SumoEnvironment
# import gym
import time

class Buffer:
    def __init__(self,n_agents,buffer_size,batch_size):
        self.n_agents = n_agents
        self.batch_size = batch_size
        self.replay_buffers = []
        for agent_idx in n_agents:
            self.replay_buffers.append(deque(maxlen=buffer_size))

    def store(self,transition):
        i = 0
        for agent_idx in self.n_agents:
            # print(agent_idx)
            # obs = transition[0][agent_idx]
            # actions = transition[1][agent_idx]
            # rewards = transition[2][agent_idx]
            # dones = transition[3][agent_idx]
            # new_obs = transition[4][agent_idx]

            obs = transition.get((agent_idx,'obs'))
            actions = transition.get((agent_idx,'actions'))
            rewards = transition.get((agent_idx,'rewards'))
            dones = transition.get((agent_idx,'dones'))
            new_obs = transition.get((agent_idx,'new_obs'))

            agent_transition = (obs, actions, rewards, dones, new_obs)
            self.replay_buffers[i].append(agent_transition)
            i+=1

    def sample(self):
        samples = []
        for agent_idx in range(len(self.n_agents)):
            samples.append(random.sample(self.replay_buffers[agent_idx], self.batch_size))
        return samples


class SumoEnvironment(MultiAgentEnv):
    """
    SUMO Environment for Traffic Signal Control

    :param net_file: (str) SUMO .net.xml file
    :param phases: (traci.trafficlight.Phase list) Traffic Signal phases definition
    :param out_csv_name: (str) name of the .csv output with simulation results. If None no output is generated
    :param use_gui: (bool) Wheter to run SUMO simulation with GUI visualisation
    :param num_seconds: (int) Number of simulated seconds on SUMO
    :param delta_time: (int) Simulation seconds between actions
    :param min_green: (int) Minimum green time in a phase
    :param max_green: (int) Max green time in a phase
    :single_agent: (bool) If true, it behaves like a regular gym.Env. Else, it behaves like a MultiagentEnv (https://github.com/ray-project/ray/blob/master/python/ray/rllib/env/multi_agent_env.py)
    """

    def __init__(self, net_file, out_csv_name=None, use_gui=False, num_seconds=68400,
                 time_to_teleport=900, delta_time=15, yellow_time=0, min_green=15
                 , max_green_onnut=135, max_green_virtual=30, single_agent=False):

        self._net = net_file
        self.use_gui = use_gui
        if self.use_gui:
            self._sumo_binary = sumolib.checkBinary('sumo-gui')
        else:
            self._sumo_binary = sumolib.checkBinary('sumo')

        self.sim_max_time = num_seconds
        self.delta_time = delta_time  # seconds on sumo at each step
        self.begin_time = 54000
        self.time_to_teleport = time_to_teleport
        self.min_green = min_green
        # max_green_onnut = 135,
        # max_green_virtual =  30,
        self.max_green_onnut = max_green_onnut
        self.max_green_virtual = max_green_virtual
        self.max_green =  {'cluster_1088409501_272206263_5136790697_70702637':max_green_onnut,'gneJ42':max_green_virtual}
        self.yellow_time = yellow_time
        # self.random_number = random_number

        traci.start([sumolib.checkBinary('sumo'), '-n', self._net])  # start only to retrieve information

        self.single_agent = single_agent
        self.ts_ids = ['cluster_1088409501_272206263_5136790697_70702637','gneJ42']
        self.ts_junction = {'cluster_1088409501_272206263_5136790697_70702637':'ONNUT','gneJ42':'VIRTUAL'}
        self.traffic_signals = {ts: TrafficSignal(self,
                                                  ts,
                                                  self.delta_time,
                                                  self.yellow_time,
                                                  self.min_green,
                                                  self.max_green[ts],
                                                  self.begin_time,
                                                  self.ts_junction[ts]) for ts in self.ts_ids}

        self.observations = {ts: None for ts in self.ts_ids}
        self.rewards = {ts: None for ts in self.ts_ids}
        self.teleport_numbers = 0
        self.reward_range = (-float('inf'), float('inf'))
        self.run = 0
        self.metrics = []
        self.out_csv_name = out_csv_name
        traci.close()

    def save_score(self):
        self.save_csv(self.out_csv_name, self.run)
        self.run += 1

    def save_score_max_green(self):
        self.save_csv_max_green(self.out_csv_name, self.run, self.max_green_onnut, self.max_green_virtual)
        self.run += 1

    def reset(self,random_seed):
        if self.run != 0:
            traci.close()
            # self.save_csv(self.out_csv_name, self.run)
        # self.run += 1
        self.metrics = []

        traci.start([self._sumo_binary,
                     '-n', self._net,
                     '-c', "onnut_ake.sumocfg",
                     '--time-to-teleport', str(self.time_to_teleport),
                     '--start', 'true',
                     '--quit-on-end','true',
                     "--no-internal-links",'true',
                     "--ignore-junction-blocker",'-1',
                     '--no-warnings', 'true',
                     '--seed', str(random_seed),
                     ])

        self.traffic_signals = {ts: TrafficSignal(self,
                                                  ts,
                                                  self.delta_time,
                                                  self.yellow_time,
                                                  self.min_green,
                                                  self.max_green[ts],
                                                  self.begin_time,
                                                  self.ts_junction[ts]) for ts in self.ts_ids}


        if self.single_agent:
            return self._compute_observations()[self.ts_ids[0]]
        else:
            return self._compute_observations()

    @property
    def sim_step(self):
        """
        Return current simulation second on SUMO
        """
        return traci.simulation.getTime()

    def step(self, action):
        # No action, follow fixed TL defined in self.phases
        if action is None or action == {}:
            for _ in range(self.delta_time):
                self._sumo_step()
                if self.sim_step % 15 == 0:
                    info = self._compute_step_info()
                    self.metrics.append(info)
        else:
            self._apply_actions(action)

            time_to_act = False

            #---- Reset reward for new timestep ----#
            self.rewards[self.ts_ids[0]] = 0
            self.rewards[self.ts_ids[1]] = 0

            self.teleport_numbers = 0

            # i = 0
            while not time_to_act:
                self._sumo_step()

                rewards = self._compute_rewards()
                for k, v in rewards.items():
                    temp = self.rewards.get(k)
                    if temp == None:
                        temp = 0
                    self.rewards[k] = temp+v

                teleport_number = self._compute_teleports()
                if teleport_number == None :
                    teleport_number = 0
                self.teleport_numbers += teleport_number

                # for k, v in teleport_number.items():
                #     temp = self.teleport_numbers.get(k)
                #     if temp == None:
                #         temp = 0
                #     self.teleport_numbers[k] = temp+v

                for ts in self.ts_ids:
                    self.traffic_signals[ts].update()
                    if self.traffic_signals[ts].time_to_act:
                        time_to_act = True

                if self.sim_step % 15 == 0:
                    info = self._compute_step_info()
                    self.metrics.append(info)

        observations = self._compute_observations()
        # rewards = self._compute_rewards()
        done = {'__all__': self.sim_step > self.sim_max_time}
        done.update({ts_id: False for ts_id in self.ts_ids})

        if self.single_agent:
            return observations[self.ts_ids[0]], self.rewards[self.ts_ids[0]], done['__all__'], {}
        else:
            return observations, self.rewards, done, {}

    def _apply_actions(self, actions):
        """
        Set the next green phase for the traffic signals
        :param actions: If single-agent, actions is an int between 0 and self.num_green_phases (next green phase)
                        If multiagent, actions is a dict {ts_id : greenPhase}
        """   
        if self.single_agent:
            self.traffic_signals[self.ts_ids[0]].set_next_phase(actions)
        else:
            for ts, action in actions.items():
                self.traffic_signals[ts].set_next_phase(action)
    
    def _compute_observations(self):
        self.observations.update({ts: self.traffic_signals[ts].compute_observation() for ts in self.ts_ids if self.traffic_signals[ts].time_to_act})
        return {ts: self.observations[ts].copy() for ts in self.observations.keys() if self.traffic_signals[ts].time_to_act}

    def _compute_rewards(self):
        # return {ts: self.traffic_signals[ts].compute_reward() for ts in self.ts_ids if self.traffic_signals[ts].time_to_act}
        return {ts: self.traffic_signals[ts].compute_reward() for ts in self.ts_ids}

    def _compute_teleports(self):
        # return {ts: self.traffic_signals[ts].compute_reward() for ts in self.ts_ids if self.traffic_signals[ts].time_to_act}
        return self.traffic_signals[self.ts_ids[0]].compute_teleport()
        # {ts: self.traffic_signals[ts].compute_teleport() for ts in self.ts_ids}

    @property
    def observation_space(self):
        return self.traffic_signals[self.ts_ids[0]].observation_space
    
    @property
    def action_space(self):
        return self.traffic_signals[self.ts_ids[0]].action_space
    
    def observation_spaces(self, ts_id):
        return self.traffic_signals[ts_id].observation_space
    
    def action_spaces(self, ts_id):
        return self.traffic_signals[ts_id].action_space

    def _sumo_step(self):
        traci.simulationStep()

    def _compute_step_info(self):
        return {
            'step_time': self.sim_step,
            'onnut_action': self.traffic_signals[self.ts_ids[0]].current_phase,
            'virtual_action': self.traffic_signals[self.ts_ids[1]].current_phase,        
            # 'reward_onnut': sum(self.traffic_signals[ts].compute_reward() for ts in self.ts_ids if self.traffic_signals[ts].time_to_act),
            'reward_onnut' : self.rewards[self.ts_ids[0]],
            'reward_virtual' : self.rewards[self.ts_ids[1]],
            'total_travel_time_onnut' : self.traffic_signals[self.ts_ids[0]].get_travel_time(),
            'total_travel_time_virtual' : self.traffic_signals[self.ts_ids[1]].get_travel_time(),
            # 'total_travel_time_all' : self.traffic_signals[self.ts_ids[0]].get_travel_time_all()
            'teleport_number' : self.teleport_numbers #doesn't care about teleport number
        }

    def close(self):
        traci.close()

    def save_csv(self, out_csv_name, run):
        if out_csv_name is not None:
            df = pd.DataFrame(self.metrics)
            Path(Path(out_csv_name).parent).mkdir(parents=True, exist_ok=True)
            df.to_csv(out_csv_name + '_run{}'.format(run) + '.csv', index=False)

    def save_csv_max_green(self, out_csv_name, run, max_green_onnut, max_green_virtual):
        if out_csv_name is not None:
            df = pd.DataFrame(self.metrics)
            Path(Path(out_csv_name).parent).mkdir(parents=True, exist_ok=True)
            df.to_csv(out_csv_name + '_onnut{}'.format(max_green_onnut) +'_virtual{}'.format(max_green_virtual)+ '_run{}'.format(run) + '.csv', index=False)

    def getTime(self,time):
        time=time%(24*3600)
        hours=time//3600
        time%=3600
        minutes=time//60
        time%=60
        seconds=time
        periods=[('hours',int(hours)),('minutes',int(minutes)),('seconds',int(seconds))]
        time_string=':'.join('{}'.format(value) for name,value in periods)
        return time_string



class TrafficSignal:
    """
    This class represents a Traffic Signal of an intersection
    It is responsible for retrieving information and changing the traffic phase using Traci API
    """

    def __init__(self, env, ts_id, delta_time, yellow_time, min_green, max_green, begin_time, junction):
        self.id = ts_id
        self.env = env
        self.delta_time = delta_time
        self.yellow_time = yellow_time
        self.min_green = min_green
        self.max_green = max_green
        self.current_phase = 0
        self.is_yellow = False
        self.time_since_last_phase_change = 0
        self.next_action_time = begin_time
        self.last_measure = 0.0
        self.last_reward = None

        self.phases = traci.trafficlight.getCompleteRedYellowGreenDefinition(self.id)[0].phases
        self.num_phases = len(self.phases)  # Number of green phases

        #============ Indicate junction =========#
        self.junction = junction  #Using for difference in junction
        #========================================#


        #self.observation_space = spaces.Box(low=np.zeros(self.num_phases + 1 + self.get_observation_places()), high=np.ones(self.num_phases + 1 +self.get_observation_places()), dtype=np.float32)

        if self.junction == "ONNUT" :
            self.observation_space = spaces.Box(low=np.zeros(self.num_phases + 1 + 18), high=np.ones(self.num_phases + 1 + 18), dtype=np.float32)
        elif self.junction == "VIRTUAL" :
            self.observation_space = spaces.Box(low=np.zeros(self.num_phases + 1 + 15), high=np.ones(self.num_phases + 1 + 15), dtype=np.float32)    

        print('Observation space of ', junction,'is :', self.observation_space)
        # print('>>>>>>>>>>>>>>>>>>>>>>>>')
        self.action_space = spaces.Discrete(self.num_phases)
        print('Action space of ', junction,'is :', self.action_space)

        if self.junction == 'ONNUT' :
            #SB,WB,NB
            # ONNUT_UPSTREAM_DETECTOR_ID = 4,5,6,9,11,12
            # ONNUT_DOWNSTREAM_DETECTOR_ID = 3,8,10

            #Dict with key: indicate up/downstream , value: list of detectorID

            self.SB_detectorID_dict = {
            'UPSTREAM' :   [ "S_ONT_04_0","S_ONT_04_1","S_ONT_04_2",
                             "S_ONT_09_0","S_ONT_09_1","S_ONT_09_2"] ,

            'DOWNSTREAM' : [ "S_ONT_03_0","S_ONT_03_1","S_ONT_03_2",
                             "S_ONT_07_0","S_ONT_07_1","S_ONT_07_2"]
                                    }

            self.WB_detectorID_dict = {
            'UPSTREAM' : ["S_ONT_05_0","S_ONT_05_1"],
            'DOWNSTREAM' : ["S_ONT_18_0","S_ONT_18_1"]
                                    }

            self.NB_detectorID_dict = {
            'UPSTREAM' : [
                    "S_ONT_06_0","S_ONT_06_1","S_ONT_06_2","S_ONT_06_3",
                    "S_ONT_11_0","S_ONT_11_1","S_ONT_11_2","S_ONT_11_3",
                    "S_ONT_12_0","S_ONT_12_1","S_ONT_12_2","S_ONT_12_3"
                                        ],
            'DOWNSTREAM' : [
                    "S_ONT_08_0",
                    "S_ONT_10_0","S_ONT_10_1","S_ONT_10_2"
                                            ]
                                    }

            self.loopID = [
                'Induction_Loop_1','Induction_Loop_2','Induction_Loop_3',
                'Induction_Loop_4','Induction_Loop_5','Induction_Loop_6',
                'Induction_Loop_7','Induction_Loop_8','Induction_Loop_9','Induction_Loop_10' 
                        ]

            self.edgeID_for_MOE = [
                #NB
                '824116560#0','824116560#1','824116560#2','824116560#3','824816455','220429932#0','824116561-AddedOffRampEdge','113135465#5',
                '750035412#1-AddedOffRampEdge','824816456','113135465#0','113135465#2',
                #SB
                '751454884#3','751454884#2','751454884#0',
                '751454885#0','751454885#2','751454885#3','751454885#5',
                #WB
                '824456410#0','824456410#2','824456410#2.41',
                '156591171#2','-824456410#2','156591171#0'
                                ]

            # self.edgeID_all = traci.edge.getIDList()
                
        elif self.junction == "VIRTUAL" :

            # SB, WB, BigC
            #Dict with key: indicate up/downstream , value: list of detectorID
            #VIRTUAL_UPSTREAM_DETECTOR_ID = 2,13,14,17
            #VIRTUAL_DOWNSTREAM_DETECTOR_ID = 1,15,16

            self.SB_detectorID_dict = {
            'UPSTREAM' : [
                    "S_ONT_13_0","S_ONT_13_1",
                    "S_ONT_14_0","S_ONT_14_1",
                                            ],
            'DOWNSTREAM' : [
                    "S_ONT_15_0","S_ONT_15_1",
                    "S_ONT_16_0","S_ONT_16_1"
                                            ]
                                        }
            self.WB_detectorID_dict = {
            'UPSTREAM' : [
                    "S_ONT_02_0","S_ONT_02_1",
                    "S_ONT_13_0","S_ONT_13_1",
                    "S_ONT_14_0","S_ONT_14_1",
                    "S_ONT_17_0"
                                            ],
            'DOWNSTREAM' : [
                    "S_ONT_01_0","S_ONT_01_1",
                    "S_ONT_15_0","S_ONT_15_1",
                    "S_ONT_16_0","S_ONT_16_1"
                                            ]
                                        }
            self.BIGC_detectorID_dict = {
            'UPSTREAM' : ["S_ONT_17_0"],
            'DOWNSTREAM' : []
                                        }

            self.loopID = [
                "Virtual_loop_1","Virtual_loop_2","Virtual_loop_3",
                "Virtual_loop_4","Virtual_loop_5"]
            
            self.edgeID_for_MOE = [
                #SB
                '824456410#2','824456410#2.41','824456409#0','-gneE25',
                '-824456410#2','156591171#0','-824456409#0','-gneE24',
                #WB
                'gneE34','824456409#5','824456409#6','113135397#1','113135397#3',
                'gneE33','-824456409#5','-113135397#0','-113135397#2','-113135397#4',
                #BigC
                'gneE32','gneE29'
                                ]
            # self.edgeID_all = []

    @property
    def phase(self):
        return traci.trafficlight.getPhase(self.id)

    @property
    def time_to_act(self):
        # print(self.next_action_time , self.env.sim_step,self.next_action_time == self.env.sim_step)
        return self.next_action_time == self.env.sim_step

    def update(self):
        self.time_since_last_phase_change += 1

    def set_next_phase(self, new_phase):
        self.new_phase = new_phase
        # print('*************************************************')
        # print('tls id : ', self.id)
        # print('current phase : ', self.phase)
        # print('new phase :', self.new_phase)
        # print('current sim time : ',self.env.sim_step)
        # print('self.time_since_last_phase_change',self.time_since_last_phase_change)

        if self.phase == self.new_phase and self.time_since_last_phase_change < self.min_green:
            # print('current phase and agent\'s action phase are equal but duration of current phase is less than min green time')
            self.next_action_time = self.env.sim_step +  self.delta_time

        elif self.phase == self.new_phase and self.time_since_last_phase_change >= self.min_green:
            # print('current phase and agent\'s action phase are equal, and still less than max green time')
            self.next_action_time = self.env.sim_step + self.delta_time

        elif self.phase == self.new_phase and self.time_since_last_phase_change > self.max_green:
            # print('duration of current phase is over max_green now')
            if self.new_phase +1 >=self.num_phases:
                self.current_phase = 0
            else:
                self.current_phase = self.new_phase + 1

            traci.trafficlight.setPhase(self.id, self.current_phase)
            self.next_action_time = self.env.sim_step + self.delta_time
            self.time_since_last_phase_change = 0

        elif self.phase != self.new_phase and self.time_since_last_phase_change < self.min_green:
            # print('current phase and agent\'s action phase are not equal but duration of current phase is less than min green time')
            self.next_action_time = self.env.sim_step +  self.delta_time

        elif self.phase != self.new_phase and self.time_since_last_phase_change >= self.min_green:
            # print('current phase and agent\'s action phase are not equal, and duration is greater than min green time')
            self.current_phase = self.new_phase
            traci.trafficlight.setPhase(self.id, self.current_phase)
            self.next_action_time = self.env.sim_step + self.delta_time
            self.time_since_last_phase_change = 0
        #######################################################

        # print(self.next_action_time)


    def compute_observation(self):
        phase_id = [1 if self.current_phase == i else 0 for i in range(self.num_phases)]  # one-hot encoding
        min_green = [0 if self.time_since_last_phase_change < self.min_green else 1]
        UPSTREAM_OBS = []
        DOWNSTREAM_OBS = []
        if self.junction == "ONNUT" :
            occu_SB_UP,occu_WB_UP,occu_NB_UP = self.get_occupancy_average_percent(indicate="UPSTREAM")
            flow_SB_UP,flow_WB_UP,flow_NB_UP = self.get_flow_sum(indicate="UPSTREAM")
            unjam_SB_UP,unjam_WB_UP,unjam_NB_UP = self.get_flow_sum(indicate="UPSTREAM")

            UPSTREAM_OBS = [occu_SB_UP,flow_SB_UP,unjam_SB_UP,
                            occu_WB_UP,flow_WB_UP,unjam_WB_UP,
                            occu_NB_UP,flow_NB_UP,unjam_NB_UP]

            occu_SB_DOWN,occu_WB_DOWN,occu_NB_DOWN = self.get_occupancy_average_percent(indicate="DOWNSTREAM")
            flow_SB_DOWN,flow_WB_DOWN,flow_NB_DOWN = self.get_flow_sum(indicate="DOWNSTREAM")
            unjam_SB_DOWN,unjam_WB_DOWN,unjam_NB_DOWN = self.get_flow_sum(indicate="DOWNSTREAM")

            DOWNSTREAM_OBS = [occu_SB_DOWN,flow_SB_DOWN,unjam_SB_DOWN,
                              occu_WB_DOWN,flow_WB_DOWN,unjam_WB_DOWN,
                              occu_NB_DOWN,flow_NB_DOWN,unjam_NB_DOWN]

        elif self.junction == "VIRTUAL" :
            occu_SB_UP,occu_WB_UP,occu_BIGC_UP = self.get_occupancy_average_percent(indicate="UPSTREAM")
            flow_SB_UP,flow_WB_UP,flow_BIGC_UP = self.get_flow_sum(indicate="UPSTREAM")
            unjam_SB_UP,unjam_WB_UP,unjam_BIGC_UP = self.get_flow_sum(indicate="UPSTREAM")

            UPSTREAM_OBS = [occu_SB_UP,flow_SB_UP,unjam_SB_UP,
                            occu_WB_UP,flow_WB_UP,unjam_WB_UP,
                            occu_BIGC_UP,flow_BIGC_UP,unjam_BIGC_UP]

            occu_SB_DOWN,occu_WB_DOWN,occu_BIGC_DOWN = self.get_occupancy_average_percent(indicate="DOWNSTREAM")
            flow_SB_DOWN,flow_WB_DOWN,flow_BIGC_DOWN = self.get_flow_sum(indicate="DOWNSTREAM")
            unjam_SB_DOWN,unjam_WB_DOWN,unjam_BIGC_DOWN = self.get_flow_sum(indicate="DOWNSTREAM")

            DOWNSTREAM_OBS = [occu_SB_DOWN,flow_SB_DOWN,unjam_SB_DOWN,
                            occu_WB_DOWN,flow_WB_DOWN,unjam_WB_DOWN]   


        observation = np.array(phase_id + min_green + UPSTREAM_OBS + DOWNSTREAM_OBS , dtype=np.float32)
        # print(len(observation))
        # print('------------------------------------------------------------------')
        return observation
            
    def compute_reward(self):
        self.last_reward = self._throughput_reward()
        return self.last_reward

    def compute_teleport(self):
        self.last_teleport = traci.simulation.getEndingTeleportNumber()
        return self.last_teleport


    def _throughput_reward(self):
        ####################### detectors for each intersection ######################################################
        ONNUT_loopcoil = ['Induction_Loop_1','Induction_Loop_2','Induction_Loop_3',
                          'Induction_Loop_4','Induction_Loop_5','Induction_Loop_6',
                          'Induction_Loop_7','Induction_Loop_8','Induction_Loop_9','Induction_Loop_10']

        VIRTUAL_loopcoil = [
            "Virtual_loop_1","Virtual_loop_2","Virtual_loop_3",
            "Virtual_loop_4","Virtual_loop_5"]

        if self.junction == 'VIRTUAL':
            self.throughput = self.get_throughput(VIRTUAL_loopcoil)

        elif self.junction == 'ONNUT':
            self.throughput = self.get_throughput(ONNUT_loopcoil)

        return self.throughput

    def get_throughput(self,loopcoilIDs):
        throughput = 0
        for id in loopcoilIDs:

            laneID = traci.inductionloop.getLaneID(id)
            edgeID = traci.lane.getEdgeID(laneID)
            speed = traci.edge.getLastStepMeanSpeed(edgeID)
            if speed > 0:
                throughput += traci.inductionloop.getLastStepVehicleNumber(id)

        # # throughput = sum([traci.inductionloop.getLastStepVehicleNumber(i) for i in loopcoilIDs if traci.inductionloop.getLastStepMeanSpeed(i) > 0])
        return throughput
    ########################################################################
    #
    # getting attention places for each intersection
    #
    def get_observation_places(self):

        Jvirtual = {
            'SB_UPSTREAM' : [
                            "S_ONT_13_0","S_ONT_13_1",
                            "S_ONT_14_0","S_ONT_14_1",
                            ],
            'SB_DOWNSTREAM' : [
                            "S_ONT_15_0","S_ONT_15_1",
                            "S_ONT_16_0","S_ONT_16_1"
                            ],
            'WB_UPSTREAM' : [
                "S_ONT_02_0","S_ONT_02_1",
                "S_ONT_13_0","S_ONT_13_1",
                "S_ONT_14_0","S_ONT_14_1",
                "S_ONT_17_0"
            ],
            'WB_DOWNSTREAM' : [
                "S_ONT_01_0","S_ONT_01_1",
                "S_ONT_15_0","S_ONT_15_1",
                "S_ONT_16_0","S_ONT_16_1"
            ],

            'BIGC_UPSTREAM' : ["S_ONT_17_0"],
            # 'BIGC_DOWNSTREAM' : []
        }

        JOnnut = {
                'SB_UPSTREAM' : [
                    "S_ONT_04_0","S_ONT_04_1","S_ONT_04_2",
                    "S_ONT_09_0","S_ONT_09_1","S_ONT_09_2"
                ],
                'SB_DOWNSTREAM' : ["S_ONT_03_0","S_ONT_03_1","S_ONT_03_2",
                                   "S_ONT_07_0","S_ONT_07_1","S_ONT_07_2"],

                'WB_UPSTREAM' : ["S_ONT_05_0","S_ONT_05_1"],
                'WB_DOWNSTREAM' : ["S_ONT_18_0","S_ONT_18_1"],

                'NB_UPSTREAM' : [
                    "S_ONT_06_0","S_ONT_06_1","S_ONT_06_2","S_ONT_06_3",
                    "S_ONT_11_0","S_ONT_11_1","S_ONT_11_2","S_ONT_11_3",
                    "S_ONT_12_0","S_ONT_12_1","S_ONT_12_2","S_ONT_12_3"
                                ],
                'NB_DOWNSTREAM' : [
                    "S_ONT_08_0",
                    "S_ONT_10_0","S_ONT_10_1","S_ONT_10_2"
                ]
                }

        MAP = [Jvirtual, JOnnut]

        state_places = None
        if self.junction == 'ONNUT':
            state_places = MAP[0]
        elif self.junction == 'VIRTUAL':
            state_places = MAP[1]

        return len(state_places)

    def get_flow_sum(self,indicate):
        #     Speed (metres per sec) = flow (vehicle per sec) / density (veh per metre), Ajarn chaodit
        #         flow= int(densityPerLane) * float(meanSpeed)#flow per lane
        #     print('LastStepVehicleNumber', sum([traci.lanearea.getLastStepVehicleNumber(e) for e in detector_id]))
        #     print('length', sum([traci.lanearea.getLength(i) for i in detector_id]))
        #     density = sum([traci.lanearea.getLastStepVehicleNumber(e) for e in detector_id])/\
        #     sum([traci.lanearea.getLength(i) for i in detector_id])
        #     print('density', density)
        flow_SB = 0
        flow_WB = 0
        flow_OtherBound = 0
        if indicate == "UPSTREAM" :
            flow_SB = sum(([traci.lanearea.getLastStepVehicleNumber(e)*traci.lanearea.getLastStepMeanSpeed(e)/traci.lanearea.getLength(e)
                            for e in self.SB_detectorID_dict['UPSTREAM'] if traci.lanearea.getLastStepMeanSpeed(e) >= 0]))
            flow_WB = sum(([traci.lanearea.getLastStepVehicleNumber(e)*traci.lanearea.getLastStepMeanSpeed(e)/traci.lanearea.getLength(e)
                            for e in self.WB_detectorID_dict['UPSTREAM'] if traci.lanearea.getLastStepMeanSpeed(e) >= 0]))

            if self.junction == "ONNUT" :
                #OtherBound = NB in ONNUT
                flow_OtherBound = sum(([traci.lanearea.getLastStepVehicleNumber(e)*traci.lanearea.getLastStepMeanSpeed(e)/traci.lanearea.getLength(e)
                                        for e in self.NB_detectorID_dict['UPSTREAM'] if traci.lanearea.getLastStepMeanSpeed(e) >= 0]))

            elif self.junction == "VIRTUAL" :
                #OtherBound = BIGC in onnut
                flow_OtherBound = sum(([traci.lanearea.getLastStepVehicleNumber(e)*traci.lanearea.getLastStepMeanSpeed(e)/traci.lanearea.getLength(e)
                                        for e in self.BIGC_detectorID_dict['UPSTREAM'] if traci.lanearea.getLastStepMeanSpeed(e) >= 0]))

        elif indicate == "DOWNSTREAM" :

            flow_SB = sum(([traci.lanearea.getLastStepVehicleNumber(e)*traci.lanearea.getLastStepMeanSpeed(e)/traci.lanearea.getLength(e)
                            for e in self.SB_detectorID_dict["DOWNSTREAM"] if traci.lanearea.getLastStepMeanSpeed(e) >= 0]))
            flow_WB = sum(([traci.lanearea.getLastStepVehicleNumber(e)*traci.lanearea.getLastStepMeanSpeed(e)/traci.lanearea.getLength(e)
                            for e in self.WB_detectorID_dict["DOWNSTREAM"] if traci.lanearea.getLastStepMeanSpeed(e) >= 0]))

            if self.junction == "ONNUT" :
                #OtherBound = NB in ONNUT
                flow_OtherBound = sum(([traci.lanearea.getLastStepVehicleNumber(e)*traci.lanearea.getLastStepMeanSpeed(e)/traci.lanearea.getLength(e)
                                        for e in self.NB_detectorID_dict["DOWNSTREAM"] if traci.lanearea.getLastStepMeanSpeed(e) >= 0]))

            elif self.junction == "VIRTUAL" :
                #OtherBound = BIGC in onnut
                flow_OtherBound = sum(([traci.lanearea.getLastStepVehicleNumber(e)*traci.lanearea.getLastStepMeanSpeed(e)/traci.lanearea.getLength(e)
                                        for e in self.BIGC_detectorID_dict["DOWNSTREAM"] if traci.lanearea.getLastStepMeanSpeed(e) >= 0]))

        return flow_SB, flow_WB, flow_OtherBound #OtherBound refer to different bound in onnut - virtual


    def get_unjamlength_meters(self,indicate):
        unjam_SB = 0
        unjam_WB = 0
        unjam_OtherBound = 0
        if indicate == "UPSTREAM" :
            detector_length = sum(traci.lanearea.getLength(d) for d in self.SB_detectorID_dict["UPSTREAM"])
            unjam_SB = detector_length - (sum([traci.lanearea.getJamLengthMeters(e) for e in self.SB_detectorID_dict["UPSTREAM"]])) #/detector_length

            detector_length = sum(traci.lanearea.getLength(d) for d in self.WB_detectorID_dict["UPSTREAM"])
            unjam_WB = detector_length - (sum([traci.lanearea.getJamLengthMeters(e) for e in self.WB_detectorID_dict["UPSTREAM"]])) #/detector_length

            if self.id == "ONNUT" :
                #OtherBound = NB in ONNUT
                detector_length = sum(traci.lanearea.getLength(d) for d in self.NB_detectorID_dict["UPSTREAM"])
                unjam_OtherBound = detector_length - (sum([traci.lanearea.getJamLengthMeters(e) for e in self.NB_detectorID_dict["UPSTREAM"]])) #/detector_length

            elif self.id == "VIRTUAL" :
                #OtherBound = BIGC in onnut
                detector_length = sum(traci.lanearea.getLength(d) for d in self.BIGC_detectorID_dict["UPSTREAM"])
                unjam_OtherBound = detector_length - (sum([traci.lanearea.getJamLengthMeters(e) for e in self.BIGC_detectorID_dict["UPSTREAM"]])) #/detector_length

        elif indicate == "DOWNSTREAM" :

            detector_length = sum(traci.lanearea.getLength(d) for d in self.SB_detectorID_dict["DOWNSTREAM"])
            unjam_SB = detector_length - (sum([traci.lanearea.getJamLengthMeters(e) for e in self.SB_detectorID_dict["DOWNSTREAM"]])) #/detector_length

            detector_length = sum(traci.lanearea.getLength(d) for d in self.WB_detectorID_dict["DOWNSTREAM"])
            unjam_WB = detector_length - (sum([traci.lanearea.getJamLengthMeters(e) for e in self.WB_detectorID_dict["DOWNSTREAM"]])) #/detector_length

            if self.id == "ONNUT" :
                #OtherBound = NB in ONNUT
                detector_length = sum(traci.lanearea.getLength(d) for d in self.NB_detectorID_dict["DOWNSTREAM"])
                unjam_OtherBound = detector_length - (sum([traci.lanearea.getJamLengthMeters(e) for e in self.NB_detectorID_dict["DOWNSTREAM"]])) #/detector_length

            elif self.id == "VIRTUAL" :
                #OtherBound = BIGC in onnut
                detector_length = sum(traci.lanearea.getLength(d) for d in self.BIGC_detectorID_dict["DOWNSTREAM"])
                unjam_OtherBound = detector_length - (sum([traci.lanearea.getJamLengthMeters(e) for e in self.BIGC_detectorID_dict["DOWNSTREAM"]])) #/detector_length

        return unjam_SB, unjam_WB, unjam_OtherBound #OtherBound refer to different bound in onnut - virtual

    def get_occupancy_average_percent(self,indicate):
        #get occupancy average for all detector in detector_id and scale by (Vehicle Length + MinimumGap)/MinimumGap
        #Vehicle Length = 4.62 MinimumGap = 2.37
        occu_SB = 0
        occu_WB = 0
        occu_OtherBound = 0
        if indicate == "UPSTREAM" :
            occu_SB = (sum([traci.lanearea.getLastStepOccupancy(e) for e in self.SB_detectorID_dict['UPSTREAM']])
                       /len(self.SB_detectorID_dict["UPSTREAM"]))*((4.62+2.37)/4.62)
            occu_WB = (sum([traci.lanearea.getLastStepOccupancy(e) for e in self.WB_detectorID_dict['UPSTREAM']])
                       /len(self.WB_detectorID_dict["UPSTREAM"]))*((4.62+2.37)/4.62)
            if self.junction == "ONNUT" :
                #OtherBound = NB in ONNUT
                occu_OtherBound = (sum([traci.lanearea.getLastStepOccupancy(e) for e in self.NB_detectorID_dict['UPSTREAM']])
                                   /len(self.NB_detectorID_dict["UPSTREAM"]))*((4.62+2.37)/4.62)

            elif self.junction == "VIRTUAL" :
                #OtherBound = BIGC in onnut
                occu_OtherBound = (sum([traci.lanearea.getLastStepOccupancy(e) for e in self.BIGC_detectorID_dict['UPSTREAM']])
                                   /len(self.BIGC_detectorID_dict["UPSTREAM"]))*((4.62+2.37)/4.62)

        elif indicate == "DOWNSTREAM" :

            occu_SB = (sum([traci.lanearea.getLastStepOccupancy(e) for e in self.SB_detectorID_dict["DOWNSTREAM"]])
                       /len(self.SB_detectorID_dict["UPSTREAM"]))*((4.62+2.37)/4.62)

            if len(self.WB_detectorID_dict["DOWNSTREAM"]) == 0 :
                occu_WB = 0
            else :
                occu_WB = (sum([traci.lanearea.getLastStepOccupancy(e) for e in self.WB_detectorID_dict["DOWNSTREAM"]])
                           /len(self.WB_detectorID_dict["DOWNSTREAM"]))*((4.62+2.37)/4.62)
            if self.junction == "ONNUT" :
                #OtherBound = NB in ONNUT
                occu_OtherBound = (sum([traci.lanearea.getLastStepOccupancy(e) for e in self.NB_detectorID_dict["DOWNSTREAM"]])
                                   /len(self.NB_detectorID_dict["DOWNSTREAM"]))*((4.62+2.37)/4.62)

            elif self.junction == "VIRTUAL" :
                #OtherBound = BIGC in onnut
                if len(self.BIGC_detectorID_dict["DOWNSTREAM"]) == 0 :
                    occu_OtherBound = 0
                else :
                    occu_OtherBound = (sum([traci.lanearea.getLastStepOccupancy(e) for e in self.BIGC_detectorID_dict["DOWNSTREAM"]])
                                       /len(self.BIGC_detectorID_dict["DOWNSTREAM"]))*((4.62+2.37)/4.62)

        return occu_SB, occu_WB, occu_OtherBound #OtherBound refer to different bound in onnut - virtual

    def get_travel_time(self) :
        # sum(traci.edge.getTraveltime(edgeID) for edgeID in self.edgeID_for_MOE)/len(self.edgeID_for_MOE)
        if ((sum(traci.edge.getLastStepVehicleNumber(edgeID) for edgeID in self.edgeID_for_MOE)) != 0) : 
            travel_time = sum(traci.edge.getTraveltime(edgeID)*traci.edge.getLastStepVehicleNumber(edgeID) for edgeID in self.edgeID_for_MOE) \
                /(sum(traci.edge.getLastStepVehicleNumber(edgeID) for edgeID in self.edgeID_for_MOE) *len(self.edgeID_for_MOE))
        else : 
            travel_time = 0
        return travel_time
    
    # def get_travel_time_all(self) :
    #     if ((sum(traci.edge.getLastStepVehicleNumber(edgeID) for edgeID in self.edgeID_all)) != 0) :
    #         travel_time_all = sum(traci.edge.getTraveltime(edgeID)*traci.edge.getLastStepVehicleNumber(edgeID) for edgeID in self.edgeID_all) \
    #             /(sum(traci.edge.getLastStepVehicleNumber(edgeID) for edgeID in self.edgeID_all) *len(self.edgeID_all))
    #     else :
    #         travel_time_all = 0
    #     return travel_time_all


class maDQN:
    def __init__(self,n_agents,num_actions,input_dims,learning_rate,gamma,env_name):
        self.agents = {}
        self.target_agents = {}
        self.optimizers = []
        self.n_agents = n_agents
        self.gamma = gamma
        self.env_name = env_name
        self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        for agent_idx in range(len(self.n_agents)):
            self.agents[n_agents[agent_idx]]= Network(gamma, input_dims[agent_idx],num_actions[agent_idx])
            self.agents[n_agents[agent_idx]].to(self.device)

            self.target_agents[n_agents[agent_idx]] = Network(gamma,input_dims[agent_idx],num_actions[agent_idx])
            self.target_agents[n_agents[agent_idx]].to(self.device)
            self.target_agents[n_agents[agent_idx]].load_state_dict(self.agents[self.n_agents[agent_idx]].state_dict())
            self.optimizers.append(torch.optim.Adam(self.agents[self.n_agents[agent_idx]].parameters(), lr=learning_rate))

    def target_update(self):
        for agent_idx in range(len(self.n_agents)):
            self.target_agents[self.n_agents[agent_idx]].load_state_dict(self.agents[self.n_agents[agent_idx]].state_dict())


    def save_checkpoint(self, score):
        print('... saving checkpoint ...')
        agent_save_paths = []
        target_agent_save_paths = []

        for i in range(len(self.n_agents)):
            agent_save_paths.append('/pv/' +self.env_name + '/agent'+str(i)+'-score-{:.2f}.pack'.format(score)) 
            target_agent_save_paths.append('/pv/' +self.env_name + '/target-agent'+str(i)+'-score-{:.2f}.pack'.format(score)) 

        for agent_idx in range(len(self.n_agents)):
            self.target_agents[self.n_agents[agent_idx]].load_state_dict(self.agents[self.n_agents[agent_idx]].state_dict())
            self.agents[self.n_agents[agent_idx]].save(agent_save_paths[agent_idx])
            self.target_agents[self.n_agents[agent_idx]].save(target_agent_save_paths[agent_idx])

    def load_checkpoint(self,load_score):
        print('... loading checkpoint ...')
        for agent_idx in range(len(self.n_agents)):
            path_a = '/pv/'+ self.env_name + '/agent' + str(agent_idx) + '-score-' + load_score + '.pack'
            path_ta = '/pv/' + self.env_name + '/agent' + str(agent_idx) + '-score-' + load_score + '.pack'
            self.agents[self.n_agents[agent_idx]].load(path_a)
            self.target_agents[self.n_agents[agent_idx]].load(path_ta)

    def choose_actions(self, raw_obs):
        actions = []
        for agent_idx, agent in self.agents.items():
            action = agent.act(raw_obs[agent_idx])
            actions.append(action)
        return actions

    def learn(self, batch):
        for agent_idx in range(len(self.n_agents)):
            agent_batch = batch[agent_idx]
            loss = self.compute_loss(agent_batch,self.agents[self.n_agents[agent_idx]],self.target_agents[self.n_agents[agent_idx]])
            # Gradient Descent
            self.optimizers[agent_idx].zero_grad()
            loss.backward()
            self.optimizers[agent_idx].step()


    def compute_loss(self, transitions, online_net, target_net):

        transitions = pd.DataFrame(transitions).dropna().to_numpy()
        
        obses = np.asarray([t[0] for t in transitions],dtype=np.float32)
        actions = np.asarray([t[1] for t in transitions],dtype=np.int64)
        rews = np.asarray([t[2] for t in transitions],dtype=np.float32)
        dones = np.asarray([t[3] for t in transitions],dtype=np.float32)
        new_obses = np.asarray([t[4] for t in transitions],dtype=np.float32)

        obses_t = torch.as_tensor(obses, dtype=torch.float32, device=self.device)
        actions_t = torch.as_tensor(actions, dtype=torch.int64, device=self.device).unsqueeze(-1)
        rews_t = torch.as_tensor(rews, dtype=torch.float32, device=self.device).unsqueeze(-1)
        dones_t = torch.as_tensor(dones, dtype=torch.float32, device=self.device).unsqueeze(-1)
        new_obses_t = torch.as_tensor(new_obses, dtype=torch.float32, device=self.device)

        # Compute Targets
        target_q_values = target_net(new_obses_t)
        max_target_q_values = target_q_values.max(dim=1, keepdim=True)[0]
        targets = rews_t + self.gamma * (1 - dones_t) * max_target_q_values

        # Compute Loss
        q_values = online_net(obses_t)
        action_q_values = torch.gather(input=q_values, dim=1, index=actions_t)

        loss = nn.functional.smooth_l1_loss(action_q_values, targets)
        return loss

# from torch import nn
# import torch
# import msgpack


class Network(nn.Module):
    def __init__(self, gamma,input_dim, actions):
        super().__init__()
        self.gamma = gamma
        self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        self.net = nn.Sequential(
                                nn.Linear(input_dim, 64),
                                nn.Linear(64, actions))
        named_layers = dict(self.net.named_modules())
        # print(named_layers)

    def forward(self, x):
        return self.net(x)

    def act(self, obs):
        obs_t = torch.as_tensor(obs, dtype=torch.float32, device=self.device)
        q_values = self(obs_t.unsqueeze(0))
        max_q_index = torch.argmax(q_values, dim=1)[0]
        action = max_q_index.detach().item()
        return action

    def save(self, save_path):
        params = {k: t.detach().cpu().numpy() for k, t in self.state_dict().items()}
        params_data = msgpack.dumps(params)
        os.makedirs(os.path.dirname(save_path),exist_ok=True)
        with open(save_path,'wb') as f:
            f.write(params_data)

    def load(self,load_path):
        if not os.path.exists(load_path):
            raise FileNotFoundError(load_path)
        with open(load_path, 'rb') as f:
            params_numpy = msgpack.loads(f.read())
        params = {k: torch.as_tensor(v, device=self.device) for k,v in params_numpy.items()}

        self.load_state_dict(params)
#==================================== Main ==================================#


#======== for confidence interval ========#
# import random
# random_seed = 0
# random.seed(random_seed)
#========================================#

import msgpack
import msgpack_numpy as m
m.patch()

ENV_NAME = 'onnut'

N_EPISODES = 500 #1000
MAX_STEPS = 960
GAMMA=0.95
BATCH_SIZE = 60
BUFFER_SIZE=int(1000000)
EPSILON_START= 1
EPSILON_END=0.01
EPSILON_DECAY= 10
TARGET_UPDATE_FREQ = 60
LR = 0.01
PRINT_INTERVAL = 1
TRAIN_INTERVAL = 10
# TRAIN_INTERVAL = 2
MOVING_AVERAGE = 960
best_score = 1
action_splits = []

#==== max_green ====#
max_green_ONNUT = 90
max_green_VIRTUAL = 90

env = SumoEnvironment(net_file='onnut.net.xml',
                            single_agent=False,
                            out_csv_name='/pv/outputs/onnut-dqn',
                            use_gui=False,
                            num_seconds=68400,
                            yellow_time=0,
                            min_green=15,
                            max_green_onnut= max_green_ONNUT,
                            max_green_virtual= max_green_VIRTUAL)

for ts in env.ts_ids:
    if isinstance(env.action_spaces(ts), gym.spaces.Discrete):
        action_splits.append([env.action_spaces(ts).n]) # action_space.n means Discrete(3)
    elif isinstance(env.action_spaces(ts), gym.spaces.MultiDiscrete):
        action_splits.append(env.action_spaces(ts).nvec) # action_space.nvec means (3)

n_actions_each = [sum(a) for a in action_splits]

input_dims = []
for ts in env.ts_ids:
    input_dims.append(env.observation_spaces(ts).shape[0])

maDQN_agents = maDQN(n_agents=list(env.traffic_signals.keys()), num_actions=n_actions_each, input_dims= input_dims,
                         learning_rate=LR,
                         gamma=GAMMA, env_name=ENV_NAME)

replay_buffer = Buffer(n_agents=list(env.traffic_signals.keys()),buffer_size=BUFFER_SIZE,batch_size=BATCH_SIZE)

score_history = []
transition = {}
# Main Training Loop
for step in range(N_EPISODES):

    #======== for confidence interval ========#
    random_seed = step
    random.seed(random_seed)
    #========================================#

    obs =  env.reset(random_seed)

    done = [False]*3
    episode_reward = 0
    episode_step = 0
    start_time = time.time()
    epsilon = np.interp(step, [0, EPSILON_DECAY], [EPSILON_START, EPSILON_END])

    if step >=TARGET_UPDATE_FREQ and step % TARGET_UPDATE_FREQ == 0:
        maDQN_agents.target_update()

    while not any(done):
        rnd_sample = random.random()
        if rnd_sample <= epsilon:
            # actions = np.random.randint(0,N_ACTIONS,size=N_AGENTS)
            action_onnut = random.randint(0,1)
            action_virtual = random.randint(0,2)
            actions = [action_onnut,action_virtual]
        else:
            actions = maDQN_agents.choose_actions(obs)


        actions = dict(zip(env.traffic_signals.keys(), actions))
        new_obs, rewards, done, _  = env.step(actions)
        # transition = (obs, actions, rewards, dones, new_obs)
        for ts in env.ts_ids:
            transition[(ts, 'obs')] = obs[ts]
            transition[(ts, 'actions')] = actions[ts]
            transition[(ts, 'rewards')] = rewards[ts]
            transition[(ts, 'dones')] = done[ts]
            transition[(ts, 'new_obs')] = new_obs[ts]
            replay_buffer.store(transition)

        if episode_step >= MAX_STEPS:
            done = [True] * 3
        else:
            done = list(done.values())
        episode_step += 1

        obs = new_obs
        episode_reward += sum(rewards.values())

        if step >= TRAIN_INTERVAL and  step % TRAIN_INTERVAL == 0:
            batch = replay_buffer.sample()
            maDQN_agents.learn(batch)

    env.save_score()
    score_history.append(episode_reward)

    # Logging
    if step >= PRINT_INTERVAL and step % PRINT_INTERVAL == 0:
        avg_score = np.mean(score_history[-MOVING_AVERAGE:])
        print('Step:', step)
        print('Average Score: {:.2f}'.format(avg_score))
        np.save('/pv/'+ENV_NAME+'score_history.npy',np.array(score_history))
        if avg_score > best_score:
            best_score = avg_score
            maDQN_agents.save_checkpoint(best_score)

    elapsed_time = env.getTime((time.time() - start_time))
    print('episode', step, ' takes time', elapsed_time)

env.close()

Error: Canceled future for execute_request message before replies were done