
# 04 — Sanity Checks & Quick Stats

Verifies:
- Balanced IDs exist and overlap with JSONLs.
- Sample records preview.
- Class distribution in raw (derived) targets and in the balanced subset.


In [5]:
# 04_check_data.py — robust sanity + coverage + leakage checks

import os, glob, json, re, random
from collections import defaultdict
import pandas as pd
import numpy as np

# ====== CONFIG (keep in sync with 00/01/02/03) ======
BASE_OUT = "/Users/tree/Projects/recommemdation_bank/outputs"

BAL_DIR            = f"{BASE_OUT}/balanced"
BALANCED_PATH      = f"{BAL_DIR}/mbd_targets_balanced.parquet"
TARGETS_RAW_PATH   = f"{BAL_DIR}/targets_raw.parquet"

JSON_TRX_ALL       = f"{BASE_OUT}/json/trx/mbd_all.jsonl"
JSON_GEO_ALL       = f"{BASE_OUT}/json/geo/mbd_all.jsonl"
JSON_MM_ALL        = f"{BASE_OUT}/json/mm/mbd_all.jsonl"

JSON_TRX_BAL       = f"{BASE_OUT}/json/trx/json_balanced_trx.jsonl"
JSON_GEO_BAL       = f"{BASE_OUT}/json/geo/json_balanced_geo.jsonl"
JSON_MM_BAL        = f"{BASE_OUT}/json/mm/json_balanced_mm.jsonl"

TRX_GLOB           = "/Users/tree/Projects/recommemdation_bank/data/mbd_mini/detail/trx/fold=*/part-*.parquet"
GEO_GLOB           = "/Users/tree/Projects/recommemdation_bank/data/mbd_mini/detail/geo/fold=*/part-*.parquet"

FOLDS              = [0,1,2,3,4]

def _exists(p):
    ok = os.path.exists(p)
    print(f"[{'OK' if ok else 'MISS'}] {p}")
    return ok

def _load_jsonl_ids(path):
    """Load unique client_ids from JSONL; return empty set if missing."""
    if not os.path.exists(path):
        print(f"[WARN] JSONL missing: {path}")
        return set()
    ids=set()
    with open(path, "r") as f:
        for line in f:
            try:
                rec=json.loads(line)
                cid=str(rec.get("client_id",""))
                if cid:
                    ids.add(cid)
            except Exception:
                continue
    return ids

def _count_jsonl(path):
    if not os.path.exists(path):
        return 0
    n=0
    with open(path,"r") as f:
        for _ in f: n+=1
    return n

def _read_with_inferred_fold(glob_pattern, need_cols=("client_id","fold")):
    """Read parquet parts; infer 'fold' from path if missing. Returns empty DF if no files."""
    paths = sorted(glob.glob(glob_pattern))
    if not paths:
        print(f"[WARN] No files matched: {glob_pattern}")
        return pd.DataFrame(columns=list(need_cols))
    dfs=[]
    for p in paths:
        try:
            cols_req = [c for c in need_cols if c != "fold"]
            dfp = pd.read_parquet(p, columns=cols_req if cols_req else None)
        except Exception:
            dfp = pd.read_parquet(p)
            dfp = dfp[[c for c in need_cols if c in dfp.columns]]
        if "fold" in need_cols and "fold" not in dfp.columns:
            m = re.search(r"fold=(\d+)", p)
            dfp["fold"] = int(m.group(1)) if m else -1
        if "client_id" in dfp.columns:
            dfp["client_id"] = dfp["client_id"].astype(str)
        dfs.append(dfp)
    if not dfs:
        return pd.DataFrame(columns=list(need_cols))
    return pd.concat(dfs, ignore_index=True)

def _text_length_stats(path, sample=5000):
    """Approx word counts; safe if file missing."""
    if not os.path.exists(path):
        return {"count": 0}
    L=[]
    with open(path) as f:
        lines = f.readlines()
    if not lines:
        return {"count": 0}
    for line in random.sample(lines, min(sample, len(lines))):
        try:
            t = json.loads(line).get("text","")
        except Exception:
            t = ""
        L.append(len(str(t).split()))
    if not L:
        return {"count": 0}
    a=np.array(L)
    return {
        "count": len(L),
        "mean": float(np.mean(a)),
        "p50":  float(np.percentile(a,50)),
        "p90":  float(np.percentile(a,90)),
        "p99":  float(np.percentile(a,99)),
        "max":  int(np.max(a)),
    }

def _sample_json(path, k=2, trim=800):
    """Return k example records (trimmed) from a JSONL, or [] if missing."""
    if not os.path.exists(path):
        return []
    with open(path,"r") as f:
        lines=f.readlines()
    out=[]
    for line in random.sample(lines, min(k,len(lines))):
        try:
            rec=json.loads(line)
        except Exception:
            continue
        rec2 = {
            "client_id": rec.get("client_id"),
            "text": str(rec.get("text","")).replace("\n"," ")[:trim]
        }
        out.append(rec2)
    return out

# ====== 0) Show what's present ======
print("=== Presence check ===")
_ = [_exists(p) for p in [
    BALANCED_PATH, TARGETS_RAW_PATH,
    JSON_TRX_ALL, JSON_TRX_BAL,
    JSON_GEO_ALL, JSON_GEO_BAL,
    JSON_MM_ALL,  JSON_MM_BAL
]]

# ====== 1) Balanced list & targets ======
print("\n=== Balanced IDs & Targets ===")
if os.path.exists(BALANCED_PATH):
    bal = pd.read_parquet(BALANCED_PATH)
    bal_ids = set(bal["client_id"].astype(str))
    print("Balanced clients:", len(bal_ids))
else:
    bal_ids = set()
    print("[WARN] Balanced parquet not found.")

if os.path.exists(TARGETS_RAW_PATH):
    targets = pd.read_parquet(TARGETS_RAW_PATH)
    targets["client_id"] = targets["client_id"].astype(str)
    tcols = [c for c in targets.columns if c.startswith("target_")]
    print("Targets shape:", targets.shape, "| target cols:", tcols)
else:
    targets = pd.DataFrame()
    tcols = []
    print("[WARN] targets_raw parquet not found.")

# ====== 2) Load ids from JSONLs ======
print("\n=== JSONL ID coverage ===")
ids_trx_all = _load_jsonl_ids(JSON_TRX_ALL)
ids_geo_all = _load_jsonl_ids(JSON_GEO_ALL)
ids_mm_all  = _load_jsonl_ids(JSON_MM_ALL)

ids_trx_bal = _load_jsonl_ids(JSON_TRX_BAL)
ids_geo_bal = _load_jsonl_ids(JSON_GEO_BAL)
ids_mm_bal  = _load_jsonl_ids(JSON_MM_BAL)

print(f"#TRX all: {len(ids_trx_all)} | balanced: {len(ids_trx_bal)}")
print(f"#GEO all: {len(ids_geo_all)} | balanced: {len(ids_geo_bal)}")
print(f"#MM  all: {len(ids_mm_all)} | balanced: {len(ids_mm_bal)}")

# ====== 3) Coverage: Union(TRX,GEO) vs MM ======
print("\n=== Union(TRX,GEO) vs MM (balanced) ===")
union_bal = ids_trx_bal | ids_geo_bal
print("Union(TRX,GEO) balanced:", len(union_bal))
if ids_mm_bal:
    print("MM covers union(TRX,GEO)?", union_bal <= ids_mm_bal)
    miss_in_mm = union_bal - ids_mm_bal
    print("Missing from MM (should be 0):", len(miss_in_mm))
else:
    print("[WARN] MM balanced JSON missing or empty; cannot compare.")
    miss_in_mm = set()

# ====== 4) Balanced clients with no TRX & no GEO ======
print("\n=== Balanced clients with no TRX & no GEO ===")
missing_any_modality = bal_ids - union_bal if bal_ids else set()
print("Count:", len(missing_any_modality))
peek = sorted(list(missing_any_modality))[:20]
if peek:
    print("Sample:", peek)

# ====== 5) Fold membership for clients with no TRX/GEO ======
print("\n=== Fold membership for clients with no TRX/GEO ===")
if not targets.empty and missing_any_modality:
    fold_counts = (targets[targets["client_id"].isin(missing_any_modality)]["fold"]
                   .value_counts().sort_index())
    print(fold_counts.to_string())
else:
    print("(skip) No targets or no missing clients.")

# ====== 6) Leakage: same client in multiple folds ======
print("\n=== Leakage checks (client in multiple folds) ===")
trx_ids_df = _read_with_inferred_fold(TRX_GLOB, need_cols=("client_id","fold"))
geo_ids_df = _read_with_inferred_fold(GEO_GLOB, need_cols=("client_id","fold"))

if not trx_ids_df.empty:
    num_trx_multi = (trx_ids_df.groupby("client_id")["fold"].nunique() > 1).sum()
    print("TRX clients in >1 fold:", int(num_trx_multi))
    if num_trx_multi > 0:
        offenders = (trx_ids_df.groupby("client_id")["fold"].nunique() > 1)
        print("Sample TRX offenders:", offenders[offenders].index.tolist()[:20])
else:
    print("[WARN] TRX source empty; skip leakage check.")

if not geo_ids_df.empty:
    num_geo_multi = (geo_ids_df.groupby("client_id")["fold"].nunique() > 1).sum()
    print("GEO clients in >1 fold:", int(num_geo_multi))
    if num_geo_multi > 0:
        offenders = (geo_ids_df.groupby("client_id")["fold"].nunique() > 1)
        print("Sample GEO offenders:", offenders[offenders].index.tolist()[:20])
else:
    print("[WARN] GEO source empty; skip leakage check.")

# ====== 7) Class distribution (raw + balanced subset) ======
print("\n=== Class distributions (raw targets vs balanced subset) ===")
if not targets.empty and tcols:
    print("Raw targets:")
    for t in tcols:
        vc = targets[t].value_counts().to_dict()
        print(f"  {t}: {vc}")
    if bal_ids:
        t_bal = targets[targets["client_id"].isin(bal_ids)]
        print("Balanced subset:")
        for t in tcols:
            vc = t_bal[t].value_counts().to_dict()
            print(f"  {t}: {vc}")
else:
    print("(skip) No targets found or no target_* columns.")

# ====== 8) Text length stats (balanced JSONLs) ======
print("\n=== Text length stats (balanced JSONLs) ===")
print("TRX:", _text_length_stats(JSON_TRX_BAL))
print("GEO:", _text_length_stats(JSON_GEO_BAL))
print("MM :", _text_length_stats(JSON_MM_BAL))

# ====== 9) Peek a couple of balanced records (trimmed) ======
print("\n=== Sample balanced records (trimmed) ===")
print("TRX samples:", _sample_json(JSON_TRX_BAL, k=2))
print("GEO samples:", _sample_json(JSON_GEO_BAL, k=2))
print("MM  samples:", _sample_json(JSON_MM_BAL, k=2))

=== Presence check ===
[OK] /Users/tree/Projects/recommemdation_bank/outputs/balanced/mbd_targets_balanced.parquet
[OK] /Users/tree/Projects/recommemdation_bank/outputs/balanced/targets_raw.parquet
[OK] /Users/tree/Projects/recommemdation_bank/outputs/json/trx/mbd_all.jsonl
[OK] /Users/tree/Projects/recommemdation_bank/outputs/json/trx/json_balanced_trx.jsonl
[OK] /Users/tree/Projects/recommemdation_bank/outputs/json/geo/mbd_all.jsonl
[OK] /Users/tree/Projects/recommemdation_bank/outputs/json/geo/json_balanced_geo.jsonl
[OK] /Users/tree/Projects/recommemdation_bank/outputs/json/mm/mbd_all.jsonl
[OK] /Users/tree/Projects/recommemdation_bank/outputs/json/mm/json_balanced_mm.jsonl

=== Balanced IDs & Targets ===
Balanced clients: 2132
Targets shape: (100224, 6) | target cols: ['target_1', 'target_2', 'target_3', 'target_4']

=== JSONL ID coverage ===
#TRX all: 98721 | balanced: 2118
#GEO all: 72573 | balanced: 1623
#MM  all: 99647 | balanced: 2127

=== Union(TRX,GEO) vs MM (balanced) ===


In [7]:
# Build labels aligned to json_balanced_mm.jsonl and attach real folds from TRX/GEO

import re, glob, json
import pandas as pd

# ====== CONFIG (match 01–04) ======
BASE_OUT = "/Users/tree/Projects/recommemdation_bank/outputs"
JSON_MM_BAL   = f"{BASE_OUT}/json/mm/json_balanced_mm.jsonl"
TARGETS_RAW   = f"{BASE_OUT}/balanced/targets_raw.parquet"
OUT_LABELS    = f"{BASE_OUT}/balanced/labels_mm_folded.parquet"

TRX_GLOB = "/Users/tree/Projects/recommemdation_bank/data/mbd_mini/detail/trx/fold=*/part-*.parquet"
GEO_GLOB = "/Users/tree/Projects/recommemdation_bank/data/mbd_mini/detail/geo/fold=*/part-*.parquet"
FOLDS    = [0,1,2,3,4]

def load_mm_ids(path):
    ids=set()
    with open(path) as f:
        for line in f:
            rec = json.loads(line)
            cid = str(rec.get("client_id",""))
            if cid:
                ids.add(cid)
    return ids

def read_fold_map(glob_pat):
    files = sorted(glob.glob(glob_pat))
    rows=[]
    for p in files:
        # infer fold from path
        m = re.search(r"fold=(\d+)", p)
        fold = int(m.group(1)) if m else -1
        # read only client_id
        try:
            df = pd.read_parquet(p, columns=["client_id"])
        except Exception:
            df = pd.read_parquet(p)
            df = df[["client_id"]]
        df["client_id"] = df["client_id"].astype(str)
        df["fold"] = fold
        rows.append(df)
    if not rows:
        return pd.DataFrame(columns=["client_id","fold"])
    df_all = pd.concat(rows, ignore_index=True)
    # choose most common fold per client
    return (df_all.groupby(["client_id","fold"]).size()
                  .reset_index(name="n")
                  .sort_values(["client_id","n"], ascending=[True,False])
                  .drop_duplicates("client_id")[["client_id","fold"]])

# 1) ids that actually have MM text
mm_ids = load_mm_ids(JSON_MM_BAL)
print("MM balanced ids:", len(mm_ids))

# 2) base targets (drop to per-client unique)
t = pd.read_parquet(TARGETS_RAW)
t["client_id"] = t["client_id"].astype(str)
target_cols = [c for c in t.columns if c.startswith("target_")]
t = t[["client_id"] + target_cols].drop_duplicates("client_id")

# keep only those in MM
labels = t[t["client_id"].isin(mm_ids)].copy()
print("Labels (no fold yet):", len(labels))

# 3) fold map from TRX + GEO (majority vote)
trx_map = read_fold_map(TRX_GLOB)
geo_map = read_fold_map(GEO_GLOB)
fold_map = pd.concat([trx_map, geo_map], ignore_index=True)
if not fold_map.empty:
    fold_map = (fold_map.groupby(["client_id","fold"]).size()
                        .reset_index(name="n")
                        .sort_values(["client_id","n"], ascending=[True,False])
                        .drop_duplicates("client_id")[["client_id","fold"]])

# 4) attach fold; filter to valid folds
labels = labels.merge(fold_map, on="client_id", how="left")
missing = labels["fold"].isna().sum()
print("Missing fold after merge:", int(missing))

labels = labels[labels["fold"].isin(FOLDS)].copy()
print("Final label rows:", len(labels))
print("Fold counts:", labels["fold"].value_counts().sort_index().to_dict())
print("Targets:", target_cols)

# 5) save
labels.to_parquet(OUT_LABELS, index=False)
print("Wrote:", OUT_LABELS)

MM balanced ids: 2127
Labels (no fold yet): 2127
Missing fold after merge: 0
Final label rows: 2127
Fold counts: {0: 425, 1: 423, 2: 419, 3: 446, 4: 414}
Targets: ['target_1', 'target_2', 'target_3', 'target_4']
Wrote: /Users/tree/Projects/recommemdation_bank/outputs/balanced/labels_mm_folded.parquet
