In [None]:
import polars as pl                   # 高效能 DataFrame 套件，類似 pandas，但速度更快
import torch                          # 深度學習主力套件
import torch.nn as nn                 # torch 神經網路子模組
import torch.nn.functional as F       # torch 常用函式
import networkx as nx                 # 圖論處理/計算中心性用
import numpy as np                    # 數值運算
from torch_geometric.nn import GATConv    # 圖神經網路 attention layer
from torch_geometric.data import Data     # PyG 資料格式
from sklearn.metrics import roc_auc_score, f1_score, precision_score, recall_score, accuracy_score, roc_curve
from collections import defaultdict
import matplotlib.pyplot as plt
import random

# === 參數區（滑動視窗設定、GAT模型設定、對比學習參數）===
file_path = '/content/SAML-D.parquet'    # 檔案位置（Colab上傳後，直接用這個路徑）
WINDOW_SIZE = 100000                     # 每個滑動視窗包含幾筆交易
STRIDE = 20000                           # 每次滑動多少資料
epochs_per_window = 20                   # 每個窗口訓練幾次
EMBED_DIM = 16                           # 群組嵌入維度
GAT_HIDDEN = 64                          # GAT 隱藏層維度
TRIPLET_SAMPLES = 64                     # 每 window 每 epoch 的 triplet 數量
TRIPLET_MARGIN = 1.0                     # Triplet Loss 閾值
ALPHA = 0.1                              # 對比學習 loss 權重

# === Union-Find 動態群組分群結構 ===
class UnionFind:
    def __init__(self):
        self.parent = {}
    def find(self, x):
        if x not in self.parent:
            self.parent[x] = x
        if self.parent[x] != x:
            self.parent[x] = self.find(self.parent[x])
        return self.parent[x]
    def union(self, x, y):
        xr, yr = self.find(x), self.find(y)
        merged = (xr != yr)                  # 如果兩個root不同才會合併
        if merged:
            self.parent[yr] = xr             # 合併樹 root
        return merged                        # 回傳這筆交易是否為新群組合併

# === 群組嵌入向量表（每個群組一個嵌入特徵） ===
group_embeddings = {}
def get_group_embedding(gid):
    if gid not in group_embeddings:
        group_embeddings[gid] = np.random.normal(size=EMBED_DIM)  # 若沒有則隨機初始化
    return group_embeddings[gid]
def set_group_embedding(gid, emb):
    group_embeddings[gid] = emb

# === GAT 多模態圖神經網路模型 ===
class UFGAT(nn.Module):
    def __init__(self, in_node, in_edge, hidden=64):
        super().__init__()
        self.gat1 = GATConv(in_node, hidden, heads=1)       # 第一層GAT
        self.gat2 = GATConv(hidden, hidden, heads=1)        # 第二層GAT
        self.mlp = nn.Sequential(
            nn.Linear(hidden * 2 + in_edge, 64),            # 將起點/終點 node embedding + edge feature 串起來
            nn.ReLU(),
            nn.Linear(64, 1)                                # 輸出一維（邊二元分類）
        )
    def forward(self, x, edge_index, edge_attr):
        h = F.relu(self.gat1(x, edge_index))
        h = F.relu(self.gat2(h, edge_index))
        h_u, h_v = h[edge_index[0]], h[edge_index[1]]       # 取出邊的兩端 node embedding
        edge_input = torch.cat([h_u, h_v, edge_attr], dim=-1)   # 邊的起終點嵌入+邊特徵一起給MLP
        return self.mlp(edge_input).squeeze(), h            # 回傳：邊的異常分數, 所有 node embedding

# === Triplet 對比學習損失函數（node embedding 對比用） ===
def triplet_contrastive_loss(node_emb, triplet_idx, margin=1.0):
    if len(triplet_idx) == 0:
        return torch.tensor(0.0, device=node_emb.device)
    anchor_idx, pos_idx, neg_idx = zip(*triplet_idx)
    anchor = node_emb[list(anchor_idx)]
    positive = node_emb[list(pos_idx)]
    negative = node_emb[list(neg_idx)]
    pos_dist = torch.norm(anchor - positive, p=2, dim=-1)       # anchor 和正例的距離
    neg_dist = torch.norm(anchor - negative, p=2, dim=-1)       # anchor 和負例的距離
    loss = torch.relu(pos_dist - neg_dist + margin).mean()      # Triplet Loss：希望正例更靠近、負例更遠
    return loss

# === 掃描所有帳戶建立 node index 映射 ===
print("掃描帳戶全集...")
df = pl.read_parquet(file_path).sort(["Date", "Time"])         # 讀檔＆排序
account_set = set()
for row in df.iter_rows(named=True):                           # 建立所有帳戶清單
    account_set.add(row["Sender_account"])
    account_set.add(row["Receiver_account"])
node_idx_map = {acc: i for i, acc in enumerate(sorted(account_set))}   # 每個帳戶對應一個整數編號
node_count = len(node_idx_map)
print(f"Total nodes (accounts): {node_count}")

# === 分批滑動視窗訓練主迴圈 ===
total_samples = df.height
all_pred, all_y_true = [], []
model, optimizer, loss_fn = None, None, None
global_epoch = 1

for start_idx in range(0, total_samples - WINDOW_SIZE + 1, STRIDE):
    end_idx = start_idx + WINDOW_SIZE
    print(f"\n=== 處理第 {start_idx+1} ~ {end_idx} 筆資料（window size: {WINDOW_SIZE}, stride: {STRIDE}） ===")
    window_df = df[start_idx:end_idx]
    uf = UnionFind()                                  # 新視窗重置群組
    group_graphs = defaultdict(nx.DiGraph)            # 每個群組一張子圖（DAG）
    edge_records = []
    node_features_cache = np.zeros((node_count, 4))   # 節點特徵暫存
    merge_edges, nonmerge_edges = [], []              # 記錄合併/非合併事件邊

    # === 建立本 window 內所有交易紀錄及特徵 ===
    for idx, row in enumerate(window_df.iter_rows(named=True)):
        s, r = row["Sender_account"], row["Receiver_account"]
        amount = float(row["Amount"])
        paytype = row["Payment_type"]
        is_laundering = int(row["Is_laundering"])
        merged = uf.union(s, r)                               # Union-Find 分群
        gid_s, gid_r = uf.find(s), uf.find(r)
        emb_s, emb_r = get_group_embedding(gid_s), get_group_embedding(gid_r)
        if merged:
            new_emb = (emb_s + emb_r) / 2                    # 合併時取平均（可替換更進階融合）
            set_group_embedding(gid_s, new_emb)
            set_group_embedding(gid_r, new_emb)
            merge_flag = 1
            merge_edges.append((node_idx_map[s], node_idx_map[r]))
        else:
            merge_flag = 0
            nonmerge_edges.append((node_idx_map[s], node_idx_map[r]))
        group_graphs[gid_s].add_edge(s, r, weight=amount)    # 群組子圖持續擴展
        G = group_graphs[gid_s]
        group_size = G.number_of_nodes()
        # 若群組大於5才計算中心性指標（避免噪聲）
        if group_size >= 5:
            closeness = nx.closeness_centrality(G)
            betweenness = nx.betweenness_centrality(G)
            avg_closeness = np.mean(list(closeness.values()))
            avg_betweenness = np.mean(list(betweenness.values()))
        else:
            avg_closeness = 0
            avg_betweenness = 0
        # 節點特徵（只存degree+merge_flag, 你可擴充）
        node_features_cache[node_idx_map[s], 0] = G.degree(s)
        node_features_cache[node_idx_map[r], 0] = G.degree(r)
        node_features_cache[node_idx_map[s], 1] = merge_flag
        node_features_cache[node_idx_map[r], 1] = merge_flag
        group_sim = np.linalg.norm(emb_s - emb_r)            # 群組嵌入間距
        # 邊特徵（多模態，含金額log、支付方式one-hot、中心性、群組大小、合併flag、嵌入距離）
        edge_feat = [
            np.log1p(amount),
            1 if paytype == "Cash Deposit" else 0,
            1 if paytype == "Credit card" else 0,
            1 if paytype == "Cross-border" else 0,
            1 if paytype == "Cheque" else 0,
            avg_closeness, avg_betweenness, group_size,
            merge_flag,
            group_sim
        ]
        edge_records.append([node_idx_map[s], node_idx_map[r], edge_feat, is_laundering])
    print(f"本窗口樣本數: {len(edge_records)}，正樣本: {np.sum([rec[3] for rec in edge_records])}")

    # === 將資料整理成 PyTorch Geometric 格式 ===
    edges = np.array([[rec[0], rec[1]] for rec in edge_records]).T
    edge_features = np.array([rec[2] for rec in edge_records])
    labels = np.array([rec[3] for rec in edge_records])
    node_features = node_features_cache
    data = Data(
        x=torch.tensor(node_features, dtype=torch.float),
        edge_index=torch.tensor(edges, dtype=torch.long),
        edge_attr=torch.tensor(edge_features, dtype=torch.float),
        y=torch.tensor(labels, dtype=torch.float)
    )

    # === 處理正負樣本不均，算加權 ===
    n_pos = np.sum(labels == 1)
    n_neg = np.sum(labels == 0)
    pos_weight = torch.tensor([n_neg / n_pos]) if n_pos > 0 else torch.tensor([1.0])
    print(f"正負樣本分布：正樣本={n_pos}，負樣本={n_neg}，pos_weight={pos_weight.item():.2f}")

    # === 初始化/重用模型與優化器 ===
    if model is None:
        model = UFGAT(in_node=node_features.shape[1], in_edge=edge_features.shape[1], hidden=GAT_HIDDEN)
        optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    loss_fn = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

    # === 取出 Triplet 對比樣本組 ===
    def get_triplet_samples(merge_edges, nonmerge_edges, k=TRIPLET_SAMPLES):
        triplets = []
        if len(merge_edges) == 0 or len(nonmerge_edges) == 0:
            return triplets
        sample_merges = random.choices(merge_edges, k=min(k, len(merge_edges)))
        sample_nonmerges = random.choices(nonmerge_edges, k=min(k, len(nonmerge_edges)))
        for (a, p), (n1, n2) in zip(sample_merges, sample_nonmerges):
            triplets.append((a, p, n1))          # (anchor, positive, negative)
        return triplets

    # === 訓練模型 ===
    for epoch in range(epochs_per_window):
        model.train()
        optimizer.zero_grad()
        out, node_emb = model(data.x, data.edge_index, data.edge_attr)  # GAT forward
        loss = loss_fn(out, data.y)                                    # 二元分類 loss
        triplet_idx = get_triplet_samples(merge_edges, nonmerge_edges, TRIPLET_SAMPLES)
        contrastive = triplet_contrastive_loss(node_emb, triplet_idx, margin=TRIPLET_MARGIN)
        total_loss = loss + ALPHA * contrastive                        # 總損失：分類+對比
        total_loss.backward()
        optimizer.step()
        with torch.no_grad():
            pred_prob = torch.sigmoid(model(data.x, data.edge_index, data.edge_attr)[0]).cpu().numpy()
            pred_label = (pred_prob > 0.5).astype(int)
            y_true_np = data.y.cpu().numpy()
            try:
                auc_val = roc_auc_score(y_true_np, pred_prob)
            except:
                auc_val = float('nan')
            f1_val = f1_score(y_true_np, pred_label, zero_division=0)
            precision_val = precision_score(y_true_np, pred_label, zero_division=0)
            recall_val = recall_score(y_true_np, pred_label, zero_division=0)
            acc_val = accuracy_score(y_true_np, pred_label)
        print(f"Win {start_idx//STRIDE+1} | Epoch {global_epoch} - Loss: {loss.item():.4f} | Contrast: {contrastive.item():.4f} | "
              f"AUC: {auc_val:.4f} | F1: {f1_val:.4f} | Pre: {precision_val:.4f} | Rec: {recall_val:.4f} | Acc: {acc_val:.4f}")
        global_epoch += 1

    # === 本 window 評估及顯示 ROC ===
    with torch.no_grad():
        pred = torch.sigmoid(model(data.x, data.edge_index, data.edge_attr)[0]).cpu().numpy()
        y_true_np = data.y.cpu().numpy()
    all_pred.append(pred)
    all_y_true.append(y_true_np)
    fpr, tpr, thresholds = roc_curve(y_true_np, pred)
    f1s = [f1_score(y_true_np, (pred > thr).astype(int)) for thr in thresholds]
    best_idx = int(np.argmax(f1s))
    best_thr = thresholds[best_idx]
    best_pred_label = (pred > best_thr).astype(int)
    auc_val = roc_auc_score(y_true_np, pred)
    f1_val = f1_score(y_true_np, best_pred_label)
    precision_val = precision_score(y_true_np, best_pred_label, zero_division=0)
    recall_val = recall_score(y_true_np, best_pred_label, zero_division=0)
    acc_val = accuracy_score(y_true_np, best_pred_label)
    print(f"窗口ROC -- Best Threshold (max F1): {best_thr:.4f}")
    print("AUC:      ", auc_val)
    print("F1:       ", f1_val)
    print("Precision:", precision_val)
    print("Recall:   ", recall_val)
    print("Accuracy: ", acc_val)
    plt.figure(figsize=(7,5))
    plt.plot(fpr, tpr, label=f'ROC curve')
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title(f"ROC Curve (Win {start_idx//STRIDE+1})")
    plt.legend(loc='lower right')
    metrics_text = f"AUC={auc_val:.3f}\nF1={f1_val:.3f}\nPre={precision_val:.3f}\nRec={recall_val:.3f}\nThr={best_thr:.3f}"
    plt.gca().text(0.02, 0.98, metrics_text, fontsize=11, verticalalignment='top', horizontalalignment='left', bbox=dict(facecolor='white', alpha=0.7))
    plt.show()   # ⭐ 只顯示，不存圖

# === 合併全體結果，計算全局微平均ROC及F1 ===
all_pred = np.concatenate(all_pred)
all_y_true = np.concatenate(all_y_true)
fpr, tpr, thresholds = roc_curve(all_y_true, all_pred)
f1s = [f1_score(all_y_true, (all_pred > thr).astype(int)) for thr in thresholds]
best_idx = int(np.argmax(f1s))
best_thr = thresholds[best_idx]
best_pred_label = (all_pred > best_thr).astype(int)
auc_val = roc_auc_score(all_y_true, all_pred)
f1_val = f1_score(all_y_true, best_pred_label)
precision_val = precision_score(all_y_true, best_pred_label, zero_division=0)
recall_val = recall_score(all_y_true, best_pred_label, zero_division=0)
acc_val = accuracy_score(all_y_true, best_pred_label)
print("\n--- 全體微平均(Global)最佳指標 ---")
print(f"Best Threshold (max F1): {best_thr:.4f}")
print("AUC:      ", auc_val)
print("F1:       ", f1_val)
print("Precision:", precision_val)
print("Recall:   ", recall_val)
print("Accuracy: ", acc_val)
plt.figure(figsize=(8,6))
plt.plot(fpr, tpr, label=f'Global ROC')
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title(f"Global ROC Curve (Micro Avg.)")
plt.legend(loc='lower right')
metrics_text = f"AUC={auc_val:.3f}\nF1={f1_val:.3f}\nPre={precision_val:.3f}\nRec={recall_val:.3f}\nThr={best_thr:.3f}"
plt.gca().text(0.02, 0.98, metrics_text, fontsize=13, verticalalignment='top', horizontalalignment='left', bbox=dict(facecolor='white', alpha=0.7))
plt.show()  # ⭐ 只顯示，不存圖
