# Bitcoin Transaction Fraud Detection: Model Training

This notebook focuses on implementing and training Graph Neural Network (GNN) models for Bitcoin transaction fraud detection. We'll use the processed data and engineered features from the previous notebooks to train several GNN architectures and compare their performance.

## Table of Contents
1. [Setup](#Setup)
2. [Loading Processed Data](#Loading-Processed-Data)
3. [GNN Model Implementation](#GNN-Model-Implementation)
   - [Graph Convolutional Network (GCN)](#Graph-Convolutional-Network-GCN)
   - [GraphSAGE](#GraphSAGE)
   - [Graph Attention Network (GAT)](#Graph-Attention-Network-GAT)
4. [Training Functions](#Training-Functions)
5. [Model Training](#Model-Training)
   - [Training GCN Model](#Training-GCN-Model)
   - [Training GraphSAGE Model](#Training-GraphSAGE-Model)
   - [Training GAT Model](#Training-GAT-Model)
6. [Model Comparison](#Model-Comparison)
7. [Saving Models](#Saving-Models)

## Setup

Let's import the necessary libraries and configure the environment.

In [None]:
import os
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import time
import copy
import json

from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv, SAGEConv, GATConv
import torch.nn as nn
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score
import logging

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Set random seeds for reproducibility
np.random.seed(42)
torch.manual_seed(42)

# Create directories
os.makedirs('models', exist_ok=True)

# Set device (GPU if available, otherwise CPU)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logger.info(f"Using device: {device}")

## Loading Processed Data

First, let's load the processed data and engineered features from the previous notebooks.

In [None]:
def load_processed_data(input_dir='data/processed'):
    """
    Load processed data from disk.
    
    Parameters:
    -----------
    input_dir : str
        Directory with processed data
        
    Returns:
    --------
    data : torch_geometric.data.Data
        PyTorch Geometric Data object
    split_idx : dict
        Dictionary containing indices for train/val/test splits
    """
    logger.info(f"Loading processed data from {input_dir}")
    
    # Load data object
    data_path = os.path.join(input_dir, 'data.pt')
    
    try:
        # Try direct loading first
        data = torch.load(data_path)
    except Exception as e:
        logger.warning(f"Failed to load data directly: {e}. Reconstructing from components...")
        # If that fails, try to reconstruct the data object from components
        features_path = os.path.join(input_dir, 'features.npy')
        labels_path = os.path.join(input_dir, 'labels.npy')
        edge_index_path = os.path.join(input_dir, 'edge_index.pt')
        
        # Check if we have combined features (from feature engineering)
        combined_features_path = os.path.join(input_dir, 'combined_features.npy')
        if os.path.exists(combined_features_path):
            logger.info("Loading combined features from feature engineering")
            features = torch.FloatTensor(np.load(combined_features_path))
        else:
            features = torch.FloatTensor(np.load(features_path))
            
        labels = torch.LongTensor(np.load(labels_path))
        edge_index = torch.load(edge_index_path)
        
        data = Data(x=features, edge_index=edge_index, y=labels)
    
    # Load splits
    split_idx = {}
    for split in ['train', 'val', 'test']:
        split_path = os.path.join(input_dir, f'{split}_idx.npy')
        if os.path.exists(split_path):
            split_idx[split] = np.load(split_path)
            logger.info(f"Loaded {split} indices with {len(split_idx[split])} samples")
    
    logger.info(f"Successfully loaded processed data from {input_dir}")
    logger.info(f"Data contains {data.num_nodes} nodes, {data.num_edges} edges, and {data.num_features} features")
    
    return data, split_idx

# Load processed data
try:
    data, split_idx = load_processed_data()
    
    # Print information about the data
    print("Data information:")
    print(f"Number of nodes: {data.num_nodes}")
    print(f"Number of edges: {data.num_edges}")
    print(f"Number of features: {data.num_features}")
    print(f"Number of classes: {len(torch.unique(data.y))}")
    
    # Print split sizes
    print("\nSplit sizes:")
    for split_name, indices in split_idx.items():
        print(f"{split_name.capitalize()}: {len(indices)} nodes")
        
    # Move data to device
    data = data.to(device)
    
except FileNotFoundError:
    logger.error("Processed data not found. Please run Data Preparation notebook first.")

## GNN Model Implementation

Now, let's implement the GNN models for our task. We'll implement three popular GNN architectures:
1. Graph Convolutional Network (GCN)
2. GraphSAGE
3. Graph Attention Network (GAT)

These models capture different aspects of graph structure and have different inductive biases.

### Graph Convolutional Network (GCN)

GCN is a popular GNN architecture that performs message passing by aggregating information from neighboring nodes. It's efficient and works well for many graph-based tasks.

In [None]:
class GCNModel(nn.Module):
    """
    Graph Convolutional Network model for transaction classification.
    
    Features:
    - Multiple GCN layers
    - Batch normalization
    - Residual connections
    - Dropout for regularization
    """
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers=3, 
                dropout=0.5, residual=True, batch_norm=True):
        """
        Initialize the GCN model.
        
        Parameters:
        -----------
        input_dim : int
            Dimension of input features
        hidden_dim : int
            Dimension of hidden layers
        output_dim : int
            Dimension of output (number of classes)
        num_layers : int
            Number of GCN layers
        dropout : float
            Dropout probability
        residual : bool
            Whether to use residual connections
        batch_norm : bool
            Whether to use batch normalization
        """
        super(GCNModel, self).__init__()
        
        self.num_layers = num_layers
        self.dropout = dropout
        self.residual = residual
        self.batch_norm = batch_norm
        
        # Input layer
        self.convs = nn.ModuleList([GCNConv(input_dim, hidden_dim)])
        
        # Hidden layers
        for _ in range(num_layers - 2):
            self.convs.append(GCNConv(hidden_dim, hidden_dim))
        
        # Output layer
        self.convs.append(GCNConv(hidden_dim, output_dim))
        
        # Batch normalization layers
        if batch_norm:
            self.bns = nn.ModuleList([
                nn.BatchNorm1d(hidden_dim) for _ in range(num_layers - 1)
            ])
        
        # Initialize parameters
        self.reset_parameters()
        
        logger.info(f"Initialized GCN model with {num_layers} layers")
        logger.info(f"Input dim: {input_dim}, Hidden dim: {hidden_dim}, Output dim: {output_dim}")
        
    def reset_parameters(self):
        """Reset all parameters for better initialization"""
        for conv in self.convs:
            conv.reset_parameters()
        
        if self.batch_norm:
            for bn in self.bns:
                bn.reset_parameters()
    
    def forward(self, x, edge_index):
        """
        Forward pass through the network.
        
        Parameters:
        -----------
        x : torch.Tensor
            Node features [num_nodes, input_dim]
        edge_index : torch.LongTensor
            Graph connectivity [2, num_edges]
            
        Returns:
        --------
        x : torch.Tensor
            Output predictions [num_nodes, output_dim]
        """
        # Input layer
        h = self.convs[0](x, edge_index)
        h = F.relu(h)
        h = F.dropout(h, p=self.dropout, training=self.training)
        
        # Hidden layers
        for i in range(1, self.num_layers - 1):
            h_prev = h
            h = self.convs[i](h, edge_index)
            
            if self.batch_norm:
                h = self.bns[i-1](h)
            
            h = F.relu(h)
            
            if self.residual:
                h = h + h_prev  # Residual connection
                
            h = F.dropout(h, p=self.dropout, training=self.training)
        
        # Output layer
        h = self.convs[-1](h, edge_index)
        
        return F.log_softmax(h, dim=1)

    def get_embeddings(self, x, edge_index, layer=-2):
        """
        Get embeddings from an intermediate layer.
        
        Parameters:
        -----------
        x : torch.Tensor
            Node features [num_nodes, input_dim]
        edge_index : torch.LongTensor
            Graph connectivity [2, num_edges]
        layer : int
            Index of the layer to extract embeddings from (negative indices count from end)
            
        Returns:
        --------
        embeddings : torch.Tensor
            Node embeddings
        """
        h = x
        
        # Process up to the desired layer
        max_layer = self.num_layers if layer >= 0 else self.num_layers + layer
        
        for i in range(max_layer):
            h = self.convs[i](h, edge_index)
            
            if i < self.num_layers - 1:  # Not the last layer
                if self.batch_norm and i > 0:
                    h = self.bns[i-1](h)
                
                h = F.relu(h)
                
                if self.residual and i > 0:
                    h_prev = h  # Store for residual connection
                    h = h + h_prev
                    
                h = F.dropout(h, p=self.dropout, training=self.training)
        
        return h

### GraphSAGE

GraphSAGE (Sample and Aggregate) is another popular GNN architecture that is designed to scale well to large graphs by sampling a fixed number of neighbors per node during message passing.

In [None]:
class SAGEModel(nn.Module):
    """
    GraphSAGE model for transaction classification.
    
    Features:
    - GraphSAGE convolutions
    - Batch normalization
    - Residual connections
    - Dropout for regularization
    """
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers=3, 
                dropout=0.5, residual=True, batch_norm=True, aggr='mean'):
        """
        Initialize the GraphSAGE model.
        
        Parameters:
        -----------
        input_dim : int
            Dimension of input features
        hidden_dim : int
            Dimension of hidden layers
        output_dim : int
            Dimension of output (number of classes)
        num_layers : int
            Number of SAGE layers
        dropout : float
            Dropout probability
        residual : bool
            Whether to use residual connections
        batch_norm : bool
            Whether to use batch normalization
        aggr : str
            Aggregation method ('mean', 'max', or 'sum')
        """
        super(SAGEModel, self).__init__()
        
        self.num_layers = num_layers
        self.dropout = dropout
        self.residual = residual
        self.batch_norm = batch_norm
        
        # Input layer
        self.convs = nn.ModuleList([SAGEConv(input_dim, hidden_dim, aggr=aggr)])
        
        # Hidden layers
        for _ in range(num_layers - 2):
            self.convs.append(SAGEConv(hidden_dim, hidden_dim, aggr=aggr))
        
        # Output layer
        self.convs.append(SAGEConv(hidden_dim, output_dim, aggr=aggr))
        
        # Batch normalization layers
        if batch_norm:
            self.bns = nn.ModuleList([
                nn.BatchNorm1d(hidden_dim) for _ in range(num_layers - 1)
            ])
        
        # Initialize parameters
        self.reset_parameters()
        
        logger.info(f"Initialized GraphSAGE model with {num_layers} layers")
        logger.info(f"Input dim: {input_dim}, Hidden dim: {hidden_dim}, Output dim: {output_dim}")
    
    def reset_parameters(self):
        """Reset all parameters for better initialization"""
        for conv in self.convs:
            conv.reset_parameters()
        
        if self.batch_norm:
            for bn in self.bns:
                bn.reset_parameters()
    
    def forward(self, x, edge_index):
        """
        Forward pass through the network.
        
        Parameters:
        -----------
        x : torch.Tensor
            Node features [num_nodes, input_dim]
        edge_index : torch.LongTensor
            Graph connectivity [2, num_edges]
            
        Returns:
        --------
        x : torch.Tensor
            Output predictions [num_nodes, output_dim]
        """
        # Input layer
        h = self.convs[0](x, edge_index)
        h = F.relu(h)
        h = F.dropout(h, p=self.dropout, training=self.training)
        
        # Hidden layers
        for i in range(1, self.num_layers - 1):
            h_prev = h
            h = self.convs[i](h, edge_index)
            
            if self.batch_norm:
                h = self.bns[i-1](h)
            
            h = F.relu(h)
            
            if self.residual:
                h = h + h_prev  # Residual connection
                
            h = F.dropout(h, p=self.dropout, training=self.training)
        
        # Output layer
        h = self.convs[-1](h, edge_index)
        
        return F.log_softmax(h, dim=1)

    def get_embeddings(self, x, edge_index, layer=-2):
        """
        Get embeddings from an intermediate layer.
        
        Parameters:
        -----------
        x : torch.Tensor
            Node features [num_nodes, input_dim]
        edge_index : torch.LongTensor
            Graph connectivity [2, num_edges]
        layer : int
            Index of the layer to extract embeddings from (negative indices count from end)
            
        Returns:
        --------
        embeddings : torch.Tensor
            Node embeddings
        """
        h = x
        
        # Process up to the desired layer
        max_layer = self.num_layers if layer >= 0 else self.num_layers + layer
        
        for i in range(max_layer):
            h = self.convs[i](h, edge_index)
            
            if i < self.num_layers - 1:  # Not the last layer
                if self.batch_norm and i > 0:
                    h = self.bns[i-1](h)
                
                h = F.relu(h)
                
                if self.residual and i > 0:
                    h_prev = h  # Store for residual connection
                    h = h + h_prev
                    
                h = F.dropout(h, p=self.dropout, training=self.training)
        
        return h

### Graph Attention Network (GAT)

GAT uses attention mechanisms to weigh the importance of neighboring nodes differently during message passing, which can be especially useful for fraud detection as it can learn to focus on suspicious connections.

In [None]:
class GATModel(nn.Module):
    """
    Graph Attention Network model for transaction classification.
    
    Features:
    - GAT layers with attention
    - Batch normalization
    - Residual connections
    - Dropout for regularization
    """
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers=3, 
                heads=8, dropout=0.5, residual=True, batch_norm=True):
        """
        Initialize the GAT model.
        
        Parameters:
        -----------
        input_dim : int
            Dimension of input features
        hidden_dim : int
            Dimension of hidden layers
        output_dim : int
            Dimension of output (number of classes)
        num_layers : int
            Number of GAT layers
        heads : int
            Number of attention heads
        dropout : float
            Dropout probability
        residual : bool
            Whether to use residual connections
        batch_norm : bool
            Whether to use batch normalization
        """
        super(GATModel, self).__init__()
        
        self.num_layers = num_layers
        self.dropout = dropout
        self.residual = residual
        self.batch_norm = batch_norm
        
        # Input layer (with multiple heads)
        self.convs = nn.ModuleList([GATConv(input_dim, hidden_dim // heads, heads=heads)])
        
        # Hidden layers
        for _ in range(num_layers - 2):
            self.convs.append(GATConv(hidden_dim, hidden_dim // heads, heads=heads))
        
        # Output layer (with 1 head)
        self.convs.append(GATConv(hidden_dim, output_dim, heads=1))
        
        # Batch normalization layers
        if batch_norm:
            self.bns = nn.ModuleList([
                nn.BatchNorm1d(hidden_dim) for _ in range(num_layers - 1)
            ])
        
        # Initialize parameters
        self.reset_parameters()
        
        logger.info(f"Initialized GAT model with {num_layers} layers and {heads} heads")
        logger.info(f"Input dim: {input_dim}, Hidden dim: {hidden_dim}, Output dim: {output_dim}")
    
    def reset_parameters(self):
        """Reset all parameters for better initialization"""
        for conv in self.convs:
            conv.reset_parameters()
        
        if self.batch_norm:
            for bn in self.bns:
                bn.reset_parameters()
    
    def forward(self, x, edge_index):
        """
        Forward pass through the network.
        
        Parameters:
        -----------
        x : torch.Tensor
            Node features [num_nodes, input_dim]
        edge_index : torch.LongTensor
            Graph connectivity [2, num_edges]
            
        Returns:
        --------
        x : torch.Tensor
            Output predictions [num_nodes, output_dim]
        """
        # Input layer
        h = self.convs[0](x, edge_index)
        h = F.relu(h)
        h = F.dropout(h, p=self.dropout, training=self.training)
        
        # Hidden layers
        for i in range(1, self.num_layers - 1):
            h_prev = h
            h = self.convs[i](h, edge_index)
            
            if self.batch_norm:
                h = self.bns[i-1](h)
            
            h = F.relu(h)
            
            if self.residual:
                h = h + h_prev  # Residual connection
                
            h = F.dropout(h, p=self.dropout, training=self.training)
        
        # Output layer
        h = self.convs[-1](h, edge_index)
        
        return F.log_softmax(h, dim=1)

    def get_embeddings(self, x, edge_index, layer=-2):
        """
        Get embeddings from an intermediate layer.
        
        Parameters:
        -----------
        x : torch.Tensor
            Node features [num_nodes, input_dim]
        edge_index : torch.LongTensor
            Graph connectivity [2, num_edges]
        layer : int
            Index of the layer to extract embeddings from (negative indices count from end)
            
        Returns:
        --------
        embeddings : torch.Tensor
            Node embeddings
        """
        h = x
        
        # Process up to the desired layer
        max_layer = self.num_layers if layer >= 0 else self.num_layers + layer
        
        for i in range(max_layer):
            h = self.convs[i](h, edge_index)
            
            if i < self.num_layers - 1:  # Not the last layer
                if self.batch_norm and i > 0:
                    h = self.bns[i-1](h)
                
                h = F.relu(h)
                
                if self.residual and i > 0:
                    h_prev = h  # Store for residual connection
                    h = h + h_prev
                    
                h = F.dropout(h, p=self.dropout, training=self.training)
        
        return h

## Training Functions

Let's define a training function that can be used for all GNN models. This function will handle the training loop, early stopping, and model evaluation.

In [None]:
def train_model(model, data, split_idx, optimizer, criterion, 
               scheduler=None, epochs=200, patience=20, 
               device='cpu', model_dir='models', model_name='gnn'):
    """
    Train a GNN model with early stopping.
    
    Parameters:
    -----------
    model : torch.nn.Module
        The model to train
    data : torch_geometric.data.Data
        The graph data
    split_idx : dict
        Dictionary containing indices for train/val/test splits
    optimizer : torch.optim.Optimizer
        The optimizer to use
    criterion : torch.nn.Module
        The loss function
    scheduler : torch.optim.lr_scheduler._LRScheduler, optional
        Learning rate scheduler
    epochs : int
        Maximum number of epochs
    patience : int
        Patience for early stopping
    device : str
        Device to use ('cpu' or 'cuda')
    model_dir : str
        Directory to save the model
    model_name : str
        Name of the model for saving
        
    Returns:
    --------
    model : torch.nn.Module
        The trained model
    history : dict
        Training history
    """
    # Prepare model directory
    os.makedirs(model_dir, exist_ok=True)
    
    # Move data to device
    data = data.to(device)
    
    # Training history
    history = {
        'train_loss': [],
        'val_loss': [],
        'train_acc': [],
        'val_acc': [],
        'train_auc': [],
        'val_auc': []
    }
    
    # Best model tracking
    best_val_loss = float('inf')
    best_model_state = None
    best_epoch = 0
    patience_counter = 0
    
    # Training loop
    logger.info(f"Starting training for {epochs} epochs (early stopping patience: {patience})")
    start_time = time.time()
    
    for epoch in range(epochs):
        epoch_start = time.time()
        
        # Train phase
        model.train()
        optimizer.zero_grad()
        out = model(data.x, data.edge_index)
        train_loss = criterion(out[split_idx['train']], data.y[split_idx['train']])
        train_loss.backward()
        optimizer.step()
        
        # Calculate training metrics
        with torch.no_grad():
            model.eval()
            out = model(data.x, data.edge_index)
            
            # Validation loss
            val_loss = criterion(out[split_idx['val']], data.y[split_idx['val']])
            
            # Accuracy
            pred = out.argmax(dim=1)
            train_correct = pred[split_idx['train']].eq(data.y[split_idx['train']]).sum().item()
            train_acc = train_correct / len(split_idx['train'])
            
            val_correct = pred[split_idx['val']].eq(data.y[split_idx['val']]).sum().item()
            val_acc = val_correct / len(split_idx['val'])
            
            # Calculate AUC if binary classification
            num_classes = torch.unique(data.y).shape[0]
            if num_classes == 2:
                try:
                    train_probs = torch.exp(out[split_idx['train'], 1]).cpu().numpy()
                    train_labels = data.y[split_idx['train']].cpu().numpy()
                    val_probs = torch.exp(out[split_idx['val'], 1]).cpu().numpy()
                    val_labels = data.y[split_idx['val']].cpu().numpy()
                    
                    train_auc = roc_auc_score(train_labels, train_probs)
                    val_auc = roc_auc_score(val_labels, val_probs)
                    
                    history['train_auc'].append(train_auc)
                    history['val_auc'].append(val_auc)
                except Exception as e:
                    # If AUC calculation fails, skip it
                    logger.warning(f"AUC calculation failed: {str(e)}")
                    history['train_auc'].append(0)
                    history['val_auc'].append(0)
        
        # Update learning rate
        if scheduler is not None:
            if isinstance(scheduler, ReduceLROnPlateau):
                scheduler.step(val_loss)
            else:
                scheduler.step()
        
        # Track history
        history['train_loss'].append(train_loss.item())
        history['val_loss'].append(val_loss.item())
        history['train_acc'].append(train_acc)
        history['val_acc'].append(val_acc)
        
        # Early stopping check
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_model_state = copy.deepcopy(model.state_dict())
            best_epoch = epoch
            patience_counter = 0
        else:
            patience_counter += 1
        
        # Print progress
        epoch_time = time.time() - epoch_start
        if epoch % 10 == 0 or epoch == epochs - 1:
            logger.info(f"Epoch {epoch+1}/{epochs} | "
                       f"Train Loss: {train_loss:.4f} | "
                       f"Val Loss: {val_loss:.4f} | "
                       f"Train Acc: {train_acc:.4f} | "
                       f"Val Acc: {val_acc:.4f} | "
                       f"Time: {epoch_time:.2f}s")
        
        # Check early stopping
        if patience_counter >= patience:
            logger.info(f"Early stopping triggered after {epoch+1} epochs")
            break
    
    total_time = time.time() - start_time
    logger.info(f"Training completed in {total_time:.2f}s | Best epoch: {best_epoch+1}")
    
    # Load best model
    model.load_state_dict(best_model_state)
    
    # Save model
    model_path = os.path.join(model_dir, f'{model_name}_best.pt')
    torch.save(best_model_state, model_path)
    logger.info(f"Saved best model to {model_path}")
    
    # Save training history
    history_path = os.path.join(model_dir, f'{model_name}_history.json')
    with open(history_path, 'w') as f:
        json.dump(history, f)
    logger.info(f"Saved training history to {history_path}")
    
    return model, history

In [None]:
def plot_training_history(history, model_name):
    """Plot training and validation loss and accuracy."""
    # Create figure with two subplots
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))
    
    # Plot training and validation loss
    epochs = range(1, len(history['train_loss']) + 1)
    ax1.plot(epochs, history['train_loss'], 'b-', label='Training Loss')
    ax1.plot(epochs, history['val_loss'], 'r-', label='Validation Loss')
    ax1.set_title(f'{model_name} - Loss')
    ax1.set_xlabel('Epochs')
    ax1.set_ylabel('Loss')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    # Plot training and validation accuracy
    ax2.plot(epochs, history['train_acc'], 'b-', label='Training Accuracy')
    ax2.plot(epochs, history['val_acc'], 'r-', label='Validation Accuracy')
    ax2.set_title(f'{model_name} - Accuracy')
    ax2.set_xlabel('Epochs')
    ax2.set_ylabel('Accuracy')
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    # If AUC is available (binary classification), plot it as well
    if 'train_auc' in history and history['train_auc']:
        plt.figure(figsize=(8, 6))
        plt.plot(epochs, history['train_auc'], 'b-', label='Training AUC')
        plt.plot(epochs, history['val_auc'], 'r-', label='Validation AUC')
        plt.title(f'{model_name} - ROC AUC')
        plt.xlabel('Epochs')
        plt.ylabel('AUC')
        plt.legend()
        plt.grid(True, alpha=0.3)
        plt.tight_layout()
        plt.show()

## Model Training

Now, let's train each of the GNN models on our transaction dataset. We'll define specific training functions for each model type.

### Training GCN Model

First, let's train the Graph Convolutional Network (GCN) model.

In [None]:
def train_gcn(data, split_idx, hidden_dim=256, num_layers=3, 
             dropout=0.5, lr=0.01, weight_decay=5e-4, 
             epochs=200, patience=20, device='cpu', model_dir='models'):
    """
    Train a GCN model with the given parameters.
    
    Parameters:
    -----------
    data : torch_geometric.data.Data
        The graph data
    split_idx : dict
        Dictionary containing indices for train/val/test splits
    hidden_dim : int
        Dimension of hidden layers
    num_layers : int
        Number of GCN layers
    dropout : float
        Dropout probability
    lr : float
        Learning rate
    weight_decay : float
        Weight decay factor
    epochs : int
        Maximum number of epochs
    patience : int
        Patience for early stopping
    device : str
        Device to use ('cpu' or 'cuda')
    model_dir : str
        Directory to save the model
        
    Returns:
    --------
    model : torch.nn.Module
        The trained model
    history : dict
        Training history
    """
    # Create model
    input_dim = data.x.shape[1]
    output_dim = len(torch.unique(data.y))
    
    model = GCNModel(
        input_dim=input_dim,
        hidden_dim=hidden_dim,
        output_dim=output_dim,
        num_layers=num_layers,
        dropout=dropout
    ).to(device)
    
    # Setup training
    optimizer = Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    criterion = torch.nn.NLLLoss()
    scheduler = ReduceLROnPlateau(optimizer, 'min', patience=10, factor=0.5, min_lr=1e-5, verbose=True)
    
    # Train model
    model, history = train_model(
        model=model,
        data=data,
        split_idx=split_idx,
        optimizer=optimizer,
        criterion=criterion,
        scheduler=scheduler,
        epochs=epochs,
        patience=patience,
        device=device,
        model_dir=model_dir,
        model_name='gcn'
    )
    
    return model, history

# Define hyperparameters
gcn_params = {
    'hidden_dim': 256,
    'num_layers': 3,
    'dropout': 0.5,
    'lr': 0.01,
    'weight_decay': 5e-4,
    'epochs': 200,
    'patience': 20,
    'device': device,
    'model_dir': 'models'
}

# Train GCN model
logger.info("Training GCN model")
gcn_model, gcn_history = train_gcn(data, split_idx, **gcn_params)

# Visualize training history
plot_training_history(gcn_history, 'GCN')

### Training GraphSAGE Model

Next, let's train the GraphSAGE model.

In [None]:
def train_sage(data, split_idx, hidden_dim=256, num_layers=3, 
              dropout=0.5, lr=0.01, weight_decay=5e-4, 
              epochs=200, patience=20, device='cpu', model_dir='models'):
    """
    Train a GraphSAGE model with the given parameters.
    """
    # Create model
    input_dim = data.x.shape[1]
    output_dim = len(torch.unique(data.y))
    
    model = SAGEModel(
        input_dim=input_dim,
        hidden_dim=hidden_dim,
        output_dim=output_dim,
        num_layers=num_layers,
        dropout=dropout
    ).to(device)
    
    # Setup training
    optimizer = Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    criterion = torch.nn.NLLLoss()
    scheduler = ReduceLROnPlateau(optimizer, 'min', patience=10, factor=0.5, min_lr=1e-5, verbose=True)
    
    # Train model
    model, history = train_model(
        model=model,
        data=data,
        split_idx=split_idx,
        optimizer=optimizer,
        criterion=criterion,
        scheduler=scheduler,
        epochs=epochs,
        patience=patience,
        device=device,
        model_dir=model_dir,
        model_name='sage'
    )
    
    return model, history

# Define hyperparameters
sage_params = {
    'hidden_dim': 256,
    'num_layers': 3,
    'dropout': 0.5,
    'lr': 0.01,
    'weight_decay': 5e-4,
    'epochs': 200,
    'patience': 20,
    'device': device,
    'model_dir': 'models'
}

# Train GraphSAGE model
logger.info("Training GraphSAGE model")
sage_model, sage_history = train_sage(data, split_idx, **sage_params)

# Visualize training history
plot_training_history(sage_history, 'GraphSAGE')

### Training GAT Model

Finally, let's train the Graph Attention Network (GAT) model. GATs can be computationally more intensive due to the attention mechanism, so we'll make sure to handle potential memory issues.

In [None]:
def train_gat(data, split_idx, hidden_dim=256, num_layers=3, 
             heads=8, dropout=0.5, lr=0.01, weight_decay=5e-4, 
             epochs=200, patience=20, device='cpu', model_dir='models'):
    """
    Train a GAT model with the given parameters.
    """
    # Create model
    input_dim = data.x.shape[1]
    output_dim = len(torch.unique(data.y))
    
    model = GATModel(
        input_dim=input_dim,
        hidden_dim=hidden_dim,
        output_dim=output_dim,
        num_layers=num_layers,
        heads=heads,
        dropout=dropout
    ).to(device)
    
    # Setup training
    optimizer = Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    criterion = torch.nn.NLLLoss()
    scheduler = ReduceLROnPlateau(optimizer, 'min', patience=10, factor=0.5, min_lr=1e-5, verbose=True)
    
    # Train model
    model, history = train_model(
        model=model,
        data=data,
        split_idx=split_idx,
        optimizer=optimizer,
        criterion=criterion,
        scheduler=scheduler,
        epochs=epochs,
        patience=patience,
        device=device,
        model_dir=model_dir,
        model_name='gat'
    )
    
    return model, history

# Define hyperparameters
gat_params = {
    'hidden_dim': 256,
    'num_layers': 3,
    'heads': 8,
    'dropout': 0.5,
    'lr': 0.01,
    'weight_decay': 5e-4,
    'epochs': 200,
    'patience': 20,
    'device': device,
    'model_dir': 'models'
}

# Train GAT model
try:
    logger.info("Training GAT model")
    gat_model, gat_history = train_gat(data, split_idx, **gat_params)
    
    # Visualize training history
    plot_training_history(gat_history, 'GAT')
    
    has_gat = True
except RuntimeError as e:
    if 'out of memory' in str(e).lower():
        logger.warning("Not enough memory for GAT model. Trying with fewer attention heads...")
        try:
            # Try with fewer attention heads
            gat_params['heads'] = 4
            gat_model, gat_history = train_gat(data, split_idx, **gat_params)
            plot_training_history(gat_history, 'GAT (4 heads)')
            has_gat = True
        except RuntimeError:
            logger.warning("Still not enough memory. Skipping GAT model.")
            has_gat = False
    else:
        logger.error(f"Error training GAT model: {e}")
        has_gat = False

## Model Comparison

Now that we've trained all the models, let's compare their performance on the validation set.

In [None]:
# Compare model performance
logger.info("Model comparison:")
best_val_gcn = min(gcn_history['val_loss']) if gcn_history['val_loss'] else float('inf')
best_val_sage = min(sage_history['val_loss']) if sage_history['val_loss'] else float('inf')

best_val_acc_gcn = max(gcn_history['val_acc']) if gcn_history['val_acc'] else 0
best_val_acc_sage = max(sage_history['val_acc']) if sage_history['val_acc'] else 0

if 'val_auc' in gcn_history and gcn_history['val_auc']:
    best_val_auc_gcn = max(gcn_history['val_auc'])
    best_val_auc_sage = max(sage_history['val_auc'])
    has_auc = True
else:
    has_auc = False

comparison = {
    'gcn': {'val_loss': best_val_gcn, 'val_acc': best_val_acc_gcn},
    'sage': {'val_loss': best_val_sage, 'val_acc': best_val_acc_sage}
}

if has_auc:
    comparison['gcn']['val_auc'] = best_val_auc_gcn
    comparison['sage']['val_auc'] = best_val_auc_sage

if 'has_gat' in locals() and has_gat:
    best_val_gat = min(gat_history['val_loss']) if gat_history['val_loss'] else float('inf')
    best_val_acc_gat = max(gat_history['val_acc']) if gat_history['val_acc'] else 0
    
    comparison['gat'] = {'val_loss': best_val_gat, 'val_acc': best_val_acc_gat}
    
    if has_auc:
        best_val_auc_gat = max(gat_history['val_auc'])
        comparison['gat']['val_auc'] = best_val_auc_gat

# Create a DataFrame for comparison
comparison_df = pd.DataFrame()
for model_name, metrics in comparison.items():
    model_df = pd.DataFrame({**{'model': model_name}, **metrics}, index=[0])
    comparison_df = pd.concat([comparison_df, model_df], ignore_index=True)

# Sort by validation loss
comparison_df = comparison_df.sort_values('val_loss')

# Display comparison results
print("Model comparison:")
print(comparison_df)

# Determine the best model
best_model = comparison_df.iloc[0]['model']
print(f"\nBest model based on validation loss: {best_model}")
best_metrics = comparison[best_model]
for metric, value in best_metrics.items():
    print(f"  {metric}: {value:.4f}")

# Save model comparison results
with open(os.path.join('models', 'model_comparison.json'), 'w') as f:
    json.dump(comparison, f, indent=4)
print("\nSaved model comparison results to models/model_comparison.json")

Let's create a visualization to compare the performance of the models.

In [None]:
# Visualize model comparison
fig, axes = plt.subplots(1, 2, figsize=(16, 6))

# Bar plot for validation loss
sns.barplot(x='model', y='val_loss', data=comparison_df, ax=axes[0], palette='viridis')
axes[0].set_title('Validation Loss by Model')
axes[0].set_xlabel('Model')
axes[0].set_ylabel('Validation Loss')
axes[0].grid(True, axis='y', alpha=0.3)

# Bar plot for validation accuracy
sns.barplot(x='model', y='val_acc', data=comparison_df, ax=axes[1], palette='viridis')
axes[1].set_title('Validation Accuracy by Model')
axes[1].set_xlabel('Model')
axes[1].set_ylabel('Validation Accuracy')
axes[1].grid(True, axis='y', alpha=0.3)

plt.tight_layout()
plt.show()

# Plot validation AUC if available
if has_auc:
    plt.figure(figsize=(8, 6))
    sns.barplot(x='model', y='val_auc', data=comparison_df, palette='viridis')
    plt.title('Validation AUC by Model')
    plt.xlabel('Model')
    plt.ylabel('Validation AUC')
    plt.grid(True, axis='y', alpha=0.3)
    plt.tight_layout()
    plt.show()

## Saving Models

Finally, let's save the best model for use in evaluation.

In [None]:
# Copy the best model to 'best_model.pt'
best_model_path = os.path.join('models', f'{best_model}_best.pt')
import shutil
shutil.copy(best_model_path, os.path.join('models', 'best_model.pt'))

# Save the name of the best model
with open(os.path.join('models', 'best_model_name.txt'), 'w') as f:
    f.write(best_model)

print(f"Saved best model ({best_model}) as 'best_model.pt'")

## Summary

In this notebook, we've implemented and trained three different Graph Neural Network architectures for Bitcoin transaction fraud detection:

1. Graph Convolutional Network (GCN)
2. GraphSAGE
3. Graph Attention Network (GAT)

We've compared their performance on the validation set and identified the best model based on validation loss. The best model has been saved for use in the evaluation notebook.

Key findings:
- Each model architecture has its strengths and weaknesses
- The best model for this dataset is `{best_model}`
- The models were trained with early stopping to prevent overfitting
- We've visualized the training history to better understand the learning process

In the next notebook, we'll evaluate the best model on the test set and analyze its performance in detail.