Main thread -> open number of agents as threads (n)

Main thread will open 2n channels (2 for each agent) Read channel & Write channel

Each thread will perform the calculations and will send the action on the write channel - the main thread will listen to those channels and will send the action to the environment

Once we have response from the environment(Docker SUMO Instance) we will send the response to all the agents


## Imports

In [ ]:
import threading
from threading import Condition,Lock
import time
import torch
from sumo_sim.Simulation import Simulation
import torch.optim as optim
import colabs.intro.api_endpoints as api

## The agent Thread

In [ ]:
from utils.models.DeepQLearning import DeepQLearningAgent

requests_lock = Lock()
agent_requests = dict()
global_data_dict = dict()
running = True
wake_up_cond = Condition()
main_thread_condition = Condition()
def agent_thread(tls_id, selected_tls_info):
    global wake_up_cond,running,global_data_dict,requests_lock,agent_requests
    selected_program_ids = [program['program_id']  for program in selected_tls_info["programs"]]
    selected_program_ids.sort()
    
    agent = DeepQLearningAgent(tls_id, selected_program_ids)
    accumulated_loss = 0
    while running:  
        if global_data_dict:
            # calculation
            state = torch.tensor(global_data_dict[tls_id], dtype=torch.float32, device=agent.device).unsqueeze(0)
            selected_action = agent.select_action(state)
            action_func = agent.decide_action(selected_action)
            
            with requests_lock:
                agent_requests[tls_id] = action_func
            with main_thread_condition:
                main_thread_condition.notify()
                
        with wake_up_cond:
            wake_up_cond.wait()
            
        observation, reward, is_ended = agent.calculate_reward(global_data_dict[tls_id])
        reward = torch.tensor([reward], device=agent.device)
            
        if is_ended:
            next_state = None
        else:
            next_state = torch.tensor(observation, dtype=torch.float32, device=agent.device).unsqueeze(0)
        # Store the transition in memory
        agent.memory.push(state, selected_action, next_state, reward)

          # Move to the next state
        state = next_state
        
        
        # Perform one step of the optimization (on the target network)
        loss = optimize_model()
        accumulated_loss += loss
        
        # Soft update of the target network's weights
        # θ′ ← τ θ + (1 −τ )θ′
        target_net_state_dict = target_net.state_dict()
        policy_net_state_dict = policy_net.state_dict()

        for key in policy_net_state_dict:
            target_net_state_dict[key] = policy_net_state_dict[key]*TAU + target_net_state_dict[key]*(1-TAU)
        target_net.load_state_dict(target_net_state_dict)
        if is_ended:
            episode_durations.append(t + 1)
            plot_durations()
            break
    

## The main Thread

In [ ]:
params = {
    "HIDDEN_SIZE": 64,
    "LEARNING_RATE": 1e-1,
    "EPS_START": 0.9,
    "EPS_END": 0.05,
    "EPS_DECAY": 1_000,
    "BATCH_SIZE": 128,
    "GAMMA": 0.80,
    "TAU": 0.005,
    "MEM_SIZE": 100_000,
    "EPISODES": 10
}
architecture = "SimpleNetwork with 3 layers"

In [ ]:
if __name__ == '__main__':
    agent_threads = []
    simulation_id = api.start_simulation('scenarios/bologna/acosta/run.sumocfg',is_gui=False, params=params, architecture=architecture)
    tls_list = api.get_traffic_lights(simulation_id)
    initial_data = api.get_initial_data(simulation_id)
    
    number_of_agents = len(tls_list)
    
    for tls in tls_list:
        selected_tls_info = initial_data['data']['tls'][tls['id']]
        agent_threads.append(threading.Thread(target=agent_thread, args=(tls['id'],selected_tls_info,)))
    
    while running:
        while len(agent_requests) < len(number_of_agents):
            with main_thread_condition:
                main_thread_condition.wait()
        # send requests to API
        for agent_request in agent_requests:
            agent_request(simulation_id)
            
        # Make step and get response
        response = api.step_simulation(simulation_id, steps=5)
        
        # Send response to all agents
        global_data_dict = response
        
        agent_requests.clear()
        wake_up_cond.notifyAll()
        