In [4]:
import pandas as pd
import plotly.express as px
from io import StringIO

df = pd.read_csv("model_aggregates.csv")

# parse grid size
df["Grid"] = df["Size"].str.extract(r"(\d+)").astype(int)

# metrics to facet
metrics = [
    "Average Reward",
    "Average Steps",
    "Failure Rate (Max Steps)",
    "Average step (Success)",
    "Failure Rate (Wrong Door)",
    "Success",
]

# strip % and convert
for col in metrics:
    df[col] = pd.to_numeric(df[col].astype(str).str.rstrip("%"), errors="coerce")

# long format
long = df.melt(
    id_vars=["Architecture", "Grid"],
    value_vars=metrics,
    var_name="Metric",
    value_name="Value"
)

In [5]:
px.defaults.template = "plotly"

for metric in metrics:
    # 1. Filter to this metric
    df_metric = long[long["Metric"] == metric]

    # 2. Build the line plot
    fig = px.line(
        df_metric,
        x="Grid",
        y="Value",
        color="Architecture",
        markers=True,
        labels={f"Grid": "Grid size (N×2N)", "Value": metric},
        title=metric
    )

    # 3. Make it square: e.g. 700×700 pixels
    fig.update_layout(
        width=750,
        height=400,
        legend_title_text="Model",
        margin=dict(t=80, b=40, l=40, r=40)
    )

    # optional: force x‐axis ticks at every integer
    fig.update_xaxes(dtick=1)

    fig.show()