In [None]:
from utils import *
from plotly.subplots import make_subplots
from tqdm.auto import tqdm
import json
from loaders import dreamerv2_loader, sanity_check_loader

In [None]:
res_df, scalars = sanity_check_loader()
ref_df = dreamerv2_loader()

In [None]:
envs = res_df["env"].unique()

fig = make_subplots(
    rows=1,
    cols=len(envs),
    column_titles=[*envs],
)

axis = 1
for col, env in enumerate(envs, 1):
    dfs = []
    for _, test in res_df[res_df["env"] == env].iterrows():
        df = scalars.read(test["path"])
        df = df[df["tag"] == "val/mean_ep_ret"]
        df["index"] = np.arange(len(df))
        dfs.append(df)
    df = pd.concat(dfs)

    g = df.groupby("index")
    avg_df = pd.DataFrame.from_records(
        {
            "score_mean": g["value"].mean(),
            "score_std": g["value"].std(),
            "step": g["step"].median(),
        }
    )

    kw = dict() if col == 1 else dict(showlegend=False)

    colors = make_color_iter()
    color = next(colors)
    for trace in err_line(
        x=avg_df["step"],
        y=avg_df["score_mean"],
        std=avg_df["score_std"],
        color=color,
        name="PyTorch",
        **kw,
    ):
        fig.add_trace(trace, row=1, col=col)

    ref_ = ref_df[(ref_df["task"] == env) & (ref_df["time"] <= 6e6)]
    g = ref_.groupby("time")
    avg_ref_df = pd.DataFrame.from_records(
        {
            "score_mean": g["score"].mean(),
            "score_std": g["score"].std(),
        }
    )
    avg_ref_df = avg_ref_df.reset_index()

    color = next(colors)
    for trace in err_line(
        x=avg_ref_df["time"],
        y=avg_ref_df["score_mean"],
        std=avg_ref_df["score_std"],
        color=color,
        name="Reference",
        **kw,
    ):
        fig.add_trace(trace, row=1, col=col)

    fig.update_layout(**{f"xaxis{axis}": dict(title="Env step")})
    axis += 1

fig.update_layout(width=1000, height=400, yaxis_title="Score")
fig.write_image("../tex/assets/torch_v_ref.pdf")
fig