In [None]:
from utils import *
from plotly.subplots import make_subplots
from tqdm.auto import tqdm
import json
from loaders import baseline_loader, reference_loader

In [None]:
res_df, scalars = baseline_loader()
ref_df = reference_loader()

In [None]:
envs = res_df["env"].unique()

fig = make_subplots(
    rows=1,
    cols=len(envs),
    column_titles=[*envs],
)

for col, env in enumerate(envs, 1):
    df_ = res_df[res_df["env"] == env]

    g = df_.groupby("ratio")
    avg_df = pd.DataFrame.from_records(
        {
            "score_mean": g["score"].mean(),
            "score_std": g["score"].std(),
            "ratio": g["ratio"].first(),
        }
    )

    color = next(make_color_iter())
    for trace in err_line(
        x=avg_df["ratio"],
        y=avg_df["score_mean"],
        std=avg_df["score_std"],
        color=color,
        showlegend=False,
    ):
        fig.add_trace(trace, row=1, col=col)

    rand = ref_df[ref_df["task"] == env]["random"].item()
    fig.update_layout(
        **{f"xaxis{col}": dict(title="Update ratio K", type="category")},
        **{f"yaxis{col}": dict(range=[rand, None])},
    )
    if col == 1:
        fig.update_layout(
            **{f"yaxis{col}": dict(title="Score")},
        )


fig.update_layout(width=1000, height=400)

fig.write_image("../tex/assets/baseline_score.pdf")
fig

In [None]:
envs = res_df["env"].unique()

fig = make_subplots(
    rows=1,
    cols=len(envs),
    column_titles=[*envs],
)

for col, env in enumerate(envs, 1):
    score_df = res_df[res_df["env"] == env].copy()
    score_df["wm_loss"] = pd.Series(dtype=np.float32)

    for idx, row in score_df.iterrows():
        test_df = scalars.read(row["path"])
        test_df = test_df[test_df["tag"] == "val/wm_loss"]
        score_df.at[idx, "wm_loss"] = test_df.iloc[-1]["value"]

    subplot = px.scatter(score_df, x="score", y="wm_loss", trendline="ols")
    for trace in subplot.data:
        fig.add_trace(trace, row=1, col=col)

    fig.update_layout(
        **{f"xaxis{col}": dict(title="Score")},
    )
    if col == 1:
        fig.update_layout(
            **{f"yaxis{col}": dict(title="Model val loss")},
        )


fig.update_layout(width=1000, height=400)

fig.write_image("../tex/assets/baseline_model_val.pdf")
fig

In [None]:
envs = res_df["env"].unique()

fig = make_subplots(
    rows=1,
    cols=len(envs),
    column_titles=[*envs],
)

for col, env in enumerate(envs, 1):
    score_df = res_df[res_df["env"] == env].copy()
    score_df["wm_loss"] = pd.Series(dtype=np.float32)

    for idx, row in score_df.iterrows():
        test_df = scalars.read(row["path"])
        test_df = test_df[test_df["tag"] == "wm/loss"]
        score_df.at[idx, "wm_loss"] = test_df.iloc[-1]["value"]

    subplot = px.scatter(score_df, x="score", y="wm_loss", trendline="ols")
    for trace in subplot.data:
        fig.add_trace(trace, row=1, col=col)

    fig.update_layout(
        **{f"xaxis{col}": dict(title="Score")},
    )
    if col == 1:
        fig.update_layout(
            **{f"yaxis{col}": dict(title="Model train loss")},
        )

fig.update_layout(width=1000, height=400)

fig.write_image("../tex/assets/baseline_model_train.pdf")
fig

In [None]:
from rsrch.utils import sched

envs = res_df["env"].unique()

fig = make_subplots(
    rows=1,
    cols=len(envs),
    column_titles=[*envs],
)

s = sched.Linear((20e3, 0.99), (50e3, 0.1), (200e3, 5e-2))
ts = np.linspace(0.0, 400e3, 1024)
vs = np.array([s(t) for t in ts]) * np.log(9)

for col, env in enumerate(envs, 1):
    df = res_df[res_df["env"] == env]
    df = df.sort_values(by="score")
    top5 = df.iloc[-5:]
    for _, row in top5.iterrows():
        test_df = scalars.read(row["path"])
        ent = test_df[test_df["tag"] == "rl/entropy"].copy()
        ent["time"] = 20e3 + ent["step"] * row["ratio"]
        fig.add_trace(
            go.Scatter(
                x=ent["time"],
                y=exp_mov_avg(ent["value"], 0.9),
                mode="lines",
                showlegend=False,
            ),
            row=1,
            col=col,
        )

    fig.add_trace(
        go.Scatter(
            x=ts,
            y=vs,
            mode="lines",
            line=dict(dash="dot", color="black"),
            name="Entropy target",
            showlegend=(col == 1),
        ),
        row=1,
        col=col,
    )

    fig.update_layout(
        **{f"xaxis{col}": dict(title="Time step")},
        **{f"yaxis{col}": dict(type="log")},
    )
    if col == 1:
        fig.update_layout(
            **{f"yaxis{col}": dict(title="Mean entropy")},
        )

fig.update_layout(
    width=1000,
    height=400,
)

fig.write_image("../tex/assets/baseline_ent.pdf")
fig