## Plotting Digits

In [1]:
# --- Jupyter Starter Pack ---

# autoreload: refresh code on every cell run
%reload_ext autoreload
%autoreload 2

# clean warnings
import warnings
warnings.filterwarnings("ignore")

# nicer printing
from pprint import pprint

# numpy / pandas nicer display
import numpy as np
np.set_printoptions(precision=4, suppress=True)

# matplotlib defaults
import matplotlib.pyplot as plt
plt.rcParams["figure.figsize"] = (6, 4)
plt.rcParams["figure.dpi"] = 120

# tqdm in notebooks
from tqdm.notebook import tqdm

# optional: make exceptions show only the important frame
%config InlineBackend.figure_format = "retina"

In [4]:
from pathlib import Path
from datetime import datetime, date
import json
import subprocess


def parse_overrides_file(path: Path) -> dict:
    overrides = {}
    for raw in path.read_text().splitlines():
        line = raw.strip()
        if not line or line.startswith("#"):
            continue
        if line.startswith("-"):
            line = line.lstrip("-").strip()
        if "=" not in line:
            continue
        key, val = line.split("=", 1)
        overrides[key.strip()] = val.strip()
    return overrides


def _to_date(obj) -> date | None:
    if obj is None:
        return None
    if isinstance(obj, date):
        return obj
    return datetime.strptime(str(obj), "%Y-%m-%d").date()


def _extract_run_date(run_dir: Path) -> date | None:
    for part in run_dir.parts:
        try:
            return datetime.strptime(part, "%Y-%m-%d").date()
        except ValueError:
            continue
    return None


def _date_in_range(run_dir: Path, start_date: date | None, end_date: date | None) -> bool:
    run_date = _extract_run_date(run_dir)
    if run_date is None:
        return not (start_date or end_date)
    if start_date and run_date < start_date:
        return False
    if end_date and run_date > end_date:
        return False
    return True


def _read_config_name(run_dir: Path) -> str | None:
    candidates = [run_dir / "multirun.yaml", run_dir / ".hydra" / "hydra.yaml", run_dir / ".hydra" / "config.yaml"]
    for cand in candidates:
        if not cand.exists():
            continue
        for raw in cand.read_text().splitlines():
            line = raw.strip()
            if not line or line.startswith("#"):
                continue
            if line.startswith("config_name:"):
                return line.split(":", 1)[1].strip()
    return None


def find_matching_runs(
    root: Path,
    required_fixed: dict[str, str],
    start_date: str | date | None = None,
    end_date: str | date | None = None,
    target_filter: list[str] | None = None,
    lambda_reg_filter: list[str] | None = None,
    running_state_cost_scaling_filter: list[str] | None = None,
    agents_filter: list[str] | None = None,
    config_name_filter: list[str] | None = None,
) -> list[dict]:
    start_date = _to_date(start_date)
    end_date = _to_date(end_date)
    root = root.expanduser().resolve()
    matches = []
    for overrides_path in root.rglob("overrides.yaml"):
        run_dir = overrides_path.parent.parent  # typically .../<run>/.hydra/overrides.yaml
        if not _date_in_range(run_dir, start_date, end_date):
            continue
        cfg = parse_overrides_file(overrides_path)
        if not all(cfg.get(k) == v for k, v in required_fixed.items()):
            continue
        target = cfg.get("exps.soc.optimality_target")
        lam = cfg.get("exps.soc.lambda_reg")
        ror = cfg.get("exps.soc.running_state_cost_scaling")
        agents = cfg.get("exps.soc.num_control_agents")
        config_name = _read_config_name(run_dir)
        if target_filter and target not in target_filter:
            continue
        if lambda_reg_filter and lam not in lambda_reg_filter:
            continue
        if running_state_cost_scaling_filter and ror not in running_state_cost_scaling_filter:
            continue
        if agents_filter and agents not in agents_filter:
            continue
        if config_name_filter and config_name not in config_name_filter:
            continue
        matches.append({
            "run_dir": run_dir,
            "run_date": _extract_run_date(run_dir),
            "overrides": cfg,
            "target": target,
            "lambda_reg": lam,
            "running_state_cost_scaling": ror,
            "num_control_agents": agents,
            "config_name": config_name,
        })
    return matches


def _sort_key_numeric_first(val: str) -> tuple:
    try:
        return (0, int(val))
    except (TypeError, ValueError):
        return (1, str(val))

# --- add this helper ---
def _datestr(d: date | None) -> str | None:
    return d.isoformat() if isinstance(d, date) else None

required_fixed = {
    "exps.sde.name": "VP",
}

# change this to whatever root you want to scan (e.g., multirun or outputs)
root_dir = Path("/home/rbarbano/home/git/multi-agent-diffusion-working-repo/multirun")

# optional filters (strings); set to None to disable
start_date = "2026-02-03"
end_date = "2026-02-03"
target_filter = ["3"]  # e.g., ["7", "8"]
lambda_reg_filter = ["10.0"]  # e.g., ["1.0", "10.0"]
running_state_cost_scaling_filter = ["1.0"]  # e.g., ["1.0"]
agents_filter = ["3"]  # e.g., ["2", "3"]``
config_name_filter = None # ["exps/bptt_learning_agents_fine_tuning"]  # e.g., ["exps/fictitious_bptt_learning_agents_fine_tuning"]
open_first_run_dir = False  # set True to open first matching run folder in your OS file explorer

start_date_obj = _to_date(start_date)
end_date_obj = _to_date(end_date)

matches = find_matching_runs(
    root_dir,
    required_fixed,
    start_date=start_date,
    end_date=end_date,
    target_filter=target_filter,
    lambda_reg_filter=lambda_reg_filter,
    running_state_cost_scaling_filter=running_state_cost_scaling_filter,
    agents_filter=agents_filter,
    config_name_filter=config_name_filter,
)
print(
    f"Found {len(matches)} matching runs under {root_dir} between {start_date} and {end_date}"
)

# group by target -> agents -> config_name and emit JSON with all paths
grouped: dict[str, dict[str, dict[str, list[dict]]]] = {}
for m in matches:
    tgt = m.get("target") or "unknown"
    agents = m.get("num_control_agents") or "unknown"
    cfg_name = m.get("config_name") or "unknown"
    grouped.setdefault(tgt, {}).setdefault(agents, {}).setdefault(cfg_name, []).append({
        "run_dir": str(m["run_dir"]),
        "run_date": m["run_date"].isoformat() if m["run_date"] else None,
        "lambda_reg": m.get("lambda_reg"),
        "running_state_cost_scaling": m.get("running_state_cost_scaling"),
    })

# order targets and agents numerically when possible
ordered_grouped: dict[str, dict[str, dict[str, list[dict]]]] = {}
for tgt in sorted(grouped.keys(), key=_sort_key_numeric_first):
    agent_map = grouped[tgt]
    ordered_agent_map: dict[str, dict[str, list[dict]]] = {}
    for agents in sorted(agent_map.keys(), key=_sort_key_numeric_first):
        ordered_agent_map[agents] = agent_map[agents]
    ordered_grouped[tgt] = ordered_agent_map

# --- build output with metadata (dates) ---
output = {
    "meta": {
        "root_dir": str(root_dir.expanduser().resolve()),
        "start_date": _datestr(start_date_obj),
        "end_date": _datestr(end_date_obj),
        "required_fixed": required_fixed,
        "filters": {
            "target_filter": target_filter,
            "lambda_reg_filter": lambda_reg_filter,
            "running_state_cost_scaling_filter": running_state_cost_scaling_filter,
            "agents_filter": agents_filter,
            "config_name_filter": config_name_filter,
        },
        "num_matches": len(matches),
        "generated_at": datetime.now().isoformat(timespec="seconds"),
    },
    "runs": ordered_grouped,  # same grouped structure you already create
}
output_path = Path("matched_runs.json")
with output_path.open("w", encoding="utf-8") as f:
    json.dump(output, f, indent=2)

print(f"Wrote grouped run paths + dates to {output_path.resolve()}")

print(f"Wrote grouped run paths to {output_path.resolve()}")
for tgt, agent_map in ordered_grouped.items():
    for agents, cfg_map in agent_map.items():
        for cfg_name, runs in cfg_map.items():
            print(f"target={tgt}, agents={agents}, config={cfg_name}: {len(runs)} runs")

if open_first_run_dir and matches:
    first_dir = matches[0]["run_dir"]
    print(f"Opening {first_dir}...")
    try:
        subprocess.run(["xdg-open", str(first_dir)], check=False)
    except FileNotFoundError:
        print("xdg-open not available; open manually:", first_dir)


Found 2 matching runs under /home/rbarbano/home/git/multi-agent-diffusion-working-repo/multirun between 2026-02-03 and 2026-02-03
Wrote grouped run paths + dates to /home/rbarbano/home/git/multi-agent-diffusion-working-repo/viz/matched_runs.json
Wrote grouped run paths to /home/rbarbano/home/git/multi-agent-diffusion-working-repo/viz/matched_runs.json
target=3, agents=3, config=exps/fictitious_bptt_learning_agents_fine_tuning: 1 runs
target=3, agents=3, config=exps/bptt_learning_agents_fine_tuning: 1 runs


In [None]:
# AGENT 3 - DIGIT 9 - FICTITIOUS
# path_to_controls = "/home/rbarbano/home/git/multi-agent-diffusion-working-repo/multirun/2026-02-02/23-35-28/0/weights"
# # AGENT 3 - DIGIT 3
# path_to_controls = "/home/rbarbano/home/git/multi-agent-diffusion-working-repo/multirun/2026-02-03/04-23-26/0/weights"
# # AGENT 3 - DIGIT 0
# path_to_controls = "/home/rbarbano/home/git/multi-agent-diffusion-working-repo/multirun/2026-02-03/08-13-10/0/weights"


# AGENT 3 - DIGIT 9 - JOINT
# path_to_controls = "/home/rbarbano/home/git/multi-agent-diffusion-working-repo/multirun/2026-02-03/11-27-12/0/weights"
# # AGENT 3 - DIGIT 3
# path_to_controls = "/home/rbarbano/home/git/multi-agent-diffusion-working-repo/multirun/2026-02-03/12-29-15/0/weights"
# # AGENT 3 - DIGIT 0
# path_to_controls = "/home/rbarbano/home/git/multi-agent-diffusion-working-repo/multirun/2026-02-03/13-54-25/0/weights"


# AGENT 2 - DIGIT 9 - FICTITIOUS
# path_to_controls = "/home/rbarbano/home/git/multi-agent-diffusion-working-repo/multirun/2026-02-02/23-36-08/0/weights"
# # AGENT 2 - DIGIT 3
# path_to_controls = "/home/rbarbano/home/git/multi-agent-diffusion-working-repo/multirun/2026-02-03/02-23-27/0/weights"
# # AGENT 2 - DIGIT 0
# path_to_controls = "/home/rbarbano/home/git/multi-agent-diffusion-working-repo/multirun/2026-02-03/03-19-14/0/weights"


# AGENT 2 - DIGIT 9 - JOINT
# path_to_controls = "/home/rbarbano/home/git/multi-agent-diffusion-working-repo/multirun/2026-02-03/05-10-25/0/weights"
# # AGENT 2 - DIGIT 3
# path_to_controls = "/home/rbarbano/home/git/multi-agent-diffusion-working-repo/multirun/2026-02-03/06-32-51/0/weights"
# # AGENT 2 - DIGIT 0
# path_to_controls = "/home/rbarbano/home/git/multi-agent-diffusion-working-repo/multirun/2026-02-03/07-00-24/0/weights"