# Train Baseline GAT Model

This notebook trains a Graph Attention Network (GAT) for fraud detection on the transaction graph.

## Overview
1. Import libraries and load graph
2. Create train/validation/test splits
3. Initialize GAT model
4. Define training functions
5. Train with early stopping
6. Evaluate and save results

**Estimated time:** 10-20 minutes

In [1]:
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
import json
import time
from pathlib import Path
import os

# Always add project root to sys.path for src imports (Jupyter-safe)
project_root = str(Path().resolve().parent)
if project_root not in sys.path:
    sys.path.insert(0, project_root)

from src.models import GAT
from src.config import MODEL_CONFIG, TRAINING_CONFIG, ARTIFACTS_DIR, ARTIFACT_FILES
from src.utils import set_random_seeds, get_device

# Set random seeds
set_random_seeds(TRAINING_CONFIG['random_seed'])
device = get_device()

Device: cpu


## 1. Import Libraries & Setup

Import required packages and initialize device and random seeds.

In [2]:
graph_path = ARTIFACTS_DIR / ARTIFACT_FILES['baseline_graph']
data = torch.load(graph_path, weights_only=False).to(device)
print(data)

# Get labeled indices
labeled_indices = torch.where(data.labeled_mask)[0].cpu().numpy()
labeled_y = data.y[data.labeled_mask].cpu().numpy()

# Split using config values
train_val_idx, test_idx = train_test_split(
    labeled_indices, 
    test_size=TRAINING_CONFIG['train_test_split'], 
    random_state=TRAINING_CONFIG['random_seed'], 
    stratify=labeled_y
)

train_val_y = data.y[train_val_idx].cpu().numpy()
train_idx, val_idx = train_test_split(
    train_val_idx, 
    test_size=TRAINING_CONFIG['train_val_split'], 
    random_state=TRAINING_CONFIG['random_seed'], 
    stratify=train_val_y
)

# Create masks
train_mask = torch.zeros(data.num_nodes, dtype=torch.bool, device=device)
val_mask = torch.zeros(data.num_nodes, dtype=torch.bool, device=device)
test_mask = torch.zeros(data.num_nodes, dtype=torch.bool, device=device)

train_mask[train_idx] = True
val_mask[val_idx] = True
test_mask[test_idx] = True

data.train_mask = train_mask
data.val_mask = val_mask
data.test_mask = test_mask

print(f"Train: {train_mask.sum()}, Val: {val_mask.sum()}, Test: {test_mask.sum()}")

Data(x=[203769, 182], edge_index=[2, 672479], y=[203769], timestep=[203769], labeled_mask=[203769], unlabeled_mask=[203769])
Train: 27938, Val: 9313, Test: 9313


## 2. Load Graph & Create Data Splits

Load the saved graph and create stratified train/validation/test splits.
Only labeled nodes are used for training.

In [3]:
model = GAT(
    in_channels=data.num_node_features,
    hidden_channels=MODEL_CONFIG['hidden_channels'],
    out_channels=MODEL_CONFIG['out_channels'],
    num_heads=MODEL_CONFIG['num_heads'],
    num_layers=MODEL_CONFIG['num_layers'],
    dropout=MODEL_CONFIG['dropout']
).to(device)

optimizer = torch.optim.Adam(
    model.parameters(), 
    lr=TRAINING_CONFIG['learning_rate'], 
    weight_decay=TRAINING_CONFIG['weight_decay']
)
criterion = nn.CrossEntropyLoss()

print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")

Parameters: 47,878


In [4]:
# Check for NaN values in features and handle them
nan_mask = torch.isnan(data.x)
if nan_mask.any():
    print(f"Found {nan_mask.sum().item()} NaN values in features")
    # Replace NaN with 0
    data.x = torch.where(nan_mask, torch.zeros_like(data.x), data.x)
    print("NaN values replaced with 0")
else:
    print("No NaN values found in features")

# Check for inf values
inf_mask = torch.isinf(data.x)
if inf_mask.any():
    print(f"Found {inf_mask.sum().item()} inf values in features")
    # Replace inf with 0
    data.x = torch.where(inf_mask, torch.zeros_like(data.x), data.x)
    print("Inf values replaced with 0")

print(f"Feature stats - Min: {data.x.min():.4f}, Max: {data.x.max():.4f}, Mean: {data.x.mean():.4f}")

Found 16405 NaN values in features
NaN values replaced with 0
Feature stats - Min: -13.0934, Max: 445268.0000, Mean: 4.6095


In [5]:
# Normalize features using standardization (mean=0, std=1) on training set
train_x = data.x[data.train_mask]
mean = train_x.mean(dim=0, keepdim=True)
std = train_x.std(dim=0, keepdim=True)
std = torch.where(std == 0, torch.ones_like(std), std)  # Avoid division by zero

data.x = (data.x - mean) / std
print(f"Features normalized - Min: {data.x.min():.4f}, Max: {data.x.max():.4f}, Mean: {data.x.mean():.4f}")

Features normalized - Min: -12.8573, Max: 10486.3457, Mean: -0.0565


In [6]:
# Clip extreme values to prevent gradient issues
data.x = torch.clamp(data.x, min=-10, max=10)
print(f"Features clipped - Min: {data.x.min():.4f}, Max: {data.x.max():.4f}, Mean: {data.x.mean():.4f}")

Features clipped - Min: -10.0000, Max: 10.0000, Mean: -0.0682


## 3. Initialize GAT Model

Create the Graph Attention Network with configuration from `src/config.py`.

In [7]:
def train_epoch():
    model.train()
    optimizer.zero_grad()
    out = model(data.x, data.edge_index)
    loss = criterion(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    return loss.item()

@torch.no_grad()
def evaluate(mask):
    model.eval()
    out = model(data.x, data.edge_index)
    pred = out[mask].argmax(dim=1)
    prob = F.softmax(out[mask], dim=1)[:, 1]
    
    y_true = data.y[mask].cpu().numpy()
    y_pred = pred.cpu().numpy()
    y_prob = prob.cpu().numpy()
    
    # Check for NaN values and handle them
    if np.isnan(y_prob).any():
        # Replace NaN with 0.5 (neutral probability)
        y_prob = np.nan_to_num(y_prob, nan=0.5)
    
    # Only compute ROC AUC if we have both classes
    try:
        roc_auc = roc_auc_score(y_true, y_prob)
    except (ValueError, RuntimeError):
        roc_auc = 0.0
    
    return {
        'accuracy': accuracy_score(y_true, y_pred),
        'precision': precision_score(y_true, y_pred, zero_division=0),
        'recall': recall_score(y_true, y_pred, zero_division=0),
        'f1': f1_score(y_true, y_pred, zero_division=0),
        'roc_auc': roc_auc
    }

print("Training functions ready")

Training functions ready


## 4. Define Training & Evaluation Functions

Set up functions for training epochs and computing evaluation metrics.

In [None]:
history = {'train_loss': [], 'val_metrics': []}
best_val_f1 = 0
patience_counter = 0
EPOCHS = TRAINING_CONFIG['epochs']
PATIENCE = TRAINING_CONFIG['patience']

print("Training...")
start_time = time.time()

for epoch in range(1, EPOCHS + 1):
    loss = train_epoch()
    val_metrics = evaluate(data.val_mask)
    
    history['train_loss'].append(loss)
    history['val_metrics'].append(val_metrics)
    
    if epoch % 10 == 0:
        print(f"Epoch {epoch:3d} | Loss: {loss:.4f} | Val F1: {val_metrics['f1']:.4f} | "
              f"Val AUC: {val_metrics['roc_auc']:.4f}")
    
    if val_metrics['f1'] > best_val_f1:
        best_val_f1 = val_metrics['f1']
        patience_counter = 0
        save_path = ARTIFACTS_DIR / ARTIFACT_FILES['baseline_model']
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'val_f1': best_val_f1,
            'val_metrics': val_metrics
        }, save_path)
    else:
        patience_counter += 1
        if patience_counter >= PATIENCE:
            print(f"Early stopping at epoch {epoch}")
            break

training_time = time.time() - start_time
print(f"Training done in {training_time:.2f}s. Best Val F1: {best_val_f1:.4f}")

Training...
Epoch  10 | Loss: 0.3320 | Val F1: 0.1689 | Val AUC: 0.8313
Epoch  20 | Loss: 0.2668 | Val F1: 0.4790 | Val AUC: 0.9014
Epoch  30 | Loss: 0.2403 | Val F1: 0.5807 | Val AUC: 0.9216
Epoch  40 | Loss: 0.2261 | Val F1: 0.6159 | Val AUC: 0.9367
Epoch  50 | Loss: 0.2124 | Val F1: 0.6750 | Val AUC: 0.9444
Epoch  60 | Loss: 0.2044 | Val F1: 0.6965 | Val AUC: 0.9493
Epoch  70 | Loss: 0.1956 | Val F1: 0.7169 | Val AUC: 0.9522
Epoch  80 | Loss: 0.1928 | Val F1: 0.7249 | Val AUC: 0.9550
Epoch  90 | Loss: 0.1874 | Val F1: 0.7337 | Val AUC: 0.9578
Epoch 100 | Loss: 0.1823 | Val F1: 0.7381 | Val AUC: 0.9602
Epoch 110 | Loss: 0.1785 | Val F1: 0.7558 | Val AUC: 0.9617
Epoch 120 | Loss: 0.1780 | Val F1: 0.7619 | Val AUC: 0.9633
Epoch 130 | Loss: 0.1727 | Val F1: 0.7656 | Val AUC: 0.9652
Epoch 140 | Loss: 0.1705 | Val F1: 0.7729 | Val AUC: 0.9667
Epoch 150 | Loss: 0.1682 | Val F1: 0.7734 | Val AUC: 0.9682
Epoch 160 | Loss: 0.1645 | Val F1: 0.7795 | Val AUC: 0.9693
Epoch 170 | Loss: 0.1629 | V

## 5. Train Model with Early Stopping

Train the model for up to 200 epochs with early stopping based on validation F1 score.

In [9]:
model_path = ARTIFACTS_DIR / ARTIFACT_FILES['baseline_model']
checkpoint = torch.load(model_path)
model.load_state_dict(checkpoint['model_state_dict'])

train_metrics = evaluate(data.train_mask)
val_metrics = evaluate(data.val_mask)
test_metrics = evaluate(data.test_mask)

print("\nBaseline GAT Results:")
print(f"Train - F1: {train_metrics['f1']:.4f}, AUC: {train_metrics['roc_auc']:.4f}")
print(f"Val   - F1: {val_metrics['f1']:.4f}, AUC: {val_metrics['roc_auc']:.4f}")
print(f"Test  - F1: {test_metrics['f1']:.4f}, AUC: {test_metrics['roc_auc']:.4f}")


Baseline GAT Results:
Train - F1: 0.8105, AUC: 0.9750
Val   - F1: 0.7886, AUC: 0.9702
Test  - F1: 0.7938, AUC: 0.9698


## 6. Load Best Model & Evaluate

Load the best checkpoint and evaluate on train/val/test sets.

In [18]:
metrics_dict = {
    'model_type': 'GAT_Baseline',
    'training_time': training_time,
    'best_epoch': checkpoint['epoch'],
    'performance': {
        'train': train_metrics,
        'val': val_metrics,
        'test': test_metrics
    }
}

metrics_path = ARTIFACTS_DIR / ARTIFACT_FILES['baseline_metrics']
with open(metrics_path, 'w') as f:
    json.dump(metrics_dict, f, indent=2)

print(f"\nMetrics saved to {metrics_path}")
print("Proceed to 04_eval_baseline.ipynb")


Metrics saved to c:\Users\tusha\Documents\UT_Dallas\ACM_SP26\imple2\artifacts\gat_baseline_metrics.json
Proceed to 04_eval_baseline.ipynb


---

## âœ… Baseline Training Complete!

Model saved to `artifacts/gat_baseline.pt`. Proceed to **04_eval_baseline.ipynb** for detailed evaluation.

## 7. Save Metrics to JSON