In [None]:
%matplotlib inline
from typing import Union, TypedDict
import pylab as pl
from IPython import display
from matplotlib.transforms import Bbox
from pathlib import Path
from agents.trafficagent import TrafficAgent, FixedCycleTrafficAgent, QLearningTrafficAgent, DeepQLearningTrafficAgent
from utils.plotter import MetricName

class Plotter(TypedDict):
  plot: pl.Axes

# Metric = Union[TypedDict('Plotter', {'plot': pl.Axes}), dict[str, list[Union[int, float]]]]
Metric = Union[Plotter, dict[str, list[Union[int, float]]]]

seconds = 100

agents: dict[str, TrafficAgent] = {
  'fixedcycle': FixedCycleTrafficAgent(
    'fixedcycle',
    '#aa0000',
    'nets/single-intersection/single-intersection.net.xml',
    'nets/single-intersection/single-intersection.rou.xml',
    seconds,
    5,
    3,
    5,
    30
  ),
  'qlearning': QLearningTrafficAgent(
    'qlearning',
    '#00aa00',
    'nets/single-intersection/single-intersection.net.xml',
    'nets/single-intersection/single-intersection.rou.xml',
    seconds,
    5,
    3,
    5,
    30
  ),
  'deepqlearning': DeepQLearningTrafficAgent(
    'deepqlearning',
    '#0000aa',
    'nets/single-intersection/single-intersection.net.xml',
    'nets/single-intersection/single-intersection.rou.xml',
    seconds,
    5,
    3,
    5,
    30
  )
}
metrics: dict[MetricName, Metric] = {
  # 'step': {},
  'system_total_stopped': {},
  'system_total_waiting_time': {},
  'system_mean_waiting_time': {},
  'system_mean_speed': {},
  't_stopped': {},
  't_accumulated_waiting_time': {},
  't_average_speed': {},
  'agents_total_stopped': {},
  'agents_total_accumulated_waiting_time': {}
}
num_metrics = len(metrics)

plots_per_row = 1
plots_per_col = num_metrics // plots_per_row + num_metrics % plots_per_row
dpi = 100

figure = pl.figure()
figure.set_dpi(dpi)
figure.set_figheight(plots_per_col * 8)
figure.set_figwidth(min(max(seconds / 10, 32), (2**16 - 1) / dpi))
gridspec = figure.add_gridspec(plots_per_col, plots_per_row * 2)

for metric in metrics:
  metrics[metric] = { agent: [] for agent in agents }
  index = list(metrics.keys()).index(metric)
  col_index = index % plots_per_row * 2
  metrics[metric]['plot'] = figure.add_subplot(gridspec[index // plots_per_row, col_index:(col_index + 2)])
  metrics[metric]['plot'].set_title(f'{metric} over time') # type: ignore
  metrics[metric]['plot'].set_xlabel('step') # type: ignore
  metrics[metric]['plot'].set_ylabel(metric) # type: ignore

def update_metric(agent: str, metric: MetricName, new_data: float):
  if (metric in metrics):
    metrics[metric][agent].append(new_data)
    # metrics[metric]['plot'].plot(metrics[metric][name], color=color)
  # display.clear_output(wait=True)
  # display.display(pl.gcf())

for agent in agents:
  agents[agent].learn(update_metric)
  # execution(updateMetrics, agent, seconds, agents[agent]['fixed'])

for metric in metrics:
  for agent in agents:
    metrics[metric]['plot'].plot(metrics[metric][agent], color = agents[agent].color, label = agent) # type: ignore
    metrics[metric]['plot'].legend() # type: ignore
  bbox = metrics[metric]['plot'].get_tightbbox(renderer = figure.canvas.get_renderer()) # type: ignore
  if (bbox is not None):
    bbox = Bbox.from_extents(bbox.x0 / dpi, bbox.y0 / dpi, bbox.xmax / dpi, bbox.ymax / dpi)
    Path('outputs/plots/').mkdir(parents = True, exist_ok = True)
    figure.savefig(f'outputs/plots/all_{metric}_plot.png', bbox_inches = bbox.expanded(1.01, 1.01))

# pl.close(figure)