In [None]:
import json
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import pandas as pd
from utils import ATARI_100k
import numpy as np
from loaders import dreamerv2_loader
from utils import to_rgba

In [None]:
df = dreamerv2_loader()

In [None]:
g = df.groupby(["task", "time"])
avg_df = pd.DataFrame.from_records(
    {
        "score_mean": g["score"].mean(),
        "score_std": g["score"].std(),
        "score_min": g["score"].min(),
        "score_max": g["score"].max(),
    }
)
avg_df.reset_index(inplace=True)
avg_df

In [None]:
avg_df["task"].unique()

In [None]:
df_6m = avg_df[(avg_df["time"] <= 6e6) & (avg_df["task"].isin(ATARI_100k))]
df_6m

In [None]:
cols = 4
tasks = df_6m["task"].unique()
rows = (len(tasks) + cols - 1) // cols

fig = make_subplots(
    rows=rows,
    cols=cols,
    subplot_titles=[*tasks],
)
pos = np.stack(np.mgrid[:rows, :cols], -1).reshape(-1, 2) + 1


selected = {
    "Amidar",
    "Assault",
    "Asterix",
    "CrazyClimber",
    "Pong",
    "MsPacman",
    "Jamesbond",
}


axis = 1
for (row, col), task in zip(pos, tasks):
    task_df = df_6m[df_6m["task"] == task]

    x, y = task_df["time"], task_df["score_mean"]
    y_lower = task_df["score_mean"] - task_df["score_std"]
    y_upper = task_df["score_mean"] + task_df["score_std"]

    color = "rgb(0, 0, 255)" if task in selected else "rgb(255, 0, 0)"
    traces = [
        go.Scatter(
            x=x,
            y=y,
            mode="lines",
            line=dict(color=color),
            showlegend=False,
        ),
        go.Scatter(
            x=[*x, *x[::-1]],
            y=[*y_upper, *y_lower[::-1]],
            fill="tozerox",
            fillcolor=to_rgba(color),
            line=dict(color="rgba(255, 255, 255, 0)"),
            showlegend=False,
        ),
    ]
    for trace in traces:
        fig.add_trace(trace, row=row, col=col)

    if col == 1:
        fig.update_layout(**{f"yaxis{axis}": dict(title="Score")})
    axis += 1


fig.update_layout(width=800, height=1200)

fig.write_image("../tex/assets/env_selection.pdf")
fig