In [4]:
from tbparse import SummaryReader
from pathlib import Path
from ruamel.yaml import YAML
from tqdm.auto import tqdm
import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.express as px
import plotly.io as pio


yaml = YAML(typ="safe", pure=True)
results_dir = Path("../results/at-400k")
cache_dir = Path(".cache/at_400k")
pio.kaleido.scope.mathjax = None

In [5]:
test_dirs = [*results_dir.iterdir()]
for test_dir in tqdm(test_dirs):
    dst_path = cache_dir / f"{test_dir.name}.h5"
    if dst_path.exists():
        continue

    board = SummaryReader(test_dir / "board")
    dst_path.parent.mkdir(parents=True, exist_ok=True)
    board.scalars.to_hdf(dst_path, key="scalars")


def read_scalars(test_dir: Path):
    dst_path = cache_dir / f"{test_dir.name}.h5"
    return pd.read_hdf(dst_path, key="scalars")

  0%|          | 0/53 [00:00<?, ?it/s]

In [6]:
dfs = []

for test_dir in results_dir.iterdir():
    with open(test_dir / "config.yml", "r") as f:
        cfg = yaml.load(f)

    assert cfg["env"]["type"] == "atari"
    env_id = cfg["env"]["atari"]["env_id"]
    seed = cfg["repro"]["seed"]

    df = read_scalars(test_dir)
    df = df[df["tag"] == "val/mean_ep_ret"].copy()
    df["env"] = env_id
    df["seed"] = seed
    env_id_, seed_eq, freq_eq = test_dir.name.split("=")
    df["freq"] = int(freq_eq.removeprefix("freq="))
    dfs.append(df)

df = pd.concat(dfs, axis=0)

In [7]:
env_ids = ["Pong", "CrazyClimber", "Assault"]
freqs = [64, 32, 16, 8, 4, 2]

fig = make_subplots(
    rows=len(freqs),
    cols=len(env_ids),
    column_titles=env_ids,
    row_titles=freqs,
    horizontal_spacing=0.1,
    vertical_spacing=0.05,
)

for ix, env_id in enumerate(env_ids, 1):
    df1 = df[df["env"] == env_id]
    for iy, freq in enumerate(freqs, 1):
        df2 = df1[df1["freq"] == freq]
        xpr = px.line(df2, x="step", y="value", color="seed")
        for trace in xpr.data:
            fig.add_trace(trace, iy, ix)

fig.update_layout(
    showlegend=False,
    autosize=False,
    width=1024,
    height=480 * 3,
)

fig.write_image("../tex/assets/naive_at_400k.pdf")

fig