In [None]:
from utils import *
from plotly.subplots import make_subplots
from tqdm.auto import tqdm
import json
from loaders import *

In [None]:
res_df, scalars = warm_start_loader()
res_b_df, scalars_b = baseline_loader()
res_a_df, scalars_a = warm_start_actor_loader()
res_s_df, scalars_s = sanity_check_loader()

In [None]:
envs = res_df["env"].unique()

fig = make_subplots(
    rows=1,
    cols=len(envs),
    column_titles=[*envs],
)

axis = 1
for col, env in enumerate(envs, 1):
    dfs = []
    for _, test in res_df[res_df["env"] == env].iterrows():
        df = scalars.read(test["path"])
        df = df[df["tag"] == "val/mean_ep_ret"]
        df["index"] = np.arange(len(df))
        dfs.append(df)
    df = pd.concat(dfs)

    g = df.groupby("index")
    avg_df = pd.DataFrame.from_records(
        {
            "score_mean": g["value"].mean(),
            "score_std": g["value"].std(),
            "step": g["step"].median(),
        }
    )

    kw = dict() if col == 1 else dict(showlegend=False)

    colors = make_color_iter()
    color = next(colors)
    for trace in err_line(
        x=avg_df["step"],
        y=avg_df["score_mean"],
        std=avg_df["score_std"],
        color=color,
        name="Warm start",
        **kw,
    ):
        fig.add_trace(trace, row=1, col=col)

    dfs = []
    for _, test in res_s_df[res_s_df["env"] == env].iterrows():
        df = scalars_s.read(test["path"])
        df = df[df["tag"] == "val/mean_ep_ret"]
        df["index"] = np.arange(len(df))
        dfs.append(df)
    df = pd.concat(dfs)

    g = df.groupby("index")
    avg_df = pd.DataFrame.from_records(
        {
            "score_mean": g["value"].mean(),
            "score_std": g["value"].std(),
            "step": g["step"].median(),
        }
    )

    kw = dict() if col == 1 else dict(showlegend=False)

    color = next(colors)
    for trace in err_line(
        x=avg_df["step"],
        y=avg_df["score_mean"],
        std=avg_df["score_std"],
        color=color,
        name="Random init",
        **kw,
    ):
        fig.add_trace(trace, row=1, col=col)

    fig.update_layout(**{f"xaxis{axis}": dict(title="Env step")})
    axis += 1

fig.update_layout(width=1000, height=400, yaxis_title="Score")
fig.write_image("../tex/assets/plas_check.pdf")
fig

In [None]:
envs = res_a_df["env"].unique()

fig = make_subplots(
    rows=1,
    cols=len(envs),
    column_titles=[*envs],
)

axis = 1
for col, env in enumerate(envs, 1):
    dfs = []
    for _, test in res_a_df[res_a_df["env"] == env].iterrows():
        df = scalars_a.read(test["path"])
        df = df[df["tag"] == "val/mean_ep_ret"]
        df["index"] = np.arange(len(df))
        dfs.append(df)
    df = pd.concat(dfs)

    g = df.groupby("index")
    avg_df = pd.DataFrame.from_records(
        {
            "score_mean": g["value"].mean(),
            "score_std": g["value"].std(),
            "step": g["step"].median(),
        }
    )

    kw = dict() if col == 1 else dict(showlegend=False)

    colors = make_color_iter()
    color = next(colors)
    for trace in err_line(
        x=avg_df["step"],
        y=avg_df["score_mean"],
        std=avg_df["score_std"],
        color=color,
        name="Warm start (Actor)",
        **kw,
    ):
        fig.add_trace(trace, row=1, col=col)

    dfs = []
    for _, test in res_s_df[res_s_df["env"] == env].iterrows():
        df = scalars_s.read(test["path"])
        df = df[df["tag"] == "val/mean_ep_ret"]
        df["index"] = np.arange(len(df))
        dfs.append(df)
    df = pd.concat(dfs)

    g = df.groupby("index")
    avg_df = pd.DataFrame.from_records(
        {
            "score_mean": g["value"].mean(),
            "score_std": g["value"].std(),
            "step": g["step"].median(),
        }
    )

    kw = dict() if col == 1 else dict(showlegend=False)

    color = next(colors)
    for trace in err_line(
        x=avg_df["step"],
        y=avg_df["score_mean"],
        std=avg_df["score_std"],
        color=color,
        name="Random init",
        **kw,
    ):
        fig.add_trace(trace, row=1, col=col)

    fig.update_layout(**{f"xaxis{axis}": dict(title="Env step")})
    axis += 1

fig.update_layout(width=1000, height=400, yaxis_title="Score")
# fig.write_image("../tex/assets/plas_check.pdf")
fig

In [None]:
records = []
for _, row in res_df.iterrows():
    record = {"env": row["env"]}
    test_df = scalars.read(row["path"])
    record["perf"] = test_df[test_df["tag"] == "val/mean_ep_ret"]["value"].iloc[-1]
    cfg = load_config(row["path"])
    ckpt_path = Path(cfg["stages"][0]["load_ckpt"]["path"])
    base_dir = ckpt_path.parents[1]
    base_df = scalars_b.read(base_dir)
    for k in ("wm", "rl/actor", "rl/critic"):
        tag = f"plas/{k}/dead_units/avg_freq"
        dead_units = base_df[base_df["tag"] == tag]["value"].iloc[-1]
        record[f"{k}/dead"] = dead_units
        tag = f"plas/{k}/weight_norm/rel_norm"
        weight_norm = base_df[base_df["tag"] == tag]["value"].iloc[-1]
        record[f"{k}/weight_norm"] = weight_norm
    records.append(record)

plas_df = pd.DataFrame.from_records(records)
avg_s = res_s_df.groupby("env")["score"].mean().reset_index()
plas_df = plas_df.merge(avg_s, on="env")
plas_df["perf_norm"] = plas_df["perf"] / plas_df["score"]

In [None]:
nets = ["wm", "rl/actor", "rl/critic"]
metrics = ["dead", "weight_norm"]

fig = make_subplots(
    rows=len(metrics),
    row_titles=["Dead units", "Weight norm"],
    cols=len(nets),
    column_titles=["Model", "Actor", "Critic"],
)

axis = 1
titles = {"dead": "# of dead units", "weight_norm": "Weight norm"}
for row, metric in enumerate(metrics, 1):
    for col, net_id in enumerate(nets, 1):
        xpr = px.scatter(
            plas_df, y=f"{net_id}/{metric}", x="perf_norm", trendline="ols"
        )
        for trace in xpr.data:
            fig.add_trace(trace, row=row, col=col)
        if col == 1:
            fig.update_layout(**{f"yaxis{axis}": dict(title=titles[metric])})
        if row == 2 and col == 2:
            fig.update_layout(**{f"xaxis{axis}": dict(title="Normalized performance")})
        axis += 1

fig.update_layout(width=800, height=600)
fig.write_image("../tex/assets/plas_check.metrics.pdf")
fig