# üß™ E7 ‚Äî HHGTN Ablation Study

**Purpose:** Isolate which components hurt TRD-HHGTN performance in E6

## Research Questions
1. Does heterogeneous architecture hurt vs homogeneous?
2. Which edge types contribute/hurt most?
3. Are address features harmful?

## Experiments
- **A1:** tx‚Üítx only (homogeneous-like)
- **A2:** addr‚Üîtx only (bipartite)
- **A3:** Full E6 (all 4 edge types)
- **A4:** Simplified HHGTN (reduced params)

## Baseline (E3)
- Model: TRD-GraphSAGE
- PR-AUC: 0.5582
- Features: tx only (AF1-93)

## E6 Result (Failed)
- Model: TRD-HHGTN
- PR-AUC: 0.2806 (-50%)
- Features: tx + addr

---

In [None]:
# ==================== INSTALL DEPENDENCIES ====================
!pip install -q torch torch-geometric pandas numpy scikit-learn matplotlib seaborn tqdm

In [None]:
# ==================== SETUP ====================
import sys
import os
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import precision_recall_curve, roc_curve, auc, f1_score
from torch_geometric.nn import SAGEConv, HeteroConv
from torch_geometric.data import HeteroData
import warnings
warnings.filterwarnings('ignore')

# Kaggle paths
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {DEVICE}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

## üì¶ Load HeteroData
Built in E5 (notebook 02)

In [None]:
print("Loading HeteroData...")
# Path to your uploaded hetero_graph.pt dataset in Kaggle
data = torch.load('/kaggle/input/elliptic-dataset/hetero_graph.pt', weights_only=False)

print("\nHeteroData:")
print(data)

# Move to device
data = data.to(DEVICE)

print(f"\nData moved to: {DEVICE}")

# Extract masks
train_mask = data['transaction'].train_mask
val_mask = data['transaction'].val_mask
test_mask = data['transaction'].test_mask
y = data['transaction'].y

print(f"\nSplits:")
print(f"  Train: {train_mask.sum():,}")
print(f"  Val: {val_mask.sum():,}")
print(f"  Test: {test_mask.sum():,}")
print(f"  Fraud rate: {(y[train_mask] == 1).sum().item() / train_mask.sum().item():.2%}")

## üß† Simplified HHGTN Model

In [None]:
class SimplifiedHHGTN(nn.Module):
    """Simplified heterogeneous GNN for ablation studies."""
    
    def __init__(self, tx_in_dim, addr_in_dim, hidden_dim, edge_types_to_use, dropout=0.3):
        super().__init__()
        
        self.edge_types_to_use = edge_types_to_use
        
        # Input projections
        self.tx_proj = nn.Linear(tx_in_dim, hidden_dim)
        self.addr_proj = nn.Linear(addr_in_dim, hidden_dim)
        
        # Build convolution layers based on edge types
        conv_dict = {}
        if ('transaction', 'to', 'transaction') in edge_types_to_use:
            conv_dict[('transaction', 'to', 'transaction')] = SAGEConv(hidden_dim, hidden_dim)
        if ('address', 'to', 'transaction') in edge_types_to_use:
            conv_dict[('address', 'to', 'transaction')] = SAGEConv(hidden_dim, hidden_dim)
        if ('transaction', 'to', 'address') in edge_types_to_use:
            conv_dict[('transaction', 'to', 'address')] = SAGEConv(hidden_dim, hidden_dim)
        if ('address', 'to', 'address') in edge_types_to_use:
            conv_dict[('address', 'to', 'address')] = SAGEConv(hidden_dim, hidden_dim)
        
        self.conv1 = HeteroConv(conv_dict, aggr='sum')
        self.conv2 = HeteroConv(conv_dict, aggr='sum')
        
        # Classifier for transactions
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim // 2, 1)
        )
        
        self.dropout = dropout
    
    def forward(self, x_dict, edge_index_dict):
        # Project inputs
        x_dict = {
            'transaction': F.relu(self.tx_proj(x_dict['transaction'])),
            'address': F.relu(self.addr_proj(x_dict['address']))
        }
        
        # Filter edge_index_dict to only use specified edge types
        filtered_edges = {k: v for k, v in edge_index_dict.items() if k in self.edge_types_to_use}
        
        # Layer 1
        x_dict = self.conv1(x_dict, filtered_edges)
        x_dict = {key: F.relu(x) for key, x in x_dict.items()}
        x_dict = {key: F.dropout(x, p=self.dropout, training=self.training) for key, x in x_dict.items()}
        
        # Layer 2
        x_dict = self.conv2(x_dict, filtered_edges)
        x_dict = {key: F.relu(x) for key, x in x_dict.items()}
        x_dict = {key: F.dropout(x, p=self.dropout, training=self.training) for key, x in x_dict.items()}
        
        # Classify transactions
        logits = self.classifier(x_dict['transaction'])
        
        return logits.squeeze()

## üìä Evaluation Function

In [None]:
def evaluate_model(model, data, mask):
    """Evaluate model on given mask."""
    model.eval()
    
    with torch.no_grad():
        x_dict = {
            'transaction': data['transaction'].x,
            'address': data['address'].x
        }
        edge_index_dict = {
            k: data[k].edge_index for k in data.edge_types
        }
        
        logits = model(x_dict, edge_index_dict)
        probs = torch.sigmoid(logits)
        
        # Filter to mask
        y_true = data['transaction'].y[mask].cpu().numpy()
        y_pred = probs[mask].cpu().numpy()
        
        # Metrics
        from sklearn.metrics import average_precision_score, roc_auc_score
        
        pr_auc = average_precision_score(y_true, y_pred)
        roc_auc = roc_auc_score(y_true, y_pred)
        
        # F1 at best threshold
        precision, recall, thresholds = precision_recall_curve(y_true, y_pred)
        f1_scores = 2 * precision * recall / (precision + recall + 1e-10)
        best_f1 = np.max(f1_scores)
        
        return {
            'pr_auc': pr_auc,
            'roc_auc': roc_auc,
            'best_f1': best_f1
        }

## üéØ Training Function

In [None]:
def train_ablation_model(model, data, train_mask, val_mask, test_mask, max_epochs=100, patience=15, lr=0.001):
    """Train model with early stopping."""
    
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-5)
    criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([10.0]).to(DEVICE))  # Class imbalance
    
    best_val_pr_auc = 0
    patience_counter = 0
    best_epoch = 0
    
    history = {'train_loss': [], 'val_pr_auc': []}
    
    for epoch in range(max_epochs):
        model.train()
        
        x_dict = {
            'transaction': data['transaction'].x,
            'address': data['address'].x
        }
        edge_index_dict = {
            k: data[k].edge_index for k in data.edge_types
        }
        
        optimizer.zero_grad()
        
        # Forward
        logits = model(x_dict, edge_index_dict)
        
        # Loss (only on labeled train nodes)
        loss = criterion(logits[train_mask], data['transaction'].y[train_mask].float())
        
        # Backward
        loss.backward()
        optimizer.step()
        
        # Evaluate on val
        val_metrics = evaluate_model(model, data, val_mask)
        val_pr_auc = val_metrics['pr_auc']
        
        history['train_loss'].append(loss.item())
        history['val_pr_auc'].append(val_pr_auc)
        
        # Early stopping
        if val_pr_auc > best_val_pr_auc:
            best_val_pr_auc = val_pr_auc
            patience_counter = 0
            best_epoch = epoch
            # Save best state
            best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
        else:
            patience_counter += 1
        
        if (epoch + 1) % 10 == 0:
            print(f"Epoch {epoch+1:3d} | Loss: {loss.item():.4f} | Val PR-AUC: {val_pr_auc:.4f} | Best: {best_val_pr_auc:.4f}")
        
        if patience_counter >= patience:
            print(f"Early stopping at epoch {epoch+1}")
            break
    
    # Load best model
    model.load_state_dict(best_state)
    
    # Final evaluation
    train_metrics = evaluate_model(model, data, train_mask)
    val_metrics = evaluate_model(model, data, val_mask)
    test_metrics = evaluate_model(model, data, test_mask)
    
    return {
        'train': train_metrics,
        'val': val_metrics,
        'test': test_metrics,
        'best_epoch': best_epoch,
        'history': history
    }

## üß™ Ablation Experiments

### A1: tx‚Üítx only (Homogeneous-like)

In [None]:
print("\n" + "="*70)
print("A1: tx‚Üítx ONLY (Homogeneous-like)")
print("="*70)

edge_types_a1 = [('transaction', 'to', 'transaction')]

model_a1 = SimplifiedHHGTN(
    tx_in_dim=93,
    addr_in_dim=55,
    hidden_dim=128,
    edge_types_to_use=edge_types_a1,
    dropout=0.4
).to(DEVICE)

print(f"\nModel parameters: {sum(p.numel() for p in model_a1.parameters()):,}")

results_a1 = train_ablation_model(
    model_a1, data, train_mask, val_mask, test_mask,
    max_epochs=100, patience=15, lr=0.001
)

print(f"\n{'='*70}")
print("A1 RESULTS")
print(f"{'='*70}")
print(f"Test PR-AUC: {results_a1['test']['pr_auc']:.4f}")
print(f"Test ROC-AUC: {results_a1['test']['roc_auc']:.4f}")
print(f"Test F1: {results_a1['test']['best_f1']:.4f}")
print(f"Best epoch: {results_a1['best_epoch']}")

### A2: addr‚Üîtx only (Bipartite)

In [None]:
print("\n" + "="*70)
print("A2: addr‚Üîtx ONLY (Bipartite)")
print("="*70)

edge_types_a2 = [
    ('address', 'to', 'transaction'),
    ('transaction', 'to', 'address')
]

model_a2 = SimplifiedHHGTN(
    tx_in_dim=93,
    addr_in_dim=55,
    hidden_dim=128,
    edge_types_to_use=edge_types_a2,
    dropout=0.4
).to(DEVICE)

print(f"\nModel parameters: {sum(p.numel() for p in model_a2.parameters()):,}")

results_a2 = train_ablation_model(
    model_a2, data, train_mask, val_mask, test_mask,
    max_epochs=100, patience=15, lr=0.001
)

print(f"\n{'='*70}")
print("A2 RESULTS")
print(f"{'='*70}")
print(f"Test PR-AUC: {results_a2['test']['pr_auc']:.4f}")
print(f"Test ROC-AUC: {results_a2['test']['roc_auc']:.4f}")
print(f"Test F1: {results_a2['test']['best_f1']:.4f}")
print(f"Best epoch: {results_a2['best_epoch']}")

### A3: All edge types (Full E6)

In [None]:
print("\n" + "="*70)
print("A3: ALL EDGE TYPES (Full E6)")
print("="*70)

edge_types_a3 = [
    ('transaction', 'to', 'transaction'),
    ('address', 'to', 'transaction'),
    ('transaction', 'to', 'address'),
    ('address', 'to', 'address')
]

model_a3 = SimplifiedHHGTN(
    tx_in_dim=93,
    addr_in_dim=55,
    hidden_dim=128,
    edge_types_to_use=edge_types_a3,
    dropout=0.4
).to(DEVICE)

print(f"\nModel parameters: {sum(p.numel() for p in model_a3.parameters()):,}")

results_a3 = train_ablation_model(
    model_a3, data, train_mask, val_mask, test_mask,
    max_epochs=100, patience=15, lr=0.001
)

print(f"\n{'='*70}")
print("A3 RESULTS")
print(f"{'='*70}")
print(f"Test PR-AUC: {results_a3['test']['pr_auc']:.4f}")
print(f"Test ROC-AUC: {results_a3['test']['roc_auc']:.4f}")
print(f"Test F1: {results_a3['test']['best_f1']:.4f}")
print(f"Best epoch: {results_a3['best_epoch']}")

## üìä Summary Table

In [None]:
print("\n" + "="*70)
print("E7 ABLATION SUMMARY")
print("="*70)

summary_data = {
    'Experiment': ['E3 (Baseline)', 'E6 (HHGTN)', 'A1 (tx‚Üítx only)', 'A2 (addr‚Üîtx)', 'A3 (All edges)'],
    'Edge Types': ['tx‚Üítx', 'all 4', 'tx‚Üítx', 'addr‚Üîtx', 'all 4'],
    'Test PR-AUC': [
        0.5582,  # E3 baseline
        0.2806,  # E6 result
        results_a1['test']['pr_auc'],
        results_a2['test']['pr_auc'],
        results_a3['test']['pr_auc']
    ],
    'Test ROC-AUC': [
        0.8055,  # E3
        0.8250,  # E6
        results_a1['test']['roc_auc'],
        results_a2['test']['roc_auc'],
        results_a3['test']['roc_auc']
    ],
    'Test F1': [
        0.5860,  # E3
        0.4927,  # E6
        results_a1['test']['best_f1'],
        results_a2['test']['best_f1'],
        results_a3['test']['best_f1']
    ]
}

df_summary = pd.DataFrame(summary_data)
df_summary['Œî PR-AUC from E3'] = df_summary['Test PR-AUC'] - 0.5582

print("\n", df_summary.to_string(index=False))

# Save results
df_summary.to_csv('e7_ablation_results.csv', index=False)
print("\nSaved: e7_ablation_results.csv")

## üìà Visualization

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Bar chart
ax = axes[0]
colors = ['green', 'red', 'blue', 'blue', 'blue']
bars = ax.bar(df_summary['Experiment'], df_summary['Test PR-AUC'], color=colors, alpha=0.7)
ax.axhline(0.5582, color='green', linestyle='--', label='E3 Baseline', linewidth=2)
ax.set_ylabel('Test PR-AUC', fontsize=12, fontweight='bold')
ax.set_title('E7 Ablation: Test PR-AUC Comparison', fontsize=14, fontweight='bold')
ax.set_ylim(0, 0.7)
ax.legend()
ax.grid(axis='y', alpha=0.3)
plt.setp(ax.xaxis.get_majorticklabels(), rotation=45, ha='right')

# Delta chart
ax = axes[1]
delta_colors = ['green' if x >= 0 else 'red' for x in df_summary['Œî PR-AUC from E3']]
ax.barh(df_summary['Experiment'], df_summary['Œî PR-AUC from E3'], color=delta_colors, alpha=0.7)
ax.axvline(0, color='black', linestyle='-', linewidth=1)
ax.set_xlabel('Œî PR-AUC from E3 Baseline', fontsize=12, fontweight='bold')
ax.set_title('Performance Delta (Relative to E3)', fontsize=14, fontweight='bold')
ax.grid(axis='x', alpha=0.3)

plt.tight_layout()
plt.savefig('e7_ablation_comparison.png', dpi=300, bbox_inches='tight')
print("\nSaved: e7_ablation_comparison.png")
plt.show()

## üîç Key Findings

**Questions Answered:**

1. **Does heterogeneous architecture hurt?**
   - Compare A1 (tx‚Üítx only in hetero framework) vs E3 (tx‚Üítx in homogeneous)
   - If A1 < E3: Yes, heterogeneous framework adds overhead

2. **Which edge types hurt most?**
   - A1 (tx‚Üítx): Baseline within hetero framework
   - A2 (addr‚Üîtx): Tests if bipartite structure helps
   - A3 (all): Full E6 configuration

3. **Are address features harmful?**
   - A2 directly uses address features via bipartite edges
   - If A2 << A1: Address features likely noisy/harmful

**Expected Insights:**
- If A1 ‚âà E3: Hetero framework OK, problem is address features
- If A1 < E3: Hetero framework itself adds complexity
- If A2 << A1: Bipartite edges (address features) are the culprit
- If A3 ‚âà A2: Adding addr‚Üíaddr doesn't help/hurt much