In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
import networkx as nx


In [2]:
data=pd.read_csv("SAML-D.csv")

def split_by_components_stratified(
    df,
    sender_col="Sender_account",
    receiver_col="Receiver_account",
    label_col="is_laundering",
    train_frac=0.8,
    random_state=None,
):
    """
    Split transactions into disjoint account sets by connected components,
    approximately preserving:
      - total transaction ratio ≈ train_frac : (1 - train_frac)
      - class balance (is_laundering proportion)
    """

    rng = np.random.default_rng(random_state)

    # 1️⃣ Build the graph and find connected components
    G = nx.Graph()
    G.add_edges_from(df[[sender_col, receiver_col]].itertuples(index=False, name=None))
    components = list(nx.connected_components(G))

    # 2️⃣ Compute per-component statistics
    comp_data = []
    for comp_id, comp in enumerate(components):
        mask = df[sender_col].isin(comp) & df[receiver_col].isin(comp)
        sub = df.loc[mask]
        n_tx = len(sub)
        p = sub[label_col].mean() if n_tx > 0 else 0
        comp_data.append((comp_id, n_tx, p, comp))
    comp_df = pd.DataFrame(comp_data, columns=["comp_id", "n_tx", "pos_rate", "accounts"])

    total_tx = comp_df["n_tx"].sum()
    total_pos = (comp_df["n_tx"] * comp_df["pos_rate"]).sum()
    overall_pos_rate = total_pos / total_tx

    # 3️⃣ Sort components by positive rate (helps balance class ratio)
    comp_df = comp_df.sample(frac=1, random_state=random_state).sort_values("pos_rate").reset_index(drop=True)

    # 4️⃣ Greedy assignment to maintain balance
    train_comps, test_comps = [], []
    train_tx, test_tx = 0, 0
    train_pos, test_pos = 0.0, 0.0
    target_tx = train_frac * total_tx

    for _, row in comp_df.iterrows():
        # Predict how adding this component affects positive rate and size
        if train_tx < target_tx:
            # Which side keeps pos_rate closer to overall?
            new_train_pos_rate = (train_pos + row["pos_rate"] * row["n_tx"]) / (train_tx + row["n_tx"] + 1e-9)
            new_test_pos_rate  = (test_pos) / (test_tx + 1e-9) if test_tx > 0 else overall_pos_rate
            diff_if_train = abs(new_train_pos_rate - overall_pos_rate)
            diff_if_test = abs(new_test_pos_rate - overall_pos_rate)
            print("diff if train is", diff_if_train)
            print("diff if test is", diff_if_test)
            # Assign greedily by balancing both size and class
            if diff_if_train <= diff_if_test or (train_tx + row["n_tx"]) <= target_tx:
                train_comps.append(row["comp_id"])
                train_tx += row["n_tx"]
                train_pos += row["pos_rate"] * row["n_tx"]
            else:
                test_comps.append(row["comp_id"])
                test_tx += row["n_tx"]
                test_pos += row["pos_rate"] * row["n_tx"]
        else:
            test_comps.append(row["comp_id"])
            test_tx += row["n_tx"]
            test_pos += row["pos_rate"] * row["n_tx"]

    # 5️⃣ Extract transactions
    comp_to_accounts = {r["comp_id"]: r["accounts"] for _, r in comp_df.iterrows()}
    train_accounts = set().union(*(comp_to_accounts[c] for c in train_comps))
    test_accounts  = set().union(*(comp_to_accounts[c] for c in test_comps))

    train_df = df[df[sender_col].isin(train_accounts) & df[receiver_col].isin(train_accounts)]
    test_df  = df[df[sender_col].isin(test_accounts)  & df[receiver_col].isin(test_accounts)]

    # 6️⃣ Summaries
    info = {
        "n_train": len(train_df),
        "n_test": len(test_df),
        "train_frac_actual": len(train_df) / (len(train_df) + len(test_df)),
        "train_pos_rate": train_df[label_col].mean(),
        "test_pos_rate": test_df[label_col].mean(),
        "num_components": len(components),
        "largest_component": max(len(c) for c in components),
        "overall_pos_rate": overall_pos_rate,
    }
    return train_df, test_df, info


In [None]:
train_df, test_df, info = split_by_components_stratified(
    data,
    sender_col="Sender_account",
    receiver_col="Receiver_account",
    label_col="Is_laundering",
    train_frac=0.8,
    random_state=42
)

print(info)