In [1]:
from utils import *

from pathlib import Path
import plotly.express as px
from plotly.subplots import make_subplots
import plotly.graph_objects as go

results_dir = Path("../results/split-freq")
scalars = TBScalars(".cache/split_freq")

In [2]:
dfs = []

freqs = [4, 2, 1]

fig = make_subplots(
    rows=len(freqs),
    row_titles=freqs,
    y_title="wm_freq",
    cols=len(freqs),
    column_titles=freqs,
    x_title="rl_freq",
    horizontal_spacing=0.1,
    vertical_spacing=0.05,
)

for row, wm_freq in enumerate(freqs, 1):
    for col, rl_freq in enumerate(freqs, 1):
        dfs = []
        for test_dir in results_dir.glob(f"*wm_freq={wm_freq}-rl_freq={rl_freq}*"):
            with open(test_dir / "config.yml", "r") as f:
                cfg = yaml.load(f)

            df = scalars.read(test_dir)
            if "tag" not in df.columns:
                continue
            df = df[df["tag"] == "val/mean_ep_ret"].copy()
            df["seed"] = cfg["repro"]["seed"]
            dfs.append(df)

        df = pd.concat(dfs, axis=0)

        xpr = px.line(df, x="step", y="value", color="seed")
        for trace in xpr.data:
            fig.add_trace(trace, row=row, col=col)

fig

In [3]:
def is_incomplete(test_dir: Path):
    with open(test_dir / "config.yml", "r") as f:
        cfg = yaml.load(f)

    df = scalars.read(test_dir)
    if "tag" not in df.columns:
        return True

    if df["step"].max() < 350e3:
        return True


for test_dir in results_dir.iterdir():
    if is_incomplete(test_dir):
        print(test_dir)

In [5]:
records = []

for test_dir in results_dir.iterdir():
    attrs = {}
    for section in test_dir.name.split("-"):
        if "=" in section:
            name, value = section.split("=")
        else:
            name, value = None, section
        attrs[name] = value

    with open(test_dir / "config.yml", "r") as f:
        cfg = yaml.load(f)

    df = scalars.read(test_dir)
    if "tag" not in df.columns:
        continue

    df = df[df["tag"] == "val/mean_ep_ret"].copy()
    score = df.loc[df["step"].idxmax(), "value"]
    for attr in ("wm_freq", "rl_freq"):
        records.append({"attr": attr, "value": attrs[attr], "score": score})

df = pd.DataFrame.from_records(records)
px.scatter(df, x="value", y="score", color="attr")

In [3]:
records = []
for test_dir in results_dir.iterdir():
    with open(test_dir / "config.yml", "r") as f:
        cfg = yaml.load(f)
    df = scalars.read(test_dir)
    if "tag" not in df.columns:
        continue

    record = {}
    df_ = df[df["tag"] == "val/mean_ep_ret"]
    record["score"] = df_.loc[df_["step"].idxmax(), "value"]
    df_ = df[df["tag"] == "val/wm_loss"]
    record["val/wm_loss"] = df_.loc[df_["step"].idxmax(), "value"]

    records.append(record)

df_ = pd.DataFrame.from_records(records)
px.scatter(df_, x="val/wm_loss", y="score")