# GhostBusters: AML Mule Account Detection - A100 GPU Training Pipeline

**Anti-Money Laundering (AML) Fraud Detection using Graph Neural Networks**

This notebook trains multiple ML and GNN models on the AMLSim dataset to detect mule/SAR accounts.

### Pipeline Overview
1. **Environment Setup** - A100 GPU detection, MIG 1g.5gb optimization, dependency checks
2. **Dataset Verification** - Validate all train/val/test splits and graph data
3. **Data Loading & Preprocessing** - Load numpy arrays, normalize features
4. **Baseline Models** - XGBoost, Random Forest, Logistic Regression, Gradient Boosting
5. **GNN Models** - GraphSAGE, GAT, GCN, GIN, HeteroGNN
6. **Model Saving** - PyTorch (.pt), TorchScript, ONNX, Pickle (sklearn/xgb)
7. **Visualizations** - Training curves, graph structure, embeddings, attention, comparisons

### Hardware Target
- NVIDIA A100 GPU (1g.5gb MIG instance or full GPU)
- Mixed precision (FP16/BF16) with TF32 matmul
- Fused Adam optimizer, torch.compile, GradScaler

In [3]:
!pip install pandas



In [6]:
# ============================================================
# CELL 1: Environment Setup & A100 GPU Configuration
# ============================================================
import os
import sys
import json
import time
import pickle
import warnings
import numpy as np
import pandas as pd
from pathlib import Path
from collections import defaultdict
from datetime import datetime

warnings.filterwarnings('ignore')

# ---- PyTorch ----
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam, AdamW
from torch.optim.lr_scheduler import ReduceLROnPlateau, CosineAnnealingWarmRestarts

# ---- Sklearn ----
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import (
    classification_report, precision_recall_curve, auc,
    f1_score, roc_auc_score, average_precision_score,
    confusion_matrix, roc_curve
)
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.manifold import TSNE

# ---- XGBoost ----
try:
    from xgboost import XGBClassifier
    HAS_XGB = True
except ImportError:
    HAS_XGB = False
    print('[WARN] XGBoost not installed. pip install xgboost')

# ---- Visualization ----
import matplotlib
matplotlib.use('Agg')  # Non-interactive backend for remote servers
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.colors import LinearSegmentedColormap
try:
    import seaborn as sns
    sns.set_theme(style='whitegrid', font_scale=1.1)
    HAS_SNS = True
except ImportError:
    HAS_SNS = False

try:
    import networkx as nx
    HAS_NX = True
except ImportError:
    HAS_NX = False
    print('[WARN] networkx not installed. pip install networkx')

print('Core imports loaded successfully.')

ModuleNotFoundError: No module named 'torch'

In [5]:
!nvidia-smi

Wed Feb 18 21:54:39 2026       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 566.36                 Driver Version: 566.36         CUDA Version: 12.7     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                  Driver-Model | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce RTX 3060 ...  WDDM  |   00000000:01:00.0  On |                  N/A |
| N/A   54C    P8             14W /   88W |    2458MiB /   6144MiB |     16%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [None]:
# ============================================================
# CELL 2: A100 GPU Detection & MIG 1g.5gb Configuration
# ============================================================

def setup_a100_optimizations():
    """Configure PyTorch for optimal A100 performance (including MIG 1g.5gb)."""
    config = {
        'device': 'cpu',
        'gpu_name': 'N/A',
        'gpu_memory_gb': 0,
        'is_a100': False,
        'is_mig': False,
        'mig_profile': 'N/A',
        'use_amp': False,
        'use_bf16': False,
        'use_tf32': False,
        'use_compile': False,
        'use_fused_adam': False,
        'batch_size': 32768,
        'hidden_dim': 128,
        'num_workers': 2,
    }
    
    if not torch.cuda.is_available():
        print('WARNING: CUDA not available. Using CPU (training will be slow).')
        return config
    
    config['device'] = 'cuda'
    gpu_name = torch.cuda.get_device_name(0)
    gpu_mem = torch.cuda.get_device_properties(0).total_memory / (1024**3)
    config['gpu_name'] = gpu_name
    config['gpu_memory_gb'] = round(gpu_mem, 1)
    
    print(f'GPU Detected: {gpu_name}')
    print(f'GPU Memory: {gpu_mem:.1f} GB')
    print(f'CUDA Version: {torch.version.cuda}')
    print(f'PyTorch Version: {torch.__version__}')
    
    # Detect A100
    config['is_a100'] = 'A100' in gpu_name.upper()
    
    # Detect MIG instance (A100 MIG profiles have reduced memory)
    # 1g.5gb = ~5GB, 2g.10gb = ~10GB, 3g.20gb = ~20GB, 7g.40gb/80gb = full
    if config['is_a100']:
        if gpu_mem < 7:
            config['is_mig'] = True
            config['mig_profile'] = '1g.5gb'
            config['batch_size'] = 16384
            config['hidden_dim'] = 128
            config['num_workers'] = 1
            print('MIG Profile: 1g.5gb (5GB slice)')
        elif gpu_mem < 12:
            config['is_mig'] = True
            config['mig_profile'] = '2g.10gb'
            config['batch_size'] = 32768
            config['hidden_dim'] = 192
            config['num_workers'] = 2
            print('MIG Profile: 2g.10gb (10GB slice)')
        elif gpu_mem < 25:
            config['is_mig'] = True
            config['mig_profile'] = '3g.20gb'
            config['batch_size'] = 65536
            config['hidden_dim'] = 256
            config['num_workers'] = 4
            print('MIG Profile: 3g.20gb (20GB slice)')
        else:
            config['mig_profile'] = 'full'
            config['batch_size'] = 131072
            config['hidden_dim'] = 512
            config['num_workers'] = 4
            print('Full A100 GPU detected (40GB or 80GB)')
    else:
        # Non-A100 GPU - adapt based on memory
        if gpu_mem < 8:
            config['batch_size'] = 16384
            config['hidden_dim'] = 128
        elif gpu_mem < 16:
            config['batch_size'] = 32768
            config['hidden_dim'] = 192
        else:
            config['batch_size'] = 65536
            config['hidden_dim'] = 256
    
    # Enable TF32 for A100 (faster FP32 matmul with TF32 precision)
    torch.set_float32_matmul_precision('high')
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
    config['use_tf32'] = True
    print('TF32 matmul: ENABLED')
    
    # Mixed precision (AMP)
    config['use_amp'] = True
    print('Mixed Precision (AMP): ENABLED')
    
    # BF16 support (A100 has native BF16)
    if config['is_a100'] or torch.cuda.is_bf16_supported():
        config['use_bf16'] = True
        print('BF16 Support: ENABLED')
    
    # torch.compile (PyTorch 2.0+)
    if hasattr(torch, 'compile'):
        config['use_compile'] = True
        print('torch.compile: AVAILABLE')
    
    # Fused Adam optimizer
    config['use_fused_adam'] = True
    print('Fused Adam: ENABLED')
    
    # CUDA optimizations
    torch.cuda.empty_cache()
    if hasattr(torch.cuda, 'memory_stats'):
        torch.cuda.reset_peak_memory_stats()
    
    # cuDNN benchmark mode for consistent input sizes
    torch.backends.cudnn.benchmark = True
    print('cuDNN benchmark: ENABLED')
    
    print(f'\nOptimal batch_size: {config["batch_size"]}')
    print(f'Optimal hidden_dim: {config["hidden_dim"]}')
    
    return config

# Run setup
GPU_CONFIG = setup_a100_optimizations()
DEVICE = torch.device(GPU_CONFIG['device'])
print(f'\nUsing device: {DEVICE}')

In [None]:
# ============================================================
# CELL 3: Project Paths & Directory Setup
# ============================================================

# Auto-detect project root (works in Jupyter on remote server)
NOTEBOOK_DIR = os.path.dirname(os.path.abspath('__file__'))
PROJECT_ROOT = NOTEBOOK_DIR  # Notebook is at project root

BASE_DIR = os.path.join(PROJECT_ROOT, 'data', 'amlsim_v1')
RESULTS_DIR = os.path.join(PROJECT_ROOT, 'results')
MODELS_DIR = os.path.join(PROJECT_ROOT, 'models')
VIZ_DIR = os.path.join(PROJECT_ROOT, 'visualizations')
SCRIPTS_DIR = os.path.join(PROJECT_ROOT, 'scripts')

# Create output directories
for d in [RESULTS_DIR, MODELS_DIR, VIZ_DIR]:
    os.makedirs(d, exist_ok=True)

print(f'Project Root: {PROJECT_ROOT}')
print(f'Data Dir:     {BASE_DIR}')
print(f'Models Dir:   {MODELS_DIR}')
print(f'Results Dir:  {RESULTS_DIR}')
print(f'Viz Dir:      {VIZ_DIR}')
print(f'\nAll output directories ready.')

In [None]:
# ============================================================
# CELL 4: Dataset Verification - Check ALL files exist
# ============================================================

def verify_dataset():
    """Comprehensive dataset verification across all splits."""
    
    splits = ['train', 'val', 'test']
    
    # Files expected in each split directory
    csv_files = [
        'accounts.csv', 'transactions.csv', 'sar_accounts.csv',
        'alert_accounts.csv', 'enriched_transactions.csv',
        'app_logins.csv', 'atm_withdrawals.csv', 'wallet_links.csv',
        'upi_handles.csv', 'individuals-bulkload.csv', 'organizations-bulkload.csv'
    ]
    
    # Graph data files
    graph_files = [
        'account_features.npy', 'account_labels.npy',
        'edge_features.npy', 'edge_labels.npy',
        'transfer_edge_index.npy', 'same_wallet_edge_index.npy',
        'wallet_edge_index.npy', 'vpa_edge_index.npy',
        'atm_edge_index.npy', 'bank_edge_index.npy',
        'login_edge_index.npy', 'shared_device_edge_index.npy',
        'graph_stats.json', 'id_mappings.json'
    ]
    
    all_ok = True
    summary = {}
    
    for split in splits:
        split_dir = os.path.join(BASE_DIR, split)
        graph_dir = os.path.join(split_dir, 'graph_data')
        
        print(f'\n{"="*60}')
        print(f'  Verifying: {split.upper()} split')
        print(f'{"="*60}')
        
        split_info = {'csv': {}, 'graph': {}, 'missing': []}
        
        # Check CSV files
        for f in csv_files:
            path = os.path.join(split_dir, f)
            if os.path.exists(path):
                size_mb = os.path.getsize(path) / (1024 * 1024)
                try:
                    df = pd.read_csv(path, nrows=1)
                    nrows = sum(1 for _ in open(path)) - 1
                    split_info['csv'][f] = {'rows': nrows, 'cols': len(df.columns), 'size_mb': round(size_mb, 1)}
                    print(f'  OK  {f:<35} {nrows:>10,} rows  {size_mb:>7.1f} MB')
                except Exception:
                    split_info['csv'][f] = {'size_mb': round(size_mb, 1)}
                    print(f'  OK  {f:<35} {size_mb:>7.1f} MB')
            else:
                split_info['missing'].append(f)
                print(f'  MISSING  {f}')
                all_ok = False
        
        # Check graph data files
        print(f'\n  Graph Data:')
        for f in graph_files:
            path = os.path.join(graph_dir, f)
            if os.path.exists(path):
                size_mb = os.path.getsize(path) / (1024 * 1024)
                if f.endswith('.npy'):
                    arr = np.load(path)
                    split_info['graph'][f] = {'shape': list(arr.shape), 'dtype': str(arr.dtype), 'size_mb': round(size_mb, 1)}
                    print(f'  OK  {f:<35} shape={str(arr.shape):<20} {size_mb:>7.1f} MB')
                else:
                    split_info['graph'][f] = {'size_mb': round(size_mb, 1)}
                    print(f'  OK  {f:<35} {size_mb:>7.1f} MB')
            else:
                split_info['missing'].append(f'graph_data/{f}')
                print(f'  MISSING  graph_data/{f}')
                all_ok = False
        
        # Load and display graph stats
        stats_path = os.path.join(graph_dir, 'graph_stats.json')
        if os.path.exists(stats_path):
            with open(stats_path) as sf:
                stats = json.load(sf)
            print(f'\n  Graph Statistics:')
            print(f'    Accounts: {stats["num_accounts"]:,} ({stats["num_sar_accounts"]:,} SAR)')
            print(f'    Total edges: {stats["total_edges"]:,}')
            sar_rate = stats["num_sar_accounts"] / max(stats["num_accounts"], 1) * 100
            print(f'    SAR rate: {sar_rate:.2f}%')
        
        summary[split] = split_info
    
    print(f'\n{"="*60}')
    if all_ok:
        print('  DATASET VERIFICATION: ALL FILES PRESENT')
    else:
        print('  DATASET VERIFICATION: SOME FILES MISSING (see above)')
    print(f'{"="*60}')
    
    return summary, all_ok

dataset_summary, dataset_ok = verify_dataset()

In [None]:
# ============================================================
# CELL 5: Data Loading & Preprocessing
# ============================================================

def load_split(split):
    """Load all graph arrays for one split."""
    gd = os.path.join(BASE_DIR, split, 'graph_data')
    data = {}
    arrays = [
        'account_features', 'account_labels', 'edge_features', 'edge_labels',
        'transfer_edge_index', 'same_wallet_edge_index',
        'wallet_edge_index', 'vpa_edge_index', 'atm_edge_index',
        'bank_edge_index', 'login_edge_index', 'shared_device_edge_index'
    ]
    for name in arrays:
        path = os.path.join(gd, f'{name}.npy')
        if os.path.exists(path):
            data[name] = np.load(path)
    return data

def prepare_node_data(train_d, val_d, test_d):
    """Prepare node classification tensors with combined adjacency."""
    X_train = torch.tensor(train_d['account_features'], dtype=torch.float32).to(DEVICE)
    y_train = torch.tensor(train_d['account_labels'], dtype=torch.long).to(DEVICE)
    X_val = torch.tensor(val_d['account_features'], dtype=torch.float32).to(DEVICE)
    y_val = torch.tensor(val_d['account_labels'], dtype=torch.long).to(DEVICE)
    X_test = torch.tensor(test_d['account_features'], dtype=torch.float32).to(DEVICE)
    y_test = torch.tensor(test_d['account_labels'], dtype=torch.long).to(DEVICE)

    # Handle NaN/Inf
    for X in [X_train, X_val, X_test]:
        X[torch.isnan(X)] = 0
        X[torch.isinf(X)] = 0

    # Normalize features using train set statistics
    mean = X_train.mean(dim=0)
    std = X_train.std(dim=0).clamp(min=1e-6)
    X_train = (X_train - mean) / std
    X_val = (X_val - mean) / std
    X_test = (X_test - mean) / std

    # Build combined adjacency: transfer + same_wallet + shared_device
    def build_edges(d):
        edges_list = []
        for key in ['transfer_edge_index', 'same_wallet_edge_index', 'shared_device_edge_index']:
            if key in d and d[key].shape[1] > 0:
                edges_list.append(torch.tensor(d[key], dtype=torch.long))
        if edges_list:
            return torch.cat(edges_list, dim=1).to(DEVICE)
        return torch.tensor(d['transfer_edge_index'], dtype=torch.long).to(DEVICE)

    edge_index_train = build_edges(train_d)
    edge_index_val = build_edges(val_d)
    edge_index_test = build_edges(test_d)

    return {
        'X_train': X_train, 'y_train': y_train, 'edge_index_train': edge_index_train,
        'X_val': X_val, 'y_val': y_val, 'edge_index_val': edge_index_val,
        'X_test': X_test, 'y_test': y_test, 'edge_index_test': edge_index_test,
        'mean': mean, 'std': std
    }

# Load data
print('Loading graph data from all splits...')
t0 = time.time()
train_d = load_split('train')
val_d = load_split('val')
test_d = load_split('test')
print(f'  Raw data loaded in {time.time()-t0:.1f}s')

# Prepare tensors
print('Preparing GPU tensors...')
t0 = time.time()
node_data = prepare_node_data(train_d, val_d, test_d)
in_dim = node_data['X_train'].size(1)
print(f'  Tensors ready in {time.time()-t0:.1f}s')
print(f'  Input features: {in_dim}')
print(f'  Train: {node_data["X_train"].shape[0]:,} nodes, {node_data["edge_index_train"].shape[1]:,} edges')
print(f'  Val:   {node_data["X_val"].shape[0]:,} nodes, {node_data["edge_index_val"].shape[1]:,} edges')
print(f'  Test:  {node_data["X_test"].shape[0]:,} nodes, {node_data["edge_index_test"].shape[1]:,} edges')
print(f'  Class balance (train): {(node_data["y_train"]==1).sum().item():,} SAR / {(node_data["y_train"]==0).sum().item():,} normal')

# Memory report
if torch.cuda.is_available():
    mem_alloc = torch.cuda.memory_allocated() / 1024**3
    mem_total = torch.cuda.get_device_properties(0).total_memory / 1024**3
    print(f'  GPU Memory: {mem_alloc:.2f} / {mem_total:.1f} GB allocated')

In [None]:
# ============================================================
# CELL 6: Metrics Computation Utility
# ============================================================

FEATURE_NAMES = [
    'acct_type', 'is_active', 'prior_sar', 'initial_deposit',
    'sent_count', 'recv_count', 'sent_amt_mean', 'sent_amt_std',
    'recv_amt_mean', 'recv_amt_std', 'fan_out_ratio', 'fan_in_ratio',
    'recv_send_delay', 'channel_diversity'
]

def compute_metrics(y_true, y_pred, y_prob):
    """Compute comprehensive classification metrics."""
    metrics = {}
    report = classification_report(y_true, y_pred, output_dict=True, zero_division=0)
    metrics['accuracy'] = report['accuracy']
    metrics['f1_macro'] = report['macro avg']['f1-score']
    metrics['f1_weighted'] = report['weighted avg']['f1-score']

    if '1' in report:
        metrics['precision_sar'] = report['1']['precision']
        metrics['recall_sar'] = report['1']['recall']
        metrics['f1_sar'] = report['1']['f1-score']

    if y_prob is not None:
        try:
            prec, rec, _ = precision_recall_curve(y_true, y_prob)
            metrics['pr_auc'] = float(auc(rec, prec))
            metrics['avg_precision'] = float(average_precision_score(y_true, y_prob))
            metrics['roc_auc'] = float(roc_auc_score(y_true, y_prob))
        except Exception:
            pass

    cm = confusion_matrix(y_true, y_pred)
    metrics['confusion_matrix'] = cm.tolist()

    n_pos = int(y_true.sum())
    for k_mult in [1, 2, 5]:
        k = n_pos * k_mult
        if y_prob is not None and k <= len(y_prob):
            top_k_idx = np.argsort(y_prob)[-k:]
            recall_at_k = y_true[top_k_idx].sum() / max(n_pos, 1)
            metrics[f'recall@{k_mult}x'] = float(recall_at_k)

    return metrics

def print_metrics(metrics, name):
    """Print formatted metrics."""
    print(f'\n  === {name} ===')
    print(f'  PR-AUC:     {metrics.get("pr_auc", 0):.4f}')
    print(f'  ROC-AUC:    {metrics.get("roc_auc", 0):.4f}')
    print(f'  F1 (SAR):   {metrics.get("f1_sar", 0):.4f}')
    print(f'  Precision:  {metrics.get("precision_sar", 0):.4f}')
    print(f'  Recall:     {metrics.get("recall_sar", 0):.4f}')
    if 'recall@1x' in metrics:
        print(f'  Recall@1x:  {metrics["recall@1x"]:.4f}')
        print(f'  Recall@2x:  {metrics.get("recall@2x", 0):.4f}')
    cm = metrics.get('confusion_matrix', [[0,0],[0,0]])
    print(f'  Confusion:  TN={cm[0][0]:,} FP={cm[0][1]:,} | FN={cm[1][0]:,} TP={cm[1][1]:,}')

print('Metrics utilities loaded.')

## Part 1: Baseline Models (XGBoost, Random Forest, Gradient Boosting, Logistic Regression)

Traditional ML models trained on the 14-dimensional account feature vectors. These serve as strong baselines to compare against GNN approaches.

In [None]:
# ============================================================
# CELL 7: Train ALL Baseline Models + Save
# ============================================================

def train_all_baselines():
    """Train XGBoost, Random Forest, Gradient Boosting, Logistic Regression."""
    print('='*60)
    print('BASELINE MODELS - Node Classification (Mule/SAR Detection)')
    print('='*60)
    
    # Prepare numpy data
    X_train = train_d['account_features'].copy()
    y_train = train_d['account_labels'].copy()
    X_val = val_d['account_features'].copy()
    y_val = val_d['account_labels'].copy()
    X_test = test_d['account_features'].copy()
    y_test = test_d['account_labels'].copy()
    
    for X in [X_train, X_val, X_test]:
        X[np.isnan(X)] = 0
        X[np.isinf(X)] = 0
    
    scaler = StandardScaler()
    X_train_s = scaler.fit_transform(X_train)
    X_val_s = scaler.transform(X_val)
    X_test_s = scaler.transform(X_test)
    
    n_pos = y_train.sum()
    n_neg = len(y_train) - n_pos
    scale_pos = n_neg / max(n_pos, 1)
    print(f'  Train: {len(y_train):,} ({n_pos:,} SAR, {n_neg:,} normal, ratio 1:{scale_pos:.1f})')
    
    results = {}
    models = {}
    
    # 1. XGBoost
    if HAS_XGB:
        print('\n--- Training XGBoost ---')
        t0 = time.time()
        xgb = XGBClassifier(
            n_estimators=300, max_depth=6, learning_rate=0.1,
            scale_pos_weight=scale_pos, eval_metric='aucpr',
            use_label_encoder=False, random_state=42, n_jobs=-1,
            tree_method='hist',  # Fast histogram method (GPU-compatible)
        )
        xgb.fit(X_train_s, y_train, eval_set=[(X_val_s, y_val)], verbose=False)
        y_pred = xgb.predict(X_test_s)
        y_prob = xgb.predict_proba(X_test_s)[:, 1]
        m = compute_metrics(y_test, y_pred, y_prob)
        print_metrics(m, f'XGBoost ({time.time()-t0:.1f}s)')
        results['xgboost'] = m
        models['xgboost'] = xgb
    
    # 2. Random Forest
    print('\n--- Training Random Forest ---')
    t0 = time.time()
    rf = RandomForestClassifier(
        n_estimators=200, max_depth=10, class_weight='balanced',
        random_state=42, n_jobs=-1
    )
    rf.fit(X_train_s, y_train)
    y_pred_rf = rf.predict(X_test_s)
    y_prob_rf = rf.predict_proba(X_test_s)[:, 1]
    m_rf = compute_metrics(y_test, y_pred_rf, y_prob_rf)
    print_metrics(m_rf, f'Random Forest ({time.time()-t0:.1f}s)')
    results['random_forest'] = m_rf
    models['random_forest'] = rf
    
    # 3. Gradient Boosting (sklearn)
    print('\n--- Training Gradient Boosting ---')
    t0 = time.time()
    gb = GradientBoostingClassifier(
        n_estimators=200, max_depth=5, learning_rate=0.1,
        subsample=0.8, random_state=42
    )
    gb.fit(X_train_s, y_train)
    y_pred_gb = gb.predict(X_test_s)
    y_prob_gb = gb.predict_proba(X_test_s)[:, 1]
    m_gb = compute_metrics(y_test, y_pred_gb, y_prob_gb)
    print_metrics(m_gb, f'Gradient Boosting ({time.time()-t0:.1f}s)')
    results['gradient_boosting'] = m_gb
    models['gradient_boosting'] = gb
    
    # 4. Logistic Regression
    print('\n--- Training Logistic Regression ---')
    t0 = time.time()
    lr = LogisticRegression(class_weight='balanced', max_iter=1000, random_state=42)
    lr.fit(X_train_s, y_train)
    y_pred_lr = lr.predict(X_test_s)
    y_prob_lr = lr.predict_proba(X_test_s)[:, 1]
    m_lr = compute_metrics(y_test, y_pred_lr, y_prob_lr)
    print_metrics(m_lr, f'Logistic Regression ({time.time()-t0:.1f}s)')
    results['logistic_regression'] = m_lr
    models['logistic_regression'] = lr
    
    # Save all baseline models
    print('\n--- Saving Baseline Models ---')
    
    # Save scaler (needed for inference)
    scaler_path = os.path.join(MODELS_DIR, 'baseline_scaler.pkl')
    with open(scaler_path, 'wb') as f:
        pickle.dump(scaler, f)
    print(f'  Saved: {scaler_path}')
    
    for name, model in models.items():
        # Pickle format
        pkl_path = os.path.join(MODELS_DIR, f'baseline_{name}.pkl')
        with open(pkl_path, 'wb') as f:
            pickle.dump(model, f)
        print(f'  Saved: {pkl_path}')
    
    # Save XGBoost native format
    if HAS_XGB and 'xgboost' in models:
        xgb_path = os.path.join(MODELS_DIR, 'xgboost_native.json')
        models['xgboost'].save_model(xgb_path)
        print(f'  Saved: {xgb_path} (native XGBoost format)')
    
    # Store predictions for visualization later
    baseline_preds = {
        'y_test': y_test,
        'xgboost': {'prob': y_prob if HAS_XGB else None, 'pred': y_pred if HAS_XGB else None},
        'random_forest': {'prob': y_prob_rf, 'pred': y_pred_rf},
        'gradient_boosting': {'prob': y_prob_gb, 'pred': y_pred_gb},
        'logistic_regression': {'prob': y_prob_lr, 'pred': y_pred_lr},
    }
    
    return results, models, baseline_preds

baseline_results, baseline_models, baseline_preds = train_all_baselines()

## Part 2: GNN Model Definitions

Five Graph Neural Network architectures optimized for A100:
1. **GraphSAGE** - Mean aggregation with self-loop
2. **GAT** - Multi-head attention (4 heads)
3. **GCN** - Spectral graph convolution (symmetric normalization)
4. **GIN** - Graph Isomorphism Network (sum aggregation + MLP)
5. **HeteroGNN** - Multi-relation heterogeneous GNN using all edge types

In [None]:
# ============================================================
# CELL 8: GNN Layer & Model Definitions (all 5 architectures)
# ============================================================

# ---------- Layer Definitions ----------

class SAGEConv(nn.Module):
    """GraphSAGE convolution (mean aggregation)."""
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.linear_self = nn.Linear(in_dim, out_dim)
        self.linear_neigh = nn.Linear(in_dim, out_dim)

    def forward(self, x, edge_index):
        src, dst = edge_index[0], edge_index[1]
        N = x.size(0)
        neigh_sum = torch.zeros(N, x.size(1), device=x.device, dtype=x.dtype)
        neigh_count = torch.zeros(N, 1, device=x.device, dtype=x.dtype)
        neigh_sum.index_add_(0, dst, x[src])
        neigh_count.index_add_(0, dst, torch.ones(src.size(0), 1, device=x.device, dtype=x.dtype))
        neigh_mean = neigh_sum / neigh_count.clamp(min=1)
        return self.linear_self(x) + self.linear_neigh(neigh_mean)


class GATConvLayer(nn.Module):
    """Graph Attention convolution (multi-head)."""
    def __init__(self, in_dim, out_dim, num_heads=4, dropout=0.2):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = out_dim // num_heads
        assert out_dim % num_heads == 0
        self.W = nn.Linear(in_dim, out_dim, bias=False)
        self.a_src = nn.Parameter(torch.zeros(num_heads, self.head_dim))
        self.a_dst = nn.Parameter(torch.zeros(num_heads, self.head_dim))
        nn.init.xavier_uniform_(self.a_src.unsqueeze(0))
        nn.init.xavier_uniform_(self.a_dst.unsqueeze(0))
        self.leaky_relu = nn.LeakyReLU(0.2)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, edge_index):
        N = x.size(0)
        src, dst = edge_index[0], edge_index[1]
        E = src.size(0)
        h = self.W(x).view(N, self.num_heads, self.head_dim)
        score_src = (h * self.a_src).sum(dim=-1)
        score_dst = (h * self.a_dst).sum(dim=-1)
        e = self.leaky_relu(score_src[src] + score_dst[dst])
        e_max = torch.zeros(N, self.num_heads, device=x.device, dtype=e.dtype)
        e_max.index_reduce_(0, dst, e, 'amax', include_self=True)
        e_exp = torch.exp(e - e_max[dst])
        e_sum = torch.zeros(N, self.num_heads, device=x.device, dtype=e.dtype)
        e_sum.index_add_(0, dst, e_exp)
        alpha = self.dropout(e_exp / e_sum[dst].clamp(min=1e-9))
        # Chunked aggregation to avoid OOM on large graphs (3.4M+ edges)
        CHUNK = 500_000
        out = torch.zeros(N, self.num_heads, self.head_dim, device=x.device, dtype=torch.float32)
        for i in range(0, E, CHUNK):
            j = min(i + CHUNK, E)
            msg_chunk = h[src[i:j]].float() * alpha[i:j].unsqueeze(-1).float()
            out.index_add_(0, dst[i:j], msg_chunk)
        return out.reshape(N, -1)


class GCNConv(nn.Module):
    """Graph Convolutional Network layer (symmetric normalization)."""
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.linear = nn.Linear(in_dim, out_dim)

    def forward(self, x, edge_index):
        src, dst = edge_index[0], edge_index[1]
        N = x.size(0)
        # Compute degree for symmetric normalization: D^{-1/2} A D^{-1/2}
        deg = torch.zeros(N, device=x.device, dtype=x.dtype)
        deg.index_add_(0, dst, torch.ones(src.size(0), device=x.device, dtype=x.dtype))
        deg = deg.clamp(min=1)
        deg_inv_sqrt = deg.pow(-0.5)
        # Normalized message passing
        norm = deg_inv_sqrt[src] * deg_inv_sqrt[dst]
        h = self.linear(x)
        agg = torch.zeros(N, h.size(1), device=x.device)
        agg.index_add_(0, dst, h[src] * norm.unsqueeze(-1))
        # Add self-loop
        return agg + h


class GINConv(nn.Module):
    """Graph Isomorphism Network layer (sum agg + MLP)."""
    def __init__(self, in_dim, out_dim, eps=0.0):
        super().__init__()
        self.eps = nn.Parameter(torch.tensor(eps))
        self.mlp = nn.Sequential(
            nn.Linear(in_dim, out_dim),
            nn.BatchNorm1d(out_dim),
            nn.ReLU(),
            nn.Linear(out_dim, out_dim),
        )

    def forward(self, x, edge_index):
        src, dst = edge_index[0], edge_index[1]
        N = x.size(0)
        agg = torch.zeros(N, x.size(1), device=x.device, dtype=x.dtype)
        agg.index_add_(0, dst, x[src])
        out = (1 + self.eps) * x + agg
        return self.mlp(out)


# ---------- Full Model Definitions ----------

class GraphSAGEModel(nn.Module):
    """2-layer GraphSAGE."""
    def __init__(self, in_dim, hidden_dim, out_dim, dropout=0.3):
        super().__init__()
        self.conv1 = SAGEConv(in_dim, hidden_dim)
        self.conv2 = SAGEConv(hidden_dim, hidden_dim)
        self.classifier = nn.Linear(hidden_dim, out_dim)
        self.dropout = nn.Dropout(dropout)
        self.bn1 = nn.BatchNorm1d(hidden_dim)
        self.bn2 = nn.BatchNorm1d(hidden_dim)

    def forward(self, x, edge_index):
        h = F.relu(self.bn1(self.conv1(x, edge_index)))
        h = self.dropout(h)
        h = F.relu(self.bn2(self.conv2(h, edge_index)))
        h = self.dropout(h)
        return self.classifier(h)

    def get_embeddings(self, x, edge_index):
        h = F.relu(self.bn1(self.conv1(x, edge_index)))
        h = F.relu(self.bn2(self.conv2(h, edge_index)))
        return h


class GATModel(nn.Module):
    """2-layer GAT."""
    def __init__(self, in_dim, hidden_dim, out_dim, num_heads=4, dropout=0.3):
        super().__init__()
        self.conv1 = GATConvLayer(in_dim, hidden_dim, num_heads, dropout)
        self.conv2 = GATConvLayer(hidden_dim, hidden_dim, num_heads, dropout)
        self.classifier = nn.Linear(hidden_dim, out_dim)
        self.dropout = nn.Dropout(dropout)
        self.bn1 = nn.BatchNorm1d(hidden_dim)
        self.bn2 = nn.BatchNorm1d(hidden_dim)

    def forward(self, x, edge_index):
        h = F.elu(self.bn1(self.conv1(x, edge_index)))
        h = self.dropout(h)
        h = F.elu(self.bn2(self.conv2(h, edge_index)))
        h = self.dropout(h)
        return self.classifier(h)

    def get_embeddings(self, x, edge_index):
        h = F.elu(self.bn1(self.conv1(x, edge_index)))
        h = F.elu(self.bn2(self.conv2(h, edge_index)))
        return h


class GCNModel(nn.Module):
    """2-layer GCN."""
    def __init__(self, in_dim, hidden_dim, out_dim, dropout=0.3):
        super().__init__()
        self.conv1 = GCNConv(in_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.classifier = nn.Linear(hidden_dim, out_dim)
        self.dropout = nn.Dropout(dropout)
        self.bn1 = nn.BatchNorm1d(hidden_dim)
        self.bn2 = nn.BatchNorm1d(hidden_dim)

    def forward(self, x, edge_index):
        h = F.relu(self.bn1(self.conv1(x, edge_index)))
        h = self.dropout(h)
        h = F.relu(self.bn2(self.conv2(h, edge_index)))
        h = self.dropout(h)
        return self.classifier(h)

    def get_embeddings(self, x, edge_index):
        h = F.relu(self.bn1(self.conv1(x, edge_index)))
        h = F.relu(self.bn2(self.conv2(h, edge_index)))
        return h


class GINModel(nn.Module):
    """2-layer GIN."""
    def __init__(self, in_dim, hidden_dim, out_dim, dropout=0.3):
        super().__init__()
        self.conv1 = GINConv(in_dim, hidden_dim)
        self.conv2 = GINConv(hidden_dim, hidden_dim)
        self.classifier = nn.Linear(hidden_dim, out_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, edge_index):
        h = self.conv1(x, edge_index)
        h = self.dropout(h)
        h = self.conv2(h, edge_index)
        h = self.dropout(h)
        return self.classifier(h)

    def get_embeddings(self, x, edge_index):
        h = self.conv1(x, edge_index)
        h = self.conv2(h, edge_index)
        return h


class HeteroGNNModel(nn.Module):
    """Multi-relation GNN that processes each edge type separately then fuses."""
    def __init__(self, in_dim, hidden_dim, out_dim, num_relations=3, dropout=0.3):
        super().__init__()
        self.relation_convs = nn.ModuleList([
            SAGEConv(in_dim, hidden_dim) for _ in range(num_relations)
        ])
        self.relation_convs2 = nn.ModuleList([
            SAGEConv(hidden_dim, hidden_dim) for _ in range(num_relations)
        ])
        self.fusion = nn.Linear(hidden_dim * num_relations, hidden_dim)
        self.classifier = nn.Linear(hidden_dim, out_dim)
        self.dropout = nn.Dropout(dropout)
        self.bn1 = nn.ModuleList([nn.BatchNorm1d(hidden_dim) for _ in range(num_relations)])
        self.bn2 = nn.BatchNorm1d(hidden_dim)
        self.num_relations = num_relations

    def forward(self, x, edge_indices_list):
        """edge_indices_list: list of edge_index tensors, one per relation type."""
        relation_outs = []
        for i in range(min(self.num_relations, len(edge_indices_list))):
            h = F.relu(self.bn1[i](self.relation_convs[i](x, edge_indices_list[i])))
            h = self.dropout(h)
            h = self.relation_convs2[i](h, edge_indices_list[i])
            relation_outs.append(h)
        # Pad if fewer edge types
        while len(relation_outs) < self.num_relations:
            relation_outs.append(torch.zeros_like(relation_outs[0]))
        fused = torch.cat(relation_outs, dim=-1)
        h = F.relu(self.bn2(self.fusion(fused)))
        h = self.dropout(h)
        return self.classifier(h)

    def get_embeddings(self, x, edge_indices_list):
        relation_outs = []
        for i in range(min(self.num_relations, len(edge_indices_list))):
            h = F.relu(self.bn1[i](self.relation_convs[i](x, edge_indices_list[i])))
            h = self.relation_convs2[i](h, edge_indices_list[i])
            relation_outs.append(h)
        while len(relation_outs) < self.num_relations:
            relation_outs.append(torch.zeros_like(relation_outs[0]))
        fused = torch.cat(relation_outs, dim=-1)
        return F.relu(self.bn2(self.fusion(fused)))

print(f'All 6 GNN model classes defined.')
print(f'Hidden dim from GPU config: {GPU_CONFIG["hidden_dim"]}')

In [None]:
# ============================================================
# CELL 9: GNN Training Engine (A100-optimized with AMP)
# ============================================================

@torch.no_grad()
def evaluate_gnn(model, X, y, edge_index, is_hetero=False):
    """Evaluate model, returns metrics + probabilities."""
    model.eval()
    amp_dtype = torch.bfloat16 if GPU_CONFIG['use_bf16'] else torch.float16
    with torch.amp.autocast('cuda', enabled=GPU_CONFIG['use_amp'], dtype=amp_dtype):
        logits = model(X, edge_index)
        probs = F.softmax(logits.float(), dim=1)[:, 1].cpu().numpy()
        preds = logits.argmax(dim=1).cpu().numpy()
    y_np = y.cpu().numpy()
    return compute_metrics(y_np, preds, probs), probs


def train_gnn_model(model_name, model, data, epochs=50, lr=0.005, patience=15, is_hetero=False):
    """Full GNN training loop with A100 AMP, early stopping, history tracking."""
    print(f'\n{"="*60}')
    print(f'Training {model_name}')
    print(f'{"="*60}')
    print(f'  Device: {DEVICE} | Params: {sum(p.numel() for p in model.parameters()):,}')

    # torch.compile for A100 (Linux only, requires Triton)
    compiled_model = model
    if GPU_CONFIG['use_compile'] and os.name != 'nt':
        try:
            compiled_model = torch.compile(model)
            print('  torch.compile: applied')
        except Exception as e:
            print(f'  torch.compile: skipped ({e})')
            compiled_model = model

    # Class weights
    n_pos = (data['y_train'] == 1).sum().float()
    n_neg = (data['y_train'] == 0).sum().float()
    weight = torch.tensor([1.0, (n_neg / n_pos).item()], device=DEVICE)
    criterion = nn.CrossEntropyLoss(weight=weight)

    # Optimizer
    try:
        optimizer = Adam(model.parameters(), lr=lr, weight_decay=1e-4, fused=GPU_CONFIG['use_fused_adam'])
    except TypeError:
        optimizer = Adam(model.parameters(), lr=lr, weight_decay=1e-4)

    scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=5)
    scaler = torch.amp.GradScaler('cuda', enabled=GPU_CONFIG['use_amp'])

    # Edge data for hetero vs homogeneous
    if is_hetero:
        ei_train = data['edge_indices_train']
        ei_val = data['edge_indices_val']
        ei_test = data['edge_indices_test']
    else:
        ei_train = data['edge_index_train']
        ei_val = data['edge_index_val']
        ei_test = data['edge_index_test']

    best_val_prauc = 0.0
    best_epoch = 0
    no_improve = 0
    best_state = None
    history = {'train_loss': [], 'val_prauc': [], 'val_f1': [], 'epoch': []}

    t0 = time.time()
    amp_dtype = torch.bfloat16 if GPU_CONFIG['use_bf16'] else torch.float16

    for epoch in range(1, epochs + 1):
        compiled_model.train()
        with torch.amp.autocast('cuda', enabled=GPU_CONFIG['use_amp'], dtype=amp_dtype):
            logits = compiled_model(data['X_train'], ei_train)
            loss = criterion(logits, data['y_train'])
        optimizer.zero_grad(set_to_none=True)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        train_loss = loss.item()

        if epoch % 5 == 0 or epoch == epochs:
            val_metrics, _ = evaluate_gnn(compiled_model, data['X_val'], data['y_val'], ei_val)
            val_prauc = val_metrics.get('pr_auc', 0)
            val_f1 = val_metrics.get('f1_sar', 0)
            scheduler.step(val_prauc)

            history['train_loss'].append(train_loss)
            history['val_prauc'].append(val_prauc)
            history['val_f1'].append(val_f1)
            history['epoch'].append(epoch)

            print(f'  Epoch {epoch:3d} | Loss: {train_loss:.4f} | Val PR-AUC: {val_prauc:.4f} | Val F1: {val_f1:.4f}')

            if val_prauc > best_val_prauc:
                best_val_prauc = val_prauc
                best_epoch = epoch
                best_state = {k: v.clone() for k, v in model.state_dict().items()}
                no_improve = 0
            else:
                no_improve += 1

            if no_improve >= patience // 5:
                print(f'  Early stopping at epoch {epoch} (best: {best_epoch})')
                break

    elapsed = time.time() - t0
    if best_state:
        model.load_state_dict(best_state)

    test_metrics, test_probs = evaluate_gnn(model, data['X_test'], data['y_test'], ei_test)
    print_metrics(test_metrics, f'{model_name} Test ({elapsed:.1f}s)')

    return test_metrics, model, history, test_probs

print('GNN training engine ready.')

In [None]:
# ============================================================
# CELL 10: Train ALL 5 GNN Models
# ============================================================

HIDDEN = GPU_CONFIG['hidden_dim']
EPOCHS = 100
LR = 0.005
PATIENCE = 15

gnn_results = {}
gnn_models = {}
gnn_histories = {}
gnn_probs = {}

# Prepare separate edge indices for HeteroGNN
def get_hetero_edges(d, split_key):
    """Build list of edge indices for each relation type."""
    edges = []
    for key in ['transfer_edge_index', 'same_wallet_edge_index', 'shared_device_edge_index']:
        if key in d and d[key].shape[1] > 0:
            edges.append(torch.tensor(d[key], dtype=torch.long).to(DEVICE))
        else:
            # Placeholder empty edges
            edges.append(torch.zeros(2, 0, dtype=torch.long, device=DEVICE))
    return edges

hetero_data = dict(node_data)  # Copy base data
hetero_data['edge_indices_train'] = get_hetero_edges(train_d, 'train')
hetero_data['edge_indices_val'] = get_hetero_edges(val_d, 'val')
hetero_data['edge_indices_test'] = get_hetero_edges(test_d, 'test')

# 1. GraphSAGE
print('\n' + '#'*60)
print('# MODEL 1/5: GraphSAGE')
print('#'*60)
sage = GraphSAGEModel(in_dim, HIDDEN, 2, dropout=0.3).to(DEVICE)
m, sage, h, p = train_gnn_model('GraphSAGE', sage, node_data, EPOCHS, LR, PATIENCE)
gnn_results['GraphSAGE'] = m; gnn_models['GraphSAGE'] = sage
gnn_histories['GraphSAGE'] = h; gnn_probs['GraphSAGE'] = p

# 2. GAT
print('\n' + '#'*60)
print('# MODEL 2/5: GAT')
print('#'*60)
gat = GATModel(in_dim, max(HIDDEN // 2, 64), 2, num_heads=4, dropout=0.3).to(DEVICE)  # halved for memory safety
m, gat, h, p = train_gnn_model('GAT', gat, node_data, EPOCHS, LR, PATIENCE)
gnn_results['GAT'] = m; gnn_models['GAT'] = gat
gnn_histories['GAT'] = h; gnn_probs['GAT'] = p

# 3. GCN
print('\n' + '#'*60)
print('# MODEL 3/5: GCN')
print('#'*60)
gcn = GCNModel(in_dim, HIDDEN, 2, dropout=0.3).to(DEVICE)
m, gcn, h, p = train_gnn_model('GCN', gcn, node_data, EPOCHS, LR, PATIENCE)
gnn_results['GCN'] = m; gnn_models['GCN'] = gcn
gnn_histories['GCN'] = h; gnn_probs['GCN'] = p

# 4. GIN
print('\n' + '#'*60)
print('# MODEL 4/5: GIN')
print('#'*60)
gin = GINModel(in_dim, HIDDEN, 2, dropout=0.3).to(DEVICE)
m, gin, h, p = train_gnn_model('GIN', gin, node_data, EPOCHS, LR, PATIENCE)
gnn_results['GIN'] = m; gnn_models['GIN'] = gin
gnn_histories['GIN'] = h; gnn_probs['GIN'] = p

# 5. HeteroGNN
print('\n' + '#'*60)
print('# MODEL 5/5: HeteroGNN (Multi-Relation)')
print('#'*60)
hetero = HeteroGNNModel(in_dim, HIDDEN, 2, num_relations=3, dropout=0.3).to(DEVICE)
m, hetero, h, p = train_gnn_model('HeteroGNN', hetero, hetero_data, EPOCHS, LR, PATIENCE, is_hetero=True)
gnn_results['HeteroGNN'] = m; gnn_models['HeteroGNN'] = hetero
gnn_histories['HeteroGNN'] = h; gnn_probs['HeteroGNN'] = p

# Clear GPU cache
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    
print('\n' + '='*60)
print('ALL GNN MODELS TRAINED SUCCESSFULLY')
print('='*60)

## Part 3: Save ALL Models in Multiple Formats

Saving each GNN model in:
- **PyTorch checkpoint (.pt)** - state_dict + config + metrics + normalization params
- **TorchScript (.torchscript.pt)** - JIT-traced for C++/production deployment
- **ONNX (.onnx)** - Cross-framework interoperability
- **Full model pickle (.pkl)** - For quick Python reloading

In [None]:
# ============================================================
# CELL 11: Save ALL GNN Models in Multiple Formats
# ============================================================

def save_gnn_all_formats(name, model, metrics, data, is_hetero=False):
    """Save a GNN model in .pt, TorchScript, ONNX, and pickle formats."""
    safe_name = name.lower().replace(' ', '_')
    model.eval()
    model_cpu = model.cpu()
    
    # Determine config
    config = {'in_dim': in_dim, 'hidden_dim': HIDDEN, 'out_dim': 2}
    if name == 'GAT':
        config['num_heads'] = 4
    if name == 'HeteroGNN':
        config['num_relations'] = 3
    
    norm_info = {
        'mean': data['mean'].cpu().numpy().tolist(),
        'std': data['std'].cpu().numpy().tolist()
    }
    
    # 1. PyTorch Checkpoint (.pt)
    pt_path = os.path.join(MODELS_DIR, f'{safe_name}_node.pt')
    torch.save({
        'model_state': model_cpu.state_dict(),
        'config': config,
        'metrics': metrics,
        'norm': norm_info,
        'model_class': name,
        'feature_names': FEATURE_NAMES,
        'hidden_dim': HIDDEN,
        'training_info': {
            'epochs': EPOCHS,
            'lr': LR,
            'device': GPU_CONFIG['gpu_name'],
            'mig_profile': GPU_CONFIG['mig_profile'],
        }
    }, pt_path)
    print(f'  [{name}] PyTorch checkpoint: {pt_path}')
    
    # 2. TorchScript (.torchscript.pt)
    ts_path = os.path.join(MODELS_DIR, f'{safe_name}_node.torchscript.pt')
    try:
        # Create dummy inputs on CPU
        dummy_x = torch.randn(100, in_dim)
        if is_hetero:
            dummy_edges = [torch.randint(0, 100, (2, 50)) for _ in range(3)]
            traced = torch.jit.trace(model_cpu, (dummy_x, dummy_edges))
        else:
            dummy_edges = torch.randint(0, 100, (2, 200))
            traced = torch.jit.trace(model_cpu, (dummy_x, dummy_edges))
        traced.save(ts_path)
        print(f'  [{name}] TorchScript: {ts_path}')
    except Exception as e:
        print(f'  [{name}] TorchScript FAILED: {e}')
    
    # 3. ONNX (.onnx)
    onnx_path = os.path.join(MODELS_DIR, f'{safe_name}_node.onnx')
    try:
        dummy_x = torch.randn(100, in_dim)
        if not is_hetero:
            dummy_edges = torch.randint(0, 100, (2, 200))
            torch.onnx.export(
                model_cpu, (dummy_x, dummy_edges), onnx_path,
                input_names=['node_features', 'edge_index'],
                output_names=['logits'],
                dynamic_axes={
                    'node_features': {0: 'num_nodes'},
                    'edge_index': {1: 'num_edges'},
                    'logits': {0: 'num_nodes'}
                },
                opset_version=17
            )
            print(f'  [{name}] ONNX: {onnx_path}')
        else:
            print(f'  [{name}] ONNX: skipped (HeteroGNN has variable inputs)')
    except Exception as e:
        print(f'  [{name}] ONNX FAILED: {e}')
    
    # 4. Pickle (.pkl) - full model object
    pkl_path = os.path.join(MODELS_DIR, f'{safe_name}_full.pkl')
    try:
        with open(pkl_path, 'wb') as f:
            pickle.dump({
                'model': model_cpu,
                'config': config,
                'norm': norm_info,
                'metrics': metrics
            }, f)
        print(f'  [{name}] Pickle: {pkl_path}')
    except Exception as e:
        print(f'  [{name}] Pickle FAILED: {e}')
    
    # Move model back to GPU
    model.to(DEVICE)


print('='*60)
print('SAVING ALL GNN MODELS IN MULTIPLE FORMATS')
print('='*60)

for name, model in gnn_models.items():
    is_h = (name == 'HeteroGNN')
    save_gnn_all_formats(name, model, gnn_results[name], node_data, is_hetero=is_h)
    print()

# Save combined results JSON
all_results = {
    'baseline': baseline_results,
    'gnn': {},
    'gpu_config': GPU_CONFIG,
    'training_config': {'epochs': EPOCHS, 'lr': LR, 'hidden_dim': HIDDEN, 'patience': PATIENCE}
}

def convert_np(obj):
    if isinstance(obj, (np.integer,)): return int(obj)
    elif isinstance(obj, (np.floating,)): return float(obj)
    elif isinstance(obj, np.ndarray): return obj.tolist()
    return obj

for name, m in gnn_results.items():
    all_results['gnn'][name] = m

results_path = os.path.join(RESULTS_DIR, 'all_results.json')
with open(results_path, 'w') as f:
    json.dump(all_results, f, indent=2, default=convert_np)
print(f'\nCombined results: {results_path}')

# Save training histories
hist_path = os.path.join(RESULTS_DIR, 'gnn_training_histories.json')
with open(hist_path, 'w') as f:
    json.dump(gnn_histories, f, indent=2, default=convert_np)
print(f'Training histories: {hist_path}')

# List all saved models
print(f'\n{"="*60}')
print('ALL SAVED MODEL FILES:')
print(f'{"="*60}')
for f in sorted(os.listdir(MODELS_DIR)):
    fpath = os.path.join(MODELS_DIR, f)
    size_mb = os.path.getsize(fpath) / (1024*1024)
    print(f'  {f:<45} {size_mb:>7.2f} MB')

## Part 4: Comprehensive Visualizations

1. Training curves (loss, PR-AUC, F1 per epoch)
2. Model comparison bar charts (all 9 models)
3. ROC & PR curves
4. Confusion matrices (heatmaps)
5. Feature importance (XGBoost)
6. GNN node embedding t-SNE
7. Graph structure visualization
8. Attention weight distribution (GAT)
9. Class distribution & data overview
10. Score distribution histograms

In [None]:
# ============================================================
# CELL 12: VIZ 1 - GNN Training Curves
# ============================================================

fig, axes = plt.subplots(1, 3, figsize=(20, 5))
colors = plt.cm.Set1(np.linspace(0, 1, len(gnn_histories)))

# Loss
ax = axes[0]
for (name, hist), color in zip(gnn_histories.items(), colors):
    ax.plot(hist['epoch'], hist['train_loss'], label=name, color=color, linewidth=2)
ax.set_xlabel('Epoch'); ax.set_ylabel('Training Loss'); ax.set_title('Training Loss')
ax.legend(); ax.grid(True, alpha=0.3)

# PR-AUC
ax = axes[1]
for (name, hist), color in zip(gnn_histories.items(), colors):
    ax.plot(hist['epoch'], hist['val_prauc'], label=name, color=color, linewidth=2, marker='o', markersize=3)
ax.set_xlabel('Epoch'); ax.set_ylabel('Val PR-AUC'); ax.set_title('Validation PR-AUC')
ax.legend(); ax.grid(True, alpha=0.3)

# F1
ax = axes[2]
for (name, hist), color in zip(gnn_histories.items(), colors):
    ax.plot(hist['epoch'], hist['val_f1'], label=name, color=color, linewidth=2, marker='s', markersize=3)
ax.set_xlabel('Epoch'); ax.set_ylabel('Val F1 (SAR)'); ax.set_title('Validation F1 Score (SAR Class)')
ax.legend(); ax.grid(True, alpha=0.3)

plt.suptitle('GNN Training Curves (A100 GPU)', fontsize=14, fontweight='bold', y=1.02)
plt.tight_layout()
plt.savefig(os.path.join(VIZ_DIR, '01_training_curves.png'), dpi=150, bbox_inches='tight')
plt.show()
print('Saved: 01_training_curves.png')

In [None]:
# ============================================================
# CELL 13: VIZ 2 - All Models Comparison Bar Chart
# ============================================================

# Combine all results
all_model_names = []
all_pr_auc = []
all_roc_auc = []
all_f1 = []

for name, m in baseline_results.items():
    all_model_names.append(f'BL:{name}')
    all_pr_auc.append(m.get('pr_auc', 0))
    all_roc_auc.append(m.get('roc_auc', 0))
    all_f1.append(m.get('f1_sar', 0))

for name, m in gnn_results.items():
    all_model_names.append(f'GNN:{name}')
    all_pr_auc.append(m.get('pr_auc', 0))
    all_roc_auc.append(m.get('roc_auc', 0))
    all_f1.append(m.get('f1_sar', 0))

x = np.arange(len(all_model_names))
width = 0.25

fig, ax = plt.subplots(figsize=(16, 6))
bars1 = ax.bar(x - width, all_pr_auc, width, label='PR-AUC', color='#2196F3')
bars2 = ax.bar(x, all_f1, width, label='F1 (SAR)', color='#FF9800')
bars3 = ax.bar(x + width, all_roc_auc, width, label='ROC-AUC', color='#4CAF50')

ax.set_ylabel('Score')
ax.set_title('Model Comparison: All Baselines vs All GNNs', fontsize=14, fontweight='bold')
ax.set_xticks(x)
ax.set_xticklabels(all_model_names, rotation=45, ha='right', fontsize=10)
ax.legend(loc='lower right')
ax.set_ylim(0, 1.05)
ax.grid(axis='y', alpha=0.3)

# Add value labels on bars
for bars in [bars1, bars2, bars3]:
    for bar in bars:
        h = bar.get_height()
        if h > 0:
            ax.annotate(f'{h:.3f}', xy=(bar.get_x() + bar.get_width()/2, h),
                       xytext=(0, 3), textcoords='offset points', ha='center', fontsize=7)

plt.tight_layout()
plt.savefig(os.path.join(VIZ_DIR, '02_model_comparison.png'), dpi=150, bbox_inches='tight')
plt.show()
print('Saved: 02_model_comparison.png')

In [None]:
# ============================================================
# CELL 14: VIZ 3 - ROC & PR Curves (all models on same plot)
# ============================================================

y_test_np = test_d['account_labels']

fig, axes = plt.subplots(1, 2, figsize=(16, 6))

# Collect all model probabilities
all_probs = {}
for name, bp in baseline_preds.items():
    if name == 'y_test': continue
    if bp.get('prob') is not None:
        all_probs[f'BL:{name}'] = bp['prob']
for name, p in gnn_probs.items():
    all_probs[f'GNN:{name}'] = p

colors_map = plt.cm.tab10(np.linspace(0, 1, len(all_probs)))

# ROC Curve
ax = axes[0]
for (name, probs), color in zip(all_probs.items(), colors_map):
    fpr, tpr, _ = roc_curve(y_test_np, probs)
    auc_val = roc_auc_score(y_test_np, probs)
    ax.plot(fpr, tpr, label=f'{name} (AUC={auc_val:.4f})', color=color, linewidth=1.5)
ax.plot([0,1], [0,1], 'k--', alpha=0.3)
ax.set_xlabel('False Positive Rate'); ax.set_ylabel('True Positive Rate')
ax.set_title('ROC Curves - All Models'); ax.legend(fontsize=8, loc='lower right')
ax.grid(True, alpha=0.3)

# PR Curve
ax = axes[1]
for (name, probs), color in zip(all_probs.items(), colors_map):
    prec, rec, _ = precision_recall_curve(y_test_np, probs)
    ap = average_precision_score(y_test_np, probs)
    ax.plot(rec, prec, label=f'{name} (AP={ap:.4f})', color=color, linewidth=1.5)
baseline_ratio = y_test_np.sum() / len(y_test_np)
ax.axhline(y=baseline_ratio, color='gray', linestyle='--', alpha=0.5, label=f'Random ({baseline_ratio:.3f})')
ax.set_xlabel('Recall'); ax.set_ylabel('Precision')
ax.set_title('Precision-Recall Curves - All Models'); ax.legend(fontsize=8, loc='upper right')
ax.grid(True, alpha=0.3)

plt.suptitle('ROC & PR Curves Comparison', fontsize=14, fontweight='bold', y=1.02)
plt.tight_layout()
plt.savefig(os.path.join(VIZ_DIR, '03_roc_pr_curves.png'), dpi=150, bbox_inches='tight')
plt.show()
print('Saved: 03_roc_pr_curves.png')

In [None]:
# ============================================================
# CELL 15: VIZ 4 - Confusion Matrices Heatmaps (GNN models)
# ============================================================

n_models = len(gnn_results)
fig, axes = plt.subplots(1, n_models, figsize=(4 * n_models, 4))
if n_models == 1:
    axes = [axes]

for ax, (name, m) in zip(axes, gnn_results.items()):
    cm = np.array(m.get('confusion_matrix', [[0,0],[0,0]]))
    if HAS_SNS:
        sns.heatmap(cm, annot=True, fmt=',d', cmap='Blues', ax=ax,
                    xticklabels=['Normal', 'SAR'], yticklabels=['Normal', 'SAR'])
    else:
        im = ax.imshow(cm, cmap='Blues')
        for i in range(2):
            for j in range(2):
                ax.text(j, i, f'{cm[i,j]:,}', ha='center', va='center', fontsize=10)
        ax.set_xticks([0,1]); ax.set_xticklabels(['Normal','SAR'])
        ax.set_yticks([0,1]); ax.set_yticklabels(['Normal','SAR'])
    ax.set_xlabel('Predicted'); ax.set_ylabel('Actual')
    ax.set_title(f'{name}\nF1={m.get("f1_sar",0):.3f}')

plt.suptitle('Confusion Matrices - GNN Models', fontsize=14, fontweight='bold', y=1.05)
plt.tight_layout()
plt.savefig(os.path.join(VIZ_DIR, '04_confusion_matrices.png'), dpi=150, bbox_inches='tight')
plt.show()
print('Saved: 04_confusion_matrices.png')

In [None]:
# ============================================================
# CELL 16: VIZ 5 - Feature Importance (XGBoost)
# ============================================================

if HAS_XGB and 'xgboost' in baseline_models:
    imp = baseline_models['xgboost'].feature_importances_
    sorted_idx = np.argsort(imp)
    
    fig, ax = plt.subplots(figsize=(10, 6))
    names = [FEATURE_NAMES[i] if i < len(FEATURE_NAMES) else f'feat_{i}' for i in sorted_idx]
    ax.barh(names, imp[sorted_idx], color='#2196F3')
    ax.set_xlabel('Feature Importance (Gain)')
    ax.set_title('XGBoost Feature Importance for Mule Detection', fontsize=13, fontweight='bold')
    ax.grid(axis='x', alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(os.path.join(VIZ_DIR, '05_feature_importance.png'), dpi=150, bbox_inches='tight')
    plt.show()
    print('Saved: 05_feature_importance.png')
else:
    print('XGBoost not available, skipping feature importance plot.')

In [None]:
# ============================================================
# CELL 17: VIZ 6 - GNN Node Embedding t-SNE Visualization
# ============================================================

print('Computing GNN embeddings for t-SNE visualization...')
print('(Using best model, sampling 5000 nodes for speed)')

# Pick the best GNN by PR-AUC
best_gnn_name = max(gnn_results, key=lambda k: gnn_results[k].get('pr_auc', 0))
best_gnn = gnn_models[best_gnn_name]
print(f'  Best GNN: {best_gnn_name}')

# Get embeddings
best_gnn.eval()
with torch.no_grad():
    if best_gnn_name == 'HeteroGNN':
        emb = best_gnn.get_embeddings(node_data['X_test'], hetero_data['edge_indices_test'])
    else:
        emb = best_gnn.get_embeddings(node_data['X_test'], node_data['edge_index_test'])
    emb_np = emb.cpu().numpy()

labels_np = node_data['y_test'].cpu().numpy()

# Sample for t-SNE (full dataset too large)
n_sample = 5000
sar_idx = np.where(labels_np == 1)[0]
normal_idx = np.where(labels_np == 0)[0]
# Ensure balanced sampling
n_sar_sample = min(len(sar_idx), n_sample // 2)
n_normal_sample = n_sample - n_sar_sample
sample_idx = np.concatenate([
    np.random.choice(sar_idx, n_sar_sample, replace=False),
    np.random.choice(normal_idx, n_normal_sample, replace=False)
])
np.random.shuffle(sample_idx)

print(f'  Running t-SNE on {len(sample_idx)} nodes...')
tsne = TSNE(n_components=2, random_state=42, perplexity=30, n_iter=1000)
emb_2d = tsne.fit_transform(emb_np[sample_idx])
labels_sample = labels_np[sample_idx]

fig, ax = plt.subplots(figsize=(10, 8))
scatter_normal = ax.scatter(
    emb_2d[labels_sample == 0, 0], emb_2d[labels_sample == 0, 1],
    c='#2196F3', alpha=0.3, s=10, label='Normal'
)
scatter_sar = ax.scatter(
    emb_2d[labels_sample == 1, 0], emb_2d[labels_sample == 1, 1],
    c='#F44336', alpha=0.6, s=25, label='SAR/Mule', edgecolors='darkred', linewidths=0.5
)
ax.set_xlabel('t-SNE 1'); ax.set_ylabel('t-SNE 2')
ax.set_title(f'{best_gnn_name} Node Embeddings (t-SNE)\n'
             f'{n_sar_sample} SAR + {n_normal_sample} Normal accounts',
             fontsize=13, fontweight='bold')
ax.legend(markerscale=3, fontsize=11)
ax.grid(True, alpha=0.2)

plt.tight_layout()
plt.savefig(os.path.join(VIZ_DIR, '06_tsne_embeddings.png'), dpi=150, bbox_inches='tight')
plt.show()
print('Saved: 06_tsne_embeddings.png')

In [None]:
# ============================================================
# CELL 18: VIZ 7 - Graph Structure Visualization (subgraph)
# ============================================================

if HAS_NX:
    print('Visualizing transaction graph subgraph around a SAR node...')
    
    edges = test_d['transfer_edge_index']
    labels_np = test_d['account_labels']
    
    # Find a SAR node with many connections
    sar_nodes = np.where(labels_np == 1)[0]
    # Count edges per SAR node
    sar_degrees = []
    for s in sar_nodes[:200]:  # Check first 200
        deg = np.sum(edges[0] == s) + np.sum(edges[1] == s)
        sar_degrees.append((s, deg))
    sar_degrees.sort(key=lambda x: -x[1])
    center_node = sar_degrees[0][0]
    
    # Build 2-hop subgraph
    G = nx.DiGraph()
    visited = {center_node}
    frontier = {center_node}
    
    for hop in range(2):
        next_frontier = set()
        for i in range(min(edges.shape[1], 200000)):
            src, dst = int(edges[0, i]), int(edges[1, i])
            if src in frontier:
                next_frontier.add(dst)
                G.add_edge(src, dst)
            if dst in frontier:
                next_frontier.add(src)
                G.add_edge(src, dst)
        frontier = next_frontier - visited
        visited.update(frontier)
        if len(visited) > 300:
            break
    
    # Limit to 200 nodes for readability
    if len(G.nodes) > 200:
        nodes_to_keep = list(G.nodes)[:200]
        G = G.subgraph(nodes_to_keep).copy()
    
    # Color nodes
    node_colors = []
    node_sizes = []
    for n in G.nodes:
        if n == center_node:
            node_colors.append('#FF0000')
            node_sizes.append(200)
        elif n < len(labels_np) and labels_np[n] == 1:
            node_colors.append('#FF6B6B')
            node_sizes.append(80)
        else:
            node_colors.append('#64B5F6')
            node_sizes.append(30)
    
    fig, ax = plt.subplots(figsize=(14, 10))
    pos = nx.spring_layout(G, k=0.3, iterations=50, seed=42)
    nx.draw_networkx_edges(G, pos, alpha=0.15, edge_color='gray', arrows=True,
                           arrowsize=5, ax=ax)
    nx.draw_networkx_nodes(G, pos, node_color=node_colors, node_size=node_sizes,
                           alpha=0.8, ax=ax)
    
    # Legend
    legend_elements = [
        mpatches.Patch(color='#FF0000', label=f'Center SAR Node ({center_node})'),
        mpatches.Patch(color='#FF6B6B', label='SAR/Mule Accounts'),
        mpatches.Patch(color='#64B5F6', label='Normal Accounts'),
    ]
    ax.legend(handles=legend_elements, loc='upper left', fontsize=10)
    ax.set_title(f'Transaction Graph: 2-hop Neighborhood of SAR Account {center_node}\n'
                 f'{len(G.nodes)} nodes, {len(G.edges)} edges',
                 fontsize=13, fontweight='bold')
    ax.axis('off')
    
    plt.tight_layout()
    plt.savefig(os.path.join(VIZ_DIR, '07_graph_structure.png'), dpi=150, bbox_inches='tight')
    plt.show()
    print('Saved: 07_graph_structure.png')
else:
    print('networkx not available, skipping graph visualization.')

In [None]:
# ============================================================
# CELL 19: VIZ 8 - Score Distribution & Edge Type Analysis
# ============================================================

fig, axes = plt.subplots(2, 2, figsize=(16, 12))

# 8a: Score distribution for best GNN
ax = axes[0, 0]
best_probs = gnn_probs[best_gnn_name]
sar_mask = test_d['account_labels'] == 1
ax.hist(best_probs[~sar_mask], bins=50, alpha=0.7, label='Normal', color='#2196F3', density=True)
ax.hist(best_probs[sar_mask], bins=50, alpha=0.7, label='SAR/Mule', color='#F44336', density=True)
ax.axvline(x=0.5, color='black', linestyle='--', alpha=0.5, label='Threshold=0.5')
ax.set_xlabel('Predicted SAR Probability')
ax.set_ylabel('Density')
ax.set_title(f'{best_gnn_name} Score Distribution')
ax.legend()
ax.grid(True, alpha=0.3)

# 8b: Edge type distribution
ax = axes[0, 1]
edge_types = []
edge_counts = []
for split_name in ['train', 'val', 'test']:
    stats_path = os.path.join(BASE_DIR, split_name, 'graph_data', 'graph_stats.json')
    if os.path.exists(stats_path):
        with open(stats_path) as f:
            stats = json.load(f)
        if split_name == 'test':
            for etype, count in stats['edges'].items():
                if count > 0:
                    edge_types.append(etype)
                    edge_counts.append(count)

if edge_types:
    y_pos = np.arange(len(edge_types))
    colors_edges = plt.cm.Set2(np.linspace(0, 1, len(edge_types)))
    ax.barh(y_pos, edge_counts, color=colors_edges)
    ax.set_yticks(y_pos)
    ax.set_yticklabels(edge_types, fontsize=9)
    ax.set_xlabel('Number of Edges')
    ax.set_title('Edge Type Distribution (Test Split)')
    for i, v in enumerate(edge_counts):
        ax.text(v + max(edge_counts)*0.01, i, f'{v:,}', va='center', fontsize=8)
    ax.grid(axis='x', alpha=0.3)

# 8c: Class distribution across splits
ax = axes[1, 0]
splits_names = ['train', 'val', 'test']
sar_counts = []
normal_counts = []
for split_name in splits_names:
    labels = np.load(os.path.join(BASE_DIR, split_name, 'graph_data', 'account_labels.npy'))
    sar_counts.append(labels.sum())
    normal_counts.append(len(labels) - labels.sum())

x_pos = np.arange(len(splits_names))
ax.bar(x_pos - 0.2, normal_counts, 0.4, label='Normal', color='#2196F3')
ax.bar(x_pos + 0.2, sar_counts, 0.4, label='SAR', color='#F44336')
ax.set_xticks(x_pos)
ax.set_xticklabels(splits_names)
ax.set_ylabel('Number of Accounts')
ax.set_title('Class Distribution Across Splits')
ax.legend()
for i in range(len(splits_names)):
    ax.text(i - 0.2, normal_counts[i] + 2000, f'{normal_counts[i]:,}', ha='center', fontsize=8)
    ax.text(i + 0.2, sar_counts[i] + 2000, f'{sar_counts[i]:,}', ha='center', fontsize=8)
ax.grid(axis='y', alpha=0.3)

# 8d: Recall@Kx comparison
ax = axes[1, 1]
model_names_rk = []
recall_1x = []
recall_2x = []
recall_5x = []
for name, m in {**baseline_results, **gnn_results}.items():
    if 'recall@1x' in m:
        model_names_rk.append(name)
        recall_1x.append(m.get('recall@1x', 0))
        recall_2x.append(m.get('recall@2x', 0))
        recall_5x.append(m.get('recall@5x', 0))

if model_names_rk:
    x_rk = np.arange(len(model_names_rk))
    w = 0.25
    ax.bar(x_rk - w, recall_1x, w, label='Recall@1x', color='#FF5722')
    ax.bar(x_rk, recall_2x, w, label='Recall@2x', color='#FF9800')
    ax.bar(x_rk + w, recall_5x, w, label='Recall@5x', color='#FFC107')
    ax.set_xticks(x_rk)
    ax.set_xticklabels(model_names_rk, rotation=45, ha='right', fontsize=8)
    ax.set_ylabel('Recall')
    ax.set_title('Recall@Kx (Top-K Retrieval Quality)')
    ax.legend()
    ax.set_ylim(0, 1.05)
    ax.grid(axis='y', alpha=0.3)

plt.suptitle('Data & Score Analysis', fontsize=14, fontweight='bold', y=1.02)
plt.tight_layout()
plt.savefig(os.path.join(VIZ_DIR, '08_score_analysis.png'), dpi=150, bbox_inches='tight')
plt.show()
print('Saved: 08_score_analysis.png')

In [None]:
# ============================================================
# CELL 20: VIZ 9 - GAT Attention Weights Analysis
# ============================================================

print('Analyzing GAT attention weight distribution...')

gat_model = gnn_models.get('GAT')
if gat_model is not None:
    gat_model.eval()
    X_test_t = node_data['X_test']
    ei_test = node_data['edge_index_test']
    
    # Extract attention from first GAT layer
    with torch.no_grad():
        N = X_test_t.size(0)
        src, dst = ei_test[0], ei_test[1]
        conv1 = gat_model.conv1
        h = conv1.W(X_test_t).view(N, conv1.num_heads, conv1.head_dim)
        score_src = (h * conv1.a_src).sum(dim=-1)
        score_dst = (h * conv1.a_dst).sum(dim=-1)
        e = F.leaky_relu(score_src[src] + score_dst[dst], 0.2)
        e_max = torch.zeros(N, conv1.num_heads, device=DEVICE)
        e_max.index_reduce_(0, dst, e, 'amax', include_self=True)
        e_exp = torch.exp(e - e_max[dst])
        e_sum = torch.zeros(N, conv1.num_heads, device=DEVICE)
        e_sum.index_add_(0, dst, e_exp)
        attention = (e_exp / e_sum[dst].clamp(min=1e-9)).cpu().numpy()
    
    labels_test = node_data['y_test'].cpu().numpy()
    
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))
    
    # Attention distribution per head
    ax = axes[0]
    for head in range(min(4, attention.shape[1])):
        ax.hist(attention[:, head], bins=50, alpha=0.5, label=f'Head {head+1}', density=True)
    ax.set_xlabel('Attention Weight')
    ax.set_ylabel('Density')
    ax.set_title('Attention Weight Distribution (Layer 1)')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    # Attention on SAR vs Normal edges
    ax = axes[1]
    # Edges where destination is SAR
    sar_dst_mask = labels_test[dst.cpu().numpy()] == 1
    normal_dst_mask = ~sar_dst_mask
    mean_att = attention.mean(axis=1)  # Average across heads
    
    ax.hist(mean_att[normal_dst_mask], bins=50, alpha=0.6, label='To Normal', color='#2196F3', density=True)
    ax.hist(mean_att[sar_dst_mask], bins=50, alpha=0.6, label='To SAR', color='#F44336', density=True)
    ax.set_xlabel('Mean Attention Weight')
    ax.set_ylabel('Density')
    ax.set_title('Attention: Edges to SAR vs Normal Nodes')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    # Heatmap: mean attention by head for SAR vs Normal
    ax = axes[2]
    att_sar_mean = attention[sar_dst_mask].mean(axis=0)
    att_normal_mean = attention[normal_dst_mask].mean(axis=0)
    att_matrix = np.stack([att_normal_mean, att_sar_mean])
    
    if HAS_SNS:
        sns.heatmap(att_matrix, annot=True, fmt='.4f', cmap='YlOrRd', ax=ax,
                    xticklabels=[f'Head {i+1}' for i in range(att_matrix.shape[1])],
                    yticklabels=['Normal', 'SAR'])
    else:
        im = ax.imshow(att_matrix, cmap='YlOrRd', aspect='auto')
        for i in range(2):
            for j in range(att_matrix.shape[1]):
                ax.text(j, i, f'{att_matrix[i,j]:.4f}', ha='center', va='center')
        ax.set_xticks(range(att_matrix.shape[1]))
        ax.set_xticklabels([f'Head {i+1}' for i in range(att_matrix.shape[1])])
        ax.set_yticks([0,1])
        ax.set_yticklabels(['Normal', 'SAR'])
    ax.set_title('Mean Attention by Head\n(Edges to SAR vs Normal)')
    
    plt.suptitle('GAT Attention Analysis', fontsize=14, fontweight='bold', y=1.02)
    plt.tight_layout()
    plt.savefig(os.path.join(VIZ_DIR, '09_gat_attention.png'), dpi=150, bbox_inches='tight')
    plt.show()
    print('Saved: 09_gat_attention.png')
else:
    print('GAT model not available.')

In [None]:
# ============================================================
# CELL 21: VIZ 10 - Feature Correlation & Embedding Comparison
# ============================================================

fig, axes = plt.subplots(1, 2, figsize=(18, 7))

# 10a: Feature correlation heatmap
ax = axes[0]
features_test = test_d['account_features']
feat_df = pd.DataFrame(features_test, columns=FEATURE_NAMES[:features_test.shape[1]])
corr = feat_df.corr()
if HAS_SNS:
    sns.heatmap(corr, annot=True, fmt='.2f', cmap='RdBu_r', center=0, ax=ax,
                square=True, linewidths=0.5, cbar_kws={'shrink': 0.8})
else:
    im = ax.imshow(corr.values, cmap='RdBu_r', vmin=-1, vmax=1)
    ax.set_xticks(range(len(corr.columns)))
    ax.set_xticklabels(corr.columns, rotation=90, fontsize=7)
    ax.set_yticks(range(len(corr.columns)))
    ax.set_yticklabels(corr.columns, fontsize=7)
    plt.colorbar(im, ax=ax)
ax.set_title('Account Feature Correlations', fontsize=12, fontweight='bold')
ax.tick_params(labelsize=7)

# 10b: t-SNE comparison of all GNN embeddings (2x3 subplot would be too complex, use overlay)
ax = axes[1]
# Compare top 3 GNN models by PR-AUC
sorted_gnns = sorted(gnn_results.items(), key=lambda x: x[1].get('pr_auc', 0), reverse=True)[:3]
colors_emb = ['#F44336', '#2196F3', '#4CAF50']

for i, (name, _) in enumerate(sorted_gnns):
    model_i = gnn_models[name]
    model_i.eval()
    with torch.no_grad():
        if name == 'HeteroGNN':
            emb_i = model_i.get_embeddings(node_data['X_test'], hetero_data['edge_indices_test'])
        else:
            emb_i = model_i.get_embeddings(node_data['X_test'], node_data['edge_index_test'])
    emb_i_np = emb_i.cpu().numpy()
    
    # Sample just SAR nodes for clarity
    sar_idx_i = np.where(test_d['account_labels'] == 1)[0]
    sample_i = np.random.choice(sar_idx_i, min(500, len(sar_idx_i)), replace=False)
    
    tsne_i = TSNE(n_components=2, random_state=42 + i, perplexity=20, n_iter=500)
    emb_2d_i = tsne_i.fit_transform(emb_i_np[sample_i])
    
    ax.scatter(emb_2d_i[:, 0], emb_2d_i[:, 1], c=colors_emb[i], alpha=0.4, s=15,
               label=f'{name} (SAR nodes)')

ax.set_xlabel('t-SNE 1'); ax.set_ylabel('t-SNE 2')
ax.set_title('SAR Node Embeddings: Top 3 GNN Models', fontsize=12, fontweight='bold')
ax.legend(fontsize=10)
ax.grid(True, alpha=0.2)

plt.tight_layout()
plt.savefig(os.path.join(VIZ_DIR, '10_feature_embedding_analysis.png'), dpi=150, bbox_inches='tight')
plt.show()
print('Saved: 10_feature_embedding_analysis.png')

## Part 5: Final Summary & GPU Memory Report

In [None]:
# ============================================================
# CELL 22: Final Summary Table & GPU Report
# ============================================================

print('='*80)
print('                    GHOSTBUSTERS - FINAL RESULTS SUMMARY')
print('='*80)

# Combined results table
print(f'\n{"Model":<25} {"Type":<10} {"PR-AUC":>8} {"ROC-AUC":>8} {"F1(SAR)":>8} {"Prec":>8} {"Recall":>8} {"R@1x":>6}')
print('-'*80)

for name, m in baseline_results.items():
    print(f'{name:<25} {"Baseline":<10} {m.get("pr_auc",0):>8.4f} {m.get("roc_auc",0):>8.4f} '
          f'{m.get("f1_sar",0):>8.4f} {m.get("precision_sar",0):>8.4f} '
          f'{m.get("recall_sar",0):>8.4f} {m.get("recall@1x",0):>6.3f}')

print('-'*80)

for name, m in gnn_results.items():
    print(f'{name:<25} {"GNN":<10} {m.get("pr_auc",0):>8.4f} {m.get("roc_auc",0):>8.4f} '
          f'{m.get("f1_sar",0):>8.4f} {m.get("precision_sar",0):>8.4f} '
          f'{m.get("recall_sar",0):>8.4f} {m.get("recall@1x",0):>6.3f}')

# Best models
print(f'\n{"="*80}')
best_bl_name = max(baseline_results, key=lambda k: baseline_results[k].get('pr_auc', 0))
best_gn_name = max(gnn_results, key=lambda k: gnn_results[k].get('pr_auc', 0))
print(f'Best Baseline: {best_bl_name} (PR-AUC: {baseline_results[best_bl_name].get("pr_auc",0):.4f})')
print(f'Best GNN:      {best_gn_name} (PR-AUC: {gnn_results[best_gn_name].get("pr_auc",0):.4f})')

# GPU memory summary
if torch.cuda.is_available():
    print(f'\n{"="*80}')
    print('GPU MEMORY REPORT')
    print(f'{"="*80}')
    print(f'  GPU: {GPU_CONFIG["gpu_name"]}')
    print(f'  MIG Profile: {GPU_CONFIG["mig_profile"]}')
    print(f'  Peak Memory Allocated: {torch.cuda.max_memory_allocated()/1024**3:.2f} GB')
    print(f'  Peak Memory Reserved:  {torch.cuda.max_memory_reserved()/1024**3:.2f} GB')
    print(f'  Current Memory:        {torch.cuda.memory_allocated()/1024**3:.2f} GB')

# Saved files summary
print(f'\n{"="*80}')
print('SAVED FILES')
print(f'{"="*80}')

print('\nModels:')
for f in sorted(os.listdir(MODELS_DIR)):
    fpath = os.path.join(MODELS_DIR, f)
    size_mb = os.path.getsize(fpath) / (1024*1024)
    print(f'  models/{f:<42} {size_mb:>7.2f} MB')

print('\nResults:')
for f in sorted(os.listdir(RESULTS_DIR)):
    fpath = os.path.join(RESULTS_DIR, f)
    size_mb = os.path.getsize(fpath) / (1024*1024)
    print(f'  results/{f:<41} {size_mb:>7.2f} MB')

print('\nVisualizations:')
for f in sorted(os.listdir(VIZ_DIR)):
    fpath = os.path.join(VIZ_DIR, f)
    size_mb = os.path.getsize(fpath) / (1024*1024)
    print(f'  visualizations/{f:<34} {size_mb:>7.2f} MB')

print(f'\n{"="*80}')
print('PIPELINE COMPLETE - All models trained, saved, and visualized.')
print(f'{"="*80}')