In [None]:
%matplotlib inline
from agents.trafficagent import TrafficAgent, FixedCycleTrafficAgent, QLearningTrafficAgent, DeepQLearningTrafficAgent
from utils.plotter import Metric, MultiPlotter
from warnings import filterwarnings

filterwarnings('ignore')

net = 'nets/2way-single-intersection/single-intersection.net.xml'
rou = 'nets/2way-single-intersection/single-intersection-vhvh.rou.xml'
seconds = 100
delta_time = 4
yellow_time = 3
min_green = 5
max_green = 30

agents: dict[str, TrafficAgent] = {
  'fixedcycle': FixedCycleTrafficAgent(
    'fixedcycle',
    '#aa0000',
    net,
    rou,
    seconds,
    delta_time,
    yellow_time,
    min_green,
    max_green
  ),
  'qlearning': QLearningTrafficAgent(
    'qlearning',
    '#00aa00',
    net,
    rou,
    seconds,
    delta_time,
    yellow_time,
    min_green,
    max_green,
    alpha = 0.1,
    gamma = 0.99,
    init_eps = 1,
    min_eps = 0.005,
    decay = 0.9
  ),
  'deepqlearning': DeepQLearningTrafficAgent(
    'deepqlearning',
    '#0000aa',
    net,
    rou,
    seconds,
    delta_time,
    yellow_time,
    min_green,
    max_green,
    alpha = 0.1,
    gamma = 0.99,
    init_eps = 1,
    min_eps = 0.005,
    decay_time = 0.2
  )
}
metrics: list[Metric] = [
  'system_total_stopped',
  'system_total_waiting_time',
  'system_mean_waiting_time',
  'system_mean_speed',
  't_stopped',
  't_accumulated_waiting_time',
  't_average_speed',
  'agents_total_stopped',
  'agents_total_accumulated_waiting_time'
]

multi_plotter = MultiPlotter(list(map(lambda agent: {'name': agent.name, 'color': agent.color}, list(agents.values()))), metrics)

for agent in agents:
  agents[agent].learn(multi_plotter.append, use_gui = False)

multi_plotter.save()