In [None]:
%matplotlib inline
from agents.trafficagent import TrafficAgent, FixedCycleTrafficAgent, QLearningTrafficAgent, DeepQLearningTrafficAgent, default_metrics
from utils.plotter import Metric, MultiPlotter, PlotData
from warnings import filterwarnings

filterwarnings('ignore')

plot_data = PlotData(
  [
    'system_total_stopped',
    'system_total_waiting_time',
    'system_mean_waiting_time',
    'system_mean_speed',
    't_stopped',
    't_accumulated_waiting_time',
    't_average_speed',
    'agents_total_stopped',
    'agents_total_accumulated_waiting_time'
  ],
  1,
  100
)

net = 'nets/2way-single-intersection/single-intersection.net.xml'
rou = 'nets/2way-single-intersection/single-intersection-vhvh.rou.xml'
seconds = 100
delta_time = 4
yellow_time = 3
min_green = 5
max_green = 30

agents: dict[str, TrafficAgent] = {
  'fixedcycle': FixedCycleTrafficAgent(
    'fixedcycle',
    '#aa0000',
    net,
    rou,
    seconds,
    delta_time,
    yellow_time,
    min_green,
    max_green,
    plot_data
  ),
  'qlearning': QLearningTrafficAgent(
    'qlearning',
    '#00aa00',
    net,
    rou,
    seconds,
    delta_time,
    yellow_time,
    min_green,
    max_green,
    alpha = 0.1,
    gamma = 0.99,
    init_eps = 1,
    min_eps = 0.005,
    decay = 0.9,
    plot_data = plot_data
  ),
  'deepqlearning': DeepQLearningTrafficAgent(
    'deepqlearning',
    '#0000aa',
    net,
    rou,
    seconds,
    delta_time,
    yellow_time,
    min_green,
    max_green,
    alpha = 0.1,
    gamma = 0.99,
    init_eps = 1,
    min_eps = 0.005,
    decay_time = 0.2,
    plot_data = plot_data
  )
}

multi_plotter = MultiPlotter(list(map(lambda agent: {'name': agent.name, 'color': agent.color}, list(agents.values()))), plot_data)

for agent in agents:
  agents[agent].run(multi_plotter.append, use_gui = False)

multi_plotter.save()

# agents['fixedcycle'].run(multi_plotter.append, load_path = '')
# agents['qlearning'].run(multi_plotter.append, load_path = 'outputs/qlearning/saves/2023-02-20 12-24-21,seconds=100,delta_time=4,yellow_time=3,min_green=5,max_green=30,alpha=0.1,gamma=0.99,init_eps=1,min_eps=0.005,decay=0.9.json')
# agents['deepqlearning'].run(multi_plotter.append, load_path = 'outputs/deepqlearning/saves/2023-02-20 12-24-24,seconds=26,delta_time=4,yellow_time=3,min_green=5,max_green=30,alpha=0.1,gamma=0.99,init_eps=1,min_eps=0.005,decay=0.2.zip')

# multi_plotter.save()