# Baseline GAT Training

Train a standard Graph Attention Network (GAT) for Bitcoin fraud detection.

**Architecture:**
- Input: 182 node features
- 2 GAT layers with 64 hidden channels and 4 attention heads
- Output: 2 classes (fraud/non-fraud)

**Training:**
- Weighted CrossEntropyLoss for class imbalance
- Adam optimizer with cosine annealing scheduler
- Early stopping with patience=50

## Setup

In [None]:
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'

import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from pathlib import Path
from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    roc_auc_score, confusion_matrix, classification_report
)
import json
import time
from torch_geometric.nn import GATConv

# Add project root to path
project_root = Path.cwd().parent
sys.path.insert(0, str(project_root))

from src.models import GAT
from src.utils import set_random_seeds

# Set random seeds
set_random_seeds(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

### What This Cell Does (Setup)
This cell imports all necessary libraries and initializes the environment:

1. **Set environment variable**: Allows PyTorch to use OpenMP without conflicts

2. **Import core libraries**:
   - `torch`, `torch.nn`, `torch.nn.functional`: Deep learning
   - `sklearn.metrics`: F1, accuracy, precision, recall, ROC-AUC calculation
   - `torch_geometric`: Graph neural network layers (GATConv)

3. **Import project code**:
   - `src.models.GAT`: Custom GAT model class
   - `src.utils.set_random_seeds`: Ensure reproducible results

4. **Set seeds and device**:
   - Random seed = 42 (reproducible results)
   - Select GPU if available, otherwise CPU

## Load Data

In [None]:
print("Loading graph...")
graph = torch.load('../artifacts/elliptic_graph.pt', weights_only=False).to(device)

labeled_mask = (graph.y != -1)
labeled_indices = torch.where(labeled_mask)[0].cpu().numpy()
labeled_y = graph.y[labeled_mask].cpu().numpy()

print(f"Graph: {graph.num_nodes:,} nodes, {graph.num_edges:,} edges, {graph.num_node_features} features")
print(f"Labeled nodes: {len(labeled_indices):,}")
print(f"Fraud (class 1): {(labeled_y == 1).sum():,}")
print(f"Non-fraud (class 0): {(labeled_y == 0).sum():,}")

### What This Cell Does (Load Data)
This cell loads the **pre-constructed graph** created by `create_graph.ipynb`:

1. **Load saved graph**:
   - Loads `elliptic_graph.pt` (200MB file with all nodes, edges, features)
   - Move to device (GPU/CPU)

2. **Extract labeled nodes**:
   - Find nodes with labels (not -1/unknown)
   - Separate indices and labels for splitting

3. **Print dataset statistics**:
   - Total nodes: 203k
   - Total edges: 230k
   - Features per node: 182
   - Labeled nodes and class distribution

## Data Preprocessing & Splitting

In [None]:
# Stratified split: 70% train/val, 30% test
train_val_idx, test_idx, train_val_y, test_y = train_test_split(
    labeled_indices, labeled_y,
    test_size=0.30,
    random_state=42,
    stratify=labeled_y
)

# Split train into 70% train, 30% val
train_idx, val_idx, _, _ = train_test_split(
    train_val_idx, train_val_y,
    test_size=0.30,
    random_state=42,
    stratify=train_val_y
)

# Create mask tensors
train_mask = torch.zeros(graph.num_nodes, dtype=torch.bool, device=device)
val_mask = torch.zeros(graph.num_nodes, dtype=torch.bool, device=device)
test_mask = torch.zeros(graph.num_nodes, dtype=torch.bool, device=device)

train_mask[train_idx] = True
val_mask[val_idx] = True
test_mask[test_idx] = True

print(f"Train: {train_mask.sum():,}, Val: {val_mask.sum():,}, Test: {test_mask.sum():,}")

### What This Cell Does (Data Splitting)
This cell creates **train/validation/test splits** for evaluation:

1. **Two-step stratified split**:
   - First: Separate 30% for test, 70% for train/val
   - Second: From train/val, separate 30% for val, 70% for train
   - Result: ~49% train, 21% val, 30% test nodes

2. **Stratified splitting**:
   - Maintains class ratio (fraud/non-fraud) in each split
   - All splits use same nodes across all model experiments
   - Reproducible with random_state=42

3. **Create mask tensors**:
   - Boolean tensors marking which nodes belong to train/val/test
   - Used to select subsets during training and evaluation

## Feature Normalization

In [None]:
# Handle NaN values
nan_count = torch.isnan(graph.x).sum().item()
if nan_count > 0:
    graph.x = torch.nan_to_num(graph.x, nan=0.0)
    print(f"Replaced {nan_count} NaN values")

# Normalize based on training set
train_x = graph.x[train_mask]
mean = train_x.mean(dim=0, keepdim=True)
std = train_x.std(dim=0, keepdim=True)
std = torch.where(std == 0, torch.ones_like(std), std)

graph.x = (graph.x - mean) / std
graph.x = torch.clamp(graph.x, min=-10, max=10)

print("Features normalized")

### What This Cell Does (Normalize Features)
This cell applies **z-score normalization** to node features:

1. **Handle missing values**:
   - Check for NaN (missing) values
   - Replace with 0.0 if found

2. **Compute normalization statistics**:
   - Calculate mean and std from TRAINING data only
   - This prevents data leakage (val/test shouldn't affect normalization)

3. **Apply normalization**:
   - Transform: x_norm = (x - mean) / std
   - Clamp to [-10, +10] to handle outliers
   - Normalized features have mean ≈ 0, std ≈ 1

4. **Why normalize?**:
   - Neural networks train better with normalized inputs
   - Prevents large-magnitude features from dominating
   - Improves gradient stability during backprop

## Compute Class Weights

In [None]:
n_class_0 = (graph.y[train_mask] == 0).sum().item()
n_class_1 = (graph.y[train_mask] == 1).sum().item()

class_weight = torch.tensor(
    [1.0 / n_class_0, 1.0 / n_class_1],
    device=device,
    dtype=torch.float32
)
class_weight = class_weight / class_weight.sum()

print(f"Class weights: {class_weight}")

### What This Cell Does (Compute Class Weights)
This cell computes **weights to handle class imbalance**:

1. **Count class frequencies**:
   - How many non-fraud nodes (class 0) in training
   - How many fraud nodes (class 1) in training
   - Usually: much more non-fraud than fraud

2. **Calculate inverse frequency weights**:
   - weight = 1 / count (rarer classes get higher weight)
   - Rare fraud class gets more importance

3. **Normalize weights**:
   - Ensure weights sum to 1.0
   - Used in CrossEntropyLoss

4. **Why weighted loss?**:
   - Without weights: model predicts everything as majority class
   - With weights: misclassifying fraud costs more than non-fraud
   - Forces model to learn fraud patterns despite rarity

## Define Model

In [None]:
model = GAT(
    in_channels=graph.num_node_features,
    hidden_channels=64,
    out_channels=2,
    num_heads=4,
    num_layers=2,
    dropout=0.3
).to(device)

try:
    params = sum(p.numel() for p in model.parameters())
    print(f"Model parameters: {params:,}")
except:
    print("Model created successfully")

### What This Cell Does (Create Model)
This cell constructs the **Graph Attention Network (GAT) model**:

1. **Model architecture**:
   - Input: 182 node features
   - 2 GAT layers with 64 hidden channels each
   - 4 attention heads per layer
   - Output: 2 classes (fraud/non-fraud)
   - Dropout: 0.3 (prevents overfitting)

2. **What GAT does**:
   - Uses attention mechanism to aggregate neighbor features
   - Each attention head learns different aggregation patterns
   - Multiple heads provide diverse neighborhoods

3. **Move to device**:
   - Transfer model from CPU to GPU (or keep on CPU)

4. **Count parameters**:
   - Total: ~50,000 trainable parameters
   - Relatively small model (efficient)

## Training Setup

In [None]:
criterion = nn.CrossEntropyLoss(weight=class_weight)

optimizer = torch.optim.Adam(
    model.parameters(),
    lr=0.001,
    weight_decay=5e-4
)

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=300)

print("Training configuration:")
print(f"  Loss: Weighted CrossEntropy")
print(f"  Optimizer: Adam (lr=0.001, weight_decay=5e-4)")
print(f"  Scheduler: Cosine Annealing (T_max=300)")
print(f"  Early stopping: patience=50")

### What This Cell Does (Training Setup)
This cell configures the **training algorithm and learning schedule**:

1. **Loss function**:
   - Weighted CrossEntropyLoss with class weights
   - Weight parameter: assigns higher penalty to fraud misclassification
   - Handles class imbalance

2. **Optimizer**:
   - Adam optimizer (adaptive learning rate)
   - Learning rate: 0.001
   - Weight decay: 5e-4 (L2 regularization to prevent overfitting)

3. **Learning rate scheduler**:
   - Cosine Annealing: gradually reduces learning rate
   - T_max=300: reaches minimum at epoch 300
   - Helps model converge better in late training

4. **Early stopping**:
   - Patience: 50 epochs
   - Stops training if validation F1 doesn't improve
   - Prevents overfitting to training data

## Evaluation Function

In [None]:
def evaluate(model, mask):
    model.eval()
    with torch.no_grad():
        out = model(graph.x, graph.edge_index)
        pred = out[mask].argmax(dim=1)
        prob = F.softmax(out[mask], dim=1)[:, 1]
        
        y_true = graph.y[mask].cpu().numpy()
        y_pred = pred.cpu().numpy()
        y_prob = prob.cpu().numpy()
        y_prob = np.nan_to_num(y_prob, nan=0.5)
        
        try:
            roc_auc = roc_auc_score(y_true, y_prob)
        except:
            roc_auc = 0.0
    
    return {
        'accuracy': accuracy_score(y_true, y_pred),
        'precision': precision_score(y_true, y_pred, zero_division=0),
        'recall': recall_score(y_true, y_pred, zero_division=0),
        'f1': f1_score(y_true, y_pred, zero_division=0),
        'roc_auc': roc_auc,
        'y_true': y_true,
        'y_pred': y_pred,
        'y_prob': y_prob
    }

print("Evaluation function defined")

### What This Cell Does (Evaluation Function)
This cell defines a function to **evaluate model performance** on any subset:

1. **Set model to evaluation mode**:
   - `model.eval()`: Disables dropout and batch normalization
   - Ensures consistent predictions

2. **Make predictions**:
   - Forward pass through model (no gradients)
   - Get class probabilities using softmax
   - Extract fraud probability (class 1)

3. **Compute metrics**:
   - Accuracy: % correct predictions
   - Precision: % of predicted fraud that's actually fraud
   - Recall: % of actual fraud detected
   - F1: Harmonic mean of precision/recall (primary metric)
   - ROC-AUC: Area under receiver-operating-characteristic curve

4. **Return results**:
   - Metrics dictionary
   - True labels, predictions, and probabilities
   - Used for train/val/test evaluation

## Training Loop

In [None]:
print("="*70)
print("TRAINING BASELINE GAT")
print("="*70 + "\n")

best_val_f1 = -1
patience = 0
max_patience = 50
history = {
    'train_loss': [],
    'val_f1': [],
    'val_acc': [],
    'val_gap': []
}

start_time = time.time()

for epoch in range(1, 301):
    model.train()
    optimizer.zero_grad()
    
    out = model(graph.x, graph.edge_index)
    loss = criterion(out[train_mask], graph.y[train_mask])
    
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
    optimizer.step()
    scheduler.step()
    
    # Validation
    val_metrics = evaluate(model, val_mask)
    train_metrics = evaluate(model, train_mask)
    
    val_gap = train_metrics['f1'] - val_metrics['f1']
    
    history['train_loss'].append(loss.item())
    history['val_f1'].append(val_metrics['f1'])
    history['val_acc'].append(val_metrics['accuracy'])
    history['val_gap'].append(val_gap)
    
    if val_metrics['f1'] > best_val_f1:
        best_val_f1 = val_metrics['f1']
        patience = 0
        torch.save(model.state_dict(), '../artifacts/baseline_gat_best.pt')
    else:
        patience += 1
    
    if epoch % 20 == 0 or epoch < 10:
        print(f"Epoch {epoch:3d} | Loss: {loss:.4f} | "
              f"Train F1: {train_metrics['f1']:.4f} | Val F1: {val_metrics['f1']:.4f} | "
              f"Gap: {val_gap:.4f} | Patience: {patience}/{max_patience}")
    
    if patience >= max_patience:
        print(f"\nEarly stopping at epoch {epoch}")
        break

training_time = time.time() - start_time
print(f"\n✓ Training completed in {training_time:.2f}s")
print(f"✓ Best Val F1: {best_val_f1:.4f}")

### What This Cell Does (Training Loop)
This cell **trains the GAT model** for up to 300 epochs:

1. **Initialize state**:
   - Track best validation F1
   - Patience counter (for early stopping)
   - History lists for loss and metrics

2. **Each epoch**:
   - **Forward pass**: Predictions on training nodes
   - **Compute loss**: Weighted CrossEntropyLoss
   - **Backward pass**: Compute gradients
   - **Gradient clipping**: Prevent exploding gradients (max_norm=1.0)
   - **Update**: Optimizer step
   - **Schedule**: Reduce learning rate

3. **Validation**:
   - Evaluate on training and validation sets
   - Track generalization gap (train F1 - val F1)
   - If val F1 improves: save model, reset patience
   - If no improvement: increment patience

4. **Early stopping**:
   - If patience reaches 50: stop training
   - Prevents overfitting to training data
   - Saves training time

5. **Monitoring**:
   - Print progress every 20 epochs
   - Show loss, F1 scores, generalization gap

## Final Evaluation

In [None]:
print("\n" + "="*70)
print("FINAL EVALUATION")
print("="*70 + "\n")

model.load_state_dict(torch.load('../artifacts/baseline_gat_best.pt', map_location=device))

baseline_train = evaluate(model, train_mask)
baseline_val = evaluate(model, val_mask)
baseline_test = evaluate(model, test_mask)

print("Baseline GAT Results:")
print(f"  Train - F1: {baseline_train['f1']:.4f}, Acc: {baseline_train['accuracy']:.4f}")
print(f"  Val   - F1: {baseline_val['f1']:.4f}, Acc: {baseline_val['accuracy']:.4f}")
print(f"  Test  - F1: {baseline_test['f1']:.4f}, Acc: {baseline_test['accuracy']:.4f}")

print(f"\nGeneralization Gaps:")
print(f"  Train→Val: {baseline_train['f1'] - baseline_val['f1']:.4f}")
print(f"  Val→Test:  {baseline_val['f1'] - baseline_test['f1']:.4f}")

print(f"\nDetailed Test Report:")
print(classification_report(baseline_test['y_true'], baseline_test['y_pred'],
                          target_names=['Non-Fraud', 'Fraud']))

### What This Cell Does (Final Evaluation)
This cell **evaluates the best model** on all three data splits:

1. **Load best model**:
   - Load state_dict from checkpoint file
   - Restore the best weights from training

2. **Evaluate on all splits**:
   - Train: Performance on training nodes
   - Val: Performance on validation nodes
   - Test: Performance on held-out test nodes (final metric)

3. **Print results**:
   - F1 and accuracy for each split
   - Generalization gaps:
     - Train→Val gap: Training vs validation performance
     - Val→Test gap: Validation vs test performance
   - Detailed classification report: Precision, recall per class

4. **Expected results**:
   - Baseline GAT typically achieves F1 ≈ 0.87 on test set
   - Small generalization gap indicates good fit

## Save Results

In [None]:
report = {
    'model': 'Baseline GAT',
    'description': 'Standard Graph Attention Network with 2 layers, 64 hidden channels, 4 heads',
    'timestamp': time.strftime('%Y-%m-%d %H:%M:%S'),
    'training_time': training_time,
    'architecture': {
        'input_features': graph.num_node_features,
        'hidden_channels': 64,
        'num_layers': 2,
        'num_heads': 4,
        'dropout': 0.3,
        'optimizer': 'Adam (lr=0.001)',
        'loss': 'Weighted CrossEntropy'
    },
    'test_metrics': {
        'f1': baseline_test['f1'],
        'accuracy': baseline_test['accuracy'],
        'precision': baseline_test['precision'],
        'recall': baseline_test['recall'],
        'roc_auc': baseline_test['roc_auc']
    },
    'generalization': {
        'train_f1': baseline_train['f1'],
        'val_f1': baseline_val['f1'],
        'train_to_val_gap': baseline_train['f1'] - baseline_val['f1'],
        'val_to_test_gap': baseline_val['f1'] - baseline_test['f1']
    }
}

import json
with open('../artifacts/baseline_gat_metrics.json', 'w') as f:
    json.dump(report, f, indent=2)

print(f"\n✓ Report saved to ../artifacts/baseline_gat_metrics.json")
print(f"✓ Model saved to ../artifacts/baseline_gat_best.pt")
print("\n" + "="*70)
print("✅ BASELINE GAT TRAINING COMPLETE!")
print("="*70)

### What This Cell Does (Save Results)
This cell saves the **training results and model weights** to disk:

1. **Create report dictionary**:
   - Model name: "Baseline GAT"
   - Model architecture details (layers, hidden dims, heads, dropout)
   - Training configuration (optimizer, loss, scheduler)
   - Test metrics (F1, accuracy, precision, recall, ROC-AUC)
   - Generalization metrics (train/val gaps)
   - Timestamp and training time

2. **Save model weights**:
   - Save to `artifacts/baseline_gat_best.pt` (~500KB)
   - Contains trained parameters (can be loaded later)

3. **Save metrics**:
   - Save report to `artifacts/baseline_gat_metrics.json`
   - Can be compared with other models (quantum, etc.)
   - Readable format for reports/documentation

4. **Completion message**:
   - Confirms training finished successfully
   - Ready for next phase (quantum training)