# Bitcoin Transaction Fraud Detection: Model Evaluation

This notebook focuses on evaluating the Graph Neural Network (GNN) models we trained in the previous notebook. We'll conduct a comprehensive evaluation on the test set, including metrics such as accuracy, precision, recall, F1-score, ROC curves, and confusion matrices.

## Table of Contents
1. [Setup](#Setup)
2. [Loading Data and Models](#Loading-Data-and-Models)
3. [Evaluation Metrics](#Evaluation-Metrics)
4. [Model Evaluation](#Model-Evaluation)
5. [Performance Analysis](#Performance-Analysis)
   - [Confusion Matrix](#Confusion-Matrix)
   - [ROC Curve](#ROC-Curve)
   - [Precision-Recall Curve](#Precision-Recall-Curve)
   - [Class-specific Performance](#Class-specific-Performance)
6. [Error Analysis](#Error-Analysis)
7. [Model Comparison](#Model-Comparison)
8. [Generate Evaluation Report](#Generate-Evaluation-Report)

## Setup

Let's import the necessary libraries and configure the environment.

In [None]:
import os
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import (
    classification_report, roc_curve, auc, 
    precision_recall_curve, average_precision_score,
    confusion_matrix, f1_score, roc_auc_score
)
import logging
import json

# Import the GNN models
import sys
sys.path.append('.') # Add current directory to path to import from notebooks

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Set random seeds for reproducibility
np.random.seed(42)
torch.manual_seed(42)

# Create directories
os.makedirs('reports', exist_ok=True)
os.makedirs('reports/figures', exist_ok=True)

# Set device (GPU if available, otherwise CPU)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logger.info(f"Using device: {device}")

## Loading Data and Models

First, let's load the processed data and the trained GNN models.

In [None]:
# Define model classes to load the models
class GCNModel(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers=3, dropout=0.5):
        super(GCNModel, self).__init__()
        # This is just a skeleton for loading the model weights
        # The actual implementation is in the training notebook
        pass

class SAGEModel(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers=3, dropout=0.5):
        super(SAGEModel, self).__init__()
        # This is just a skeleton for loading the model weights
        pass

class GATModel(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers=3, heads=8, dropout=0.5):
        super(GATModel, self).__init__()
        # This is just a skeleton for loading the model weights
        pass

def load_processed_data(input_dir='data/processed'):
    """
    Load processed data from disk.
    
    Parameters:
    -----------
    input_dir : str
        Directory with processed data
        
    Returns:
    --------
    data : torch_geometric.data.Data
        PyTorch Geometric Data object
    split_idx : dict
        Dictionary containing indices for train/val/test splits
    """
    logger.info(f"Loading processed data from {input_dir}")
    
    # Load data object
    data_path = os.path.join(input_dir, 'data.pt')
    
    try:
        # Try direct loading first
        data = torch.load(data_path)
    except Exception as e:
        logger.warning(f"Failed to load data directly: {e}. Reconstructing from components...")
        # If that fails, try to reconstruct the data object from components
        features_path = os.path.join(input_dir, 'features.npy')
        labels_path = os.path.join(input_dir, 'labels.npy')
        edge_index_path = os.path.join(input_dir, 'edge_index.pt')
        
        # Check if we have combined features (from feature engineering)
        combined_features_path = os.path.join(input_dir, 'combined_features.npy')
        if os.path.exists(combined_features_path):
            logger.info("Loading combined features from feature engineering")
            features = torch.FloatTensor(np.load(combined_features_path))
        else:
            features = torch.FloatTensor(np.load(features_path))
            
        labels = torch.LongTensor(np.load(labels_path))
        edge_index = torch.load(edge_index_path)
        
        from torch_geometric.data import Data
        data = Data(x=features, edge_index=edge_index, y=labels)
    
    # Load splits
    split_idx = {}
    for split in ['train', 'val', 'test']:
        split_path = os.path.join(input_dir, f'{split}_idx.npy')
        if os.path.exists(split_path):
            split_idx[split] = np.load(split_path)
            logger.info(f"Loaded {split} indices with {len(split_idx[split])} samples")
    
    logger.info(f"Successfully loaded processed data from {input_dir}")
    logger.info(f"Data contains {data.num_nodes} nodes, {data.num_edges} edges, and {data.num_features} features")
    
    return data, split_idx

# Load processed data
try:
    data, split_idx = load_processed_data()
    
    # Print information about the data
    print("Data information:")
    print(f"Number of nodes: {data.num_nodes}")
    print(f"Number of edges: {data.num_edges}")
    print(f"Number of features: {data.num_features}")
    print(f"Number of classes: {len(torch.unique(data.y))}")
    
    # Print split sizes
    print("\nSplit sizes:")
    for split_name, indices in split_idx.items():
        print(f"{split_name.capitalize()}: {len(indices)} nodes")
        
    # Move data to device
    data = data.to(device)
    
    # Load information about the best model
    best_model_name_path = os.path.join('models', 'best_model_name.txt')
    if os.path.exists(best_model_name_path):
        with open(best_model_name_path, 'r') as f:
            best_model_name = f.read().strip()
        print(f"\nBest model: {best_model_name}")
    else:
        logger.warning("Best model name not found. Will evaluate individual models if available.")
        best_model_name = None
    
except FileNotFoundError as e:
    logger.error(f"Error loading data: {e}. Please make sure you've run the previous notebooks.")

## Evaluation Metrics

Let's define functions to evaluate our models and compute various performance metrics.

In [None]:
def evaluate_model(model, data, split_idx, criterion=None, device='cpu'):
    """
    Evaluate model performance.
    
    Parameters:
    -----------
    model : torch.nn.Module
        The trained model
    data : torch_geometric.data.Data
        The graph data
    split_idx : dict
        Dictionary containing indices for train/val/test splits
    criterion : torch.nn.Module, optional
        Loss function to calculate loss
    device : str
        Device to use ('cpu' or 'cuda')
        
    Returns:
    --------
    metrics : dict
        Dictionary containing evaluation metrics
    raw_data : dict
        Dictionary containing raw predictions
    """
    # Move data to device
    data = data.to(device)
    
    # Set model to evaluation mode
    model.eval()
    
    # Inference
    with torch.no_grad():
        out = model(data.x, data.edge_index)
        
        # Calculate loss if criterion is provided
        loss = {}
        if criterion is not None:
            for split in split_idx:
                loss[split] = criterion(out[split_idx[split]], data.y[split_idx[split]]).item()
        
        # Get predictions and probabilities
        preds = {}
        probs = {}
        
        # Get raw probabilities
        raw_probs = torch.exp(out)
        
        # Number of classes
        num_classes = raw_probs.shape[1]
        
        for split in split_idx:
            preds[split] = out.argmax(dim=1)[split_idx[split]].cpu().numpy()
            
            # For binary classification, use probability of class 1
            # For multi-class, use all probabilities
            if num_classes == 2:
                probs[split] = raw_probs[split_idx[split], 1].cpu().numpy()
            else:
                probs[split] = raw_probs[split_idx[split]].cpu().numpy()
    
    # Collect true labels
    y_true = {}
    for split in split_idx:
        y_true[split] = data.y[split_idx[split]].cpu().numpy()
    
    # Calculate metrics
    metrics = {split: {} for split in split_idx}
    
    for split in split_idx:
        # Add loss if available
        if criterion is not None and split in loss:
            metrics[split]['loss'] = loss[split]
        
        # Classification report
        try:
            report = classification_report(y_true[split], preds[split], output_dict=True)
            
            # Add metrics from report
            for k, v in report.items():
                if isinstance(v, dict):  # Class-specific metrics
                    for metric, value in v.items():
                        metrics[split][f"{k}_{metric}"] = value
                else:  # Overall metrics like accuracy
                    metrics[split][k] = v
        except Exception as e:
            logger.warning(f"Error generating classification report for {split}: {str(e)}")
            metrics[split]['accuracy'] = (y_true[split] == preds[split]).mean()
        
        # Confusion Matrix
        metrics[split]['confusion_matrix'] = confusion_matrix(y_true[split], preds[split]).tolist()
        
        # Calculate macro-average metrics if there are multiple classes
        unique_classes = np.unique(y_true[split])
        if len(unique_classes) > 1:
            # ROC AUC (one-vs-rest for multi-class)
            try:
                if num_classes == 2:
                    metrics[split]['roc_auc'] = roc_auc_score(y_true[split], probs[split])
                else:
                    # For multi-class, compute one-vs-rest AUC for each class
                    aucs = []
                    for i in range(num_classes):
                        if i in unique_classes:
                            y_true_bin = (y_true[split] == i).astype(int)
                            if isinstance(probs[split], np.ndarray) and probs[split].ndim == 2:
                                class_probs = probs[split][:, i]
                                aucs.append(roc_auc_score(y_true_bin, class_probs))
                    if aucs:
                        metrics[split]['roc_auc'] = np.mean(aucs)
                    else:
                        metrics[split]['roc_auc'] = float('nan')
            except Exception as e:
                logger.warning(f"Error calculating ROC AUC for {split}: {str(e)}")
                metrics[split]['roc_auc'] = float('nan')
                
            # PR AUC (one-vs-rest for multi-class)
            try:
                if num_classes == 2:
                    metrics[split]['pr_auc'] = average_precision_score(y_true[split], probs[split])
                else:
                    # For multi-class, compute one-vs-rest PR AUC for each class
                    pr_aucs = []
                    for i in range(num_classes):
                        if i in unique_classes:
                            y_true_bin = (y_true[split] == i).astype(int)
                            if isinstance(probs[split], np.ndarray) and probs[split].ndim == 2:
                                class_probs = probs[split][:, i]
                                pr_aucs.append(average_precision_score(y_true_bin, class_probs))
                    if pr_aucs:
                        metrics[split]['pr_auc'] = np.mean(pr_aucs)
                    else:
                        metrics[split]['pr_auc'] = float('nan')
            except Exception as e:
                logger.warning(f"Error calculating PR AUC for {split}: {str(e)}")
                metrics[split]['pr_auc'] = float('nan')
    
    # Store raw predictions for further analysis
    raw_data = {
        split: {
            'y_true': y_true[split],
            'y_pred': preds[split],
            'probabilities': probs[split]
        } for split in split_idx
    }
    
    return metrics, raw_data

In [None]:
def plot_roc_curve(y_true, y_score, output_path=None):
    """
    Plot ROC curve.
    
    Parameters:
    -----------
    y_true : numpy.ndarray
        True binary labels
    y_score : numpy.ndarray
        Target scores (probabilities)
    output_path : str, optional
        Path to save the plot
        
    Returns:
    --------
    fig : matplotlib.figure.Figure
        The generated figure
    """
    # Check if it's binary classification
    if len(np.unique(y_true)) != 2:
        logger.warning("ROC curve requires binary classification. Skipping.")
        return None
    
    fpr, tpr, _ = roc_curve(y_true, y_score)
    roc_auc = auc(fpr, tpr)
    
    fig, ax = plt.subplots(figsize=(8, 6))
    ax.plot(fpr, tpr, color='darkorange', lw=2, 
            label=f'ROC curve (area = {roc_auc:.3f})')
    ax.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
    ax.set_xlim([0.0, 1.0])
    ax.set_ylim([0.0, 1.05])
    ax.set_xlabel('False Positive Rate')
    ax.set_ylabel('True Positive Rate')
    ax.set_title('Receiver Operating Characteristic (ROC) Curve')
    ax.legend(loc="lower right")
    ax.grid(True, alpha=0.3)
    
    if output_path:
        plt.savefig(output_path, dpi=300, bbox_inches='tight')
    
    return fig

def plot_precision_recall_curve(y_true, y_score, output_path=None):
    """
    Plot precision-recall curve.
    
    Parameters:
    -----------
    y_true : numpy.ndarray
        True binary labels
    y_score : numpy.ndarray
        Target scores (probabilities)
    output_path : str, optional
        Path to save the plot
        
    Returns:
    --------
    fig : matplotlib.figure.Figure
        The generated figure
    """
    # Check if it's binary classification
    if len(np.unique(y_true)) != 2:
        logger.warning("Precision-Recall curve requires binary classification. Skipping.")
        return None
    
    precision, recall, _ = precision_recall_curve(y_true, y_score)
    avg_precision = average_precision_score(y_true, y_score)
    
    fig, ax = plt.subplots(figsize=(8, 6))
    ax.step(recall, precision, color='darkorange', lw=2, where='post',
            label=f'AP = {avg_precision:.3f}')
    ax.fill_between(recall, precision, step='post', alpha=0.2, color='darkorange')
    ax.set_xlim([0.0, 1.0])
    ax.set_ylim([0.0, 1.05])
    ax.set_xlabel('Recall')
    ax.set_ylabel('Precision')
    ax.set_title('Precision-Recall Curve')
    ax.legend(loc="lower left")
    ax.grid(True, alpha=0.3)
    
    if output_path:
        plt.savefig(output_path, dpi=300, bbox_inches='tight')
    
    return fig

def plot_confusion_matrix(y_true, y_pred, output_path=None):
    """
    Plot confusion matrix.
    
    Parameters:
    -----------
    y_true : numpy.ndarray
        True labels
    y_pred : numpy.ndarray
        Predicted labels
    output_path : str, optional
        Path to save the plot
        
    Returns:
    --------
    fig : matplotlib.figure.Figure
        The generated figure
    """
    cm = confusion_matrix(y_true, y_pred)
    
    fig, ax = plt.subplots(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=ax)
    ax.set_xlabel('Predicted labels')
    ax.set_ylabel('True labels')
    ax.set_title('Confusion Matrix')
    
    # Set x and y tick labels
    classes = sorted(np.unique(np.concatenate((y_true, y_pred))))
    class_labels = ['Legitimate' if c==0 else 'Fraudulent' for c in classes]
    ax.set_xticklabels(class_labels)
    ax.set_yticklabels(class_labels)
    
    if output_path:
        plt.savefig(output_path, dpi=300, bbox_inches='tight')
    
    return fig

In [None]:
def generate_evaluation_report(metrics, raw_data, model_name, output_dir='reports'):
    """
    Generate a comprehensive evaluation report.
    
    Parameters:
    -----------
    metrics : dict
        Dictionary containing evaluation metrics
    raw_data : dict
        Dictionary containing raw predictions
    model_name : str
        Name of the model
    output_dir : str
        Directory to save report files
        
    Returns:
    --------
    report_path : str
        Path to the generated report
    """
    os.makedirs(output_dir, exist_ok=True)
    
    # Create report subdirectory
    report_dir = os.path.join(output_dir, model_name)
    os.makedirs(report_dir, exist_ok=True)
    
    # Create figures directory
    figures_dir = os.path.join(report_dir, 'figures')
    os.makedirs(figures_dir, exist_ok=True)
    
    # Generate plots for binary classification
    for split in raw_data:
        y_true = raw_data[split]['y_true']
        y_pred = raw_data[split]['y_pred']
        probs = raw_data[split]['probabilities']
        
        # Confusion matrix (works for any number of classes)
        plot_confusion_matrix(y_true, y_pred, 
                             output_path=os.path.join(figures_dir, f'{split}_confusion_matrix.png'))
        
        # ROC and PR curves (only for binary classification)
        if len(np.unique(y_true)) == 2:
            # ROC curve
            plot_roc_curve(y_true, probs, 
                          output_path=os.path.join(figures_dir, f'{split}_roc_curve.png'))
            
            # Precision-Recall curve
            plot_precision_recall_curve(y_true, probs, 
                                       output_path=os.path.join(figures_dir, f'{split}_pr_curve.png'))
    
    # Save metrics to JSON
    metrics_path = os.path.join(report_dir, 'metrics.json')
    with open(metrics_path, 'w') as f:
        json.dump(metrics, f, indent=4)
    
    # Create markdown report
    report_md = f"# Evaluation Report for {model_name}\n\n"
    
    # Add summary section
    report_md += "## Summary\n\n"
    report_md += "| Metric | Train | Validation | Test |\n"
    report_md += "|--------|-------|------------|------|\n"
    
    # Key metrics to include in summary
    key_metrics = ['accuracy', 'weighted avg_f1-score', 'weighted avg_precision', 'weighted avg_recall']
    
    # Add ROC AUC and PR AUC if available (binary classification)
    if 'roc_auc' in metrics.get('test', {}):
        key_metrics.extend(['roc_auc', 'pr_auc'])
    
    for metric in key_metrics:
        train_val = metrics.get('train', {}).get(metric, 'N/A')
        val_val = metrics.get('val', {}).get(metric, 'N/A')
        test_val = metrics.get('test', {}).get(metric, 'N/A')
        
        # Format values
        if isinstance(train_val, float):
            train_val = f"{train_val:.4f}"
        if isinstance(val_val, float):
            val_val = f"{val_val:.4f}"
        if isinstance(test_val, float):
            test_val = f"{test_val:.4f}"
        
        report_md += f"| {metric} | {train_val} | {val_val} | {test_val} |\n"
    
    # Add class-specific metrics if available
    classes = set()
    for split in metrics:
        for k in metrics[split]:
            if '_' in k and k.split('_')[0].isdigit():
                classes.add(int(k.split('_')[0]))
    
    if classes:
        report_md += "\n## Class-specific Metrics (Test Set)\n\n"
        report_md += "| Class | Precision | Recall | F1-Score | Support |\n"
        report_md += "|-------|-----------|--------|----------|--------|\n"
        
        for cls in sorted(classes):
            precision = metrics.get('test', {}).get(f"{cls}_precision", 'N/A')
            recall = metrics.get('test', {}).get(f"{cls}_recall", 'N/A')
            f1 = metrics.get('test', {}).get(f"{cls}_f1-score", 'N/A')
            support = metrics.get('test', {}).get(f"{cls}_support", 'N/A')
            
            # Format values
            if isinstance(precision, float):
                precision = f"{precision:.4f}"
            if isinstance(recall, float):
                recall = f"{recall:.4f}"
            if isinstance(f1, float):
                f1 = f"{f1:.4f}"
            
            report_md += f"| {cls} | {precision} | {recall} | {f1} | {support} |\n"
    
    # Add detailed metrics section
    report_md += "\n## Detailed Metrics\n\n"
    
    for split in metrics:
        report_md += f"### {split.capitalize()} Set\n\n"
        
        # Add confusion matrix
        report_md += "#### Confusion Matrix\n\n"
        report_md += f"![Confusion Matrix](figures/{split}_confusion_matrix.png)\n\n"
        
        # Add ROC curve and PR curve if available (binary classification)
        if 'roc_auc' in metrics[split]:
            split_metrics = metrics[split]
            unique_classes = set()
            for k in split_metrics:
                if '_' in k and k.split('_')[0].isdigit():
                    unique_classes.add(int(k.split('_')[0]))
            
            if len(unique_classes) == 2:
                # Add ROC curve
                report_md += "#### ROC Curve\n\n"
                report_md += f"![ROC Curve](figures/{split}_roc_curve.png)\n\n"
                
                # Add PR curve
                report_md += "#### Precision-Recall Curve\n\n"
                report_md += f"![PR Curve](figures/{split}_pr_curve.png)\n\n"
    
    # Save report
    report_path = os.path.join(report_dir, 'evaluation_report.md')
    with open(report_path, 'w') as f:
        f.write(report_md)
    
    logger.info(f"Evaluation report generated at {report_path}")
    
    return report_path

## Model Evaluation

Now, let's evaluate the trained models on the test set.

In [None]:
def load_and_evaluate_model(model_name, data, split_idx, device):
    """
    Load and evaluate a model.
    
    Parameters:
    -----------
    model_name : str
        Name of the model to load
    data : torch_geometric.data.Data
        The graph data
    split_idx : dict
        Dictionary containing indices for train/val/test splits
    device : str
        Device to use ('cpu' or 'cuda')
        
    Returns:
    --------
    metrics : dict
        Dictionary containing evaluation metrics
    raw_data : dict
        Dictionary containing raw predictions
    """
    # Path to model weights
    model_path = os.path.join('models', f'{model_name}_best.pt')
    
    if not os.path.exists(model_path):
        logger.error(f"Model weights not found at {model_path}")
        return None, None
    
    # Create appropriate model
    input_dim = data.x.shape[1]
    hidden_dim = 256  # Same as in training
    output_dim = len(torch.unique(data.y))
    
    if model_name.lower() == 'gcn':
        from torch_geometric.nn import GCNConv
        
        class GCNModel(torch.nn.Module):
            def __init__(self, input_dim, hidden_dim, output_dim, num_layers=3, dropout=0.5):
                super(GCNModel, self).__init__()
                self.convs = torch.nn.ModuleList()
                self.bns = torch.nn.ModuleList()
                self.num_layers = num_layers
                self.dropout = dropout
                
                # Input layer
                self.convs.append(GCNConv(input_dim, hidden_dim))
                
                # Hidden layers
                for _ in range(num_layers - 2):
                    self.convs.append(GCNConv(hidden_dim, hidden_dim))
                    self.bns.append(torch.nn.BatchNorm1d(hidden_dim))
                
                # Output layer
                self.convs.append(GCNConv(hidden_dim, output_dim))
            
            def forward(self, x, edge_index):
                # Input layer
                h = self.convs[0](x, edge_index)
                h = torch.relu(h)
                h = torch.nn.functional.dropout(h, p=self.dropout, training=self.training)
                
                # Hidden layers
                for i in range(1, self.num_layers - 1):
                    h_prev = h
                    h = self.convs[i](h, edge_index)
                    h = self.bns[i-1](h)
                    h = torch.relu(h)
                    h = h + h_prev  # Residual connection
                    h = torch.nn.functional.dropout(h, p=self.dropout, training=self.training)
                
                # Output layer
                h = self.convs[-1](h, edge_index)
                
                return torch.nn.functional.log_softmax(h, dim=1)
                
        model = GCNModel(input_dim=input_dim, hidden_dim=hidden_dim, output_dim=output_dim).to(device)
        
    elif model_name.lower() == 'sage':
        from torch_geometric.nn import SAGEConv
        
        class SAGEModel(torch.nn.Module):
            def __init__(self, input_dim, hidden_dim, output_dim, num_layers=3, dropout=0.5):
                super(SAGEModel, self).__init__()
                self.convs = torch.nn.ModuleList()
                self.bns = torch.nn.ModuleList()
                self.num_layers = num_layers
                self.dropout = dropout
                
                # Input layer
                self.convs.append(SAGEConv(input_dim, hidden_dim))
                
                # Hidden layers
                for _ in range(num_layers - 2):
                    self.convs.append(SAGEConv(hidden_dim, hidden_dim))
                    self.bns.append(torch.nn.BatchNorm1d(hidden_dim))
                
                # Output layer
                self.convs.append(SAGEConv(hidden_dim, output_dim))
            
            def forward(self, x, edge_index):
                # Input layer
                h = self.convs[0](x, edge_index)
                h = torch.relu(h)
                h = torch.nn.functional.dropout(h, p=self.dropout, training=self.training)
                
                # Hidden layers
                for i in range(1, self.num_layers - 1):
                    h_prev = h
                    h = self.convs[i](h, edge_index)
                    h = self.bns[i-1](h)
                    h = torch.relu(h)
                    h = h + h_prev  # Residual connection
                    h = torch.nn.functional.dropout(h, p=self.dropout, training=self.training)
                
                # Output layer
                h = self.convs[-1](h, edge_index)
                
                return torch.nn.functional.log_softmax(h, dim=1)
                
        model = SAGEModel(input_dim=input_dim, hidden_dim=hidden_dim, output_dim=output_dim).to(device)
        
    elif model_name.lower() == 'gat':
        from torch_geometric.nn import GATConv
        
        class GATModel(torch.nn.Module):
            def __init__(self, input_dim, hidden_dim, output_dim, num_layers=3, heads=8, dropout=0.5):
                super(GATModel, self).__init__()
                self.convs = torch.nn.ModuleList()
                self.bns = torch.nn.ModuleList()
                self.num_layers = num_layers
                self.dropout = dropout
                
                # Input layer
                self.convs.append(GATConv(input_dim, hidden_dim // heads, heads=heads))
                
                # Hidden layers
                for _ in range(num_layers - 2):
                    self.convs.append(GATConv(hidden_dim, hidden_dim // heads, heads=heads))
                    self.bns.append(torch.nn.BatchNorm1d(hidden_dim))
                
                # Output layer
                self.convs.append(GATConv(hidden_dim, output_dim, heads=1))
            
            def forward(self, x, edge_index):
                # Input layer
                h = self.convs[0](x, edge_index)
                h = torch.relu(h)
                h = torch.nn.functional.dropout(h, p=self.dropout, training=self.training)
                
                # Hidden layers
                for i in range(1, self.num_layers - 1):
                    h_prev = h
                    h = self.convs[i](h, edge_index)
                    h = self.bns[i-1](h)
                    h = torch.relu(h)
                    h = h + h_prev  # Residual connection
                    h = torch.nn.functional.dropout(h, p=self.dropout, training=self.training)
                
                # Output layer
                h = self.convs[-1](h, edge_index)
                
                return torch.nn.functional.log_softmax(h, dim=1)
                
        model = GATModel(input_dim=input_dim, hidden_dim=hidden_dim, output_dim=output_dim).to(device)
    else:
        logger.error(f"Unknown model type: {model_name}")
        return None, None
    
    try:
        # Load model weights
        logger.info(f"Loading model weights from {model_path}")
        model.load_state_dict(torch.load(model_path, map_location=device))
        
        # Set model to evaluation mode
        model.eval()
        
        # Define loss function
        criterion = torch.nn.NLLLoss()
        
        # Evaluate model
        logger.info(f"Evaluating {model_name} model")
        metrics, raw_data = evaluate_model(
            model=model,
            data=data,
            split_idx=split_idx,
            criterion=criterion,
            device=device
        )
        
        return metrics, raw_data
    except Exception as e:
        logger.error(f"Error evaluating model: {e}")
        return None, None

# Evaluate the best model
if best_model_name is not None:
    logger.info(f"Evaluating best model: {best_model_name}")
    metrics, raw_data = load_and_evaluate_model(best_model_name, data, split_idx, device)
    
    if metrics is not None:
        # Generate evaluation report
        report_path = generate_evaluation_report(metrics, raw_data, best_model_name)
        print(f"Evaluation report generated at {report_path}")
    else:
        logger.warning("Failed to evaluate best model. Will try individual models.")
else:
    logger.warning("Best model name not found. Will evaluate individual models.")

## Performance Analysis

Let's analyze the performance of the best model on the test set.

### Confusion Matrix

First, let's visualize the confusion matrix to understand the model's performance.

In [None]:
# Display confusion matrix for test set
y_true = raw_data['test']['y_true']
y_pred = raw_data['test']['y_pred']
plt.figure(figsize=(10, 8))
cm = confusion_matrix(y_true, y_pred)
cm_norm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

# Plot confusion matrix
fig, axes = plt.subplots(1, 2, figsize=(18, 8))

# Absolute counts
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=axes[0])
axes[0].set_xlabel('Predicted labels')
axes[0].set_ylabel('True labels')
axes[0].set_title('Confusion Matrix - Absolute Counts')
axes[0].set_xticklabels(['Legitimate', 'Fraudulent'])
axes[0].set_yticklabels(['Legitimate', 'Fraudulent'])

# Normalized
sns.heatmap(cm_norm, annot=True, fmt='.2f', cmap='Blues', ax=axes[1])
axes[1].set_xlabel('Predicted labels')
axes[1].set_ylabel('True labels')
axes[1].set_title('Confusion Matrix - Normalized by True Labels')
axes[1].set_xticklabels(['Legitimate', 'Fraudulent'])
axes[1].set_yticklabels(['Legitimate', 'Fraudulent'])

plt.tight_layout()
plt.show()

### ROC Curve

Let's plot the ROC curve to evaluate the model's ability to distinguish between legitimate and fraudulent transactions.

In [None]:
# Plot ROC curve for test set
y_true = raw_data['test']['y_true']
y_score = raw_data['test']['probabilities']

if len(np.unique(y_true)) == 2:
    fpr, tpr, thresholds = roc_curve(y_true, y_score)
    roc_auc = auc(fpr, tpr)
    
    plt.figure(figsize=(10, 8))
    plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (area = {roc_auc:.3f})')
    plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver Operating Characteristic (ROC) Curve')
    plt.legend(loc="lower right")
    plt.grid(True, alpha=0.3)
    plt.show()
else:
    print("ROC curve is only applicable to binary classification.")

### Precision-Recall Curve

In fraud detection, precision and recall are often more important metrics because the classes are imbalanced. Let's plot the precision-recall curve.

In [None]:
# Plot precision-recall curve for test set
y_true = raw_data['test']['y_true']
y_score = raw_data['test']['probabilities']

if len(np.unique(y_true)) == 2:
    precision, recall, thresholds = precision_recall_curve(y_true, y_score)
    average_precision = average_precision_score(y_true, y_score)
    
    plt.figure(figsize=(10, 8))
    plt.step(recall, precision, color='darkorange', lw=2, where='post',
             label=f'Precision-Recall curve (AP = {average_precision:.3f})')
    plt.fill_between(recall, precision, step='post', alpha=0.2, color='darkorange')
    plt.xlabel('Recall')
    plt.ylabel('Precision')
    plt.ylim([0.0, 1.05])
    plt.xlim([0.0, 1.0])
    plt.title('Precision-Recall Curve')
    plt.legend(loc="lower left")
    plt.grid(True, alpha=0.3)
    plt.show()
    
    # Calculate class imbalance
    class_counts = np.bincount(y_true)
    fraud_ratio = class_counts[1] / len(y_true) if len(class_counts) > 1 else 0
    print(f"Class imbalance: {fraud_ratio:.2%} fraudulent transactions")
    print(f"Random classifier baseline AP: {fraud_ratio:.4f}")
    print(f"Model AP: {average_precision:.4f} (improvement: {average_precision/fraud_ratio:.2f}x)")
else:
    print("Precision-Recall curve is only applicable to binary classification.")

### Class-specific Performance

Let's analyze the model's performance for each class separately.

In [None]:
# Print class-specific metrics
test_metrics = metrics['test']

print(f"Classification Report for {best_model_name} on Test Set:\n")

# Extract class-specific metrics
classes = set()
for k in test_metrics:
    if '_' in k and k.split('_')[0].isdigit():
        classes.add(int(k.split('_')[0]))

if classes:
    class_metrics = pd.DataFrame()
    
    for cls in sorted(classes):
        cls_metrics = {}
        cls_metrics['Class'] = 'Legitimate' if cls == 0 else 'Fraudulent'
        cls_metrics['Precision'] = test_metrics.get(f"{cls}_precision", 'N/A')
        cls_metrics['Recall'] = test_metrics.get(f"{cls}_recall", 'N/A')
        cls_metrics['F1-Score'] = test_metrics.get(f"{cls}_f1-score", 'N/A')
        cls_metrics['Support'] = test_metrics.get(f"{cls}_support", 'N/A')
        
        class_metrics = pd.concat([class_metrics, pd.DataFrame([cls_metrics])], ignore_index=True)
    
    print(class_metrics)
    
    # Add overall metrics
    overall_metrics = pd.DataFrame([
        {
            'Class': 'Overall (accuracy)',
            'Precision': 'N/A',
            'Recall': 'N/A',
            'F1-Score': test_metrics.get('accuracy', 'N/A'),
            'Support': len(y_true)
        },
        {
            'Class': 'Overall (weighted avg)',
            'Precision': test_metrics.get('weighted avg_precision', 'N/A'),
            'Recall': test_metrics.get('weighted avg_recall', 'N/A'),
            'F1-Score': test_metrics.get('weighted avg_f1-score', 'N/A'),
            'Support': len(y_true)
        }
    ])
    
    print("\nOverall metrics:")
    print(overall_metrics)
    
    # Print ROC AUC and PR AUC if available
    print("\nAdditional metrics:")
    if 'roc_auc' in test_metrics:
        print(f"ROC AUC: {test_metrics['roc_auc']:.4f}")
    if 'pr_auc' in test_metrics:
        print(f"PR AUC: {test_metrics['pr_auc']:.4f}")
else:
    print("No class-specific metrics found.")

## Error Analysis

Let's analyze the model's errors to understand where it's failing.

In [None]:
# Analyze errors
y_true = raw_data['test']['y_true']
y_pred = raw_data['test']['y_pred']
probs = raw_data['test']['probabilities']

# Create a DataFrame for error analysis
error_df = pd.DataFrame({
    'True_Label': y_true,
    'Predicted_Label': y_pred,
    'Probability': probs if len(np.unique(y_true)) == 2 else np.max(probs, axis=1),
    'Is_Error': y_true != y_pred
})

# Calculate error rates
overall_error_rate = error_df['Is_Error'].mean()

# Error rates by class
class_error_rates = {}
for cls in np.unique(y_true):
    class_mask = error_df['True_Label'] == cls
    class_error_rate = error_df.loc[class_mask, 'Is_Error'].mean()
    class_error_rates[cls] = class_error_rate

print(f"Overall error rate: {overall_error_rate:.4f} ({overall_error_rate*100:.2f}%)")
print("Error rates by class:")
for cls, rate in class_error_rates.items():
    cls_name = 'Legitimate' if cls == 0 else 'Fraudulent'
    print(f"Class {cls} ({cls_name}): {rate:.4f} ({rate*100:.2f}%)")

# Analyze false positives and false negatives
false_positives = error_df[(error_df['True_Label'] == 0) & (error_df['Predicted_Label'] == 1)]
false_negatives = error_df[(error_df['True_Label'] == 1) & (error_df['Predicted_Label'] == 0)]

print(f"\nFalse positives: {len(false_positives)} ({len(false_positives)/len(error_df):.4f})")
print(f"False negatives: {len(false_negatives)} ({len(false_negatives)/len(error_df):.4f})")

# Distribution of probabilities for errors
plt.figure(figsize=(14, 6))

# Plot histogram of probabilities for correct and incorrect predictions
plt.subplot(1, 2, 1)
sns.histplot(error_df[~error_df['Is_Error']]['Probability'], bins=20, alpha=0.5, label='Correct', color='green')
sns.histplot(error_df[error_df['Is_Error']]['Probability'], bins=20, alpha=0.5, label='Incorrect', color='red')
plt.title('Distribution of Probabilities for Correct vs. Incorrect Predictions')
plt.xlabel('Probability of Predicted Class')
plt.ylabel('Count')
plt.legend()

# Plot histogram of probabilities by true class
plt.subplot(1, 2, 2)
sns.histplot(error_df[error_df['True_Label'] == 0]['Probability'], bins=20, alpha=0.5, label='True Legitimate', color='blue')
sns.histplot(error_df[error_df['True_Label'] == 1]['Probability'], bins=20, alpha=0.5, label='True Fraudulent', color='red')
plt.title('Distribution of Probabilities by True Class')
plt.xlabel('Probability (of being Fraudulent)')
plt.ylabel('Count')
plt.legend()

plt.tight_layout()
plt.show()

Let's find the optimal threshold for maximizing F1 score or minimizing the cost of misclassification.

In [None]:
# Find optimal threshold based on F1 score
if len(np.unique(y_true)) == 2:
    # Calculate precision, recall, and F1 score for different thresholds
    precisions, recalls, thresholds = precision_recall_curve(y_true, probs)
    
    # Calculate F1 score for each threshold
    f1_scores = 2 * (precisions[:-1] * recalls[:-1]) / (precisions[:-1] + recalls[:-1] + 1e-10)
    
    # Find threshold with the highest F1 score
    best_idx = np.argmax(f1_scores)
    best_threshold = thresholds[best_idx]
    best_f1 = f1_scores[best_idx]
    
    print(f"Optimal threshold for F1 score: {best_threshold:.4f} (F1 = {best_f1:.4f})")
    
    # Calculate metrics at the optimal threshold
    y_pred_opt = (probs >= best_threshold).astype(int)
    cm_opt = confusion_matrix(y_true, y_pred_opt)
    tn, fp, fn, tp = cm_opt.ravel()
    
    # Calculate metrics
    accuracy = (tp + tn) / (tp + tn + fp + fn)
    precision = tp / (tp + fp)
    recall = tp / (tp + fn)
    f1 = 2 * (precision * recall) / (precision + recall)
    
    print(f"Metrics at optimal threshold:")
    print(f"Accuracy: {accuracy:.4f}")
    print(f"Precision: {precision:.4f}")
    print(f"Recall: {recall:.4f}")
    print(f"F1 Score: {f1:.4f}")
    
    # Plot precision-recall curve and optimal threshold
    plt.figure(figsize=(10, 8))
    plt.plot(recalls, precisions, color='darkorange', lw=2)
    plt.scatter(recalls[best_idx], precisions[best_idx], color='red', s=100, label=f'Optimal Threshold: {best_threshold:.4f}')
    plt.xlabel('Recall')
    plt.ylabel('Precision')
    plt.title('Precision-Recall Curve with Optimal Threshold')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.show()
    
    # Plot F1 score vs. threshold
    plt.figure(figsize=(10, 8))
    plt.plot(thresholds, f1_scores, color='darkorange', lw=2)
    plt.scatter(best_threshold, best_f1, color='red', s=100, label=f'Optimal Threshold: {best_threshold:.4f}')
    plt.xlabel('Threshold')
    plt.ylabel('F1 Score')
    plt.title('F1 Score vs. Threshold')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.show()
else:
    print("Threshold optimization is only applicable to binary classification.")

## Model Comparison

Let's compare the performance of all trained models on the test set.

In [None]:
# Define models to evaluate
model_names = ['gcn', 'sage', 'gat']

# Store metrics for each model
all_metrics = {}
available_models = []

# Evaluate each model
for model_name in model_names:
    model_path = os.path.join('models', f'{model_name}_best.pt')
    if os.path.exists(model_path):
        logger.info(f"Evaluating {model_name} model")
        
        try:
            # Load and evaluate model
            metrics, raw_data = load_and_evaluate_model(model_name, data, split_idx, device)
            
            if metrics is not None:
                # Store metrics
                all_metrics[model_name] = metrics
                available_models.append(model_name)
                
                # Generate evaluation report
                generate_evaluation_report(metrics, raw_data, model_name)
            else:
                logger.warning(f"Failed to evaluate {model_name} model")
        except Exception as e:
            logger.error(f"Error evaluating {model_name} model: {e}")
    else:
        logger.warning(f"{model_name} model weights not found at {model_path}")

# Compare models if multiple models are available
if len(available_models) > 1:
    logger.info("Comparing model performance")
    
    # Create a DataFrame for comparison
    comparison_data = []
    
    for model_name in available_models:
        model_metrics = all_metrics[model_name]['test']
        
        comparison_data.append({
            'Model': model_name,
            'Accuracy': model_metrics.get('accuracy', float('nan')),
            'Precision': model_metrics.get('weighted avg_precision', float('nan')),
            'Recall': model_metrics.get('weighted avg_recall', float('nan')),
            'F1-Score': model_metrics.get('weighted avg_f1-score', float('nan')),
            'ROC AUC': model_metrics.get('roc_auc', float('nan')),
            'PR AUC': model_metrics.get('pr_auc', float('nan')),
        })
    
    comparison_df = pd.DataFrame(comparison_data)
    
    # Display comparison
    print("Model comparison (test set performance):")
    print(comparison_df)
    
    # Visualize comparison
    metrics_to_plot = ['Accuracy', 'Precision', 'Recall', 'F1-Score']
    if not comparison_df['ROC AUC'].isna().all():
        metrics_to_plot.extend(['ROC AUC', 'PR AUC'])
    
    # Reshape for plotting
    comparison_melted = pd.melt(comparison_df, id_vars=['Model'], value_vars=metrics_to_plot, 
                               var_name='Metric', value_name='Value')
    
    # Create a figure with subplots
    fig, axs = plt.subplots(1, len(metrics_to_plot), figsize=(20, 6))
    
    # Plot each metric
    for i, metric in enumerate(metrics_to_plot):
        metric_data = comparison_melted[comparison_melted['Metric'] == metric]
        sns.barplot(x='Model', y='Value', data=metric_data, ax=axs[i], palette='viridis')
        axs[i].set_title(metric)
        axs[i].set_ylim([0, 1])
        axs[i].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    # Generate comparison report
    comparison_report_dir = os.path.join('reports', 'comparison')
    os.makedirs(comparison_report_dir, exist_ok=True)
    
    # Save comparison to CSV
    comparison_df.to_csv(os.path.join(comparison_report_dir, 'model_comparison.csv'), index=False)
    
    # Create markdown report
    report_md = "# Model Comparison Report\n\n"
    
    # Add test performance comparison
    report_md += "## Test Set Performance\n\n"
    
    # Create table header
    report_md += "| Model | " + " | ".join(metrics_to_plot) + " |\n"
    report_md += "|-------|" + "|".join(["---" for _ in metrics_to_plot]) + "|\n"
    
    for _, row in comparison_df.iterrows():
        # Add model name
        table_row = f"| {row['Model']} |"
        
        # Add metrics
        for metric in metrics_to_plot:
            value = row[metric]
            if isinstance(value, float) and not np.isnan(value):
                value = f"{value:.4f}"
            else:
                value = 'N/A'
            table_row += f" {value} |"
        
        report_md += table_row + "\n"
    
    # Save report
    report_path = os.path.join(comparison_report_dir, 'model_comparison.md')
    with open(report_path, 'w') as f:
        f.write(report_md)
    
    print(f"\nModel comparison report saved to {report_path}")
else:
    print("Not enough models available for comparison.")

## Generate Evaluation Report

We've already generated comprehensive evaluation reports for each model. Let's summarize the key findings.

In [None]:
# Summarize key findings
if best_model_name is not None and best_model_name in all_metrics:
    test_metrics = all_metrics[best_model_name]['test']
    
    print(f"Summary of {best_model_name.upper()} model performance on test set:")
    print(f"Accuracy: {test_metrics.get('accuracy', 'N/A'):.4f}")
    print(f"Weighted Precision: {test_metrics.get('weighted avg_precision', 'N/A'):.4f}")
    print(f"Weighted Recall: {test_metrics.get('weighted avg_recall', 'N/A'):.4f}")
    print(f"Weighted F1-Score: {test_metrics.get('weighted avg_f1-score', 'N/A'):.4f}")
    
    if 'roc_auc' in test_metrics:
        print(f"ROC AUC: {test_metrics.get('roc_auc', 'N/A'):.4f}")
    if 'pr_auc' in test_metrics:
        print(f"PR AUC: {test_metrics.get('pr_auc', 'N/A'):.4f}")
        
    # Class-specific metrics
    print("\nClass-specific metrics:")
    for cls in sorted(classes):
        cls_name = 'Legitimate' if cls == 0 else 'Fraudulent'
        print(f"Class {cls} ({cls_name}):")
        print(f"  Precision: {test_metrics.get(f'{cls}_precision', 'N/A'):.4f}")
        print(f"  Recall: {test_metrics.get(f'{cls}_recall', 'N/A'):.4f}")
        print(f"  F1-Score: {test_metrics.get(f'{cls}_f1-score', 'N/A'):.4f}")
        print(f"  Support: {test_metrics.get(f'{cls}_support', 'N/A')}")
else:
    print("No best model metrics available.")