
# 03 — TRX+GEO (Concat) → JSONL + Balanced Filter

- Reads both TRX and GEO; creates per-client `<TRX>...` + `<GEO>...` concatenated text.
- Outputs per-fold JSONL and `json_balanced_mm.jsonl`.


In [1]:
# ====== CONFIG (keep in sync with 00/01/02) ======
import os, glob, json, re, math
import pandas as pd
import numpy as np
from datetime import datetime

TRX_GLOB = "/Users/tree/Projects/recommemdation_bank/data/mbd_mini/detail/trx/fold=*/part-*.parquet"
GEO_GLOB = "/Users/tree/Projects/recommemdation_bank/data/mbd_mini/detail/geo/fold=*/part-*.parquet"
FOLDS = [0,1,2,3,4]
BASE_OUT = "/Users/tree/Projects/recommemdation_bank/outputs"
BALANCED_PATH = f"{BASE_OUT}/balanced/mbd_targets_balanced.parquet"

OUT_DIR = f"{BASE_OUT}/json/mm"
os.makedirs(OUT_DIR, exist_ok=True)

# ====== HELPERS (same behavior as 01/02) ======
def pretty_date(ts):
    # robust YYYY-MM-DD for sec/ms/ts/string
    if pd.isna(ts):
        return ""
    if isinstance(ts, (np.integer, int, float)):
        s = float(ts)
        if s > 1e12:      # ms
            dt = datetime.utcfromtimestamp(s/1000.0)
        elif s > 1e10:    # ambiguous -> assume ms
            dt = datetime.utcfromtimestamp(s/1000.0)
        else:             # sec
            dt = datetime.utcfromtimestamp(s)
        return dt.strftime("%Y-%m-%d")
    if isinstance(ts, (pd.Timestamp, datetime)):
        return pd.to_datetime(ts).strftime("%Y-%m-%d")
    try:
        return pd.to_datetime(ts).strftime("%Y-%m-%d")
    except Exception:
        return ""

def read_trx_with_inferred_fold(glob_pattern):
    files = sorted(glob.glob(glob_pattern))
    if not files:
        raise FileNotFoundError(f"No TRX parquet matched: {glob_pattern}")
    dfs = []
    for p in files:
        dfp = pd.read_parquet(p)
        if "fold" not in dfp.columns:
            m = re.search(r"fold=(\d+)", p)
            dfp["fold"] = int(m.group(1)) if m else -1
        if "client_id" in dfp.columns:
            dfp["client_id"] = dfp["client_id"].astype(str)
        dfs.append(dfp)
    return pd.concat(dfs, ignore_index=True)

def read_geo_with_inferred_fold(glob_pattern):
    files = sorted(glob.glob(glob_pattern))
    if not files:
        raise FileNotFoundError(f"No GEO parquet matched: {glob_pattern}")
    dfs = []
    for p in files:
        # read only needed columns if possible
        try:
            dfp = pd.read_parquet(p, columns=["client_id","event_time","geohash_4","geohash_5","geohash_6"])
        except Exception:
            dfp = pd.read_parquet(p)
        if "fold" not in dfp.columns:
            m = re.search(r"fold=(\d+)", p)
            dfp["fold"] = int(m.group(1)) if m else -1
        if "client_id" in dfp.columns:
            dfp["client_id"] = dfp["client_id"].astype(str)
        dfs.append(dfp)
    return pd.concat(dfs, ignore_index=True)

def trx_to_text(df):
    # expects columns: client_id, event_time, event_type, amount, src_type32 (if present)
    cols = [c for c in ["event_time","event_type","amount","src_type32"] if c in df.columns]
    d = df[cols].copy()
    # cap to last 256 by event_time if available
    if "event_time" in d.columns:
        d = d.sort_values("event_time")
    d = d.tail(256)
    # transforms
    if "amount" in d.columns:
        d["amount"] = pd.to_numeric(d["amount"], errors="coerce")
        d["amount_log10"] = d["amount"].apply(lambda x: "" if pd.isna(x) or x<=0 else f"{math.log10(x):.2f}")
    if "event_time" in d.columns:
        d["date"] = d["event_time"].apply(pretty_date)
    # assemble rows
    rows = []
    for _, r in d.iterrows():
        parts = []
        if "date" in r: parts.append(str(r["date"]))
        if "event_type" in r: parts.append(f"t{int(r['event_type'])}" if pd.notna(r["event_type"]) else "t")
        if "amount_log10" in r and r["amount_log10"]!="": parts.append(f"a{r['amount_log10']}")
        if "src_type32" in r and pd.notna(r["src_type32"]): parts.append(f"s{int(r['src_type32'])}")
        rows.append(" ".join(parts))
    body = "\n".join(rows)
    return f"<TRX>\n{body}\n</TRX>"

def _safe_tag(val, prefix):
    if pd.isna(val):
        return ""
    try:
        return f"{prefix}{int(val)}"
    except Exception:
        return f"{prefix}{str(val)}"

def dedupe_consecutive(df, cols):
    """
    Drop consecutive duplicates within THIS frame (no grouping).
    geo_to_text() is invoked per client, so grouping is unnecessary.
    """
    if df.empty or not cols:
        return df
    prev = df[cols].shift()
    keep = prev.isna().any(axis=1) | (df[cols] != prev).any(axis=1)
    return df.loc[keep]

def geo_to_text(df):
    # expects a single client's rows; columns may include: event_time, geohash_4/5/6
    cols = [c for c in ["event_time","geohash_4","geohash_5","geohash_6"] if c in df.columns]
    d = df[cols].copy()

    if "event_time" in d.columns:
        d = d.sort_values("event_time")

    # consecutive duplicate collapse on geohash columns
    dedupe_cols = [c for c in ["geohash_4","geohash_5","geohash_6"] if c in d.columns]
    d = dedupe_consecutive(d, dedupe_cols)

    # cap to last 64 rows
    d = d.tail(64)

    if "event_time" in d.columns:
        d["date"] = d["event_time"].apply(pretty_date)

    rows = []
    for _, r in d.iterrows():
        parts = []
        if "date" in r:
            parts.append(str(r["date"]))
        if "geohash_4" in r:
            tag = _safe_tag(r["geohash_4"], "g4")
            if tag:
                parts.append(tag)
        if "geohash_5" in r:
            tag = _safe_tag(r["geohash_5"], "g5")
            if tag:
                parts.append(tag)
        if "geohash_6" in r:
            tag = _safe_tag(r["geohash_6"], "g6")
            if tag:
                parts.append(tag)
        rows.append(" ".join(parts))
    body = "\n".join(rows)
    return f"<GEO>\n{body}\n</GEO>"

# ====== BUILDERS ======
def write_jsonl_per_fold():
    trx = read_trx_with_inferred_fold(TRX_GLOB)
    geo = read_geo_with_inferred_fold(GEO_GLOB)

    # debug coverage
    print("Discovered folds (TRX):", sorted(trx["fold"].unique()))
    print("Discovered folds (GEO):", sorted(geo["fold"].unique()))
    bal = pd.read_parquet(BALANCED_PATH)
    bal_ids = set(bal["client_id"].astype(str))
    trx_ids = set(trx["client_id"].astype(str))
    geo_ids = set(geo["client_id"].astype(str))
    print("Balanced∩TRX:", len(bal_ids & trx_ids))
    print("Balanced∩GEO:", len(bal_ids & geo_ids))

    out_all = []

    for fold in sorted(set(trx["fold"].unique()) | set(geo["fold"].unique())):
        if fold not in FOLDS:
            continue
        dtr = trx[trx["fold"] == fold].copy()
        dge = geo[geo["fold"] == fold].copy()

        # precompute per-client text
        tr_map = {}
        for cid, g in dtr.groupby("client_id"):
            tr_map[str(cid)] = trx_to_text(g)
        ge_map = {}
        for cid, g in dge.groupby("client_id"):
            ge_map[str(cid)] = geo_to_text(g)

        all_ids = set(tr_map.keys()) | set(ge_map.keys())
        parts = []
        for cid in sorted(all_ids):
            t = tr_map.get(cid, "<TRX>\n</TRX>")
            g = ge_map.get(cid, "<GEO>\n</GEO>")
            parts.append({"client_id": cid, "text": f"{t}\n{g}"})
            out_all.append(parts[-1])

        out_path = f"{OUT_DIR}/mbd_fold_{fold}.jsonl"
        with open(out_path, "w") as f:
            for rec in parts:
                f.write(json.dumps(rec, ensure_ascii=False) + "\n")
        print("Wrote", out_path, len(parts))

    # combined
    out_path_all = f"{OUT_DIR}/mbd_all.jsonl"
    with open(out_path_all, "w") as f:
        for rec in out_all:
            f.write(json.dumps(rec, ensure_ascii=False) + "\n")
    print("Wrote", out_path_all, len(out_all))

def filter_by_balanced():
    bal = pd.read_parquet(BALANCED_PATH)
    ids = set(bal["client_id"].astype(str).tolist())
    src = f"{OUT_DIR}/mbd_all.jsonl"
    dst = f"{OUT_DIR}/json_balanced_mm.jsonl"
    kept = 0
    with open(src, "r") as fin, open(dst, "w") as fout:
        for line in fin:
            rec = json.loads(line)
            if rec["client_id"] in ids:
                fout.write(line); kept += 1
    print("Balanced MM json written:", dst, "rows:", kept)

# ====== RUN ======
write_jsonl_per_fold()
filter_by_balanced()

Discovered folds (TRX): [np.int64(0), np.int64(1), np.int64(2), np.int64(3), np.int64(4)]
Discovered folds (GEO): [np.int64(0), np.int64(1), np.int64(2), np.int64(3), np.int64(4)]
Balanced∩TRX: 2118
Balanced∩GEO: 1623
Wrote /Users/tree/Projects/recommemdation_bank/outputs/json/mm/mbd_fold_0.jsonl 20217
Wrote /Users/tree/Projects/recommemdation_bank/outputs/json/mm/mbd_fold_1.jsonl 19783
Wrote /Users/tree/Projects/recommemdation_bank/outputs/json/mm/mbd_fold_2.jsonl 19807
Wrote /Users/tree/Projects/recommemdation_bank/outputs/json/mm/mbd_fold_3.jsonl 19885
Wrote /Users/tree/Projects/recommemdation_bank/outputs/json/mm/mbd_fold_4.jsonl 19955
Wrote /Users/tree/Projects/recommemdation_bank/outputs/json/mm/mbd_all.jsonl 99647
Balanced MM json written: /Users/tree/Projects/recommemdation_bank/outputs/json/mm/json_balanced_mm.jsonl rows: 2127
