In [None]:
# Count observational (no-issue) rows and interventional (issue) episodes/rows in PetShop
import pandas as pd
from pathlib import Path

# 1) Clone the repo if needed
ROOT = Path("/content")
REPO = ROOT / "petshop-root-cause-analysis"
if not REPO.exists():
    !git -C /content clone -q https://github.com/amazon-science/petshop-root-cause-analysis.git

DATA_ROOT = REPO / "dataset"
assert DATA_ROOT.exists(), "dataset/ not found in the repo."

def len_csv(csv_path: Path) -> int:
    try:
        return pd.read_csv(csv_path, engine="python").shape[0]
    except Exception as e:
        print(f"[warn] failed reading {csv_path}: {e}")
        return 0

rows = []
for scenario_dir in sorted([p for p in DATA_ROOT.iterdir() if p.is_dir()]):
    # observational (noissue)
    noissue_csv = scenario_dir / "noissue" / "metrics.csv"
    obs_rows = len_csv(noissue_csv) if noissue_csv.exists() else 0

    # interventional: each metrics.csv not under noissue is one issue episode
    issue_csvs = [p for p in scenario_dir.rglob("metrics.csv") if "noissue" not in p.parts]
    issue_eps  = len(issue_csvs)
    issue_rows_total = sum(len_csv(p) for p in issue_csvs)
    issue_rows_avg   = (issue_rows_total / issue_eps) if issue_eps > 0 else 0

    rows.append({
        "scenario": scenario_dir.name,
        "obs_rows": obs_rows,
        "issue_episodes": issue_eps,
        "issue_rows_total": issue_rows_total,
        "issue_rows_avg": round(issue_rows_avg, 2),
    })

summary = pd.DataFrame(rows).sort_values("scenario").reset_index(drop=True)

# Add grand totals
totals = pd.DataFrame([{
    "scenario": "TOTAL",
    "obs_rows": int(summary["obs_rows"].sum()),
    "issue_episodes": int(summary["issue_episodes"].sum()),
    "issue_rows_total": int(summary["issue_rows_total"].sum()),
    "issue_rows_avg": ""
}])

display(summary)
display(totals)

# Save to CSV for convenience
out_path = "/content/petshop_observational_interventional_counts.csv"
summary.to_csv(out_path, index=False)
print(f"\nSaved: {out_path}")


Unnamed: 0,scenario,obs_rows,issue_episodes,issue_rows_total,issue_rows_avg
0,high_traffic,592,26,208,8.0
1,low_traffic,592,26,208,8.0
2,temporal_traffic1,1655,8,64,8.0
3,temporal_traffic2,1655,8,64,8.0


Unnamed: 0,scenario,obs_rows,issue_episodes,issue_rows_total,issue_rows_avg
0,TOTAL,4494,68,544,



Saved: /content/petshop_observational_interventional_counts.csv


4494 observations and 41 features

with 544 interventional samples

In [None]:
# Count perturbed features per incident in PetShop (mean/variance shift)
import pandas as pd, numpy as np, re
from pathlib import Path

# --- repo/data ---
ROOT = Path("/content")
REPO = ROOT / "petshop-root-cause-analysis"
if not REPO.exists():
    !git -C /content clone -q https://github.com/amazon-science/petshop-root-cause-analysis.git
DATA_ROOT = REPO / "dataset"
assert DATA_ROOT.exists(), "dataset/ not found"

# --- helpers ---
def read_csv_any(p: Path) -> pd.DataFrame:
    return pd.read_csv(p, engine="python")

def numeric_table(df: pd.DataFrame) -> pd.DataFrame:
    out = df.copy()
    for c in out.columns:
        out[c] = pd.to_numeric(out[c], errors="coerce")
    out = out.dropna(axis=1, how="all").dropna(axis=0, how="all")
    # simple fill to avoid NaNs breaking stats
    out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
    return out

def select_wide(df: pd.DataFrame) -> pd.DataFrame:
    """Return the orientation with more numeric columns (features)."""
    A = numeric_table(df)
    B = numeric_table(df.T)
    # Heuristic: prefer >=10 columns and >=30 rows; otherwise take the one with more columns
    def score(X): return (X.shape[1] if (X.shape[1] >= 10 and X.shape[0] >= 30) else 0, X.shape[1])
    return A if score(A) >= score(B) else B

def smd_mean_var(obs: np.ndarray, iss: np.ndarray, eps=1e-8):
    """standardized mean difference and log variance ratio"""
    m_o, v_o = float(obs.mean()), float(obs.var(ddof=1)+eps)
    m_i, v_i = float(iss.mean()), float(iss.var(ddof=1)+eps)
    smd = abs(m_i - m_o) / np.sqrt(0.5*(v_i + v_o))
    lvr = abs(np.log(v_i / v_o))
    return smd, lvr

SMD_THR = 0.8        # "large" effect size (Cohen)
LOGVR_THR = np.log(1.5)  # ≥ 50% variance change

def count_perturbed(obs_wide: pd.DataFrame, iss_wide: pd.DataFrame):
    common = list(set(obs_wide.columns) & set(iss_wide.columns))
    if len(common) == 0:
        return 0, 0, 0, []
    obs = obs_wide[common].to_numpy(np.float64)
    iss = iss_wide[common].to_numpy(np.float64)

    mean_hits = 0
    var_hits = 0
    either_hits = 0
    per_feat = []
    for j, col in enumerate(common):
        smd, lvr = smd_mean_var(obs[:, j], iss[:, j])
        is_mean = smd >= SMD_THR
        is_var  = lvr >= LOGVR_THR
        mean_hits += int(is_mean)
        var_hits  += int(is_var)
        either_hits += int(is_mean or is_var)
        per_feat.append((col, smd, lvr, int(is_mean), int(is_var)))
    return mean_hits, var_hits, either_hits, per_feat

# --- main scan ---
rows = []
detail_rows = []

for scen_dir in sorted(p for p in DATA_ROOT.iterdir() if p.is_dir()):
    noissue = scen_dir / "noissue" / "metrics.csv"
    if not noissue.exists():
        continue
    try:
        obs_df = select_wide(read_csv_any(noissue))
    except Exception as e:
        print(f"[warn] cannot read {noissue}: {e}")
        continue

    issue_csvs = [p for p in scen_dir.rglob("metrics.csv") if "noissue" not in p.parts]
    for ic in sorted(issue_csvs):
        try:
            iss_df = select_wide(read_csv_any(ic))
            m, v, e, per_feat = count_perturbed(obs_df, iss_df)
            rows.append({
                "scenario": scen_dir.name,
                "issue_path": str(ic),
                "n_features_common": len(set(obs_df.columns) & set(iss_df.columns)),
                "perturbed_mean": m,
                "perturbed_var": v,
                "perturbed_either": e
            })
            for col, smd, lvr, im, iv in per_feat:
                detail_rows.append({
                    "scenario": scen_dir.name,
                    "issue_path": str(ic),
                    "feature": col,
                    "smd": smd,
                    "log_var_ratio_abs": lvr,
                    "flag_mean": im,
                    "flag_var": iv
                })
        except Exception as e:
            print(f"[warn] issue read failed {ic}: {e}")

summary = pd.DataFrame(rows).sort_values(["scenario","issue_path"])
details = pd.DataFrame(detail_rows)

# Totals
totals = summary[["perturbed_mean","perturbed_var","perturbed_either"]].sum().to_dict()
print("Totals across all issue episodes:")
print(totals)

# Show a compact per-scenario aggregation
agg = (summary
       .groupby("scenario")[["perturbed_mean","perturbed_var","perturbed_either"]]
       .sum()
       .reset_index()
       .sort_values("scenario"))
print("\nPer-scenario summed counts:")
print(agg.to_string(index=False))

# Save CSVs
summary.to_csv("/content/petshop_perturbed_feature_counts_by_issue.csv", index=False)
details.to_csv("/content/petshop_perturbed_feature_details.csv", index=False)
print("\nSaved:",
      "/content/petshop_perturbed_feature_counts_by_issue.csv",
      "and /content/petshop_perturbed_feature_details.csv")


  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
  out = out.fillna(m

Totals across all issue episodes:
{'perturbed_mean': 0, 'perturbed_var': 0, 'perturbed_either': 0}

Per-scenario summed counts:
         scenario  perturbed_mean  perturbed_var  perturbed_either
     high_traffic               0              0                 0
      low_traffic               0              0                 0
temporal_traffic1               0              0                 0
temporal_traffic2               0              0                 0

Saved: /content/petshop_perturbed_feature_counts_by_issue.csv and /content/petshop_perturbed_feature_details.csv


In [None]:
# ===== PETSHOP: robust (name→canonical→correlation) alignment + Set Transformer (BCE) =====
!pip -q install einops ruamel.yaml openpyxl

import re, json, random
from pathlib import Path
import numpy as np
import pandas as pd
import torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from ruamel.yaml import YAML

# -------------------- config --------------------
SEED = 123
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)

SCENARIO_NAME = "temporal_traffic2"   # change if needed (try other scenario dirs under dataset/)
S_TOP = 3                             # weak-label support size if GT missing
EPOCHS = 20  # 500 in server
BATCH_TRAIN = 8
BATCH_TEST = 8
LR = 2e-4
SIM_THRESHOLD = 0.6                  # min cosine similarity to accept a row-pair in correlation fallback
MAX_ROWS_MATCH = 2000                 # safety limit

# -------------------- repo / data --------------------
ROOT = Path("/content")
REPO = ROOT / "petshop-root-cause-analysis"
if not REPO.exists():
    !git -C /content clone -q https://github.com/amazon-science/petshop-root-cause-analysis.git

DATA_ROOT = REPO / "dataset"
assert DATA_ROOT.exists(), f"{DATA_ROOT} missing"

def get_scenario_dir(name_hint: str):
    exact = DATA_ROOT / name_hint
    if exact.exists(): return exact
    for p in DATA_ROOT.iterdir():
        if p.is_dir() and name_hint.lower() in p.name.lower():
            return p
    raise FileNotFoundError(f"Scenario '{name_hint}' not found under {DATA_ROOT}")

SCENARIO = get_scenario_dir(SCENARIO_NAME)
print("Using scenario:", SCENARIO)

# -------------------- CSV → candidate "wide" matrices (rows=services, cols=time) --------------------
def read_csv_any(csv_path: Path) -> pd.DataFrame:
    # engine='python' allows sep autodetect; errors='ignore' handles stray encodings
    return pd.read_csv(csv_path, engine="python")

def _to_numeric_df(df: pd.DataFrame) -> pd.DataFrame:
    out = df.copy()
    for c in out.columns:
        out[c] = pd.to_numeric(out[c], errors="coerce")
    return out

def _clean_fill(W: pd.DataFrame) -> pd.DataFrame:
    W = W.replace([np.inf, -np.inf], np.nan)
    W = W.fillna(method="ffill", axis=1).fillna(method="bfill", axis=1).fillna(0.0)
    return W.astype(np.float32)

def candidates_for_csv(csv_path: Path):
    cands = []
    try:
        raw = read_csv_any(csv_path)
    except Exception as e:
        print(f"[read fail] {csv_path}: {e}")
        return cands

    # A: 'microservice' column → rows=services, cols=time
    if "microservice" in raw.columns:
        A = raw.set_index("microservice")
        A = _to_numeric_df(A).dropna(axis=1, how="all").dropna(axis=0, how="all")
        A = _clean_fill(A)
        if A.shape[0] >= 5 and A.shape[1] >= 5:
            cands.append(("rows_are_services", A))

    # B: rows are time, numeric columns are services → transpose
    num_df = _to_numeric_df(raw).dropna(axis=1, how="all").dropna(axis=0, how="all")
    if num_df.shape[1] >= 5 and num_df.shape[0] >= 5:
        B = num_df.T.copy()
        B.index = [str(i) for i in B.index]
        B = _clean_fill(B)
        if B.shape[0] >= 5 and B.shape[1] >= 5:
            cands.append(("cols_are_services", B))
    # remove duplicates
    uniq, seen = [], set()
    for name, W in cands:
        sig = (name, W.shape[0], W.shape[1], float(np.nan_to_num(W.values[:5,:5]).sum()))
        if sig not in seen:
            uniq.append((name, W)); seen.add(sig)
    return uniq

# -------------------- alignment (exact → canonical → correlation) --------------------
def canon_name(s: str) -> str:
    s = str(s).lower()
    s = re.sub(r'arn:aws:[^ ]+', ' ', s)
    s = re.sub(r'aws::[a-z0-9:_\-]+', ' ', s)
    s = re.sub(r'[^a-z0-9]+', ' ', s)
    s = re.sub(r'\b(prod|stage|stg|dev|test|qa)\b', ' ', s)
    s = re.sub(r'\s+', ' ', s).strip()
    return s

def align_exact(A: pd.DataFrame, B: pd.DataFrame):
    exact = A.index.intersection(B.index)
    if len(exact) >= 5:
        return A.loc[exact].copy(), B.loc[exact].copy(), list(exact), {"mode":"exact", "count":len(exact)}
    return A.iloc[[]], B.iloc[[]], [], {"mode":"exact", "count":0}

def align_canonical(A: pd.DataFrame, B: pd.DataFrame):
    mapA, mapB = {}, {}
    for r in A.index:
        k = canon_name(r)
        if k and k not in mapA: mapA[k] = r
    for r in B.index:
        k = canon_name(r)
        if k and k not in mapB: mapB[k] = r
    keys = sorted(set(mapA.keys()) & set(mapB.keys()))
    if len(keys) >= 5:
        A2 = A.loc[[mapA[k] for k in keys]].copy()
        B2 = B.loc[[mapB[k] for k in keys]].copy()
        return A2, B2, [mapA[k] for k in keys], {"mode":"canonical", "count":len(keys)}
    return A.iloc[[]], B.iloc[[]], [], {"mode":"canonical", "count":0}

def l2_normalize_rows(M: np.ndarray, eps=1e-8):
    norms = np.linalg.norm(M, axis=1, keepdims=True)
    norms = np.maximum(norms, eps)
    return M / norms

def align_by_correlation(A: pd.DataFrame, B: pd.DataFrame, threshold=SIM_THRESHOLD, max_rows=MAX_ROWS_MATCH):
    """
    Greedy cosine matching between rows (services) across A & B (time series).
    Works even if time axes differ: uses min(TA, TB) leading columns for similarity.
    """
    # Trim to common time length
    T = min(A.shape[1], B.shape[1])
    if T < 5:
        return A.iloc[[]], B.iloc[[]], [], {"mode":"correlation", "count":0}
    A0 = A.iloc[:, :T].to_numpy(np.float32)
    B0 = B.iloc[:, :T].to_numpy(np.float32)
    # zscore (remove mean), then l2 normalize for cosine
    A0 = A0 - A0.mean(axis=1, keepdims=True)
    B0 = B0 - B0.mean(axis=1, keepdims=True)
    A0 = l2_normalize_rows(A0)
    B0 = l2_normalize_rows(B0)

    # cosine sim = A0 @ B0^T
    sim = A0 @ B0.T  # (RA, RB)
    RA, RB = sim.shape
    used_r = np.zeros(RA, dtype=bool)
    used_c = np.zeros(RB, dtype=bool)

    pairs = []
    # greedy: repeatedly pick max sim not yet used
    flat_idx = np.argsort(sim.ravel())[::-1]
    for idx in flat_idx:
        if len(pairs) >= max_rows:
            break
        r = idx // RB
        c = idx % RB
        if used_r[r] or used_c[c]:
            continue
        if sim[r, c] >= threshold:
            used_r[r] = True; used_c[c] = True
            pairs.append((r, c, float(sim[r,c])))
        # stop early if best remaining is below threshold
        if sim[r, c] < threshold:
            break

    if len(pairs) < 5:
        return A.iloc[[]], B.iloc[[]], [], {"mode":"correlation", "count":len(pairs)}

    # build aligned frames in A order
    pairs.sort(key=lambda t: t[0])
    idxA = [p[0] for p in pairs]
    idxB = [p[1] for p in pairs]
    A2 = A.iloc[idxA].copy()
    B2 = B.iloc[idxB].copy()
    rows = list(A2.index)
    return A2, B2, rows, {"mode":"correlation", "count":len(pairs)}

def best_issue_alignment(W_obs, issue_W, issue_tag="issue"):
    """
    Try exact → canonical → correlation; pick the one with largest match count.
    """
    trials = []
    A2,B2,rows,info = align_exact(W_obs, issue_W); trials.append((A2,B2,rows,info))
    A2,B2,rows,info = align_canonical(W_obs, issue_W); trials.append((A2,B2,rows,info))
    A2,B2,rows,info = align_by_correlation(W_obs, issue_W); trials.append((A2,B2,rows,info))
    best = max(trials, key=lambda x: x[3]["count"])
    print(f"[align] {issue_tag}: mode={best[3]['mode']}, matched_rows={best[3]['count']}")
    return best

# -------------------- read OBS (pick candidate with most rows) --------------------
obs_csv = SCENARIO / "noissue" / "metrics.csv"
if not obs_csv.exists():
    raise FileNotFoundError(f"Missing noissue/metrics.csv under {SCENARIO}")

obs_cands = candidates_for_csv(obs_csv)
if not obs_cands:
    raise RuntimeError("Unable to parse noissue/metrics.csv into any candidate orientations.")
obs_cands = sorted(obs_cands, key=lambda x: x[1].shape[0], reverse=True)
obs_variant, W_obs = obs_cands[0][0], obs_cands[0][1]
print(f"[obs] picked variant={obs_variant}, shape={W_obs.shape}")

# -------------------- gather issue files --------------------
issue_csvs = [p for p in SCENARIO.rglob("metrics.csv") if "noissue" not in str(p)]
print("Found issue metric files:", len(issue_csvs))

# -------------------- side files to extract GT (optional) --------------------
def try_load_meta(issue_dir: Path):
    files = list(issue_dir.glob("*.json")) + list(issue_dir.glob("*.yml")) + list(issue_dir.glob("*.yaml"))
    out = []
    for fp in files:
        try:
            if fp.suffix == ".json":
                out.append(json.loads(fp.read_text()))
            else:
                out.append(YAML(typ="safe").load(fp.read_text()))
        except Exception:
            pass
    return out

def extract_root_services(objs):
    names = set()
    def rec(x, path=""):
        if isinstance(x, dict):
            for k,v in x.items():
                kp = (path + "." + str(k)).lower()
                if any(t in kp for t in ["root","cause","service","node","culprit"]):
                    if isinstance(v, str): names.add(v)
                rec(v, kp)
        elif isinstance(x, list):
            for it in x: rec(it, path)
    for o in objs: rec(o, "")
    return {s.strip() for s in names if s and str(s).strip()}

# -------------------- build episodes --------------------
episodes = []
skipped = 0
for mpath in sorted(issue_csvs):
    try:
        issue_cands = candidates_for_csv(mpath)
        if not issue_cands:
            print(f"[skip] no candidates parsed for {mpath.parent.name}")
            skipped += 1; continue

        # choose issue candidate with max alignment
        best = None
        best_count = -1
        best_pack = None
        for name, W_issue in issue_cands:
            A2,B2,rows,info = best_issue_alignment(W_obs, W_issue, issue_tag=f"{mpath.parent.name}/{name}")
            if info["count"] > best_count:
                best_count = info["count"]
                best_pack = (A2,B2,rows,info)

        if best_pack is None:
            print(f"[skip] {mpath.parent.name}: no alignment at all")
            skipped += 1; continue

        A,B,rows,info = best_pack
        if len(rows) < 5 or A.shape[1] < 5 or B.shape[1] < 5:
            print(f"[skip] {mpath.parent.name}: insufficient overlap/time (rows={len(rows)}, Acols={A.shape[1]}, Bcols={B.shape[1]})")
            skipped += 1; continue

        # per-service stats across time
        mu_obs = A.mean(axis=1).values
        mu_int = B.mean(axis=1).values
        dmu    = mu_int - mu_obs
        eps = 1e-6
        logv_obs = np.log(A.var(axis=1).values + eps)
        logv_int = np.log(B.var(axis=1).values + eps)

        tokens = np.stack([mu_obs, mu_int, dmu, logv_obs, logv_int], axis=1).astype(np.float32)

        # ground truth from side files (if any); fallback to top-|Δμ|
        meta_objs = try_load_meta(mpath.parent)
        gt = extract_root_services(meta_objs)

        y = np.zeros(len(rows), dtype=np.float32)
        if gt:
            # match by canonical names against A rows (original obs rows)
            rlow = [canon_name(r) for r in rows]
            for g in gt:
                gl = canon_name(g)
                hits = [i for i,r in enumerate(rlow) if (gl in r or r in gl)]
                for i in hits: y[i] = 1.0
        if y.sum() == 0:
            order = np.argsort(-np.abs(dmu))
            y[order[:min(S_TOP,len(rows))]] = 1.0

        episodes.append({"X": tokens, "y": y, "rows": rows, "src": str(mpath), "align_mode": info["mode"]})
    except Exception as e:
        print("skip:", mpath, "reason:", e)
        skipped += 1

print(f"Episodes built: {len(episodes)} (skipped {skipped})")
if len(episodes) == 0:
    # print some heads to inspect
    print("\n--- DEBUG: noissue head (first 3 rows/3 cols) ---")
    print(W_obs.iloc[:3, :3])
    any_issue = sorted(issue_csvs)[0]
    print(f"\n--- DEBUG: first issue CSV raw head() from {any_issue.parent.name} ---")
    print(read_csv_any(any_issue).head())
    raise RuntimeError("Still no episodes after correlation fallback. See DEBUG heads printed above.")

# -------------------- dataset / loaders --------------------
class EpisodeDS(Dataset):
    def __init__(self, eps):
        self.E = eps
        self.p = min(e["X"].shape[0] for e in eps)  # truncate to common P
        for e in self.E:
            e["X"] = e["X"][:self.p]
            e["y"] = e["y"][:self.p]
            e["rows"] = e["rows"][:self.p]
    def __len__(self): return len(self.E)
    def __getitem__(self, i):
        e = self.E[i]
        return torch.from_numpy(e["X"]), torch.from_numpy(e["y"]), e["rows"], e["src"], e["align_mode"]

def collate(batch):
    X = torch.stack([b[0] for b in batch]).float()
    y = torch.stack([b[1] for b in batch]).float()
    rows = [b[2] for b in batch]
    srcs = [b[3] for b in batch]
    modes = [b[4] for b in batch]
    return X,y,rows,srcs,modes

random.shuffle(episodes)
split = int(0.7*len(episodes))
train_eps, test_eps = episodes[:split], episodes[split:]
train_loader = DataLoader(EpisodeDS(train_eps), batch_size=BATCH_TRAIN, shuffle=True, collate_fn=collate)
test_loader  = DataLoader(EpisodeDS(test_eps),  batch_size=BATCH_TEST,  shuffle=False, collate_fn=collate)

# -------------------- Set Transformer (BCE only) --------------------
class MAB(nn.Module):
    def __init__(self, d, n_heads=8, d_ff=256, dropout=0.0):
        super().__init__()
        self.attn = nn.MultiheadAttention(d, n_heads, batch_first=True, dropout=dropout)
        self.ln1 = nn.LayerNorm(d)
        self.ff  = nn.Sequential(nn.Linear(d, d_ff), nn.GELU(), nn.Linear(d_ff, d))
        self.ln2 = nn.LayerNorm(d)
    def forward(self, Q,K,V):
        H,_ = self.attn(Q,K,V,need_weights=False)
        X = self.ln1(Q + H)
        H2 = self.ff(X)
        return self.ln2(X + H2)

class SAB(nn.Module):
    def __init__(self, d, n_heads=8, d_ff=256, dropout=0.0):
        super().__init__()
        self.mab = MAB(d, n_heads, d_ff, dropout)
    def forward(self, X): return self.mab(X,X,X)

class SetDetector(nn.Module):
    def __init__(self, d_in=5, d=128, depth=3, n_heads=8, d_ff=256, dropout=0.0):
        super().__init__()
        self.enc = nn.Sequential(nn.LayerNorm(d_in), nn.Linear(d_in,d), nn.GELU(), nn.Linear(d,d))
        self.blocks = nn.ModuleList([SAB(d,n_heads,d_ff,dropout) for _ in range(depth)])
        self.head = nn.Sequential(nn.LayerNorm(d), nn.Linear(d,1))
    def forward(self, X):                   # X: (B,P,5)
        H = self.enc(X)
        for blk in self.blocks: H = blk(H)
        return torch.sigmoid(self.head(H).squeeze(-1))  # (B,P)

device = "cuda" if torch.cuda.is_available() else "cpu"
model = SetDetector().to(device)
opt = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=1e-4)

def run_epoch(loader, train=True):
    model.train(train)
    total, n = 0.0, 0
    for X,Y,_,_,_ in loader:
        X,Y = X.to(device), Y.to(device)
        P = model(X)
        loss = F.binary_cross_entropy(P.clamp(1e-6,1-1e-6), Y)
        if train:
            opt.zero_grad(); loss.backward(); opt.step()
        total += loss.item(); n += 1
    return total/max(1,n)

for ep in range(1, EPOCHS+1):
    tr = run_epoch(train_loader, True)
    if ep % 5 == 0 or ep == 1:
        te = run_epoch(test_loader, False)
        print(f"epoch {ep:02d} | train {tr:.4f} | test {te:.4f}")

# -------------------- evaluation + CSVs --------------------
@torch.no_grad()
def evaluate(loader, s_default=S_TOP):
    model.eval()
    rows, details = [], []
    for X,Y,labels,srcs,modes in loader:
        X,Y = X.to(device), Y.to(device)
        P = model(X)
        B,Pdim = P.shape
        for b in range(B):
            pr = P[b].detach().cpu().numpy()
            y  = Y[b].cpu().numpy()
            true = set(np.where(y>0.5)[0].tolist())
            s = max(1, len(true)) if len(true)>0 else s_default
            k_s = min(s, Pdim)
            topS = set(np.argpartition(-pr, k_s)[:k_s].tolist())
            top5 = set(np.argpartition(-pr, min(5,Pdim))[:min(5,Pdim)].tolist())
            rows.append({
                "src": srcs[b],
                "align_mode": modes[b],
                "p": int(Pdim),
                "s_true": int(len(true)),
                "hits@S": int(len(topS & true)),
                "hits@5": int(len(top5 & true)),
            })
            details.append({
                "src": srcs[b],
                "align_mode": modes[b],
                "services": [str(r) for r in labels[b]],
                "true_idx": sorted(list(true)),
                "pred_topS_idx": sorted(list(topS)),
                "pred_top5_idx": sorted(list(top5)),
                "probs": pr.tolist(),
            })
    return pd.DataFrame(rows), pd.DataFrame(details)

summary_df, details_df = evaluate(test_loader)
print("\n=== Test summary (means) ===")
print(summary_df.mean(numeric_only=True).to_string())

summary_df.to_csv("petshop_setxf_summary.csv", index=False)
details_df.to_csv("petshop_setxf_details.csv", index=False)
print("\nSaved: petshop_setxf_summary.csv, petshop_setxf_details.csv")


[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/119.9 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m119.9/119.9 kB[0m [31m12.3 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/788.2 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m788.2/788.2 kB[0m [31m43.1 MB/s[0m eta [36m0:00:00[0m
[?25hUsing scenario: /content/petshop-root-cause-analysis/dataset/temporal_traffic2


  W = W.fillna(method="ffill", axis=1).fillna(method="bfill", axis=1).fillna(0.0)
  W = W.fillna(method="ffill", axis=1).fillna(method="bfill", axis=1).fillna(0.0)
  W = W.fillna(method="ffill", axis=1).fillna(method="bfill", axis=1).fillna(0.0)
  W = W.fillna(method="ffill", axis=1).fillna(method="bfill", axis=1).fillna(0.0)


[obs] picked variant=rows_are_services, shape=(1652, 275)
Found issue metric files: 8
[align] issue_0/rows_are_services: mode=exact, matched_rows=0
[align] issue_0/cols_are_services: mode=correlation, matched_rows=73
[align] issue_1/rows_are_services: mode=exact, matched_rows=0


  W = W.fillna(method="ffill", axis=1).fillna(method="bfill", axis=1).fillna(0.0)
  W = W.fillna(method="ffill", axis=1).fillna(method="bfill", axis=1).fillna(0.0)


[align] issue_1/cols_are_services: mode=correlation, matched_rows=54


  W = W.fillna(method="ffill", axis=1).fillna(method="bfill", axis=1).fillna(0.0)
  W = W.fillna(method="ffill", axis=1).fillna(method="bfill", axis=1).fillna(0.0)


[align] issue_2/rows_are_services: mode=exact, matched_rows=0
[align] issue_2/cols_are_services: mode=correlation, matched_rows=50


  W = W.fillna(method="ffill", axis=1).fillna(method="bfill", axis=1).fillna(0.0)
  W = W.fillna(method="ffill", axis=1).fillna(method="bfill", axis=1).fillna(0.0)


[align] issue_3/rows_are_services: mode=exact, matched_rows=0
[align] issue_3/cols_are_services: mode=correlation, matched_rows=59


  W = W.fillna(method="ffill", axis=1).fillna(method="bfill", axis=1).fillna(0.0)
  W = W.fillna(method="ffill", axis=1).fillna(method="bfill", axis=1).fillna(0.0)


[align] issue_4/rows_are_services: mode=exact, matched_rows=0
[align] issue_4/cols_are_services: mode=correlation, matched_rows=68


  W = W.fillna(method="ffill", axis=1).fillna(method="bfill", axis=1).fillna(0.0)
  W = W.fillna(method="ffill", axis=1).fillna(method="bfill", axis=1).fillna(0.0)


[align] issue_5/rows_are_services: mode=exact, matched_rows=0
[align] issue_5/cols_are_services: mode=correlation, matched_rows=76


  W = W.fillna(method="ffill", axis=1).fillna(method="bfill", axis=1).fillna(0.0)
  W = W.fillna(method="ffill", axis=1).fillna(method="bfill", axis=1).fillna(0.0)


[align] issue_0/rows_are_services: mode=exact, matched_rows=0
[align] issue_0/cols_are_services: mode=correlation, matched_rows=89


  W = W.fillna(method="ffill", axis=1).fillna(method="bfill", axis=1).fillna(0.0)
  W = W.fillna(method="ffill", axis=1).fillna(method="bfill", axis=1).fillna(0.0)


[align] issue_1/rows_are_services: mode=exact, matched_rows=0
[align] issue_1/cols_are_services: mode=correlation, matched_rows=76
Episodes built: 8 (skipped 0)
epoch 01 | train 0.7139 | test 0.3724
epoch 05 | train 0.1371 | test 0.1247
epoch 10 | train 0.1274 | test 0.1135
epoch 15 | train 0.1129 | test 0.1232
epoch 20 | train 0.1131 | test 0.1582
epoch 25 | train 0.1098 | test 0.1379
epoch 30 | train 0.1058 | test 0.1119
epoch 35 | train 0.1041 | test 0.1067
epoch 40 | train 0.1007 | test 0.1131
epoch 45 | train 0.0991 | test 0.1120
epoch 50 | train 0.0967 | test 0.0974
epoch 55 | train 0.0951 | test 0.0905
epoch 60 | train 0.0932 | test 0.0890
epoch 65 | train 0.0916 | test 0.0825
epoch 70 | train 0.0902 | test 0.0789
epoch 75 | train 0.0887 | test 0.0773
epoch 80 | train 0.0873 | test 0.0742
epoch 85 | train 0.0859 | test 0.0731
epoch 90 | train 0.0844 | test 0.0715
epoch 95 | train 0.0830 | test 0.0700
epoch 100 | train 0.0815 | test 0.0691
epoch 105 | train 0.0801 | test 0.0687
e

In [None]:
# ======================= PetShop → Set Transformer root-cause detector (BCE only) =======================
# One cell you can run in Colab.
!pip -q install ruamel.yaml openpyxl

import json, re, random
from pathlib import Path
import numpy as np
import pandas as pd
import torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from ruamel.yaml import YAML

# ---------------------------- config ----------------------------
SEED = 123
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)

# Training / eval
EPOCHS        = 25
LR            = 2e-4
BATCH_TRAIN   = 8
BATCH_TEST    = 8
TOP_S_DEFAULT = 3          # weak-label support size if GT missing
P_MIN_KEEP    = 30         # drop episodes with < P_MIN_KEEP services (after alignment)
SIM_THRESHOLD = 0.6        # correlation-match threshold for row alignment
MAX_ROWS_MATCH = 2000

# Model sizes
D_TOKEN   = 5              # [mu_obs, mu_iss, dmu, logvar_obs, logvar_iss]
D_MODEL   = 128
N_HEADS   = 8
D_FF      = 256
DEPTH     = 3
DROPOUT   = 0.0

# ---------------------------- pull repo & locate data ----------------------------
ROOT = Path("/content")
REPO = ROOT / "petshop-root-cause-analysis"
if not REPO.exists():
    !git -C /content clone -q https://github.com/amazon-science/petshop-root-cause-analysis.git
DATA_ROOT = REPO / "dataset"
assert DATA_ROOT.exists(), "dataset/ not found in the repo."

# ---------------------------- CSV helpers ----------------------------
def read_csv_any(p: Path) -> pd.DataFrame:
    return pd.read_csv(p, engine="python")

def numeric_table(df: pd.DataFrame) -> pd.DataFrame:
    out = df.copy()
    for c in out.columns:
        out[c] = pd.to_numeric(out[c], errors="coerce")
    out = out.dropna(axis=1, how="all").dropna(axis=0, how="all")
    out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
    return out.astype(np.float32)

def candidates_for_csv(csv_path: Path):
    """Return possible (orientation_name, wide_df) pairs where rows≈services, cols≈time."""
    cands = []
    raw = read_csv_any(csv_path)
    # A) 'microservice' column → rows=services
    if "microservice" in raw.columns:
        A = raw.set_index("microservice")
        A = numeric_table(A)
        if A.shape[0] >= 5 and A.shape[1] >= 5:
            cands.append(("rows_are_services", A))
    # B) transpose numeric table (rows=time, cols=services)
    num = numeric_table(raw)
    B = num.T
    B.index = [str(i) for i in B.index]
    if B.shape[0] >= 5 and B.shape[1] >= 5:
        cands.append(("cols_are_services", B))
    # dedup by shape
    uniq, seen = [], set()
    for name, W in cands:
        sig = (name, W.shape[0], W.shape[1])
        if sig not in seen:
            uniq.append((name, W)); seen.add(sig)
    return uniq

# ---------------------------- alignment (exact → canonical → correlation) ----------------------------
def canon_name(s: str) -> str:
    s = str(s).lower()
    s = re.sub(r'arn:aws:[^ ]+', ' ', s)
    s = re.sub(r'aws::[a-z0-9:_\-]+', ' ', s)
    s = re.sub(r'[^a-z0-9]+', ' ', s)
    s = re.sub(r'\b(prod|stage|stg|dev|test|qa)\b', ' ', s)
    s = re.sub(r'\s+', ' ', s).strip()
    return s

def align_exact(A: pd.DataFrame, B: pd.DataFrame):
    inter = A.index.intersection(B.index)
    if len(inter) >= 5:
        return A.loc[inter].copy(), B.loc[inter].copy(), list(inter), {"mode":"exact", "count":len(inter)}
    return A.iloc[[]], B.iloc[[]], [], {"mode":"exact", "count":0}

def align_canonical(A: pd.DataFrame, B: pd.DataFrame):
    mapA, mapB = {}, {}
    for r in A.index:
        k = canon_name(r)
        if k and k not in mapA: mapA[k] = r
    for r in B.index:
        k = canon_name(r)
        if k and k not in mapB: mapB[k] = r
    keys = sorted(set(mapA.keys()) & set(mapB.keys()))
    if len(keys) >= 5:
        A2 = A.loc[[mapA[k] for k in keys]].copy()
        B2 = B.loc[[mapB[k] for k in keys]].copy()
        return A2, B2, [mapA[k] for k in keys], {"mode":"canonical", "count":len(keys)}
    return A.iloc[[]], B.iloc[[]], [], {"mode":"canonical", "count":0}

def l2_normalize_rows(M: np.ndarray, eps=1e-8):
    norms = np.linalg.norm(M, axis=1, keepdims=True)
    norms = np.maximum(norms, eps)
    return M / norms

def align_by_correlation(A: pd.DataFrame, B: pd.DataFrame, threshold=SIM_THRESHOLD, max_rows=MAX_ROWS_MATCH):
    T = min(A.shape[1], B.shape[1])
    if T < 5:
        return A.iloc[[]], B.iloc[[]], [], {"mode":"correlation", "count":0}
    A0 = A.iloc[:, :T].to_numpy(np.float32)
    B0 = B.iloc[:, :T].to_numpy(np.float32)
    # z-score then l2-normalize for cosine
    A0 = l2_normalize_rows(A0 - A0.mean(axis=1, keepdims=True))
    B0 = l2_normalize_rows(B0 - B0.mean(axis=1, keepdims=True))
    sim = A0 @ B0.T
    RA, RB = sim.shape
    used_r = np.zeros(RA, dtype=bool)
    used_c = np.zeros(RB, dtype=bool)
    pairs = []
    flat = np.argsort(sim.ravel())[::-1]
    for idx in flat:
        if len(pairs) >= max_rows: break
        r, c = divmod(idx, RB)
        if used_r[r] or used_c[c]: continue
        if sim[r, c] < threshold: break
        used_r[r] = used_c[c] = True
        pairs.append((r, c))
    if len(pairs) < 5:
        return A.iloc[[]], B.iloc[[]], [], {"mode":"correlation", "count":len(pairs)}
    pairs.sort(key=lambda t: t[0])
    idxA = [r for r,c in pairs]
    idxB = [c for r,c in pairs]
    A2 = A.iloc[idxA].copy()
    B2 = B.iloc[idxB].copy()
    return A2, B2, list(A2.index), {"mode":"correlation", "count":len(pairs)}

def best_align(W_obs: pd.DataFrame, W_issue: pd.DataFrame, tag="issue"):
    trials = []
    trials.append(align_exact(W_obs, W_issue))
    trials.append(align_canonical(W_obs, W_issue))
    trials.append(align_by_correlation(W_obs, W_issue))
    best = max(trials, key=lambda x: x[3]["count"])
    print(f"[align] {tag}: mode={best[3]['mode']}, matched_rows={best[3]['count']}")
    return best

# ---------------------------- label helpers (from side files if possible) ----------------------------
def try_load_meta(issue_dir: Path):
    files = list(issue_dir.glob("*.json")) + list(issue_dir.glob("*.yml")) + list(issue_dir.glob("*.yaml"))
    out = []
    for fp in files:
        try:
            if fp.suffix == ".json":
                out.append(json.loads(fp.read_text()))
            else:
                out.append(YAML(typ="safe").load(fp.read_text()))
        except Exception:
            pass
    return out

def extract_root_services(objs):
    names = set()
    def rec(x, path=""):
        if isinstance(x, dict):
            for k,v in x.items():
                kp = (path + "." + str(k)).lower()
                if any(t in kp for t in ["root","cause","service","culprit","node"]):
                    if isinstance(v, str): names.add(v)
                rec(v, kp)
        elif isinstance(x, list):
            for it in x: rec(it, path)
    for o in objs: rec(o, "")
    return {s.strip() for s in names if s and str(s).strip()}

# ---------------------------- episode builder across all scenarios ----------------------------
episodes = []
skipped  = 0

for scen_dir in sorted(p for p in DATA_ROOT.iterdir() if p.is_dir()):
    noissue_csv = scen_dir / "noissue" / "metrics.csv"
    if not noissue_csv.exists():
        continue
    obs_cands = candidates_for_csv(noissue_csv)
    if not obs_cands:
        print(f"[skip] cannot parse {noissue_csv}")
        continue
    # choose obs candidate with most rows
    obs_cands.sort(key=lambda z: z[1].shape[0], reverse=True)
    W_obs = obs_cands[0][1]

    issue_csvs = [p for p in scen_dir.rglob("metrics.csv") if "noissue" not in p.parts]
    for ic in sorted(issue_csvs):
        try:
            issue_cands = candidates_for_csv(ic)
            if not issue_cands:
                skipped += 1; continue
            # choose best aligned
            best = None; best_count = -1; best_pack = None
            for name, W_issue in issue_cands:
                A,B,rows,info = best_align(W_obs, W_issue, tag=f"{scen_dir.name}/{ic.parent.name}/{name}")
                if info["count"] > best_count:
                    best_count = info["count"]; best_pack = (A,B,rows,info)
            A,B,rows,info = best_pack
            # basic filter
            if len(rows) < P_MIN_KEEP or A.shape[1] < 5 or B.shape[1] < 5:
                print(f"[skip] {ic} (rows={len(rows)}, Acols={A.shape[1]}, Bcols={B.shape[1]})")
                skipped += 1; continue

            # per-service stats across time
            mu_obs = A.mean(axis=1).values
            mu_iss = B.mean(axis=1).values
            dmu    = mu_iss - mu_obs
            eps = 1e-6
            logv_obs = np.log(A.var(axis=1).values + eps)
            logv_iss = np.log(B.var(axis=1).values + eps)
            tokens = np.stack([mu_obs, mu_iss, dmu, logv_obs, logv_iss], axis=1).astype(np.float32)

            # labels (root cause). Prefer side-files; else weak labels by |Δμ|.
            y = np.zeros(len(rows), dtype=np.float32)
            gt = extract_root_services(try_load_meta(ic.parent))
            if gt:
                rlow = [canon_name(r) for r in rows]
                for g in gt:
                    gl = canon_name(g)
                    hits = [i for i,r in enumerate(rlow) if (gl in r or r in gl)]
                    for i in hits: y[i] = 1.0
            if y.sum() == 0:
                order = np.argsort(-np.abs(dmu))
                y[order[:min(TOP_S_DEFAULT, len(rows))]] = 1.0

            episodes.append({"X": tokens, "y": y, "rows": rows, "src": str(ic), "align_mode": info["mode"]})
        except Exception as e:
            print("skip:", ic, "reason:", e)
            skipped += 1

print(f"\nEpisodes built: {len(episodes)} (skipped {skipped})")
assert len(episodes) > 0, "No usable issue episodes were constructed."

# ---------------------------- dataset (truncate to common P per split) ----------------------------
class EpisodeDS(Dataset):
    def __init__(self, eps):
        self.E = eps
        self.p = min(e["X"].shape[0] for e in eps)    # truncate to min P in this split
        # keep only episodes with enough rows
        self.E = [e for e in self.E if e["X"].shape[0] >= self.p]
        for e in self.E:
            e["X"] = e["X"][:self.p]
            e["y"] = e["y"][:self.p]
            e["rows"] = e["rows"][:self.p]
    def __len__(self): return len(self.E)
    def __getitem__(self, i):
        e = self.E[i]
        return torch.from_numpy(e["X"]), torch.from_numpy(e["y"]), e["rows"], e["src"], e["align_mode"]

def collate(batch):
    X = torch.stack([b[0] for b in batch]).float()   # (B,P,5)
    y = torch.stack([b[1] for b in batch]).float()   # (B,P)
    rows = [b[2] for b in batch]; srcs = [b[3] for b in batch]; modes = [b[4] for b in batch]
    return X,y,rows,srcs,modes

random.shuffle(episodes)
split = int(0.7*len(episodes))
train_eps, test_eps = episodes[:split], episodes[split:]
train_loader = DataLoader(EpisodeDS(train_eps), batch_size=BATCH_TRAIN, shuffle=True, collate_fn=collate)
test_loader  = DataLoader(EpisodeDS(test_eps),  batch_size=BATCH_TEST,  shuffle=False, collate_fn=collate)

print(f"Train episodes: {len(train_loader.dataset)} | Test episodes: {len(test_loader.dataset)} | P={train_loader.dataset.p}")

# ---------------------------- Set Transformer (SAB blocks) — FIXED FFN ----------------------------
class MAB(nn.Module):
    def __init__(self, d, n_heads=8, d_ff=256, dropout=0.0):
        super().__init__()
        self.attn = nn.MultiheadAttention(d, n_heads, batch_first=True, dropout=dropout)
        self.ln1  = nn.LayerNorm(d)
        self.ff   = nn.Sequential(
            nn.Linear(d, d_ff),
            nn.GELU(),
            nn.Linear(d_ff, d),     # <- FIX: d_ff → d
            nn.Dropout(dropout),
        )
        self.ln2  = nn.LayerNorm(d)
    def forward(self, Q,K,V):
        H,_ = self.attn(Q,K,V, need_weights=False)
        X = self.ln1(Q + H)
        H2 = self.ff(X)
        return self.ln2(X + H2)

class SAB(nn.Module):
    def __init__(self, d, n_heads=8, d_ff=256, dropout=0.0):
        super().__init__()
        self.mab = MAB(d, n_heads, d_ff, dropout)
    def forward(self, X): return self.mab(X,X,X)

class SetDetector(nn.Module):
    def __init__(self, d_in=D_TOKEN, d=D_MODEL, depth=DEPTH, n_heads=N_HEADS, d_ff=D_FF, dropout=DROPOUT):
        super().__init__()
        self.enc = nn.Sequential(nn.LayerNorm(d_in), nn.Linear(d_in,d), nn.GELU(), nn.Linear(d,d))
        self.blocks = nn.ModuleList([SAB(d, n_heads, d_ff, dropout) for _ in range(depth)])
        self.head = nn.Sequential(nn.LayerNorm(d), nn.Linear(d,1))
    def forward(self, X):
        H = self.enc(X)
        for blk in self.blocks: H = blk(H)
        return torch.sigmoid(self.head(H).squeeze(-1))  # (B,P)

device = "cuda" if torch.cuda.is_available() else "cpu"
model = SetDetector().to(device)
opt = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=1e-4)

def run_epoch(loader, train=True):
    model.train(train)
    total, n = 0.0, 0
    for X,Y,_,_,_ in loader:
        X,Y = X.to(device), Y.to(device)
        P = model(X)  # (B,P)
        loss = F.binary_cross_entropy(P.clamp(1e-6,1-1e-6), Y)
        if train:
            opt.zero_grad(); loss.backward(); opt.step()
        total += loss.item(); n += 1
    return total/max(1,n)

print("\nTraining …")
for ep in range(1, EPOCHS+1):
    tr = run_epoch(train_loader, True)
    te = run_epoch(test_loader, False)
    if ep % 5 == 0 or ep == 1:
        print(f"epoch {ep:02d} | train {tr:.4f} | test {te:.4f}")

# ---------------------------- evaluation & saving ----------------------------
@torch.no_grad()
def evaluate(loader, s_default=TOP_S_DEFAULT):
    model.eval()
    rows, details = [], []
    for X,Y,row_names,srcs,modes in loader:
        X,Y = X.to(device), Y.to(device)
        P = model(X)  # (B,P)
        B,Pdim = P.shape
        for b in range(B):
            pr = P[b].detach().cpu().numpy()
            y  = Y[b].cpu().numpy()
            true = set(np.where(y>0.5)[0].tolist())
            s = max(1, len(true)) if len(true)>0 else s_default
            kS = min(s, Pdim)
            idx_topS = set(np.argpartition(-pr, kS)[:kS].tolist())
            idx_top5 = set(np.argpartition(-pr, min(5,Pdim))[:min(5,Pdim)].tolist())
            rows.append({
                "src": srcs[b],
                "align_mode": modes[b],
                "p": int(Pdim),
                "s_true": int(len(true)),
                "hits@S": int(len(idx_topS & true)),
                "hits@5": int(len(idx_top5 & true)),
            })
            details.append({
                "src": srcs[b],
                "align_mode": modes[b],
                "services": [str(r) for r in row_names[b]],
                "true_idx": sorted(list(true)),
                "pred_topS_idx": sorted(list(idx_topS)),
                "pred_top5_idx": sorted(list(idx_top5)),
                "probs": pr.tolist(),
            })
    return pd.DataFrame(rows), pd.DataFrame(details)

summary_df, details_df = evaluate(test_loader)
print("\n=== Test summary (mean across incidents) ===")
print(summary_df.mean(numeric_only=True).to_string())

# Save CSVs
summary_df.to_csv("/content/petshop_setxf_summary.csv", index=False)
details_df.to_csv("/content/petshop_setxf_details.csv", index=False)
print("\nSaved: /content/petshop_setxf_summary.csv and /content/petshop_setxf_details.csv")

# (Optional) quick download in Colab
try:
    from google.colab import files
    files.download("/content/petshop_setxf_summary.csv")
    files.download("/content/petshop_setxf_details.csv")
except Exception:
    pass


  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)


[align] high_traffic/issue_0/rows_are_services: mode=exact, matched_rows=0
[align] high_traffic/issue_0/cols_are_services: mode=correlation, matched_rows=38


  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)


[align] high_traffic/issue_1/rows_are_services: mode=exact, matched_rows=0
[align] high_traffic/issue_1/cols_are_services: mode=correlation, matched_rows=35


  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)


[align] high_traffic/issue_10/rows_are_services: mode=exact, matched_rows=0
[align] high_traffic/issue_10/cols_are_services: mode=correlation, matched_rows=24
[skip] /content/petshop-root-cause-analysis/dataset/high_traffic/test/issue_10/metrics.csv (rows=24, Acols=273, Bcols=5)


  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)


[align] high_traffic/issue_11/rows_are_services: mode=exact, matched_rows=0
[align] high_traffic/issue_11/cols_are_services: mode=correlation, matched_rows=37


  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)


[align] high_traffic/issue_12/rows_are_services: mode=exact, matched_rows=0
[align] high_traffic/issue_12/cols_are_services: mode=correlation, matched_rows=56


  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)


[align] high_traffic/issue_13/rows_are_services: mode=exact, matched_rows=0
[align] high_traffic/issue_13/cols_are_services: mode=correlation, matched_rows=35


  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)


[align] high_traffic/issue_14/rows_are_services: mode=exact, matched_rows=0
[align] high_traffic/issue_14/cols_are_services: mode=correlation, matched_rows=27
[skip] /content/petshop-root-cause-analysis/dataset/high_traffic/test/issue_14/metrics.csv (rows=27, Acols=273, Bcols=5)


  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)


[align] high_traffic/issue_15/rows_are_services: mode=exact, matched_rows=0
[align] high_traffic/issue_15/cols_are_services: mode=correlation, matched_rows=26
[skip] /content/petshop-root-cause-analysis/dataset/high_traffic/test/issue_15/metrics.csv (rows=26, Acols=273, Bcols=5)


  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)


[align] high_traffic/issue_16/rows_are_services: mode=exact, matched_rows=0
[align] high_traffic/issue_16/cols_are_services: mode=correlation, matched_rows=44


  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)


[align] high_traffic/issue_17/rows_are_services: mode=exact, matched_rows=0
[align] high_traffic/issue_17/cols_are_services: mode=correlation, matched_rows=16
[skip] /content/petshop-root-cause-analysis/dataset/high_traffic/test/issue_17/metrics.csv (rows=16, Acols=273, Bcols=5)


  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)


[align] high_traffic/issue_2/rows_are_services: mode=correlation, matched_rows=5
[align] high_traffic/issue_2/cols_are_services: mode=correlation, matched_rows=53


  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)


[align] high_traffic/issue_3/rows_are_services: mode=exact, matched_rows=0
[align] high_traffic/issue_3/cols_are_services: mode=correlation, matched_rows=69


  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)


[align] high_traffic/issue_4/rows_are_services: mode=exact, matched_rows=0
[align] high_traffic/issue_4/cols_are_services: mode=correlation, matched_rows=30


  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)


[align] high_traffic/issue_5/rows_are_services: mode=exact, matched_rows=0
[align] high_traffic/issue_5/cols_are_services: mode=correlation, matched_rows=69


  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)


[align] high_traffic/issue_6/rows_are_services: mode=exact, matched_rows=0
[align] high_traffic/issue_6/cols_are_services: mode=correlation, matched_rows=46


  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)


[align] high_traffic/issue_7/rows_are_services: mode=exact, matched_rows=0
[align] high_traffic/issue_7/cols_are_services: mode=correlation, matched_rows=63


  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)


[align] high_traffic/issue_8/rows_are_services: mode=exact, matched_rows=0
[align] high_traffic/issue_8/cols_are_services: mode=correlation, matched_rows=30


  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)


[align] high_traffic/issue_9/rows_are_services: mode=exact, matched_rows=0
[align] high_traffic/issue_9/cols_are_services: mode=correlation, matched_rows=59


  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)


[align] high_traffic/issue_0/rows_are_services: mode=exact, matched_rows=0
[align] high_traffic/issue_0/cols_are_services: mode=correlation, matched_rows=36


  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)


[align] high_traffic/issue_1/rows_are_services: mode=exact, matched_rows=0
[align] high_traffic/issue_1/cols_are_services: mode=correlation, matched_rows=51


  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)


[align] high_traffic/issue_2/rows_are_services: mode=exact, matched_rows=0
[align] high_traffic/issue_2/cols_are_services: mode=correlation, matched_rows=39


  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)


[align] high_traffic/issue_3/rows_are_services: mode=exact, matched_rows=0
[align] high_traffic/issue_3/cols_are_services: mode=correlation, matched_rows=52


  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)


[align] high_traffic/issue_4/rows_are_services: mode=exact, matched_rows=0
[align] high_traffic/issue_4/cols_are_services: mode=correlation, matched_rows=32


  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)


[align] high_traffic/issue_5/rows_are_services: mode=exact, matched_rows=0
[align] high_traffic/issue_5/cols_are_services: mode=correlation, matched_rows=48


  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)


[align] high_traffic/issue_6/rows_are_services: mode=exact, matched_rows=0
[align] high_traffic/issue_6/cols_are_services: mode=correlation, matched_rows=44


  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)


[align] high_traffic/issue_7/rows_are_services: mode=exact, matched_rows=0
[align] high_traffic/issue_7/cols_are_services: mode=correlation, matched_rows=45


  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)


[align] low_traffic/issue_0/rows_are_services: mode=exact, matched_rows=0
[align] low_traffic/issue_0/cols_are_services: mode=correlation, matched_rows=35


  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)


[align] low_traffic/issue_1/rows_are_services: mode=exact, matched_rows=0
[align] low_traffic/issue_1/cols_are_services: mode=correlation, matched_rows=27
[skip] /content/petshop-root-cause-analysis/dataset/low_traffic/test/issue_1/metrics.csv (rows=27, Acols=287, Bcols=5)


  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)


[align] low_traffic/issue_10/rows_are_services: mode=exact, matched_rows=0
[align] low_traffic/issue_10/cols_are_services: mode=correlation, matched_rows=34


  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)


[align] low_traffic/issue_11/rows_are_services: mode=exact, matched_rows=0
[align] low_traffic/issue_11/cols_are_services: mode=correlation, matched_rows=48
[align] low_traffic/issue_12/rows_are_services: mode=exact, matched_rows=0
[align] low_traffic/issue_12/cols_are_services: mode=correlation, matched_rows=29
[skip] /content/petshop-root-cause-analysis/dataset/low_traffic/test/issue_12/metrics.csv (rows=29, Acols=287, Bcols=5)


  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)


[align] low_traffic/issue_13/rows_are_services: mode=exact, matched_rows=0
[align] low_traffic/issue_13/cols_are_services: mode=correlation, matched_rows=35
[align] low_traffic/issue_14/rows_are_services: mode=exact, matched_rows=0
[align] low_traffic/issue_14/cols_are_services: mode=correlation, matched_rows=44


  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)


[align] low_traffic/issue_15/rows_are_services: mode=exact, matched_rows=0
[align] low_traffic/issue_15/cols_are_services: mode=correlation, matched_rows=26
[skip] /content/petshop-root-cause-analysis/dataset/low_traffic/test/issue_15/metrics.csv (rows=26, Acols=287, Bcols=5)


  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)


[align] low_traffic/issue_16/rows_are_services: mode=exact, matched_rows=0
[align] low_traffic/issue_16/cols_are_services: mode=correlation, matched_rows=59


  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)


[align] low_traffic/issue_17/rows_are_services: mode=exact, matched_rows=0
[align] low_traffic/issue_17/cols_are_services: mode=correlation, matched_rows=41


  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)


[align] low_traffic/issue_2/rows_are_services: mode=exact, matched_rows=0
[align] low_traffic/issue_2/cols_are_services: mode=correlation, matched_rows=48


  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)


[align] low_traffic/issue_3/rows_are_services: mode=exact, matched_rows=0
[align] low_traffic/issue_3/cols_are_services: mode=correlation, matched_rows=33


  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)


[align] low_traffic/issue_4/rows_are_services: mode=exact, matched_rows=0
[align] low_traffic/issue_4/cols_are_services: mode=correlation, matched_rows=46


  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)


[align] low_traffic/issue_5/rows_are_services: mode=exact, matched_rows=0
[align] low_traffic/issue_5/cols_are_services: mode=correlation, matched_rows=31


  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)


[align] low_traffic/issue_6/rows_are_services: mode=exact, matched_rows=0
[align] low_traffic/issue_6/cols_are_services: mode=correlation, matched_rows=54


  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)


[align] low_traffic/issue_7/rows_are_services: mode=exact, matched_rows=0
[align] low_traffic/issue_7/cols_are_services: mode=correlation, matched_rows=48


  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)


[align] low_traffic/issue_8/rows_are_services: mode=exact, matched_rows=0
[align] low_traffic/issue_8/cols_are_services: mode=correlation, matched_rows=27
[skip] /content/petshop-root-cause-analysis/dataset/low_traffic/test/issue_8/metrics.csv (rows=27, Acols=287, Bcols=5)


  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)


[align] low_traffic/issue_9/rows_are_services: mode=exact, matched_rows=0
[align] low_traffic/issue_9/cols_are_services: mode=correlation, matched_rows=34


  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)


[align] low_traffic/issue_0/rows_are_services: mode=exact, matched_rows=0
[align] low_traffic/issue_0/cols_are_services: mode=correlation, matched_rows=70


  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)


[align] low_traffic/issue_1/rows_are_services: mode=exact, matched_rows=0
[align] low_traffic/issue_1/cols_are_services: mode=correlation, matched_rows=51


  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)


[align] low_traffic/issue_2/rows_are_services: mode=exact, matched_rows=0
[align] low_traffic/issue_2/cols_are_services: mode=correlation, matched_rows=56


  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)


[align] low_traffic/issue_3/rows_are_services: mode=exact, matched_rows=0
[align] low_traffic/issue_3/cols_are_services: mode=correlation, matched_rows=37


  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)


[align] low_traffic/issue_4/rows_are_services: mode=exact, matched_rows=0
[align] low_traffic/issue_4/cols_are_services: mode=correlation, matched_rows=42


  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)


[align] low_traffic/issue_5/rows_are_services: mode=exact, matched_rows=0
[align] low_traffic/issue_5/cols_are_services: mode=correlation, matched_rows=51


  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)


[align] low_traffic/issue_6/rows_are_services: mode=exact, matched_rows=0
[align] low_traffic/issue_6/cols_are_services: mode=correlation, matched_rows=62


  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)


[align] low_traffic/issue_7/rows_are_services: mode=exact, matched_rows=0
[align] low_traffic/issue_7/cols_are_services: mode=correlation, matched_rows=64


  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)


[align] temporal_traffic1/issue_0/rows_are_services: mode=exact, matched_rows=0
[align] temporal_traffic1/issue_0/cols_are_services: mode=correlation, matched_rows=60


  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)


[align] temporal_traffic1/issue_1/rows_are_services: mode=exact, matched_rows=0
[align] temporal_traffic1/issue_1/cols_are_services: mode=correlation, matched_rows=82


  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)


[align] temporal_traffic1/issue_2/rows_are_services: mode=exact, matched_rows=0
[align] temporal_traffic1/issue_2/cols_are_services: mode=correlation, matched_rows=57


  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)


[align] temporal_traffic1/issue_3/rows_are_services: mode=exact, matched_rows=0
[align] temporal_traffic1/issue_3/cols_are_services: mode=correlation, matched_rows=65


  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)


[align] temporal_traffic1/issue_4/rows_are_services: mode=exact, matched_rows=0
[align] temporal_traffic1/issue_4/cols_are_services: mode=correlation, matched_rows=31


  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)


[align] temporal_traffic1/issue_5/rows_are_services: mode=correlation, matched_rows=5
[align] temporal_traffic1/issue_5/cols_are_services: mode=correlation, matched_rows=75


  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)


[align] temporal_traffic1/issue_0/rows_are_services: mode=exact, matched_rows=0
[align] temporal_traffic1/issue_0/cols_are_services: mode=correlation, matched_rows=56


  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)


[align] temporal_traffic1/issue_1/rows_are_services: mode=exact, matched_rows=0
[align] temporal_traffic1/issue_1/cols_are_services: mode=correlation, matched_rows=81


  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)


[align] temporal_traffic2/issue_0/rows_are_services: mode=exact, matched_rows=0
[align] temporal_traffic2/issue_0/cols_are_services: mode=correlation, matched_rows=66


  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)


[align] temporal_traffic2/issue_1/rows_are_services: mode=exact, matched_rows=0
[align] temporal_traffic2/issue_1/cols_are_services: mode=correlation, matched_rows=50


  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)


[align] temporal_traffic2/issue_2/rows_are_services: mode=exact, matched_rows=0
[align] temporal_traffic2/issue_2/cols_are_services: mode=correlation, matched_rows=41


  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)


[align] temporal_traffic2/issue_3/rows_are_services: mode=exact, matched_rows=0
[align] temporal_traffic2/issue_3/cols_are_services: mode=correlation, matched_rows=53


  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)


[align] temporal_traffic2/issue_4/rows_are_services: mode=exact, matched_rows=0
[align] temporal_traffic2/issue_4/cols_are_services: mode=correlation, matched_rows=44


  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)


[align] temporal_traffic2/issue_5/rows_are_services: mode=exact, matched_rows=0
[align] temporal_traffic2/issue_5/cols_are_services: mode=correlation, matched_rows=70


  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)


[align] temporal_traffic2/issue_0/rows_are_services: mode=exact, matched_rows=0
[align] temporal_traffic2/issue_0/cols_are_services: mode=correlation, matched_rows=57


  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
  out = out.fillna(method="ffill").fillna(method="bfill").fillna(0.0)


[align] temporal_traffic2/issue_1/rows_are_services: mode=exact, matched_rows=0
[align] temporal_traffic2/issue_1/cols_are_services: mode=correlation, matched_rows=69

Episodes built: 60 (skipped 8)
Train episodes: 42 | Test episodes: 18 | P=30

Training …
epoch 01 | train 0.3348 | test 0.2334
epoch 05 | train 0.1946 | test 0.2401
epoch 10 | train 0.1845 | test 0.2066
epoch 15 | train 0.1597 | test 0.2040
epoch 20 | train 0.1738 | test 0.2311
epoch 25 | train 0.1731 | test 0.1767

=== Test summary (mean across incidents) ===
p         30.000000
s_true     1.888889
hits@S     0.777778
hits@5     1.000000

Saved: /content/petshop_setxf_summary.csv and /content/petshop_setxf_details.csv


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [None]:
# ======================= Add synthetic linear-SEM episodes and retrain =======================
import numpy as np
import torch
import pandas as pd

# ---------- 1) Linear SEM utilities ----------
def sample_strict_lower_triangular_B(p: int, exp_degree: float = 3.0, weight_std: float = 0.3, g=None):
    """
    Strictly lower-triangular B (DAG in causal order). (I - B) is triangular with 1s on diag → invertible.
    Edge prob q ≈ exp_degree / (p-1).
    """
    if g is None:
        g = torch.Generator()
    q = min(0.999, max(0.0, exp_degree / max(1, p-1)))
    B = torch.zeros(p, p)
    bern = torch.rand(p, p, generator=g)
    w = torch.randn(p, p, generator=g) * weight_std
    mask = (torch.tril(torch.ones(p, p), diagonal=-1) == 1) & (bern < q)
    B[mask] = w[mask]
    return B  # (p,p)

def sample_D_diag(p: int, low: float = 0.5, high: float = 1.5, g=None):
    if g is None:
        g = torch.Generator()
    sig2 = low + (high - low) * torch.rand(p, generator=g)
    return sig2  # variances (p,)

def simulate_sem(B: torch.Tensor, sig2: torch.Tensor, n: int, delta: torch.Tensor | None, g=None):
    """
    ε ~ N(0, diag(sig2)); interventional ε' = ε + δ.
    X = (I - B)^(-1) ε.
    Returns X as (n,p) float32 tensor.
    """
    if g is None:
        g = torch.Generator()
    p = B.shape[0]
    I = torch.eye(p)
    A = torch.linalg.inv(I - B)        # (p,p)
    std = torch.sqrt(sig2).view(1, p)  # (1,p)
    Z = torch.randn(n, p, generator=g)
    E = Z * std                        # (n,p)
    if delta is not None:
        E = E + delta.view(1, p)
    X = E @ A.T                        # (n,p)
    return X.to(torch.float32)

def build_sem_episode(p=64, n_obs=512, n_int=512, s=3,
                      exp_degree=3.0, delta_low=0.3, delta_high=1.0,
                      sig2_low=0.5, sig2_high=1.5, seed=None):
    """
    Build ONE synthetic episode with tokens and labels:
      - tokens: [mu_obs, mu_int, dmu, logvar_obs, logvar_int]
      - y: 1 on supp(δ) (root causes), else 0
    """
    g = torch.Generator()
    if seed is not None:
        g.manual_seed(int(seed))

    B = sample_strict_lower_triangular_B(p, exp_degree=exp_degree, weight_std=0.3, g=g)
    sig2 = sample_D_diag(p, low=sig2_low, high=sig2_high, g=g)

    # choose support S and magnitudes
    idx = torch.randperm(p, generator=g)[:s]
    delta = torch.zeros(p)
    mags = delta_low + (delta_high - delta_low) * torch.rand(s, generator=g)
    signs = torch.where(torch.rand(s, generator=g) < 0.5, -torch.ones(s), torch.ones(s))
    delta[idx] = signs * mags

    X_obs = simulate_sem(B, sig2, n_obs, delta=None,  g=g)
    X_int = simulate_sem(B, sig2, n_int, delta=delta, g=g)

    mu_obs = X_obs.mean(dim=0)                 # (p,)
    mu_int = X_int.mean(dim=0)
    dmu    = mu_int - mu_obs
    var_obs = X_obs.var(dim=0, unbiased=False) + 1e-6
    var_int = X_int.var(dim=0, unbiased=False) + 1e-6
    logv_obs = torch.log(var_obs)
    logv_int = torch.log(var_int)

    tokens = torch.stack([mu_obs, mu_int, dmu, logv_obs, logv_int], dim=1).cpu().numpy().astype(np.float32)
    y = torch.zeros(p, dtype=torch.float32)
    y[idx] = 1.0
    rows = [f"node_{i}" for i in range(p)]
    return {
        "X": tokens,                 # (p,5)
        "y": y.numpy(),              # (p,)
        "rows": rows,
        "src": f"synth/sem_seed{seed}",
        "align_mode": "synthetic",
    }

def make_sem_episodes(N=100, **kwargs):
    eps = []
    for k in range(N):
        eps.append(build_sem_episode(seed=10_000 + k, **kwargs))
    return eps

# ---------- 2) Add synthetic episodes to your existing 'episodes' list ----------
# Assumes you already built 'episodes' from PetShop earlier in the notebook.
# If not, create an empty list and just use the synthetic part (works standalone).
try:
    _ = episodes
except NameError:
    episodes = []

N_SYNTH   = 100      # how many synthetic incidents to add
P_SYNTH   = 64       # #services in synthetic incidents
S_SUPPORT = 3        # #root causes per synthetic incident
N_OBS     = 512
N_INT     = 512

synth_eps = make_sem_episodes(
    N=N_SYNTH, p=P_SYNTH, n_obs=N_OBS, n_int=N_INT, s=S_SUPPORT,
    exp_degree=3.0, delta_low=0.3, delta_high=1.0,
    sig2_low=0.5, sig2_high=1.5
)

episodes_all = episodes + synth_eps
print(f"Real episodes: {len(episodes)} | Synthetic episodes added: {len(synth_eps)} | Total: {len(episodes_all)}")

# ---------- 3) Rebuild loaders (same collate/dataset classes you already have) ----------
# If your notebook already defines EpisodeDS, collate, SetDetector, run_epoch, etc., this reuses them.

from torch.utils.data import Dataset, DataLoader

class EpisodeDS(Dataset):
    def __init__(self, eps):
        self.E = eps
        self.p = min(e["X"].shape[0] for e in eps)    # truncate to min P in this split
        self.E = [e for e in self.E if e["X"].shape[0] >= self.p]
        for e in self.E:
            e["X"] = e["X"][:self.p]
            e["y"] = e["y"][:self.p]
            e["rows"] = e["rows"][:self.p]
    def __len__(self): return len(self.E)
    def __getitem__(self, i):
        e = self.E[i]
        return torch.from_numpy(e["X"]), torch.from_numpy(e["y"]), e["rows"], e["src"], e["align_mode"]

def collate(batch):
    X = torch.stack([b[0] for b in batch]).float()
    y = torch.stack([b[1] for b in batch]).float()
    rows = [b[2] for b in batch]; srcs = [b[3] for b in batch]; modes = [b[4] for b in batch]
    return X,y,rows,srcs,modes

random.shuffle(episodes_all)
split = int(0.5 * len(episodes_all))
train_eps, test_eps = episodes_all[:split], episodes_all[split:]
train_loader = DataLoader(EpisodeDS(train_eps), batch_size=8, shuffle=True,  collate_fn=collate)
test_loader  = DataLoader(EpisodeDS(test_eps),  batch_size=8, shuffle=False, collate_fn=collate)
print(f"Train episodes: {len(train_loader.dataset)} | Test episodes: {len(test_loader.dataset)} | P={train_loader.dataset.p}")

# ---------- 4) Model (same fixed SetDetector as you used) ----------
class MAB(nn.Module):
    def __init__(self, d, n_heads=8, d_ff=256, dropout=0.0):
        super().__init__()
        self.attn = nn.MultiheadAttention(d, n_heads, batch_first=True, dropout=dropout)
        self.ln1  = nn.LayerNorm(d)
        self.ff   = nn.Sequential(
            nn.Linear(d, d_ff),
            nn.GELU(),
            nn.Linear(d_ff, d),   # fixed: d_ff -> d
            nn.Dropout(dropout),
        )
        self.ln2  = nn.LayerNorm(d)
    def forward(self, Q,K,V):
        H,_ = self.attn(Q,K,V, need_weights=False)
        X = self.ln1(Q + H)
        H2 = self.ff(X)
        return self.ln2(X + H2)

class SAB(nn.Module):
    def __init__(self, d, n_heads=8, d_ff=256, dropout=0.0):
        super().__init__()
        self.mab = MAB(d, n_heads, d_ff, dropout)
    def forward(self, X): return self.mab(X,X,X)

class SetDetector(nn.Module):
    def __init__(self, d_in=5, d=128, depth=3, n_heads=8, d_ff=256, dropout=0.0):
        super().__init__()
        self.enc = nn.Sequential(nn.LayerNorm(d_in), nn.Linear(d_in,d), nn.GELU(), nn.Linear(d,d))
        self.blocks = nn.ModuleList([SAB(d, n_heads, d_ff, dropout) for _ in range(depth)])
        self.head = nn.Sequential(nn.LayerNorm(d), nn.Linear(d,1))
    def forward(self, X):
        H = self.enc(X)
        for blk in self.blocks: H = blk(H)
        return torch.sigmoid(self.head(H).squeeze(-1))  # (B,P)

device = "cuda" if torch.cuda.is_available() else "cpu"
model = SetDetector().to(device)
opt = torch.optim.AdamW(model.parameters(), lr=2e-4, weight_decay=1e-4)

def run_epoch(loader, train=True):
    model.train(train)
    total, n = 0.0, 0
    for X,Y,_,_,_ in loader:
        X,Y = X.to(device), Y.to(device)
        P = model(X)
        loss = F.binary_cross_entropy(P.clamp(1e-6,1-1e-6), Y)
        if train:
            opt.zero_grad(); loss.backward(); opt.step()
        total += loss.item(); n += 1
    return total/max(1,n)

print("\nTraining on (real + synthetic) …")
for ep in range(1, 21):
    tr = run_epoch(train_loader, True)
    te = run_epoch(test_loader, False)
    if ep % 5 == 0 or ep == 1:
        print(f"epoch {ep:02d} | train {tr:.4f} | test {te:.4f}")

# ---------- 5) Evaluate (Hits@S / Hits@5) ----------
@torch.no_grad()
def evaluate(loader, s_default=3):
    model.eval()
    rows, details = [], []
    for X,Y,row_names,srcs,modes in loader:
        X,Y = X.to(device), Y.to(device)
        P = model(X)
        B,Pdim = P.shape
        for b in range(B):
            pr = P[b].detach().cpu().numpy()
            y  = Y[b].cpu().numpy()
            true = set(np.where(y>0.5)[0].tolist())
            s = max(1, len(true)) if len(true)>0 else s_default
            kS = min(s, Pdim)
            idx_topS = set(np.argpartition(-pr, kS)[:kS].tolist())
            idx_top5 = set(np.argpartition(-pr, min(5,Pdim))[:min(5,Pdim)].tolist())
            rows.append({
                "src": srcs[b],
                "align_mode": modes[b],
                "p": int(Pdim),
                "s_true": int(len(true)),
                "hits@S": int(len(idx_topS & true)),
                "hits@5": int(len(idx_top5 & true)),
            })
            details.append({
                "src": srcs[b],
                "align_mode": modes[b],
                "services": [str(r) for r in row_names[b]],
                "true_idx": sorted(list(true)),
                "pred_topS_idx": sorted(list(idx_topS)),
                "pred_top5_idx": sorted(list(idx_top5)),
                "probs": pr.tolist(),
            })
    return pd.DataFrame(rows), pd.DataFrame(details)

summary_df, details_df = evaluate(test_loader)
print("\n=== Test summary (means) ===")
print(summary_df.mean(numeric_only=True).to_string())

# Save artifacts
summary_df.to_csv("/content/mixed_real_plus_sem_summary.csv", index=False)
details_df.to_csv("/content/mixed_real_plus_sem_details.csv", index=False)
print("\nSaved: /content/mixed_real_plus_sem_summary.csv, /content/mixed_real_plus_sem_details.csv")

# Optional Colab downloads
try:
    from google.colab import files
    files.download("/content/mixed_real_plus_sem_summary.csv")
    files.download("/content/mixed_real_plus_sem_details.csv")
except Exception:
    pass


Real episodes: 60 | Synthetic episodes added: 100 | Total: 160
Train episodes: 112 | Test episodes: 48 | P=30

Training on (real + synthetic) …
epoch 01 | train 0.3554 | test 0.2127
epoch 05 | train 0.1570 | test 0.1667
epoch 10 | train 0.1108 | test 0.1318
epoch 15 | train 0.0953 | test 0.1160
epoch 20 | train 0.0885 | test 0.1044

=== Test summary (means) ===
p         30.000000
s_true     1.583333
hits@S     0.958333
hits@5     1.125000

Saved: /content/mixed_real_plus_sem_summary.csv, /content/mixed_real_plus_sem_details.csv


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [None]:
# ===================== PetShop (50/50) + 10 Prompts (real 2..20 + synth → 500) =====================
# Train Set Transformer (BCE only) and evaluate on ORIGINAL test set (no synthetic in test).
# Token per feature i: [mu_obs_i, mu_int_i, dmu_i, logvar_obs_i, logvar_int_i]
# Label y_i: 1 if feature/service i is root cause (if GT unavailable -> heuristic top-|dmu|).

!pip -q install ruamel.yaml

import re, json, random
from pathlib import Path
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from ruamel.yaml import YAML

# -------------------- Repro --------------------
SEED = 123
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)

# -------------------- User settings --------------------
# 50 vs 50 split on ORIGINAL episodes
PROMPTS_M       = list(range(2, 21, 2))     # 10 prompts: 2,4,...,20 real episodes per prompt
PROMPT_SIZE     = 500                       # top up each prompt to 500 episodes using synthetic
EPOCHS          = 500
BATCH_SIZE      = 8
LR              = 2e-4

# Model
D_TOKEN = 5
D_MODEL = 128
N_HEADS = 8
D_FF    = 256
DEPTH   = 3
DROPOUT = 0.0

# Evaluation
TOPK_5  = 5
HEURISTIC_S = 3   # if GT support not found

# Synthetic SEM episodes (only for TRAIN prompts)
P_SYNTH = 64
S_SYNTH = 3
N_OBS   = 512
N_INT   = 512

# Alignment parameters for PetShop
MIN_FEATURES_KEEP = 30

# -------------------- Repo clone --------------------
ROOT = Path("/content")
REPO = ROOT / "petshop-root-cause-analysis"
if not REPO.exists():
    !git -C /content clone -q https://github.com/amazon-science/petshop-root-cause-analysis.git
DATA_ROOT = REPO / "dataset"
assert DATA_ROOT.exists(), "Repo cloned but dataset/ not found."

# -------------------- Helpers: parsing --------------------
def canon_name(s: str) -> str:
    s = str(s).lower()
    s = re.sub(r'arn:aws:[^ ]+', ' ', s)
    s = re.sub(r'aws::[a-z0-9:_\-]+', ' ', s)
    s = re.sub(r'[^a-z0-9]+', ' ', s)
    s = re.sub(r'\b(prod|stage|stg|dev|test|qa)\b', ' ', s)
    s = re.sub(r'\s+', ' ', s).strip()
    return s

def load_metrics_as_time_by_feature(csv_path: Path) -> pd.DataFrame:
    """
    Return df: shape (T, F) where columns are numeric features/services.
    Tries to drop obvious non-feature columns (like microservice/time identifiers).
    """
    raw = pd.read_csv(csv_path, engine="python")
    # Drop known id columns if present
    drop_cols = []
    for c in raw.columns:
        cl = c.lower()
        if cl in ["timestamp","time","date","datetime","ts","microservice","service","index"]:
            drop_cols.append(c)
    # Convert to numeric
    num = raw.copy()
    for c in num.columns:
        num[c] = pd.to_numeric(num[c], errors="coerce")
    # Choose feature cols: enough numeric support
    feat_cols = []
    T = len(num)
    for c in num.columns:
        if c in drop_cols:
            continue
        nnz = num[c].notna().sum()
        if nnz >= max(10, int(0.1*T)):
            feat_cols.append(c)
    if len(feat_cols) == 0:
        # fallback: take all numeric cols
        feat_cols = [c for c in num.columns if num[c].notna().sum() > 0]
    df = num[feat_cols].copy()
    # fill
    df = df.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
    df = df.astype(np.float32)
    return df

def align_feature_columns(df_obs: pd.DataFrame, df_int: pd.DataFrame):
    """
    Align by exact column intersection; if too small, use canonical matching.
    Returns: obs_aligned, int_aligned, aligned_feature_names, align_mode
    """
    inter = [c for c in df_obs.columns if c in set(df_int.columns)]
    if len(inter) >= MIN_FEATURES_KEEP:
        return df_obs[inter], df_int[inter], inter, "exact"

    # canonical match
    map_obs = {}
    for c in df_obs.columns:
        k = canon_name(c)
        if k and k not in map_obs:
            map_obs[k] = c
    map_int = {}
    for c in df_int.columns:
        k = canon_name(c)
        if k and k not in map_int:
            map_int[k] = c
    keys = [k for k in map_obs.keys() if k in map_int]
    if len(keys) >= MIN_FEATURES_KEEP:
        cols_obs = [map_obs[k] for k in keys]
        cols_int = [map_int[k] for k in keys]
        # Use obs names as canonical feature names
        return df_obs[cols_obs], df_int[cols_int], cols_obs, "canonical"

    # final fallback: take top min common by size (may be small)
    inter = [c for c in df_obs.columns if c in set(df_int.columns)]
    if len(inter) == 0:
        return None, None, None, "none"
    return df_obs[inter], df_int[inter], inter, "exact_small"

def try_load_side_meta(issue_dir: Path):
    objs = []
    for fp in list(issue_dir.glob("*.json")) + list(issue_dir.glob("*.yml")) + list(issue_dir.glob("*.yaml")):
        try:
            if fp.suffix == ".json":
                objs.append(json.loads(fp.read_text()))
            else:
                objs.append(YAML(typ="safe").load(fp.read_text()))
        except Exception:
            pass
    return objs

def extract_root_services(meta_objs):
    names = set()
    def rec(x, path=""):
        if isinstance(x, dict):
            for k,v in x.items():
                kp = (path + "." + str(k)).lower()
                if any(t in kp for t in ["root","cause","service","culprit","node"]):
                    if isinstance(v, str):
                        names.add(v)
                    if isinstance(v, list):
                        for vv in v:
                            if isinstance(vv, str):
                                names.add(vv)
                rec(v, kp)
        elif isinstance(x, list):
            for it in x:
                rec(it, path)
    for o in meta_objs:
        rec(o, "")
    return {str(s).strip() for s in names if s and str(s).strip()}

def build_episode_from_pair(df_obs: pd.DataFrame, df_int: pd.DataFrame, feature_names, issue_dir: Path, src: str, align_mode: str):
    """
    tokens: (P,5), y: (P,)
    """
    mu_obs = df_obs.mean(axis=0).values.astype(np.float32)
    mu_int = df_int.mean(axis=0).values.astype(np.float32)
    dmu    = (mu_int - mu_obs).astype(np.float32)

    eps = 1e-6
    logv_obs = np.log(df_obs.var(axis=0).values.astype(np.float32) + eps)
    logv_int = np.log(df_int.var(axis=0).values.astype(np.float32) + eps)

    X = np.stack([mu_obs, mu_int, dmu, logv_obs, logv_int], axis=1).astype(np.float32)  # (P,5)

    # labels
    y = np.zeros(len(feature_names), dtype=np.float32)
    gt = extract_root_services(try_load_side_meta(issue_dir))
    if gt:
        feat_can = [canon_name(f) for f in feature_names]
        for g in gt:
            gc = canon_name(g)
            if not gc:
                continue
            for i, fc in enumerate(feat_can):
                if (gc in fc) or (fc in gc):
                    y[i] = 1.0

    # fallback heuristic if no GT
    if y.sum() == 0:
        order = np.argsort(-np.abs(dmu))
        y[order[:min(HEURISTIC_S, len(y))]] = 1.0
        y_source = "heuristic"
    else:
        y_source = "gt_or_partial"

    return {
        "X": X,
        "y": y,
        "features": list(feature_names),
        "src": src,
        "align_mode": align_mode,
        "y_source": y_source
    }

# -------------------- Build ORIGINAL PetShop episodes --------------------
episodes_real = []
skipped = 0

for scen_dir in sorted([p for p in DATA_ROOT.iterdir() if p.is_dir()]):
    noissue_csv = scen_dir / "noissue" / "metrics.csv"
    if not noissue_csv.exists():
        continue

    try:
        obs_df = load_metrics_as_time_by_feature(noissue_csv)
    except Exception as e:
        skipped += 1
        continue

    issue_csvs = [p for p in scen_dir.rglob("metrics.csv") if "noissue" not in p.parts]
    for ic in sorted(issue_csvs):
        try:
            int_df = load_metrics_as_time_by_feature(ic)
            A,B,feat_names,mode = align_feature_columns(obs_df, int_df)
            if mode == "none" or A is None or B is None or feat_names is None:
                skipped += 1
                continue
            if len(feat_names) < MIN_FEATURES_KEEP:
                skipped += 1
                continue
            ep = build_episode_from_pair(
                A, B, feat_names,
                issue_dir=ic.parent,
                src=str(ic),
                align_mode=mode
            )
            episodes_real.append(ep)
        except Exception:
            skipped += 1

print(f"Original episodes built: {len(episodes_real)} (skipped {skipped})")
assert len(episodes_real) > 0, "No usable PetShop episodes were constructed."

# -------------------- 50/50 split (ORIGINAL only) --------------------
random.shuffle(episodes_real)
mid = len(episodes_real)//2
train_real = episodes_real[:mid]
test_real  = episodes_real[mid:]
print(f"Train(real): {len(train_real)} | Test(real): {len(test_real)}")

# -------------------- Synthetic linear-SEM episode generator --------------------
def sample_strict_lower_triangular_B(p: int, exp_degree: float = 3.0, weight_std: float = 0.3, g=None):
    if g is None: g = torch.Generator()
    q = min(0.999, max(0.0, exp_degree / max(1, p-1)))
    B = torch.zeros(p, p)
    bern = torch.rand(p, p, generator=g)
    w = torch.randn(p, p, generator=g) * weight_std
    mask = (torch.tril(torch.ones(p, p), diagonal=-1) == 1) & (bern < q)
    B[mask] = w[mask]
    return B

def sample_D_diag(p: int, low: float = 0.5, high: float = 1.5, g=None):
    if g is None: g = torch.Generator()
    return low + (high - low) * torch.rand(p, generator=g)

def simulate_sem(B: torch.Tensor, sig2: torch.Tensor, n: int, delta: torch.Tensor | None, g=None):
    if g is None: g = torch.Generator()
    p = B.shape[0]
    I = torch.eye(p)
    A = torch.linalg.inv(I - B)
    std = torch.sqrt(sig2).view(1, p)
    Z = torch.randn(n, p, generator=g)
    E = Z * std
    if delta is not None:
        E = E + delta.view(1, p)  # soft mean shift
    X = E @ A.T
    return X.to(torch.float32)

def build_sem_episode(seed: int, p=P_SYNTH, n_obs=N_OBS, n_int=N_INT, s=S_SYNTH):
    g = torch.Generator().manual_seed(int(seed))
    B = sample_strict_lower_triangular_B(p, exp_degree=3.0, weight_std=0.3, g=g)
    sig2 = sample_D_diag(p, low=0.5, high=1.5, g=g)

    idx = torch.randperm(p, generator=g)[:s]
    delta = torch.zeros(p)
    mags  = 0.3 + (1.0 - 0.3) * torch.rand(s, generator=g)
    # NO random signs: always add positive delta
    delta[idx] = mags

    X_obs = simulate_sem(B, sig2, n_obs, delta=None, g=g)
    X_int = simulate_sem(B, sig2, n_int, delta=delta, g=g)

    mu_obs = X_obs.mean(dim=0)
    mu_int = X_int.mean(dim=0)
    dmu    = mu_int - mu_obs
    var_obs = X_obs.var(dim=0, unbiased=False) + 1e-6
    var_int = X_int.var(dim=0, unbiased=False) + 1e-6

    tokens = torch.stack([mu_obs, mu_int, dmu, torch.log(var_obs), torch.log(var_int)], dim=1).cpu().numpy().astype(np.float32)
    y = torch.zeros(p); y[idx] = 1.0

    return {
        "X": tokens,
        "y": y.numpy().astype(np.float32),
        "features": [f"node_{i}" for i in range(p)],
        "src": f"synth/sem_{seed}",
        "align_mode": "synthetic",
        "y_source": "synthetic_gt"
    }

# -------------------- Build 10 TRAIN prompts (each prompt has 500 episodes) --------------------
rng = np.random.default_rng(SEED)
train_prompts = []
for pi, m_real in enumerate(PROMPTS_M, start=1):
    # sample real episodes from train_real (with replacement if needed)
    ridx = rng.integers(0, len(train_real), size=m_real)
    real_subset = [train_real[i] for i in ridx]

    n_synth = max(0, PROMPT_SIZE - m_real)
    synth_subset = [build_sem_episode(seed=10_000 + 1000*pi + j) for j in range(n_synth)]

    prompt_eps = real_subset + synth_subset
    rng.shuffle(prompt_eps)
    train_prompts.append(prompt_eps)

    print(f"Prompt {pi:02d}: real={m_real}, synth={n_synth}, total={len(prompt_eps)}")

# pool all prompt episodes for training (10 * 500 = 5000)
train_pool = [e for prompt in train_prompts for e in prompt]
print(f"Total training episodes pooled from 10 prompts: {len(train_pool)}")
print(f"Total test episodes (original only): {len(test_real)}")

# -------------------- Dataset + padding collate --------------------
class EpisodeDS(Dataset):
    def __init__(self, eps):
        self.E = eps
    def __len__(self): return len(self.E)
    def __getitem__(self, i):
        e = self.E[i]
        X = torch.from_numpy(e["X"]).float()     # (P,5)
        y = torch.from_numpy(e["y"]).float()     # (P,)
        return X, y, e["features"], e["src"], e["align_mode"], e["y_source"]

def pad_collate(batch):
    # batch: list of (X(P_i,5), y(P_i), features(list), src, mode, y_source)
    Pmax = max(b[0].shape[0] for b in batch)
    B = len(batch)
    Xpad = torch.zeros(B, Pmax, D_TOKEN, dtype=torch.float32)
    ypad = torch.zeros(B, Pmax, dtype=torch.float32)
    mask = torch.ones(B, Pmax, dtype=torch.bool)  # True = PAD
    feats, srcs, modes, ysources, lengths = [], [], [], [], []
    for i,(X,y,f,src,mode,ysrc) in enumerate(batch):
        p = X.shape[0]
        Xpad[i,:p] = X
        ypad[i,:p] = y
        mask[i,:p] = False
        feats.append(f)
        srcs.append(src)
        modes.append(mode)
        ysources.append(ysrc)
        lengths.append(p)
    return Xpad, ypad, mask, feats, srcs, modes, ysources, lengths

train_loader = DataLoader(EpisodeDS(train_pool), batch_size=BATCH_SIZE, shuffle=True, collate_fn=pad_collate)
test_loader  = DataLoader(EpisodeDS(test_real),  batch_size=BATCH_SIZE, shuffle=False, collate_fn=pad_collate)

# -------------------- Set Transformer with padding mask --------------------
class MAB(nn.Module):
    def __init__(self, d, n_heads=8, d_ff=256, dropout=0.0):
        super().__init__()
        self.attn = nn.MultiheadAttention(d, n_heads, batch_first=True, dropout=dropout)
        self.ln1  = nn.LayerNorm(d)
        self.ff   = nn.Sequential(
            nn.Linear(d, d_ff),
            nn.GELU(),
            nn.Linear(d_ff, d),   # important: d_ff -> d
            nn.Dropout(dropout),
        )
        self.ln2 = nn.LayerNorm(d)

    def forward(self, X, pad_mask: torch.Tensor):
        # pad_mask: (B,S) True for pad positions
        H, _ = self.attn(X, X, X, key_padding_mask=pad_mask, need_weights=False)
        X = self.ln1(X + H)
        X = self.ln2(X + self.ff(X))
        # keep pads clean
        X = X.masked_fill(pad_mask.unsqueeze(-1), 0.0)
        return X

class SAB(nn.Module):
    def __init__(self, d, n_heads=8, d_ff=256, dropout=0.0):
        super().__init__()
        self.mab = MAB(d, n_heads, d_ff, dropout)
    def forward(self, X, pad_mask):
        return self.mab(X, pad_mask)

class SetDetector(nn.Module):
    def __init__(self, d_in=5, d=D_MODEL, depth=DEPTH, n_heads=N_HEADS, d_ff=D_FF, dropout=DROPOUT):
        super().__init__()
        self.enc = nn.Sequential(
            nn.LayerNorm(d_in),
            nn.Linear(d_in, d),
            nn.GELU(),
            nn.Linear(d, d),
        )
        self.blocks = nn.ModuleList([SAB(d, n_heads, d_ff, dropout) for _ in range(depth)])
        self.head = nn.Sequential(nn.LayerNorm(d), nn.Linear(d, 1))

    def forward(self, X, pad_mask):
        H = self.enc(X)
        H = H.masked_fill(pad_mask.unsqueeze(-1), 0.0)
        for blk in self.blocks:
            H = blk(H, pad_mask)
        logits = self.head(H).squeeze(-1)  # (B,S)
        probs  = torch.sigmoid(logits)
        probs  = probs.masked_fill(pad_mask, 0.0)
        return probs

model = SetDetector().to(device)
opt = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=1e-4)

def masked_bce(probs, y, pad_mask):
    # pad_mask True=ignore
    loss = F.binary_cross_entropy(probs.clamp(1e-6, 1-1e-6), y, reduction="none")  # (B,S)
    keep = (~pad_mask).float()
    return (loss * keep).sum() / (keep.sum() + 1e-6)

def run_epoch(loader, train=True):
    model.train(train)
    tot, n = 0.0, 0
    for X, y, pad_mask, *_ in loader:
        X = X.to(device); y = y.to(device); pad_mask = pad_mask.to(device)
        probs = model(X, pad_mask)
        loss = masked_bce(probs, y, pad_mask)
        if train:
            opt.zero_grad()
            loss.backward()
            opt.step()
        tot += float(loss.item()); n += 1
    return tot / max(1, n)

print("\nTraining (BCE only) ...")
for ep in range(1, EPOCHS+1):
    tr = run_epoch(train_loader, train=True)
    te = run_epoch(test_loader,  train=False)
    if ep == 1 or ep % 5 == 0:
        print(f"epoch {ep:02d} | train {tr:.4f} | test {te:.4f}")

# -------------------- Evaluation on ORIGINAL test set only --------------------
@torch.no_grad()
def evaluate(loader):
    model.eval()
    summary_rows = []
    details_rows = []
    for X, y, pad_mask, feats, srcs, modes, ysources, lengths in loader:
        X = X.to(device); y = y.to(device); pad_mask = pad_mask.to(device)
        probs = model(X, pad_mask).detach().cpu().numpy()
        y_np  = y.detach().cpu().numpy()
        mask_np = pad_mask.detach().cpu().numpy()
        B, Smax = probs.shape
        for b in range(B):
            L = int(lengths[b])
            pr = probs[b, :L]
            yy = y_np[b, :L]
            true_idx = np.where(yy > 0.5)[0].tolist()
            true_set = set(true_idx)
            s_true = len(true_set)
            s_for_metric = max(1, s_true) if s_true > 0 else HEURISTIC_S
            kS = min(s_for_metric, L)
            k5 = min(TOPK_5, L)

            topS = np.argsort(-pr)[:kS].tolist()
            top5 = np.argsort(-pr)[:k5].tolist()

            hitsS = len(set(topS) & true_set)
            hits5 = len(set(top5) & true_set)

            summary_rows.append({
                "src": srcs[b],
                "align_mode": modes[b],
                "y_source": ysources[b],
                "P": L,
                "s_true": s_true,
                "hits@S": hitsS,
                "hits@5": hits5,
            })
            details_rows.append({
                "src": srcs[b],
                "align_mode": modes[b],
                "y_source": ysources[b],
                "P": L,
                "features": list(feats[b]),
                "true_idx": sorted(true_idx),
                "pred_topS_idx": sorted(topS),
                "pred_top5_idx": sorted(top5),
                "top5_probs": [float(pr[i]) for i in top5],
            })
    return pd.DataFrame(summary_rows), pd.DataFrame(details_rows)

summary_df, details_df = evaluate(test_loader)

print("\n=== Test summary (mean over incidents) ===")
print(summary_df.mean(numeric_only=True).to_string())

# Save CSVs
sum_path = "/content/petshop_prompt500_real2to20_train_summary.csv"
det_path = "/content/petshop_prompt500_real2to20_train_details.csv"
summary_df.to_csv(sum_path, index=False)
details_df.to_csv(det_path, index=False)
print(f"\nSaved:\n  {sum_path}\n  {det_path}")

# Download (Colab)
try:
    from google.colab import files
    files.download(sum_path)
    files.download(det_path)
except Exception:
    pass


Device: cuda


  df = df.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
  df = df.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
  df = df.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
  df = df.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
  df = df.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
  df = df.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
  df = df.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
  df = df.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
  df = df.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
  df = df.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
  df = df.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
  df = df.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
  df = df.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
  df = df.fillna(method="ffill").fillna(method="bfill").fillna(0.0)
  df = df.fillna(method="ffill").fillna(method="

Original episodes built: 68 (skipped 0)
Train(real): 34 | Test(real): 34
Prompt 01: real=2, synth=498, total=500
Prompt 02: real=4, synth=496, total=500
Prompt 03: real=6, synth=494, total=500
Prompt 04: real=8, synth=492, total=500
Prompt 05: real=10, synth=490, total=500
Prompt 06: real=12, synth=488, total=500
Prompt 07: real=14, synth=486, total=500
Prompt 08: real=16, synth=484, total=500
Prompt 09: real=18, synth=482, total=500
Prompt 10: real=20, synth=480, total=500
Total training episodes pooled from 10 prompts: 5000
Total test episodes (original only): 34

Training (BCE only) ...
epoch 01 | train 0.0499 | test 0.2551
epoch 05 | train 0.0354 | test 0.2470
epoch 10 | train 0.0335 | test 0.2411
epoch 15 | train 0.0330 | test 0.2315
epoch 20 | train 0.0320 | test 0.2241
epoch 25 | train 0.0313 | test 0.2360
epoch 30 | train 0.0303 | test 0.2318
epoch 35 | train 0.0291 | test 0.2238
epoch 40 | train 0.0278 | test 0.2249
epoch 45 | train 0.0259 | test 0.2309
epoch 50 | train 0.0234

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [None]:
# ===================== PetShop: sweep fixed m_int across ALL 10 prompts (m_int=4..20 step 4) =====================
# Each run:
#   - 50/50 split ORIGINAL episodes into train_real / test_real
#   - Build 10 prompts, each prompt has:
#         m_int real episodes (sampled from train_real, with replacement)
#       + (500 - m_int) synthetic linear-SEM episodes  ==> total 500
#   - Train Set Transformer (BCE only) on the union of the 10 prompts
#   - Evaluate on test_real ONLY (no synthetic in test)

!pip -q install ruamel.yaml

import re, json, random
from pathlib import Path
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from ruamel.yaml import YAML

# -------------------- Repro --------------------
BASE_SEED = 123
random.seed(BASE_SEED); np.random.seed(BASE_SEED); torch.manual_seed(BASE_SEED)

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)

# -------------------- Sweep settings --------------------
M_INT_VALUES   = [4, 8, 12, 16, 20]   # <- your requested sweep
N_PROMPTS      = 10
PROMPT_SIZE    = 500                 # total episodes per prompt (real + synthetic)
EPOCHS       = 200
BATCH_SIZE     = 16
LR             = 2e-4

# Token per feature i: [mu_obs_i, mu_int_i, dmu_i, logvar_obs_i, logvar_int_i]
D_TOKEN = 5

# Evaluation
TOPK_5 = 5
HEURISTIC_S = 3     # if no GT root-cause label can be extracted, fallback to top-|dmu| of size HEURISTIC_S

# Synthetic SEM episodes (only used in TRAIN prompts)
P_SYNTH = 64
S_SYNTH = 3
N_OBS   = 512
N_INT   = 512

# Alignment threshold
MIN_FEATURES_KEEP = 30

# -------------------- Repo clone --------------------
ROOT = Path("/content")
REPO = ROOT / "petshop-root-cause-analysis"
if not REPO.exists():
    !git -C /content clone -q https://github.com/amazon-science/petshop-root-cause-analysis.git
DATA_ROOT = REPO / "dataset"
assert DATA_ROOT.exists(), "Repo cloned but dataset/ not found."

# -------------------- Helpers --------------------
def canon_name(s: str) -> str:
    s = str(s).lower()
    s = re.sub(r'arn:aws:[^ ]+', ' ', s)
    s = re.sub(r'aws::[a-z0-9:_\-]+', ' ', s)
    s = re.sub(r'[^a-z0-9]+', ' ', s)
    s = re.sub(r'\b(prod|stage|stg|dev|test|qa)\b', ' ', s)
    s = re.sub(r'\s+', ' ', s).strip()
    return s

def load_metrics_as_time_by_feature(csv_path: Path) -> pd.DataFrame:
    raw = pd.read_csv(csv_path, engine="python")
    # drop obvious id cols
    drop_cols = set()
    for c in raw.columns:
        cl = c.lower()
        if cl in ["timestamp","time","date","datetime","ts","microservice","service","index"]:
            drop_cols.add(c)

    num = raw.copy()
    for c in num.columns:
        num[c] = pd.to_numeric(num[c], errors="coerce")

    T = len(num)
    feat_cols = []
    for c in num.columns:
        if c in drop_cols:
            continue
        nnz = num[c].notna().sum()
        if nnz >= max(10, int(0.1*T)):
            feat_cols.append(c)

    if len(feat_cols) == 0:
        feat_cols = [c for c in num.columns if num[c].notna().sum() > 0]

    df = num[feat_cols].copy()
    df = df.ffill().bfill().fillna(0.0).astype(np.float32)
    return df

def align_feature_columns(df_obs: pd.DataFrame, df_int: pd.DataFrame):
    inter = [c for c in df_obs.columns if c in set(df_int.columns)]
    if len(inter) >= MIN_FEATURES_KEEP:
        return df_obs[inter], df_int[inter], inter, "exact"

    # canonical matching
    map_obs = {}
    for c in df_obs.columns:
        k = canon_name(c)
        if k and k not in map_obs:
            map_obs[k] = c
    map_int = {}
    for c in df_int.columns:
        k = canon_name(c)
        if k and k not in map_int:
            map_int[k] = c

    keys = [k for k in map_obs.keys() if k in map_int]
    if len(keys) >= MIN_FEATURES_KEEP:
        cols_obs = [map_obs[k] for k in keys]
        cols_int = [map_int[k] for k in keys]
        # use obs-side names as feature identifiers
        return df_obs[cols_obs], df_int[cols_int], cols_obs, "canonical"

    # fallback
    if len(inter) == 0:
        return None, None, None, "none"
    return df_obs[inter], df_int[inter], inter, "exact_small"

def try_load_side_meta(issue_dir: Path):
    objs = []
    for fp in list(issue_dir.glob("*.json")) + list(issue_dir.glob("*.yml")) + list(issue_dir.glob("*.yaml")):
        try:
            if fp.suffix == ".json":
                objs.append(json.loads(fp.read_text()))
            else:
                objs.append(YAML(typ="safe").load(fp.read_text()))
        except Exception:
            pass
    return objs

def extract_root_services(meta_objs):
    names = set()
    def rec(x, path=""):
        if isinstance(x, dict):
            for k,v in x.items():
                kp = (path + "." + str(k)).lower()
                if any(t in kp for t in ["root","cause","service","culprit","node"]):
                    if isinstance(v, str):
                        names.add(v)
                    if isinstance(v, list):
                        for vv in v:
                            if isinstance(vv, str):
                                names.add(vv)
                rec(v, kp)
        elif isinstance(x, list):
            for it in x:
                rec(it, path)
    for o in meta_objs:
        rec(o, "")
    return {str(s).strip() for s in names if s and str(s).strip()}

def build_episode_from_pair(df_obs: pd.DataFrame, df_int: pd.DataFrame, feature_names, issue_dir: Path, src: str, align_mode: str):
    mu_obs = df_obs.mean(axis=0).values.astype(np.float32)
    mu_int = df_int.mean(axis=0).values.astype(np.float32)
    dmu    = (mu_int - mu_obs).astype(np.float32)

    eps = 1e-6
    logv_obs = np.log(df_obs.var(axis=0).values.astype(np.float32) + eps)
    logv_int = np.log(df_int.var(axis=0).values.astype(np.float32) + eps)

    X = np.stack([mu_obs, mu_int, dmu, logv_obs, logv_int], axis=1).astype(np.float32)  # (P,5)

    # label y from metadata if possible
    y = np.zeros(len(feature_names), dtype=np.float32)
    gt = extract_root_services(try_load_side_meta(issue_dir))
    if gt:
        feat_can = [canon_name(f) for f in feature_names]
        for g in gt:
            gc = canon_name(g)
            if not gc:
                continue
            for i, fc in enumerate(feat_can):
                if (gc in fc) or (fc in gc):
                    y[i] = 1.0

    if y.sum() == 0:
        # fallback: heuristic
        order = np.argsort(-np.abs(dmu))
        y[order[:min(HEURISTIC_S, len(y))]] = 1.0
        y_source = "heuristic"
    else:
        y_source = "gt_or_partial"

    return {
        "X": X,
        "y": y,
        "features": list(feature_names),
        "src": src,
        "align_mode": align_mode,
        "y_source": y_source
    }

# -------------------- Build ORIGINAL PetShop episodes --------------------
episodes_real = []
skipped = 0

for scen_dir in sorted([p for p in DATA_ROOT.iterdir() if p.is_dir()]):
    noissue_csv = scen_dir / "noissue" / "metrics.csv"
    if not noissue_csv.exists():
        continue

    try:
        obs_df = load_metrics_as_time_by_feature(noissue_csv)
    except Exception:
        skipped += 1
        continue

    issue_csvs = [p for p in scen_dir.rglob("metrics.csv") if "noissue" not in p.parts]
    for ic in sorted(issue_csvs):
        try:
            int_df = load_metrics_as_time_by_feature(ic)
            A,B,feat_names,mode = align_feature_columns(obs_df, int_df)
            if mode == "none" or A is None:
                skipped += 1
                continue
            if len(feat_names) < MIN_FEATURES_KEEP:
                skipped += 1
                continue
            ep = build_episode_from_pair(
                A, B, feat_names,
                issue_dir=ic.parent,
                src=str(ic),
                align_mode=mode
            )
            episodes_real.append(ep)
        except Exception:
            skipped += 1

print(f"Original episodes built: {len(episodes_real)} (skipped {skipped})")
assert len(episodes_real) > 0, "No usable PetShop episodes were constructed."

# -------------------- 50/50 split (ORIGINAL only) --------------------
rng = np.random.default_rng(BASE_SEED)
rng.shuffle(episodes_real)
mid = len(episodes_real)//2
train_real = episodes_real[:mid]
test_real  = episodes_real[mid:]
print(f"Train(real): {len(train_real)} | Test(real): {len(test_real)}")

# -------------------- Synthetic linear-SEM episodes (TRAIN only) --------------------
def sample_strict_lower_triangular_B(p: int, exp_degree: float = 3.0, weight_std: float = 0.3, g=None):
    if g is None: g = torch.Generator()
    q = min(0.999, max(0.0, exp_degree / max(1, p-1)))
    B = torch.zeros(p, p)
    bern = torch.rand(p, p, generator=g)
    w = torch.randn(p, p, generator=g) * weight_std
    mask = (torch.tril(torch.ones(p, p), diagonal=-1) == 1) & (bern < q)
    B[mask] = w[mask]
    return B

def sample_D_diag(p: int, low: float = 0.5, high: float = 1.5, g=None):
    if g is None: g = torch.Generator()
    return low + (high - low) * torch.rand(p, generator=g)

def simulate_sem(B: torch.Tensor, sig2: torch.Tensor, n: int, delta: torch.Tensor | None, g=None):
    if g is None: g = torch.Generator()
    p = B.shape[0]
    A = torch.linalg.inv(torch.eye(p) - B)
    E = torch.randn(n, p, generator=g) * torch.sqrt(sig2).view(1, p)
    if delta is not None:
        E = E + delta.view(1, p)  # soft mean shift
    X = E @ A.T
    return X.to(torch.float32)

def build_sem_episode(seed: int, p=P_SYNTH, n_obs=N_OBS, n_int=N_INT, s=S_SYNTH):
    g = torch.Generator().manual_seed(int(seed))
    B = sample_strict_lower_triangular_B(p, exp_degree=3.0, weight_std=0.3, g=g)
    sig2 = sample_D_diag(p, low=0.5, high=1.5, g=g)

    idx = torch.randperm(p, generator=g)[:s]
    delta = torch.zeros(p)
    mags  = 0.3 + (1.0 - 0.3) * torch.rand(s, generator=g)
    # NO random signs: always add positive delta
    delta[idx] = mags

    X_obs = simulate_sem(B, sig2, n_obs, delta=None, g=g)
    X_int = simulate_sem(B, sig2, n_int, delta=delta, g=g)

    mu_obs  = X_obs.mean(dim=0)
    mu_int  = X_int.mean(dim=0)
    dmu     = mu_int - mu_obs
    var_obs = X_obs.var(dim=0, unbiased=False) + 1e-6
    var_int = X_int.var(dim=0, unbiased=False) + 1e-6

    tokens = torch.stack([mu_obs, mu_int, dmu, torch.log(var_obs), torch.log(var_int)], dim=1).cpu().numpy().astype(np.float32)
    y = torch.zeros(p); y[idx] = 1.0

    return {
        "X": tokens,
        "y": y.numpy().astype(np.float32),
        "features": [f"node_{i}" for i in range(p)],
        "src": f"synth/sem_{seed}",
        "align_mode": "synthetic",
        "y_source": "synthetic_gt"
    }

# -------------------- Dataset + padding collate --------------------
class EpisodeDS(torch.utils.data.Dataset):
    def __init__(self, eps):
        self.E = eps
    def __len__(self): return len(self.E)
    def __getitem__(self, i):
        e = self.E[i]
        X = torch.from_numpy(e["X"]).float()
        y = torch.from_numpy(e["y"]).float()
        return X, y, e["features"], e["src"], e["align_mode"], e["y_source"]

def pad_collate(batch):
    Pmax = max(b[0].shape[0] for b in batch)
    B = len(batch)
    Xpad = torch.zeros(B, Pmax, D_TOKEN, dtype=torch.float32)
    ypad = torch.zeros(B, Pmax, dtype=torch.float32)
    pad_mask = torch.ones(B, Pmax, dtype=torch.bool)  # True = PAD
    feats, srcs, modes, ysources, lengths = [], [], [], [], []
    for i,(X,y,f,src,mode,ysrc) in enumerate(batch):
        p = X.shape[0]
        Xpad[i,:p] = X
        ypad[i,:p] = y
        pad_mask[i,:p] = False
        feats.append(f); srcs.append(src); modes.append(mode); ysources.append(ysrc); lengths.append(p)
    return Xpad, ypad, pad_mask, feats, srcs, modes, ysources, lengths

# -------------------- Set Transformer (padding-aware) --------------------
class MAB(nn.Module):
    def __init__(self, d, n_heads=8, d_ff=256, dropout=0.0):
        super().__init__()
        self.attn = nn.MultiheadAttention(d, n_heads, batch_first=True, dropout=dropout)
        self.ln1  = nn.LayerNorm(d)
        self.ff   = nn.Sequential(
            nn.Linear(d, d_ff),
            nn.GELU(),
            nn.Linear(d_ff, d),
            nn.Dropout(dropout),
        )
        self.ln2 = nn.LayerNorm(d)

    def forward(self, X, pad_mask):
        H, _ = self.attn(X, X, X, key_padding_mask=pad_mask, need_weights=False)
        X = self.ln1(X + H)
        X = self.ln2(X + self.ff(X))
        X = X.masked_fill(pad_mask.unsqueeze(-1), 0.0)
        return X

class SetDetector(nn.Module):
    def __init__(self, d_in=5, d=128, depth=3, n_heads=8, d_ff=256, dropout=0.0):
        super().__init__()
        self.enc = nn.Sequential(
            nn.LayerNorm(d_in),
            nn.Linear(d_in, d),
            nn.GELU(),
            nn.Linear(d, d),
        )
        self.blocks = nn.ModuleList([MAB(d, n_heads, d_ff, dropout) for _ in range(depth)])
        self.head = nn.Sequential(nn.LayerNorm(d), nn.Linear(d, 1))

    def forward(self, X, pad_mask):
        H = self.enc(X)
        H = H.masked_fill(pad_mask.unsqueeze(-1), 0.0)
        for blk in self.blocks:
            H = blk(H, pad_mask)
        logits = self.head(H).squeeze(-1)
        probs = torch.sigmoid(logits).masked_fill(pad_mask, 0.0)
        return probs

def masked_bce(probs, y, pad_mask):
    loss = F.binary_cross_entropy(probs.clamp(1e-6, 1-1e-6), y, reduction="none")
    keep = (~pad_mask).float()
    return (loss * keep).sum() / (keep.sum() + 1e-6)

@torch.no_grad()
def evaluate(model, loader):
    model.eval()
    summary_rows, details_rows = [], []
    for X, y, pad_mask, feats, srcs, modes, ysources, lengths in loader:
        X = X.to(device); y = y.to(device); pad_mask = pad_mask.to(device)
        probs = model(X, pad_mask).detach().cpu().numpy()
        y_np  = y.detach().cpu().numpy()

        B, Smax = probs.shape
        for b in range(B):
            L = int(lengths[b])
            pr = probs[b, :L]
            yy = y_np[b, :L]
            true_idx = np.where(yy > 0.5)[0].tolist()
            true_set = set(true_idx)
            s_true = len(true_set)

            # support-sized metric: S = |true| (fallback if 0)
            kS = min(L, max(1, s_true) if s_true > 0 else HEURISTIC_S)
            k5 = min(L, TOPK_5)

            topS = np.argsort(-pr)[:kS].tolist()
            top5 = np.argsort(-pr)[:k5].tolist()

            hitsS = len(set(topS) & true_set)
            hits5 = len(set(top5) & true_set)

            summary_rows.append({
                "P": L,
                "s_true": s_true,
                "hits@S": hitsS,
                "hits@5": hits5,
                "align_mode": modes[b],
                "y_source": ysources[b],
                "src": srcs[b],
            })
            details_rows.append({
                "P": L,
                "features": list(feats[b]),
                "true_idx": sorted(true_idx),
                "pred_topS_idx": sorted(topS),
                "pred_top5_idx": sorted(top5),
                "top5_probs": [float(pr[i]) for i in top5],
                "align_mode": modes[b],
                "y_source": ysources[b],
                "src": srcs[b],
            })
    return pd.DataFrame(summary_rows), pd.DataFrame(details_rows)

# -------------------- Main sweep --------------------
all_summary = []
all_details = []

# fixed test loader (original only)
test_loader = torch.utils.data.DataLoader(EpisodeDS(test_real), batch_size=BATCH_SIZE, shuffle=False, collate_fn=pad_collate)

for m_int in M_INT_VALUES:
    print(f"\n==================== RUN: m_int = {m_int} (fixed for all {N_PROMPTS} prompts) ====================")

    # Build 10 prompts, each has m_int real + (500-m_int) synthetic
    prompts = []
    for pid in range(N_PROMPTS):
        # sample real episodes (with replacement)
        idxs = rng.integers(0, len(train_real), size=m_int)
        real_subset = [train_real[i] for i in idxs]

        n_synth = max(0, PROMPT_SIZE - m_int)
        synth_subset = [build_sem_episode(seed=100000 + 10000*m_int + 1000*pid + j) for j in range(n_synth)]

        eps = real_subset + synth_subset
        rng.shuffle(eps)
        prompts.append(eps)

    # pool all prompt episodes to train (10*500 = 5000)
    train_pool = [e for pmt in prompts for e in pmt]
    train_loader = torch.utils.data.DataLoader(EpisodeDS(train_pool), batch_size=BATCH_SIZE, shuffle=True, collate_fn=pad_collate)

    # init model fresh each run
    model = SetDetector(d_in=D_TOKEN, d=128, depth=3, n_heads=8, d_ff=256, dropout=0.0).to(device)
    opt = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=1e-4)

    # train
    for ep in range(1, EPOCHS+1):
        model.train(True)
        tot = 0.0; n = 0
        for X, y, pad_mask, *_ in train_loader:
            X = X.to(device); y = y.to(device); pad_mask = pad_mask.to(device)
            probs = model(X, pad_mask)
            loss = masked_bce(probs, y, pad_mask)

            opt.zero_grad()
            loss.backward()
            opt.step()

            tot += float(loss.item()); n += 1
        if ep == 1 or ep % 5 == 0:
            # quick test loss only (optional)
            model.eval()
            with torch.no_grad():
                ttot = 0.0; tn = 0
                for X, y, pad_mask, *_ in test_loader:
                    X = X.to(device); y = y.to(device); pad_mask = pad_mask.to(device)
                    probs = model(X, pad_mask)
                    ttot += float(masked_bce(probs, y, pad_mask).item()); tn += 1
            print(f"epoch {ep:02d} | train_bce={tot/max(1,n):.4f} | test_bce={ttot/max(1,tn):.4f}")

    # evaluate on original test only
    summary_df, details_df = evaluate(model, test_loader)
    summary_means = summary_df.mean(numeric_only=True)

    row = {
        "m_int": int(m_int),
        "test_hits@S_mean": float(summary_means["hits@S"]),
        "test_hits@5_mean": float(summary_means["hits@5"]),
        "test_s_true_mean": float(summary_means["s_true"]),
        "n_test_episodes": int(len(summary_df)),
    }
    all_summary.append(row)

    details_df.insert(0, "m_int", int(m_int))
    all_details.append(details_df)

# Save sweep outputs
sweep_summary_df = pd.DataFrame(all_summary).sort_values("m_int")
sweep_details_df = pd.concat(all_details, axis=0, ignore_index=True)

summary_path = "/content/petshop_sweep_summary.csv"
details_path = "/content/petshop_sweep_details.csv"
sweep_summary_df.to_csv(summary_path, index=False)
sweep_details_df.to_csv(details_path, index=False)

print("\n=== SWEEP SUMMARY ===")
print(sweep_summary_df.to_string(index=False))
print(f"\nSaved:\n  {summary_path}\n  {details_path}")

# Download (Colab)
try:
    from google.colab import files
    files.download(summary_path)
    files.download(details_path)
except Exception:
    pass


Device: cuda
Original episodes built: 68 (skipped 0)
Train(real): 34 | Test(real): 34

epoch 01 | train_bce=0.0494 | test_bce=0.2786
epoch 05 | train_bce=0.0268 | test_bce=0.2907
epoch 10 | train_bce=0.0254 | test_bce=0.3244
epoch 15 | train_bce=0.0251 | test_bce=0.2609
epoch 20 | train_bce=0.0247 | test_bce=0.2537
epoch 25 | train_bce=0.0242 | test_bce=0.2531
epoch 30 | train_bce=0.0239 | test_bce=0.2465
epoch 35 | train_bce=0.0231 | test_bce=0.2483
epoch 40 | train_bce=0.0219 | test_bce=0.2505
epoch 45 | train_bce=0.0208 | test_bce=0.2521
epoch 50 | train_bce=0.0192 | test_bce=0.2515
epoch 55 | train_bce=0.0172 | test_bce=0.2873
epoch 60 | train_bce=0.0151 | test_bce=0.2624
epoch 65 | train_bce=0.0133 | test_bce=0.2597
epoch 70 | train_bce=0.0118 | test_bce=0.2530
epoch 75 | train_bce=0.0097 | test_bce=0.2713
epoch 80 | train_bce=0.0098 | test_bce=0.2546
epoch 85 | train_bce=0.0084 | test_bce=0.2605
epoch 90 | train_bce=0.0077 | test_bce=0.2814
epoch 95 | train_bce=0.0084 | test_bce=

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>