# üß™ E7 ‚Äî HHGTN Ablation Study

**Purpose:** Isolate which components hurt TRD-HHGTN performance in E6

## Research Questions
1. Does heterogeneous architecture hurt vs homogeneous?
2. Which edge types contribute/hurt most?
3. Are address features harmful?

## Experiments
- **A1:** tx‚Üítx only (homogeneous-like)
- **A2:** addr‚Üîtx only (bipartite)
- **A3:** Full E6 (all 4 edge types)
- **A4:** Simplified HHGTN (reduced params)

## Baseline (E3)
- Model: TRD-GraphSAGE
- PR-AUC: 0.5618
- Features: tx only (AF1-93)

## E6 Result (Failed)
- Model: TRD-HHGTN
- PR-AUC: 0.2806 (-50%)
- Features: tx + addr

---

In [None]:
# ==================== SETUP ====================
import sys
import os
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import precision_recall_curve, roc_curve, auc, f1_score
from torch_geometric.nn import SAGEConv, GCNConv
from torch_geometric.data import HeteroData

# Kaggle paths
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {DEVICE}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

## üì¶ Load HeteroData
Built in E5 (notebook 02)

In [None]:
print("Loading HeteroData...")
# UPDATE THIS PATH for Kaggle dataset upload
data = torch.load('/kaggle/input/elliptic-dataset/hetero_graph.pt', weights_only=False)

print("\nHeteroData:")
print(data)

data = data.to(DEVICE)
print(f"\nData moved to: {DEVICE}")

## üéØ Extract Masks for Transaction Nodes
Only transactions have labels

In [None]:
train_mask = data['transaction'].train_mask
val_mask = data['transaction'].val_mask
test_mask = data['transaction'].test_mask
y = data['transaction'].y

print(f"Train nodes: {train_mask.sum().item()}")
print(f"Val nodes: {val_mask.sum().item()}")
print(f"Test nodes: {test_mask.sum().item()}")
print(f"Fraud rate: {y[train_mask].float().mean():.4f}")

## üîß Define Simplified HHGTN Model
Reduced params version for A4

In [None]:
class SimpleHHGTN(nn.Module):
    """Simplified heterogeneous GNN for ablation"""
    def __init__(self, metadata, hidden_dim=64, num_layers=2, dropout=0.3):
        super().__init__()
        self.node_types = metadata[0]
        self.edge_types = metadata[1]
        
        # Input projections (smaller)
        self.input_projs = nn.ModuleDict({
            node_type: nn.Linear(data[node_type].x.size(1), hidden_dim)
            for node_type in self.node_types
        })
        
        # Per-relation convs (SAGE only, lighter)
        self.convs = nn.ModuleDict()
        for src, rel, dst in self.edge_types:
            key = f'{src}__{rel}__{dst}'
            self.convs[key] = SAGEConv(hidden_dim, hidden_dim)
        
        # Attention fusion (single head)
        self.attn = nn.Linear(hidden_dim, 1)
        
        # Output
        self.classifier = nn.Linear(hidden_dim, 1)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x_dict, edge_index_dict, edge_types_to_use=None):
        # Project inputs
        x_dict = {k: F.relu(self.input_projs[k](v)) for k, v in x_dict.items()}
        
        # Filter edge types if specified
        if edge_types_to_use:
            edge_index_dict = {k: v for k, v in edge_index_dict.items() if k in edge_types_to_use}
        
        # Message passing per relation
        out_dict = {k: [] for k in x_dict.keys()}
        
        for edge_type, edge_index in edge_index_dict.items():
            src_type, rel, dst_type = edge_type
            key = f'{src_type}__{rel}__{dst_type}'
            
            if key in self.convs:
                x_src = x_dict[src_type]
                x_dst = x_dict[dst_type] if src_type != dst_type else x_src
                
                out = self.convs[key]((x_src, x_dst), edge_index)
                out_dict[dst_type].append(out)
        
        # Aggregate with attention
        x_dict_out = {}
        for node_type, outs in out_dict.items():
            if len(outs) == 0:
                x_dict_out[node_type] = x_dict[node_type]
            elif len(outs) == 1:
                x_dict_out[node_type] = outs[0]
            else:
                # Simple attention fusion
                stacked = torch.stack(outs, dim=1)  # [N, R, D]
                attn_scores = F.softmax(self.attn(stacked).squeeze(-1), dim=1)  # [N, R]
                x_dict_out[node_type] = (stacked * attn_scores.unsqueeze(-1)).sum(dim=1)
        
        # Classify transactions only
        x_tx = self.dropout(F.relu(x_dict_out['transaction']))
        return self.classifier(x_tx).squeeze(-1)

## üìä Evaluation Function

In [None]:
def evaluate_model(model, data, mask, device):
    model.eval()
    with torch.no_grad():
        x_dict = {k: v.to(device) for k, v in data.x_dict.items()}
        edge_index_dict = {k: v.to(device) for k, v in data.edge_index_dict.items()}
        
        logits = model(x_dict, edge_index_dict)
        probs = torch.sigmoid(logits[mask]).cpu().numpy()
        labels = data['transaction'].y[mask].cpu().numpy()
        
        # Metrics
        precision, recall, _ = precision_recall_curve(labels, probs)
        pr_auc = auc(recall, precision)
        
        fpr, tpr, _ = roc_curve(labels, probs)
        roc_auc = auc(fpr, tpr)
        
        preds = (probs > 0.5).astype(int)
        f1 = f1_score(labels, preds, zero_division=0)
        
    return {
        'pr_auc': pr_auc,
        'roc_auc': roc_auc,
        'f1': f1,
        'probs': probs,
        'labels': labels
    }

## üß™ A1: tx‚Üítx Only (Homogeneous-like)
Tests if heterogeneous architecture itself hurts

In [None]:
print("="*60)
print("A1: tx‚Üítx ONLY")
print("="*60)

# Filter to tx edges only
edge_types_a1 = [('transaction', 'tx_to_tx', 'transaction')]

model_a1 = SimpleHHGTN(data.metadata(), hidden_dim=64, num_layers=2, dropout=0.3).to(DEVICE)
optimizer = torch.optim.Adam(model_a1.parameters(), lr=0.001, weight_decay=1e-5)
criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([10.0]).to(DEVICE))

print(f"\nModel params: {sum(p.numel() for p in model_a1.parameters()):,}")
print(f"Edge types: {edge_types_a1}")

# Training
best_val_pr = 0
patience = 20
patience_counter = 0

for epoch in range(100):
    model_a1.train()
    optimizer.zero_grad()
    
    x_dict = {k: v.to(DEVICE) for k, v in data.x_dict.items()}
    edge_index_dict = {k: v.to(DEVICE) for k, v in data.edge_index_dict.items() if k in edge_types_a1}
    
    logits = model_a1(x_dict, edge_index_dict)
    loss = criterion(logits[train_mask], data['transaction'].y[train_mask].float())
    loss.backward()
    optimizer.step()
    
    # Validate
    if epoch % 5 == 0:
        val_results = evaluate_model(model_a1, data, val_mask, DEVICE)
        print(f"Epoch {epoch:3d} | Loss: {loss.item():.4f} | Val PR-AUC: {val_results['pr_auc']:.4f}")
        
        if val_results['pr_auc'] > best_val_pr:
            best_val_pr = val_results['pr_auc']
            patience_counter = 0
            torch.save(model_a1.state_dict(), 'a1_best.pt')
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print("Early stopping")
                break

# Load best and evaluate
model_a1.load_state_dict(torch.load('a1_best.pt'))
test_results_a1 = evaluate_model(model_a1, data, test_mask, DEVICE)

print(f"\n‚úÖ A1 Results:")
print(f"   Test PR-AUC: {test_results_a1['pr_auc']:.4f}")
print(f"   Test ROC-AUC: {test_results_a1['roc_auc']:.4f}")
print(f"   Test F1: {test_results_a1['f1']:.4f}")

## üß™ A2: addr‚Üîtx Only (Bipartite)
Tests if address features help/hurt

In [None]:
print("\n" + "="*60)
print("A2: addr‚Üîtx ONLY")
print("="*60)

edge_types_a2 = [
    ('address', 'addr_to_tx', 'transaction'),
    ('transaction', 'tx_to_addr', 'address')
]

model_a2 = SimpleHHGTN(data.metadata(), hidden_dim=64, num_layers=2, dropout=0.3).to(DEVICE)
optimizer = torch.optim.Adam(model_a2.parameters(), lr=0.001, weight_decay=1e-5)

print(f"Edge types: {edge_types_a2}")

best_val_pr = 0
patience_counter = 0

for epoch in range(100):
    model_a2.train()
    optimizer.zero_grad()
    
    x_dict = {k: v.to(DEVICE) for k, v in data.x_dict.items()}
    edge_index_dict = {k: v.to(DEVICE) for k, v in data.edge_index_dict.items() if k in edge_types_a2}
    
    logits = model_a2(x_dict, edge_index_dict)
    loss = criterion(logits[train_mask], data['transaction'].y[train_mask].float())
    loss.backward()
    optimizer.step()
    
    if epoch % 5 == 0:
        val_results = evaluate_model(model_a2, data, val_mask, DEVICE)
        print(f"Epoch {epoch:3d} | Loss: {loss.item():.4f} | Val PR-AUC: {val_results['pr_auc']:.4f}")
        
        if val_results['pr_auc'] > best_val_pr:
            best_val_pr = val_results['pr_auc']
            patience_counter = 0
            torch.save(model_a2.state_dict(), 'a2_best.pt')
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print("Early stopping")
                break

model_a2.load_state_dict(torch.load('a2_best.pt'))
test_results_a2 = evaluate_model(model_a2, data, test_mask, DEVICE)

print(f"\n‚úÖ A2 Results:")
print(f"   Test PR-AUC: {test_results_a2['pr_auc']:.4f}")
print(f"   Test ROC-AUC: {test_results_a2['roc_auc']:.4f}")
print(f"   Test F1: {test_results_a2['f1']:.4f}")

## üß™ A3: Full E6 (All Edge Types)
Reproduce E6 with simplified architecture

In [None]:
print("\n" + "="*60)
print("A3: FULL E6 (All 4 edge types)")
print("="*60)

# All edge types
edge_types_a3 = list(data.edge_index_dict.keys())

model_a3 = SimpleHHGTN(data.metadata(), hidden_dim=64, num_layers=2, dropout=0.3).to(DEVICE)
optimizer = torch.optim.Adam(model_a3.parameters(), lr=0.001, weight_decay=1e-5)

print(f"Edge types: {edge_types_a3}")

best_val_pr = 0
patience_counter = 0

for epoch in range(100):
    model_a3.train()
    optimizer.zero_grad()
    
    x_dict = {k: v.to(DEVICE) for k, v in data.x_dict.items()}
    edge_index_dict = {k: v.to(DEVICE) for k, v in data.edge_index_dict.items()}
    
    logits = model_a3(x_dict, edge_index_dict)
    loss = criterion(logits[train_mask], data['transaction'].y[train_mask].float())
    loss.backward()
    optimizer.step()
    
    if epoch % 5 == 0:
        val_results = evaluate_model(model_a3, data, val_mask, DEVICE)
        print(f"Epoch {epoch:3d} | Loss: {loss.item():.4f} | Val PR-AUC: {val_results['pr_auc']:.4f}")
        
        if val_results['pr_auc'] > best_val_pr:
            best_val_pr = val_results['pr_auc']
            patience_counter = 0
            torch.save(model_a3.state_dict(), 'a3_best.pt')
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print("Early stopping")
                break

model_a3.load_state_dict(torch.load('a3_best.pt'))
test_results_a3 = evaluate_model(model_a3, data, test_mask, DEVICE)

print(f"\n‚úÖ A3 Results:")
print(f"   Test PR-AUC: {test_results_a3['pr_auc']:.4f}")
print(f"   Test ROC-AUC: {test_results_a3['roc_auc']:.4f}")
print(f"   Test F1: {test_results_a3['f1']:.4f}")

## üìä Ablation Summary Table

In [None]:
ablation_results = pd.DataFrame([
    {
        'Experiment': 'Baseline (E3)',
        'Model': 'TRD-GraphSAGE',
        'Edge Types': 'tx‚Üítx',
        'Features': 'tx only',
        'PR-AUC': 0.5618,
        'ROC-AUC': 0.8841,
        'F1': 0.6050
    },
    {
        'Experiment': 'A1',
        'Model': 'Simple-HHGTN',
        'Edge Types': 'tx‚Üítx',
        'Features': 'tx only',
        'PR-AUC': test_results_a1['pr_auc'],
        'ROC-AUC': test_results_a1['roc_auc'],
        'F1': test_results_a1['f1']
    },
    {
        'Experiment': 'A2',
        'Model': 'Simple-HHGTN',
        'Edge Types': 'addr‚Üîtx',
        'Features': 'tx + addr',
        'PR-AUC': test_results_a2['pr_auc'],
        'ROC-AUC': test_results_a2['roc_auc'],
        'F1': test_results_a2['f1']
    },
    {
        'Experiment': 'A3',
        'Model': 'Simple-HHGTN',
        'Edge Types': 'all 4',
        'Features': 'tx + addr',
        'PR-AUC': test_results_a3['pr_auc'],
        'ROC-AUC': test_results_a3['roc_auc'],
        'F1': test_results_a3['f1']
    },
    {
        'Experiment': 'E6 (Original)',
        'Model': 'TRD-HHGTN (full)',
        'Edge Types': 'all 4',
        'Features': 'tx + addr',
        'PR-AUC': 0.2806,
        'ROC-AUC': 0.8250,
        'F1': 0.3913
    }
])

print("\n" + "="*80)
print("üìä ABLATION SUMMARY")
print("="*80)
print(ablation_results.to_string(index=False))

# Calculate deltas from baseline
baseline_pr = 0.5618
ablation_results['ŒîPR-AUC'] = ablation_results['PR-AUC'] - baseline_pr

print("\n" + "="*80)
print("üìà DELTA FROM BASELINE (E3)")
print("="*80)
print(ablation_results[['Experiment', 'Edge Types', 'PR-AUC', 'ŒîPR-AUC']].to_string(index=False))

# Save
ablation_results.to_csv('ablation_results.csv', index=False)
print("\n‚úÖ Saved: ablation_results.csv")

## üìä Visualization: Ablation Comparison

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# PR-AUC comparison
ax1 = axes[0]
colors = ['#2ecc71', '#3498db', '#e74c3c', '#f39c12', '#9b59b6']
bars = ax1.barh(ablation_results['Experiment'], ablation_results['PR-AUC'], color=colors)
ax1.axvline(baseline_pr, color='green', linestyle='--', linewidth=2, label='E3 Baseline')
ax1.set_xlabel('PR-AUC', fontsize=12)
ax1.set_title('PR-AUC Comparison Across Ablations', fontsize=14, fontweight='bold')
ax1.legend()
ax1.grid(axis='x', alpha=0.3)

# Delta from baseline
ax2 = axes[1]
delta_colors = ['green' if x >= 0 else 'red' for x in ablation_results['ŒîPR-AUC']]
bars2 = ax2.barh(ablation_results['Experiment'], ablation_results['ŒîPR-AUC'], color=delta_colors)
ax2.axvline(0, color='black', linestyle='-', linewidth=1)
ax2.set_xlabel('ŒîPR-AUC from E3', fontsize=12)
ax2.set_title('Performance Delta from Baseline', fontsize=14, fontweight='bold')
ax2.grid(axis='x', alpha=0.3)

plt.tight_layout()
plt.savefig('ablation_comparison.png', dpi=300, bbox_inches='tight')
plt.show()

print("‚úÖ Saved: ablation_comparison.png")

## üîç Analysis & Insights

In [None]:
print("\n" + "="*80)
print("üîç KEY INSIGHTS")
print("="*80)

# Compare A1 vs E3
a1_vs_e3 = test_results_a1['pr_auc'] - 0.5618
print(f"\n1. Architecture Impact (A1 vs E3):")
print(f"   - A1 (HHGTN on tx‚Üítx): {test_results_a1['pr_auc']:.4f}")
print(f"   - E3 (GraphSAGE on tx‚Üítx): 0.5618")
print(f"   - Delta: {a1_vs_e3:+.4f}")
if abs(a1_vs_e3) < 0.05:
    print(f"   ‚Üí Architecture change has MINIMAL impact (<5pp)")
else:
    print(f"   ‚Üí Architecture matters significantly")

# Compare A2 vs A1
a2_vs_a1 = test_results_a2['pr_auc'] - test_results_a1['pr_auc']
print(f"\n2. Address Feature Impact (A2 vs A1):")
print(f"   - A2 (with addr‚Üîtx): {test_results_a2['pr_auc']:.4f}")
print(f"   - A1 (tx only): {test_results_a1['pr_auc']:.4f}")
print(f"   - Delta: {a2_vs_a1:+.4f}")
if a2_vs_a1 < -0.05:
    print(f"   ‚Üí Address features HURT performance (>5pp drop)")
elif a2_vs_a1 > 0.05:
    print(f"   ‚Üí Address features HELP performance (>5pp gain)")
else:
    print(f"   ‚Üí Address features have MINIMAL impact")

# Compare A3 vs A2
a3_vs_a2 = test_results_a3['pr_auc'] - test_results_a2['pr_auc']
print(f"\n3. Additional Edge Types (A3 vs A2):")
print(f"   - A3 (all 4 edge types): {test_results_a3['pr_auc']:.4f}")
print(f"   - A2 (addr‚Üîtx only): {test_results_a2['pr_auc']:.4f}")
print(f"   - Delta: {a3_vs_a2:+.4f}")
if a3_vs_a2 < 0:
    print(f"   ‚Üí Adding tx‚Üítx + addr‚Üíaddr edges HURTS")
else:
    print(f"   ‚Üí Adding more edges helps")

# Compare A3 vs E6
a3_vs_e6 = test_results_a3['pr_auc'] - 0.2806
print(f"\n4. Simplified vs Full Architecture (A3 vs E6):")
print(f"   - A3 (simplified): {test_results_a3['pr_auc']:.4f}")
print(f"   - E6 (full HHGTN): 0.2806")
print(f"   - Delta: {a3_vs_e6:+.4f}")
if a3_vs_e6 > 0.05:
    print(f"   ‚Üí Overparameterization in E6 caused overfitting")
    print(f"   ‚Üí Simpler model generalizes better")

print("\n" + "="*80)
print("üéØ CONCLUSION")
print("="*80)

best_ablation = ablation_results.loc[ablation_results['PR-AUC'].idxmax()]
print(f"\nBest performer: {best_ablation['Experiment']}")
print(f"  - PR-AUC: {best_ablation['PR-AUC']:.4f}")
print(f"  - Configuration: {best_ablation['Edge Types']} with {best_ablation['Features']}")

if best_ablation['Experiment'] == 'Baseline (E3)':
    print("\n‚Üí E3 (TRD-GraphSAGE) remains the champion")
    print("‚Üí Heterogeneous extensions did not improve performance")
    print("‚Üí Simpler homogeneous model is superior for this task")

## üìù Export Summary for Reporting

In [None]:
summary = {
    'experiment': 'E7_Ablation_Study',
    'date': pd.Timestamp.now().isoformat(),
    'results': {
        'A1_tx_only': {
            'pr_auc': float(test_results_a1['pr_auc']),
            'roc_auc': float(test_results_a1['roc_auc']),
            'f1': float(test_results_a1['f1'])
        },
        'A2_addr_tx': {
            'pr_auc': float(test_results_a2['pr_auc']),
            'roc_auc': float(test_results_a2['roc_auc']),
            'f1': float(test_results_a2['f1'])
        },
        'A3_all_edges': {
            'pr_auc': float(test_results_a3['pr_auc']),
            'roc_auc': float(test_results_a3['roc_auc']),
            'f1': float(test_results_a3['f1'])
        }
    },
    'insights': {
        'architecture_impact': float(a1_vs_e3),
        'address_feature_impact': float(a2_vs_a1),
        'additional_edges_impact': float(a3_vs_a2),
        'simplification_benefit': float(a3_vs_e6)
    },
    'conclusion': 'E3 remains champion' if best_ablation['Experiment'] == 'Baseline (E3)' else 'New best found'
}

with open('e7_ablation_summary.json', 'w') as f:
    json.dump(summary, f, indent=2)

print("\n‚úÖ Saved: e7_ablation_summary.json")
print("\n" + "="*80)
print("üéâ E7 ABLATION STUDY COMPLETE")
print("="*80)