In [18]:
import json
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import pandas as pd
from utils import ATARI_100k
import numpy as np

In [56]:
with open("ref_scores/atari-dreamerv2.json", "rb") as f:
    scores = json.load(f)

with open("ref_scores/baselines.json", "rb") as f:
    baselines = json.load(f)

In [61]:
records = []
for task in baselines:
    records.append(
        {
            "task": task.removeprefix("atari_"),
            **{
                k: baselines[task].get(k)
                for k in ("random", "human_gamer", "human_record")
            },
        }
    )

base_df = pd.DataFrame.from_records(records)
base_df

Unnamed: 0,task,random,human_gamer,human_record
0,alien,228.80,7127.7,251916.0
1,amidar,5.80,1719.5,104159.0
2,assault,222.40,742.0,8647.0
3,asterix,210.00,8503.3,1000000.0
4,asteroids,719.10,47388.7,10506650.0
...,...,...,...,...
57,air_raid,579.25,,23050.0
58,carnival,700.80,,2541440.0
59,elevator_action,4387.00,,156550.0
60,journey_escape,-19977.00,,-4317804.0


In [68]:
records = []
for run in scores:
    for x, y in zip(run["xs"], run["ys"]):
        records.append(
            {
                "task": run["task"].removeprefix("atari_"),
                "seed": int(run["seed"]),
                "time": x,
                "score": y,
            }
        )

df = pd.DataFrame.from_records(records)
df

Unnamed: 0,task,seed,time,score
0,alien,10,1000000.0,431.0
1,alien,10,2000000.0,859.0
2,alien,10,3000000.0,1575.0
3,alien,10,4000000.0,1871.0
4,alien,10,5000000.0,1245.0
...,...,...,...,...
116966,zaxxon,2,196000000.0,44730.0
116967,zaxxon,2,197000000.0,38250.0
116968,zaxxon,2,198000000.0,43680.0
116969,zaxxon,2,199000000.0,44140.0


In [74]:
g = df.groupby(["task", "time"])
avg_df = pd.DataFrame.from_records(
    {"score_mean": g["score"].mean(), "score_std": g["score"].std()}
)
avg_df.reset_index(inplace=True)
avg_df

Unnamed: 0,task,time,score_mean,score_std
0,alien,1000000.0,365.181818,52.001573
1,alien,2000000.0,894.090909,132.407292
2,alien,3000000.0,1119.148760,241.020006
3,alien,4000000.0,1349.272727,356.397556
4,alien,5000000.0,1284.363636,305.911514
...,...,...,...,...
10995,zaxxon,196000000.0,48962.727273,10514.139138
10996,zaxxon,197000000.0,44340.000000,6915.578067
10997,zaxxon,198000000.0,45583.636364,8938.096299
10998,zaxxon,199000000.0,47744.545455,7550.926253


In [78]:
avg_df

Unnamed: 0,task,time,score_mean,score_std
0,alien,1000000.0,365.181818,52.001573
1,alien,2000000.0,894.090909,132.407292
2,alien,3000000.0,1119.148760,241.020006
3,alien,4000000.0,1349.272727,356.397556
4,alien,5000000.0,1284.363636,305.911514
...,...,...,...,...
10995,zaxxon,196000000.0,48962.727273,10514.139138
10996,zaxxon,197000000.0,44340.000000,6915.578067
10997,zaxxon,198000000.0,45583.636364,8938.096299
10998,zaxxon,199000000.0,47744.545455,7550.926253


In [76]:
tasks_100k = ["_".join(s.lower() for s in env_id.split()) for env_id in ATARI_100k]
df_6m = avg_df[(avg_df["time"] <= 6e6) & (avg_df["task"].isin(tasks_100k))]
df_6m

Unnamed: 0,task,time,score_mean,score_std
0,alien,1000000.0,365.181818,52.001573
1,alien,2000000.0,894.090909,132.407292
2,alien,3000000.0,1119.148760,241.020006
3,alien,4000000.0,1349.272727,356.397556
4,alien,5000000.0,1284.363636,305.911514
...,...,...,...,...
9801,up_n_down,2000000.0,52950.991736,24876.365672
9802,up_n_down,3000000.0,70119.272727,42941.162674
9803,up_n_down,4000000.0,74837.090909,23942.651117
9804,up_n_down,5000000.0,108919.636364,36948.315832


In [80]:
cols = 4
tasks = df_6m["task"].unique()
rows = (len(tasks) + cols - 1) // cols

pretty_names = [" ".join(s.capitalize() for s in task.split("_")) for task in tasks]

fig = make_subplots(rows=rows, cols=cols, subplot_titles=pretty_names)
pos = np.stack(np.mgrid[:rows, :cols], -1).reshape(-1, 2) + 1

colors = px.colors.qualitative.Plotly


def make_color_iter():
    while True:
        for hex in px.colors.qualitative.Plotly:
            hex = hex.removeprefix("#")
            r, g, b = hex[0:2], hex[2:4], hex[4:6]
            r, g, b = (int(v, 16) for v in (r, g, b))
            yield f"rgb{(r, g, b)}"


def to_rgba(desc: str, alpha: float = 0.2):
    r, g, b = eval(desc.removeprefix("rgb"))
    return f"rgba{(r, g, b, alpha)}"


selected = {
    "amidar",
    "assault",
    "asterix",
    "crazy_climber",
    "pong",
    "ms_pacman",
    "james_bond",
}


for (row, col), task in zip(pos, tasks):
    task_df = df_6m[df_6m["task"] == task]

    x, y = task_df["time"], task_df["score_mean"]
    y_lower = task_df["score_mean"] - task_df["score_std"]
    y_upper = task_df["score_mean"] + task_df["score_std"]

    color = "rgb(0, 0, 255)" if task in selected else "rgb(255, 0, 0)"
    traces = [
        go.Scatter(
            x=x,
            y=y,
            mode="lines",
            line=dict(color=color),
            showlegend=False,
        ),
        go.Scatter(
            x=[*x, *x[::-1]],
            y=[*y_upper, *y_lower[::-1]],
            fill="tozerox",
            fillcolor=to_rgba(color),
            line=dict(color="rgba(255, 255, 255, 0)"),
            showlegend=False,
        ),
    ]
    for trace in traces:
        fig.add_trace(trace, row=row, col=col)

fig.update_layout(width=800, height=1200)