# E9: Wallet Fusion - GNN Embeddings + Tabular Features

**Goal:** Combine E7-A3 GNN embeddings with tabular features using XGBoost fusion

**Hypothesis:** GNN (relational) + Tabular (statistical) > Either alone

**Date:** November 11, 2025

## Setup

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
import json
from pathlib import Path
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import (
    precision_recall_curve, auc, roc_auc_score, 
    f1_score, roc_curve
)
import xgboost as xgb

# Set random seeds
np.random.seed(42)
torch.manual_seed(42)

print("Libraries imported successfully")

## Step 1: Extract E7-A3 Embeddings

In [None]:
# Load heterogeneous graph
hetero_data = torch.load('/kaggle/input/elliptic-hetero-graph/hetero_graph.pt')

print("Heterogeneous graph loaded:")
print(f"  Transactions: {hetero_data['transaction'].x.shape[0]}")
print(f"  Addresses: {hetero_data['address'].x.shape[0]}")
print(f"  Edge types: {len(hetero_data.edge_types)}")

In [None]:
# Define Simple-HHGTN model (from E7-A3)
from torch_geometric.nn import HeteroConv, SAGEConv

class SimpleHHGTN(nn.Module):
    def __init__(self, hidden_dim=128, num_layers=2, dropout=0.4):
        super().__init__()
        
        # Input projections
        self.tx_proj = nn.Linear(93, hidden_dim)
        self.addr_proj = nn.Linear(55, hidden_dim)
        
        # HeteroConv layers
        self.convs = nn.ModuleList([
            HeteroConv({
                ('transaction', 'tx_to_tx', 'transaction'): SAGEConv(-1, hidden_dim),
                ('address', 'addr_to_tx', 'transaction'): SAGEConv(-1, hidden_dim),
                ('transaction', 'tx_to_addr', 'address'): SAGEConv(-1, hidden_dim),
                ('address', 'addr_to_addr', 'address'): SAGEConv(-1, hidden_dim),
            }, aggr='sum')
            for _ in range(num_layers)
        ])
        
        self.dropout = dropout
        
        # Classifier
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim, 64),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(64, 2)
        )
    
    def get_embeddings(self, x_dict, edge_index_dict):
        """Extract embeddings before classification"""
        # Project inputs
        x_dict = {
            'transaction': self.tx_proj(x_dict['transaction']),
            'address': self.addr_proj(x_dict['address'])
        }
        
        # Message passing
        for conv in self.convs:
            x_dict = conv(x_dict, edge_index_dict)
            x_dict = {key: F.dropout(F.relu(x), p=self.dropout, training=False) 
                     for key, x in x_dict.items()}
        
        return x_dict
    
    def forward(self, x_dict, edge_index_dict):
        # Get embeddings
        embeddings = self.get_embeddings(x_dict, edge_index_dict)
        
        # Classify transactions only
        return self.classifier(embeddings['transaction'])

print("Model architecture defined")

In [None]:
# Load trained E7-A3 checkpoint
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Initialize model
model = SimpleHHGTN(hidden_dim=128, num_layers=2, dropout=0.4)
model.to(device)

# Load checkpoint
checkpoint = torch.load('/kaggle/input/e7-a3-checkpoint/a3_best.pt', map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

print(f"E7-A3 model loaded (best epoch: {checkpoint.get('epoch', 'N/A')})")
print(f"Best val PR-AUC: {checkpoint.get('best_val_pr_auc', 'N/A'):.4f}")

In [None]:
# Extract embeddings
print("Extracting embeddings from E7-A3 model...")

hetero_data = hetero_data.to(device)

with torch.no_grad():
    embeddings = model.get_embeddings(
        hetero_data.x_dict,
        hetero_data.edge_index_dict
    )
    
    # Move to CPU and convert to numpy
    tx_embeddings = embeddings['transaction'].cpu().numpy()
    addr_embeddings = embeddings['address'].cpu().numpy()

print(f"Transaction embeddings: {tx_embeddings.shape}")  # [203769, 128]
print(f"Address embeddings: {addr_embeddings.shape}")      # [100000, 128]

# Save embeddings
np.save('e9_tx_embeddings.npy', tx_embeddings)
np.save('e9_addr_embeddings.npy', addr_embeddings)
print("Embeddings saved")

## Step 2: Load Tabular Features

In [None]:
# Load transaction features (AF1-AF93)
tx_features_df = pd.read_csv('/kaggle/input/elliptic-plus-plus/txs_features.csv')

# Extract local features only (columns 2-94: Time step + AF1-AF93)
# Skip Time step (column 1), use AF1-AF93 (columns 2-94)
tx_features = tx_features_df.iloc[:, 2:95].values

print(f"Transaction features loaded: {tx_features.shape}")  # [203769, 93]
print(f"Feature range: [{tx_features.min():.2f}, {tx_features.max():.2f}]")

In [None]:
# Load labels and splits
labels_df = pd.read_csv('/kaggle/input/elliptic-plus-plus/txs_classes.csv')
with open('/kaggle/input/elliptic-splits/splits.json') as f:
    splits = json.load(f)

# Convert labels to binary (1=fraud, 2=licit → 1=fraud, 0=licit)
labels = labels_df['class'].values
y = (labels == 1).astype(int)

# Get split masks
train_mask = np.array(splits['train'])
val_mask = np.array(splits['val'])
test_mask = np.array(splits['test'])

print(f"Labels: Fraud={(y==1).sum()}, Licit={(y==0).sum()}, Unknown={(labels==3).sum()}")
print(f"Splits: Train={train_mask.sum()}, Val={val_mask.sum()}, Test={test_mask.sum()}")

## Step 3: Create Fusion Features

In [None]:
# Normalize tabular features (fit on train, transform all)
scaler = StandardScaler()
tx_features_norm = scaler.fit_transform(tx_features[train_mask])
tx_features_norm_all = scaler.transform(tx_features)

print(f"Normalized features: {tx_features_norm_all.shape}")
print(f"  Mean: {tx_features_norm_all[train_mask].mean():.4f}")
print(f"  Std: {tx_features_norm_all[train_mask].std():.4f}")

In [None]:
# Create fusion features (embeddings + tabular)
tx_fusion = np.concatenate([tx_embeddings, tx_features_norm_all], axis=1)

print(f"Fusion features created: {tx_fusion.shape}")  # [203769, 128+93=221]
print(f"  Embeddings: 128 dims")
print(f"  Tabular: 93 dims")
print(f"  Total: 221 dims")

## Step 4: Train Three XGBoost Models

In [None]:
# Calculate class weight
pos_weight = (y[train_mask] == 0).sum() / (y[train_mask] == 1).sum()
print(f"Class weight (pos_weight): {pos_weight:.2f}")

# XGBoost parameters
xgb_params = {
    'max_depth': 6,
    'learning_rate': 0.1,
    'n_estimators': 100,
    'objective': 'binary:logistic',
    'eval_metric': 'logloss',
    'scale_pos_weight': pos_weight,
    'random_state': 42,
    'tree_method': 'hist',
    'device': 'cuda' if torch.cuda.is_available() else 'cpu'
}

print(f"XGBoost device: {xgb_params['device']}")

In [None]:
# Model 1: Tabular Only
print("\n" + "="*60)
print("Training Model 1: Tabular Only (AF1-AF93)")
print("="*60)

model_tabular = xgb.XGBClassifier(**xgb_params)
model_tabular.fit(
    tx_features_norm_all[train_mask], 
    y[train_mask],
    eval_set=[(tx_features_norm_all[val_mask], y[val_mask])],
    verbose=10
)

pred_tabular = model_tabular.predict_proba(tx_features_norm_all[test_mask])[:, 1]
print(f"\nTabular model trained. Test predictions range: [{pred_tabular.min():.4f}, {pred_tabular.max():.4f}]")

In [None]:
# Model 2: Embeddings Only
print("\n" + "="*60)
print("Training Model 2: Embeddings Only (GNN 128-dim)")
print("="*60)

model_embeddings = xgb.XGBClassifier(**xgb_params)
model_embeddings.fit(
    tx_embeddings[train_mask], 
    y[train_mask],
    eval_set=[(tx_embeddings[val_mask], y[val_mask])],
    verbose=10
)

pred_embeddings = model_embeddings.predict_proba(tx_embeddings[test_mask])[:, 1]
print(f"\nEmbeddings model trained. Test predictions range: [{pred_embeddings.min():.4f}, {pred_embeddings.max():.4f}]")

In [None]:
# Model 3: Fusion
print("\n" + "="*60)
print("Training Model 3: Fusion (Embeddings + Tabular, 221-dim)")
print("="*60)

model_fusion = xgb.XGBClassifier(**xgb_params)
model_fusion.fit(
    tx_fusion[train_mask], 
    y[train_mask],
    eval_set=[(tx_fusion[val_mask], y[val_mask])],
    verbose=10
)

pred_fusion = model_fusion.predict_proba(tx_fusion[test_mask])[:, 1]
print(f"\nFusion model trained. Test predictions range: [{pred_fusion.min():.4f}, {pred_fusion.max():.4f}]")

## Step 5: Evaluate & Compare

In [None]:
def compute_metrics(y_true, y_pred_proba):
    """Compute PR-AUC, ROC-AUC, F1"""
    # PR-AUC
    precision, recall, _ = precision_recall_curve(y_true, y_pred_proba)
    pr_auc = auc(recall, precision)
    
    # ROC-AUC
    roc_auc = roc_auc_score(y_true, y_pred_proba)
    
    # F1 at optimal threshold
    fpr, tpr, thresholds = roc_curve(y_true, y_pred_proba)
    optimal_idx = np.argmax(tpr - fpr)
    optimal_threshold = thresholds[optimal_idx]
    y_pred_binary = (y_pred_proba >= optimal_threshold).astype(int)
    f1 = f1_score(y_true, y_pred_binary)
    
    return {
        'pr_auc': float(pr_auc),
        'roc_auc': float(roc_auc),
        'f1': float(f1),
        'threshold': float(optimal_threshold)
    }

# Compute metrics for all three models
y_test = y[test_mask]

results = {
    'tabular_only': compute_metrics(y_test, pred_tabular),
    'embeddings_only': compute_metrics(y_test, pred_embeddings),
    'fusion': compute_metrics(y_test, pred_fusion)
}

print("\n" + "="*70)
print("E9 WALLET FUSION RESULTS (Transaction-Level Fraud Detection)")
print("="*70)

for model_name, metrics in results.items():
    print(f"\n{model_name.upper().replace('_', ' ')}:")
    print(f"  PR-AUC:   {metrics['pr_auc']:.4f}")
    print(f"  ROC-AUC:  {metrics['roc_auc']:.4f}")
    print(f"  F1:       {metrics['f1']:.4f}")

# Calculate improvements
fusion_vs_tabular = (results['fusion']['pr_auc'] - results['tabular_only']['pr_auc']) / results['tabular_only']['pr_auc'] * 100
fusion_vs_embeddings = (results['fusion']['pr_auc'] - results['embeddings_only']['pr_auc']) / results['embeddings_only']['pr_auc'] * 100

print(f"\n{'-'*70}")
print(f"FUSION IMPROVEMENT:")
print(f"  vs Tabular Only:    {fusion_vs_tabular:+.1f}%")
print(f"  vs Embeddings Only: {fusion_vs_embeddings:+.1f}%")
print("="*70)

In [None]:
# Save results
with open('e9_fusion_results.json', 'w') as f:
    json.dump(results, f, indent=2)

print("Results saved to e9_fusion_results.json")

## Step 6: Visualization

In [None]:
# Comparison bar chart
sns.set_style('whitegrid')
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

models = ['Tabular\nOnly', 'Embeddings\nOnly', 'Fusion']
metrics_names = ['PR-AUC', 'ROC-AUC', 'F1']
colors = ['#3498db', '#e74c3c', '#2ecc71']

for idx, metric_key in enumerate(['pr_auc', 'roc_auc', 'f1']):
    values = [
        results['tabular_only'][metric_key],
        results['embeddings_only'][metric_key],
        results['fusion'][metric_key]
    ]
    
    bars = axes[idx].bar(models, values, color=colors)
    axes[idx].set_ylabel(metrics_names[idx], fontsize=12)
    axes[idx].set_ylim([0, 1])
    axes[idx].set_title(f'{metrics_names[idx]} Comparison', fontsize=14, fontweight='bold')
    
    # Add value labels
    for bar, val in zip(bars, values):
        height = bar.get_height()
        axes[idx].text(bar.get_x() + bar.get_width()/2., height + 0.02,
                      f'{val:.4f}', ha='center', va='bottom', fontsize=10)
    
    # Highlight best
    best_idx = np.argmax(values)
    bars[best_idx].set_edgecolor('gold')
    bars[best_idx].set_linewidth(3)

plt.tight_layout()
plt.savefig('e9_fusion_comparison.png', dpi=300, bbox_inches='tight')
plt.show()

print("Comparison chart saved: e9_fusion_comparison.png")

In [None]:
# PR and ROC curves
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

# PR Curves
for name, pred, color, label in [
    ('tabular', pred_tabular, '#3498db', 'Tabular Only'),
    ('embeddings', pred_embeddings, '#e74c3c', 'Embeddings Only'),
    ('fusion', pred_fusion, '#2ecc71', 'Fusion')
]:
    precision, recall, _ = precision_recall_curve(y_test, pred)
    pr_auc = auc(recall, precision)
    axes[0].plot(recall, precision, color=color, lw=2.5, 
                label=f'{label} (PR-AUC={pr_auc:.4f})')

axes[0].set_xlabel('Recall', fontsize=12)
axes[0].set_ylabel('Precision', fontsize=12)
axes[0].set_title('Precision-Recall Curves', fontsize=14, fontweight='bold')
axes[0].legend(loc='best', fontsize=10)
axes[0].grid(True, alpha=0.3)

# ROC Curves
for name, pred, color, label in [
    ('tabular', pred_tabular, '#3498db', 'Tabular Only'),
    ('embeddings', pred_embeddings, '#e74c3c', 'Embeddings Only'),
    ('fusion', pred_fusion, '#2ecc71', 'Fusion')
]:
    fpr, tpr, _ = roc_curve(y_test, pred)
    roc_auc = auc(fpr, tpr)
    axes[1].plot(fpr, tpr, color=color, lw=2.5, 
                label=f'{label} (ROC-AUC={roc_auc:.4f})')

axes[1].plot([0, 1], [0, 1], 'k--', lw=1, alpha=0.3)
axes[1].set_xlabel('False Positive Rate', fontsize=12)
axes[1].set_ylabel('True Positive Rate', fontsize=12)
axes[1].set_title('ROC Curves', fontsize=14, fontweight='bold')
axes[1].legend(loc='best', fontsize=10)
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('e9_fusion_curves.png', dpi=300, bbox_inches='tight')
plt.show()

print("PR/ROC curves saved: e9_fusion_curves.png")

## Summary

In [None]:
print("\n" + "="*70)
print("E9 WALLET FUSION EXPERIMENT COMPLETE")
print("="*70)
print("\nDeliverables created:")
print("  ✅ e9_tx_embeddings.npy - Transaction embeddings [203769, 128]")
print("  ✅ e9_addr_embeddings.npy - Address embeddings [100000, 128]")
print("  ✅ e9_fusion_results.json - All metrics")
print("  ✅ e9_fusion_comparison.png - Bar chart comparison")
print("  ✅ e9_fusion_curves.png - PR/ROC curves")
print("\nKey Finding:")
if results['fusion']['pr_auc'] > max(results['tabular_only']['pr_auc'], results['embeddings_only']['pr_auc']):
    print("  ⭐ FUSION WINS: GNN embeddings + tabular features > either alone")
    print(f"  ⭐ Best PR-AUC: {results['fusion']['pr_auc']:.4f}")
else:
    best_model = max(results, key=lambda k: results[k]['pr_auc'])
    print(f"  ⭐ {best_model.upper()} WINS: {results[best_model]['pr_auc']:.4f} PR-AUC")

print("\nNext steps:")
print("  1. Update COMPARISON_REPORT.md with E9 results")
print("  2. Update README.md with fusion findings")
print("  3. Create E9_WALLET_FUSION_DOCUMENTATION.md")
print("  4. Commit to GitHub")
print("="*70)