# Heterogeneous Graph Neural Networks for Fraud Detection

## Tutorial 6: Graph-Based Fraud Detection with Multi-Type Relationships

In this tutorial, you'll learn how to leverage the power of Graph Neural Networks (GNNs) for fraud detection:
- **Heterogeneous Graphs**: Multiple node types (users, merchants, transactions)
- **Graph Attention Networks**: Focus on important relationships
- **Message Passing**: Information flow between connected entities
- **Advanced Graph Construction**: Creating realistic financial networks

## Learning Objectives

By the end of this tutorial, you'll understand:

1. **Graph-Based Fraud Detection**: Why graphs are powerful for fraud detection
2. **Heterogeneous Graphs**: Multiple node types and relationship types
3. **Graph Attention Networks**: Attention mechanisms in graph neural networks
4. **Message Passing**: How information flows through graph structures
5. **Graph Construction**: Building realistic financial networks from transaction data
6. **Advanced Embeddings**: Learning representations that capture graph structure
7. **Scalability**: Handling large-scale financial networks

In [None]:
# Import required libraries
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score, roc_auc_score,
    confusion_matrix, classification_report
)
from sklearn.metrics.pairwise import cosine_similarity
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATConv, HeteroConv, global_mean_pool
from torch_geometric.data import HeteroData
import warnings
warnings.filterwarnings('ignore')

# Set style for better visualizations
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette("husl")

# Check for GPU availability
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if device.type == 'cuda':
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## Part 1: Understanding Graph-Based Fraud Detection

### Why Graphs for Fraud Detection?

Traditional ML approaches treat each transaction independently. But fraud detection is inherently about relationships:
- **Fraudsters often target similar merchants**
- **Compromised cards show unusual behavioral patterns**
- **Fraudulent transactions cluster in time and space**
- **Legitimate users have consistent behavioral patterns**

### Heterogeneous vs Homogeneous Graphs

**Homogeneous Graph**: One type of node, one type of edge
- Simple but limited representation
- Example: User-User connections only

**Heterogeneous Graph**: Multiple node types, multiple edge types
- Rich representation of real-world complexity
- Example: Users, Merchants, Transactions with various relationships

In [None]:
# Load and examine the data
df = pd.read_csv('creditcard.csv')
print(f"Dataset shape: {df.shape}")
print(f"Fraud rate: {df['Class'].mean()*100:.3f}%")

# Visualize the challenge
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Transaction timeline
df_sample = df.sample(1000, random_state=42)
normal_transactions = df_sample[df_sample['Class'] == 0]
fraud_transactions = df_sample[df_sample['Class'] == 1]

axes[0].scatter(normal_transactions['Time'], normal_transactions['Amount'], 
               alpha=0.6, s=20, label='Normal', color='blue')
axes[0].scatter(fraud_transactions['Time'], fraud_transactions['Amount'], 
               alpha=0.8, s=50, label='Fraud', color='red')
axes[0].set_xlabel('Time')
axes[0].set_ylabel('Amount')
axes[0].set_title('Transactions in Time-Amount Space')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Feature space visualization (PCA)
features = ['V1', 'V2', 'V3', 'V4', 'V5']
X_sample = df_sample[features]
pca = PCA(n_components=2)
X_pca = pca.fit_transform(X_sample)

normal_mask = df_sample['Class'] == 0
fraud_mask = df_sample['Class'] == 1

axes[1].scatter(X_pca[normal_mask, 0], X_pca[normal_mask, 1], 
               alpha=0.6, s=20, label='Normal', color='blue')
axes[1].scatter(X_pca[fraud_mask, 0], X_pca[fraud_mask, 1], 
               alpha=0.8, s=50, label='Fraud', color='red')
axes[1].set_xlabel(f'PC1 ({pca.explained_variance_ratio_[0]:.2f})')
axes[1].set_ylabel(f'PC2 ({pca.explained_variance_ratio_[1]:.2f})')
axes[1].set_title('Feature Space (PCA)')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

# Graph concept illustration
axes[2].text(0.5, 0.8, 'Traditional ML', ha='center', va='center', 
            fontsize=14, fontweight='bold', transform=axes[2].transAxes)
axes[2].text(0.5, 0.7, 'Each transaction\nprocessed independently', ha='center', va='center', 
            fontsize=12, transform=axes[2].transAxes)
axes[2].text(0.5, 0.5, 'vs', ha='center', va='center', 
            fontsize=16, fontweight='bold', transform=axes[2].transAxes)
axes[2].text(0.5, 0.3, 'Graph-Based ML', ha='center', va='center', 
            fontsize=14, fontweight='bold', transform=axes[2].transAxes)
axes[2].text(0.5, 0.2, 'Transactions connected\nthrough relationships', ha='center', va='center', 
            fontsize=12, transform=axes[2].transAxes)
axes[2].set_xlim(0, 1)
axes[2].set_ylim(0, 1)
axes[2].axis('off')

plt.tight_layout()
plt.show()

print("\nKey Insight: Fraud detection benefits from understanding relationships between entities!")

## Part 2: Building a Heterogeneous Graph

### Graph Design for Fraud Detection

Our heterogeneous graph will have:
- **3 Node Types**: Users, Merchants, Transactions
- **7 Edge Types**: 
  - User → Transaction (makes)
  - Transaction → User (made_by)
  - Merchant → Transaction (processes)
  - Transaction → Merchant (processed_by)
  - User ↔ User (similar_to)
  - Merchant ↔ Merchant (similar_to)
  - Transaction ↔ Transaction (temporal)

In [None]:
class HeterogeneousGraphBuilder:
    """
    Build heterogeneous graphs from transaction data.
    
    Key challenge: Transaction data doesn't have user/merchant IDs.
    Solution: Use clustering to create synthetic entities based on behavioral patterns.
    """
    
    def __init__(self, random_state=42):
        self.random_state = random_state
        self.user_clusters = None
        self.merchant_clusters = None
        self.scaler = StandardScaler()
    
    def build_heterogeneous_graph(self, df, max_users=1000, max_merchants=500):
        """
        Build complete heterogeneous graph from transaction data.
        
        Args:
            df: Transaction dataframe
            max_users: Maximum number of synthetic users
            max_merchants: Maximum number of synthetic merchants
        
        Returns:
            HeteroData object with all nodes and edges
        """
        print("Building heterogeneous graph...")
        
        # Create synthetic entities
        df_with_entities = self._create_entities(df, max_users, max_merchants)
        
        # Initialize graph data structure
        data = HeteroData()
        
        # Add node features
        data = self._add_node_features(data, df_with_entities)
        
        # Add edges
        data = self._add_edges(data, df_with_entities)
        
        print(f"Graph constructed with:")
        print(f"  - {data['transaction'].num_nodes:,} transaction nodes")
        print(f"  - {data['user'].num_nodes:,} user nodes")
        print(f"  - {data['merchant'].num_nodes:,} merchant nodes")
        
        return data, df_with_entities
    
    def _create_entities(self, df, max_users, max_merchants):
        """
        Create synthetic user and merchant entities using clustering.
        
        Strategy:
        - Users: Cluster by behavioral patterns (V1-V5)
        - Merchants: Cluster by business patterns (time, amount, V6-V8)
        """
        print("Creating synthetic entities...")
        
        df_entities = df.copy()
        
        # Create user clusters based on behavioral patterns
        user_features = ['V1', 'V2', 'V3', 'V4', 'V5']
        X_user = self.scaler.fit_transform(df[user_features])
        
        n_user_clusters = min(max_users, len(df) // 10)  # At least 10 transactions per user
        self.user_clusters = KMeans(n_clusters=n_user_clusters, random_state=self.random_state)
        df_entities['user_id'] = self.user_clusters.fit_predict(X_user)
        
        # Create merchant clusters based on business patterns
        # Extract hour from time
        df_entities['hour'] = (df['Time'] % (24 * 3600)) // 3600
        
        merchant_features = ['hour', 'Amount', 'V6', 'V7', 'V8']
        X_merchant = self.scaler.fit_transform(df_entities[merchant_features])
        
        n_merchant_clusters = min(max_merchants, len(df) // 20)  # At least 20 transactions per merchant
        self.merchant_clusters = KMeans(n_clusters=n_merchant_clusters, random_state=self.random_state)
        df_entities['merchant_id'] = self.merchant_clusters.fit_predict(X_merchant)
        
        print(f"  - Created {n_user_clusters} synthetic users")
        print(f"  - Created {n_merchant_clusters} synthetic merchants")
        
        return df_entities
    
    def _add_node_features(self, data, df):
        """
        Add features for each node type.
        
        Node features:
        - Transaction: Original features (V1-V28, Amount, Time)
        - User: Aggregated behavioral statistics
        - Merchant: Aggregated business statistics
        """
        print("Adding node features...")
        
        # Transaction features (original features)
        transaction_features = [f'V{i}' for i in range(1, 29)] + ['Amount', 'Time']
        data['transaction'].x = torch.FloatTensor(df[transaction_features].values)
        data['transaction'].y = torch.LongTensor(df['Class'].values)
        
        # User features (aggregated statistics)
        user_stats = []
        for user_id in range(df['user_id'].nunique()):
            user_transactions = df[df['user_id'] == user_id]
            
            stats = [
                len(user_transactions),  # transaction count
                user_transactions['Amount'].mean(),  # avg amount
                user_transactions['Amount'].std(),  # amount std
                user_transactions['Class'].mean(),  # fraud rate
                user_transactions['Time'].max() - user_transactions['Time'].min(),  # time range
                user_transactions[[f'V{i}' for i in range(1, 6)]].mean().mean(),  # avg V features
                user_transactions['Amount'].quantile(0.95),  # 95th percentile amount
                user_transactions['Time'].std()  # time std
            ]
            
            # Handle NaN values
            stats = [0 if pd.isna(x) else x for x in stats]
            user_stats.append(stats)
        
        data['user'].x = torch.FloatTensor(user_stats)
        
        # Merchant features (aggregated statistics)
        merchant_stats = []
        for merchant_id in range(df['merchant_id'].nunique()):
            merchant_transactions = df[df['merchant_id'] == merchant_id]
            
            stats = [
                len(merchant_transactions),  # transaction count
                merchant_transactions['Amount'].mean(),  # avg amount
                merchant_transactions['Amount'].std(),  # amount std
                merchant_transactions['Class'].mean(),  # fraud rate
                merchant_transactions['user_id'].nunique(),  # unique customers
                merchant_transactions['Time'].max() - merchant_transactions['Time'].min(),  # time span
                merchant_transactions[[f'V{i}' for i in range(6, 11)]].mean().mean(),  # avg V features
                merchant_transactions['hour'].mode().iloc[0] if len(merchant_transactions) > 0 else 12  # peak hour
            ]
            
            # Handle NaN values
            stats = [0 if pd.isna(x) else x for x in stats]
            merchant_stats.append(stats)
        
        data['merchant'].x = torch.FloatTensor(merchant_stats)
        
        print(f"  - Transaction features: {data['transaction'].x.shape}")
        print(f"  - User features: {data['user'].x.shape}")
        print(f"  - Merchant features: {data['merchant'].x.shape}")
        
        return data
    
    def _add_edges(self, data, df):
        """
        Add all edge types to the graph.
        
        Edge types:
        1. User-Transaction relationships
        2. Merchant-Transaction relationships
        3. User-User similarity
        4. Merchant-Merchant similarity
        5. Transaction-Transaction temporal
        """
        print("Adding edges...")
        
        # 1. User-Transaction edges
        user_transaction_edges = []
        transaction_user_edges = []
        
        for idx, row in df.iterrows():
            user_transaction_edges.append([row['user_id'], idx])
            transaction_user_edges.append([idx, row['user_id']])
        
        data['user', 'makes', 'transaction'].edge_index = torch.LongTensor(user_transaction_edges).t().contiguous()
        data['transaction', 'made_by', 'user'].edge_index = torch.LongTensor(transaction_user_edges).t().contiguous()
        
        # 2. Merchant-Transaction edges
        merchant_transaction_edges = []
        transaction_merchant_edges = []
        
        for idx, row in df.iterrows():
            merchant_transaction_edges.append([row['merchant_id'], idx])
            transaction_merchant_edges.append([idx, row['merchant_id']])
        
        data['merchant', 'processes', 'transaction'].edge_index = torch.LongTensor(merchant_transaction_edges).t().contiguous()
        data['transaction', 'processed_by', 'merchant'].edge_index = torch.LongTensor(transaction_merchant_edges).t().contiguous()
        
        # 3. User-User similarity edges
        user_similarity_edges = self._create_similarity_edges(data['user'].x, threshold=0.8)
        if len(user_similarity_edges) > 0:
            data['user', 'similar_to', 'user'].edge_index = torch.LongTensor(user_similarity_edges).t().contiguous()
        
        # 4. Merchant-Merchant similarity edges
        merchant_similarity_edges = self._create_similarity_edges(data['merchant'].x, threshold=0.8)
        if len(merchant_similarity_edges) > 0:
            data['merchant', 'similar_to', 'merchant'].edge_index = torch.LongTensor(merchant_similarity_edges).t().contiguous()
        
        # 5. Transaction-Transaction temporal edges
        temporal_edges = self._create_temporal_edges(df, max_time_diff=3600)  # 1 hour
        if len(temporal_edges) > 0:
            data['transaction', 'temporal', 'transaction'].edge_index = torch.LongTensor(temporal_edges).t().contiguous()
        
        # Print edge statistics
        print(f"  - User-Transaction edges: {data['user', 'makes', 'transaction'].edge_index.shape[1]:,}")
        print(f"  - Merchant-Transaction edges: {data['merchant', 'processes', 'transaction'].edge_index.shape[1]:,}")
        print(f"  - User-User similarity edges: {len(user_similarity_edges):,}")
        print(f"  - Merchant-Merchant similarity edges: {len(merchant_similarity_edges):,}")
        print(f"  - Transaction-Transaction temporal edges: {len(temporal_edges):,}")
        
        return data
    
    def _create_similarity_edges(self, features, threshold=0.8):
        """
        Create similarity edges based on cosine similarity.
        
        Args:
            features: Node features tensor
            threshold: Similarity threshold
        
        Returns:
            List of edge pairs
        """
        if len(features) > 1000:  # Limit for computational efficiency
            return []
        
        # Calculate cosine similarity
        similarity_matrix = cosine_similarity(features.numpy())
        
        # Find pairs above threshold
        edges = []
        for i in range(len(similarity_matrix)):
            for j in range(i + 1, len(similarity_matrix)):
                if similarity_matrix[i, j] > threshold:
                    edges.append([i, j])
                    edges.append([j, i])  # Undirected edge
        
        return edges
    
    def _create_temporal_edges(self, df, max_time_diff=3600):
        """
        Create temporal edges between consecutive transactions.
        
        Args:
            df: Transaction dataframe
            max_time_diff: Maximum time difference (seconds)
        
        Returns:
            List of temporal edge pairs
        """
        df_sorted = df.sort_values('Time').reset_index()
        
        edges = []
        for i in range(len(df_sorted) - 1):
            current_idx = df_sorted.iloc[i]['index']
            next_idx = df_sorted.iloc[i + 1]['index']
            
            time_diff = df_sorted.iloc[i + 1]['Time'] - df_sorted.iloc[i]['Time']
            
            if time_diff <= max_time_diff:
                edges.append([current_idx, next_idx])
        
        return edges

# Build the graph
graph_builder = HeterogeneousGraphBuilder()
graph_data, df_with_entities = graph_builder.build_heterogeneous_graph(df, max_users=500, max_merchants=250)

print("\nGraph construction complete!")

## Part 3: Visualizing the Heterogeneous Graph

Let's visualize our constructed graph to understand its structure:

In [None]:
# Visualize graph structure
def visualize_graph_structure(graph_data, df_with_entities):
    """
    Create visualizations to understand the graph structure.
    """
    fig, axes = plt.subplots(2, 2, figsize=(15, 12))
    
    # 1. Node type distribution
    node_counts = {
        'Transactions': graph_data['transaction'].num_nodes,
        'Users': graph_data['user'].num_nodes,
        'Merchants': graph_data['merchant'].num_nodes
    }
    
    axes[0, 0].bar(node_counts.keys(), node_counts.values(), color=['skyblue', 'lightgreen', 'salmon'])
    axes[0, 0].set_ylabel('Number of Nodes')
    axes[0, 0].set_title('Node Type Distribution')
    axes[0, 0].grid(True, alpha=0.3)
    
    # 2. Edge type distribution
    edge_counts = {}
    for edge_type in graph_data.edge_types:
        edge_counts[f"{edge_type[0]}-{edge_type[2]}"] = graph_data[edge_type].edge_index.shape[1]
    
    axes[0, 1].bar(range(len(edge_counts)), list(edge_counts.values()), color='lightcoral')
    axes[0, 1].set_xticks(range(len(edge_counts)))
    axes[0, 1].set_xticklabels(list(edge_counts.keys()), rotation=45, ha='right')
    axes[0, 1].set_ylabel('Number of Edges')
    axes[0, 1].set_title('Edge Type Distribution')
    axes[0, 1].grid(True, alpha=0.3)
    
    # 3. User transaction distribution
    user_transaction_counts = df_with_entities['user_id'].value_counts()
    axes[1, 0].hist(user_transaction_counts.values, bins=50, alpha=0.7, color='lightgreen')
    axes[1, 0].set_xlabel('Transactions per User')
    axes[1, 0].set_ylabel('Number of Users')
    axes[1, 0].set_title('User Activity Distribution')
    axes[1, 0].grid(True, alpha=0.3)
    
    # 4. Merchant transaction distribution
    merchant_transaction_counts = df_with_entities['merchant_id'].value_counts()
    axes[1, 1].hist(merchant_transaction_counts.values, bins=50, alpha=0.7, color='salmon')
    axes[1, 1].set_xlabel('Transactions per Merchant')
    axes[1, 1].set_ylabel('Number of Merchants')
    axes[1, 1].set_title('Merchant Activity Distribution')
    axes[1, 1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    # Print statistics
    print("\nGraph Statistics:")
    print(f"Average transactions per user: {user_transaction_counts.mean():.1f}")
    print(f"Average transactions per merchant: {merchant_transaction_counts.mean():.1f}")
    print(f"Most active user: {user_transaction_counts.max()} transactions")
    print(f"Most active merchant: {merchant_transaction_counts.max()} transactions")

# Visualize the graph
visualize_graph_structure(graph_data, df_with_entities)

# Show graph schema
print("\n" + "="*60)
print("HETEROGENEOUS GRAPH SCHEMA")
print("="*60)
print("\nNode Types:")
for node_type in graph_data.node_types:
    print(f"  - {node_type}: {graph_data[node_type].num_nodes} nodes, {graph_data[node_type].x.shape[1]} features")
    
print("\nEdge Types:")
for edge_type in graph_data.edge_types:
    print(f"  - {edge_type[0]} --[{edge_type[1]}]--> {edge_type[2]}: {graph_data[edge_type].edge_index.shape[1]} edges")

## Part 4: Graph Attention Networks (GAT)

### Understanding Graph Attention

Graph Attention Networks use attention mechanisms to:
- **Focus on important neighbors**: Not all connections are equally important
- **Learn relationship weights**: Dynamically determine edge importance
- **Handle heterogeneous graphs**: Different attention for different edge types

### Mathematical Foundation

For a node $i$ with neighbors $j \in N(i)$:

1. **Attention coefficient**: $e_{ij} = a(W h_i, W h_j)$
2. **Normalized attention**: $\alpha_{ij} = \frac{\exp(e_{ij})}{\sum_{k \in N(i)} \exp(e_{ik})}$
3. **Aggregated features**: $h_i' = \sigma\left(\sum_{j \in N(i)} \alpha_{ij} W h_j\right)$

In [None]:
class HeterogeneousGAT(nn.Module):
    """
    Heterogeneous Graph Attention Network for fraud detection.
    
    Key components:
    1. Input projection layers for each node type
    2. Multiple GAT layers with heterogeneous message passing
    3. Multi-head attention for capturing different aspects
    4. Final classifier for fraud prediction
    """
    
    def __init__(self, node_types, edge_types, hidden_dim=64, num_heads=4, num_layers=2, dropout=0.2):
        super(HeterogeneousGAT, self).__init__()
        
        self.node_types = node_types
        self.edge_types = edge_types
        self.hidden_dim = hidden_dim
        self.num_heads = num_heads
        self.num_layers = num_layers
        
        # Input projections for each node type
        self.input_projections = nn.ModuleDict()
        for node_type, input_dim in node_types.items():
            self.input_projections[node_type] = nn.Linear(input_dim, hidden_dim)
        
        # Heterogeneous convolution layers
        self.convs = nn.ModuleList()
        for _ in range(num_layers):
            conv_dict = {}
            for edge_type in edge_types:
                conv_dict[edge_type] = GATConv(
                    hidden_dim, hidden_dim // num_heads, 
                    heads=num_heads, dropout=dropout, concat=True
                )
            self.convs.append(HeteroConv(conv_dict, aggr='sum'))
        
        # Final classifier (for transaction nodes)
        self.transaction_classifier = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim // 2, 2)  # Binary classification
        )
        
        # Optional: Global attention for graph-level features
        self.global_attention = nn.MultiheadAttention(hidden_dim, num_heads, dropout=dropout)
        
    def forward(self, x_dict, edge_index_dict, batch_dict=None):
        """
        Forward pass through the heterogeneous GAT.
        
        Args:
            x_dict: Dictionary of node features for each node type
            edge_index_dict: Dictionary of edge indices for each edge type
            batch_dict: Optional batch information
        
        Returns:
            Fraud predictions for transaction nodes
        """
        # Input projections
        x_dict = {node_type: self.input_projections[node_type](x) 
                 for node_type, x in x_dict.items()}
        
        # Graph convolution layers
        for conv in self.convs:
            x_dict_new = conv(x_dict, edge_index_dict)
            
            # Apply residual connections and activation
            for node_type in x_dict.keys():
                if node_type in x_dict_new:
                    x_dict[node_type] = F.relu(x_dict_new[node_type] + x_dict[node_type])
        
        # Get transaction node embeddings
        transaction_embeddings = x_dict['transaction']
        
        # Classification
        fraud_predictions = self.transaction_classifier(transaction_embeddings)
        
        return fraud_predictions
    
    def get_embeddings(self, x_dict, edge_index_dict):
        """
        Get node embeddings for visualization and analysis.
        """
        # Input projections
        x_dict = {node_type: self.input_projections[node_type](x) 
                 for node_type, x in x_dict.items()}
        
        # Graph convolution layers
        for conv in self.convs:
            x_dict_new = conv(x_dict, edge_index_dict)
            
            # Apply residual connections and activation
            for node_type in x_dict.keys():
                if node_type in x_dict_new:
                    x_dict[node_type] = F.relu(x_dict_new[node_type] + x_dict[node_type])
        
        return x_dict

# Visualize attention mechanism concept
def visualize_attention_concept():
    """
    Visualize how attention works in graph neural networks.
    """
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))
    
    # 1. Traditional aggregation
    axes[0].text(0.5, 0.9, 'Traditional Aggregation', ha='center', fontsize=14, fontweight='bold')
    axes[0].text(0.5, 0.7, 'All neighbors\nequal weight', ha='center', fontsize=12)
    axes[0].text(0.5, 0.5, 'h\'ᵢ = Σ Wh_j', ha='center', fontsize=12, family='monospace')
    axes[0].text(0.5, 0.3, 'Simple but limited', ha='center', fontsize=10, style='italic')
    axes[0].set_xlim(0, 1)
    axes[0].set_ylim(0, 1)
    axes[0].axis('off')
    
    # 2. Attention mechanism
    axes[1].text(0.5, 0.9, 'Attention Mechanism', ha='center', fontsize=14, fontweight='bold')
    axes[1].text(0.5, 0.7, 'Learned weights\nfor each neighbor', ha='center', fontsize=12)
    axes[1].text(0.5, 0.5, 'h\'ᵢ = Σ αᵢⱼ Wh_j', ha='center', fontsize=12, family='monospace')
    axes[1].text(0.5, 0.3, 'Adaptive and flexible', ha='center', fontsize=10, style='italic')
    axes[1].set_xlim(0, 1)
    axes[1].set_ylim(0, 1)
    axes[1].axis('off')
    
    # 3. Multi-head attention
    axes[2].text(0.5, 0.9, 'Multi-Head Attention', ha='center', fontsize=14, fontweight='bold')
    axes[2].text(0.5, 0.7, 'Multiple attention\nsubspaces', ha='center', fontsize=12)
    axes[2].text(0.5, 0.5, 'h\'ᵢ = || H attention heads', ha='center', fontsize=12, family='monospace')
    axes[2].text(0.5, 0.3, 'Captures different aspects', ha='center', fontsize=10, style='italic')
    axes[2].set_xlim(0, 1)
    axes[2].set_ylim(0, 1)
    axes[2].axis('off')
    
    plt.tight_layout()
    plt.show()
    
    print("Key Benefits of Attention:")
    print("1. Focus on important connections")
    print("2. Adaptive to data patterns")
    print("3. Interpretable attention weights")
    print("4. Better performance on complex graphs")

visualize_attention_concept()

# Initialize the model
node_types = {
    'transaction': graph_data['transaction'].x.shape[1],
    'user': graph_data['user'].x.shape[1],
    'merchant': graph_data['merchant'].x.shape[1]
}

edge_types = graph_data.edge_types

model = HeterogeneousGAT(
    node_types=node_types,
    edge_types=edge_types,
    hidden_dim=64,
    num_heads=4,
    num_layers=2,
    dropout=0.2
).to(device)

print(f"\nModel initialized with {sum(p.numel() for p in model.parameters()):,} parameters")
print(f"Hidden dimension: {model.hidden_dim}")
print(f"Number of attention heads: {model.num_heads}")
print(f"Number of layers: {model.num_layers}")

## Part 5: Training the Heterogeneous GNN

### Training Strategy

Training heterogeneous GNNs requires careful consideration of:
- **Class imbalance**: Fraud is rare, need weighted loss
- **Graph structure**: Different edge types contribute differently
- **Scalability**: Large graphs require efficient training
- **Overfitting**: Complex models need regularization

In [None]:
class HeterogeneousFraudDetector:
    """
    Complete fraud detection system using heterogeneous GNNs.
    
    Features:
    - Automatic graph construction
    - Training with class imbalance handling
    - Comprehensive evaluation
    - Embedding visualization
    """
    
    def __init__(self, random_state=42):
        self.random_state = random_state
        self.graph_builder = HeterogeneousGraphBuilder(random_state)
        self.model = None
        self.graph_data = None
        self.df_with_entities = None
        self.train_mask = None
        self.test_mask = None
        self.device = device
    
    def prepare_data(self, df, test_size=0.2):
        """
        Prepare data for training: build graph and create splits.
        """
        print("Preparing data for heterogeneous GNN...")
        
        # Build graph
        self.graph_data, self.df_with_entities = self.graph_builder.build_heterogeneous_graph(df)
        
        # Create train/test splits
        n_transactions = len(df)
        indices = np.arange(n_transactions)
        labels = df['Class'].values
        
        train_indices, test_indices = train_test_split(
            indices, test_size=test_size, stratify=labels, random_state=self.random_state
        )
        
        # Create masks
        self.train_mask = torch.zeros(n_transactions, dtype=torch.bool)
        self.test_mask = torch.zeros(n_transactions, dtype=torch.bool)
        
        self.train_mask[train_indices] = True
        self.test_mask[test_indices] = True
        
        # Move data to device
        self.graph_data = self.graph_data.to(self.device)
        self.train_mask = self.train_mask.to(self.device)
        self.test_mask = self.test_mask.to(self.device)
        
        print(f"Data prepared:")
        print(f"  - Training transactions: {self.train_mask.sum():,}")
        print(f"  - Test transactions: {self.test_mask.sum():,}")
        print(f"  - Training fraud rate: {self.graph_data['transaction'].y[self.train_mask].float().mean():.4f}")
        print(f"  - Test fraud rate: {self.graph_data['transaction'].y[self.test_mask].float().mean():.4f}")
    
    def train_model(self, hidden_dim=64, num_heads=4, num_layers=2, epochs=200, lr=0.01):
        """
        Train the heterogeneous GNN model.
        """
        print(f"Training heterogeneous GNN...")
        
        # Initialize model
        node_types = {
            'transaction': self.graph_data['transaction'].x.shape[1],
            'user': self.graph_data['user'].x.shape[1],
            'merchant': self.graph_data['merchant'].x.shape[1]
        }
        
        self.model = HeterogeneousGAT(
            node_types=node_types,
            edge_types=self.graph_data.edge_types,
            hidden_dim=hidden_dim,
            num_heads=num_heads,
            num_layers=num_layers
        ).to(self.device)
        
        # Calculate class weights for imbalanced data
        train_labels = self.graph_data['transaction'].y[self.train_mask]
        pos_weight = (train_labels == 0).sum().float() / (train_labels == 1).sum().float()
        
        # Loss function and optimizer
        criterion = nn.CrossEntropyLoss(weight=torch.tensor([1.0, pos_weight]).to(self.device))
        optimizer = torch.optim.Adam(self.model.parameters(), lr=lr, weight_decay=1e-4)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.5)
        
        # Training loop
        history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}
        
        for epoch in range(epochs):
            # Training
            self.model.train()
            optimizer.zero_grad()
            
            # Forward pass
            out = self.model(self.graph_data.x_dict, self.graph_data.edge_index_dict)
            train_out = out[self.train_mask]
            train_labels = self.graph_data['transaction'].y[self.train_mask]
            
            # Calculate loss
            loss = criterion(train_out, train_labels)
            
            # Backward pass
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
            optimizer.step()
            
            # Calculate training accuracy
            with torch.no_grad():
                train_pred = train_out.argmax(dim=1)
                train_acc = (train_pred == train_labels).float().mean()
            
            # Validation
            self.model.eval()
            with torch.no_grad():
                val_out = out[self.test_mask]
                val_labels = self.graph_data['transaction'].y[self.test_mask]
                val_loss = criterion(val_out, val_labels)
                val_pred = val_out.argmax(dim=1)
                val_acc = (val_pred == val_labels).float().mean()
            
            # Update scheduler
            scheduler.step()
            
            # Save history
            history['train_loss'].append(loss.item())
            history['train_acc'].append(train_acc.item())
            history['val_loss'].append(val_loss.item())
            history['val_acc'].append(val_acc.item())
            
            # Print progress
            if epoch % 25 == 0 or epoch == epochs - 1:
                print(f"Epoch {epoch:3d}: Train Loss={loss:.4f}, Train Acc={train_acc:.4f}, "
                      f"Val Loss={val_loss:.4f}, Val Acc={val_acc:.4f}")
        
        return history
    
    def evaluate_model(self):
        """
        Comprehensive evaluation of the trained model.
        """
        if self.model is None:
            raise ValueError("Model not trained yet. Call train_model() first.")
        
        self.model.eval()
        
        with torch.no_grad():
            # Get predictions
            out = self.model(self.graph_data.x_dict, self.graph_data.edge_index_dict)
            test_out = out[self.test_mask]
            test_labels = self.graph_data['transaction'].y[self.test_mask]
            
            # Get probabilities and predictions
            probabilities = F.softmax(test_out, dim=1)[:, 1]  # Fraud probabilities
            predictions = test_out.argmax(dim=1)
            
            # Convert to numpy for sklearn metrics
            y_true = test_labels.cpu().numpy()
            y_pred = predictions.cpu().numpy()
            y_prob = probabilities.cpu().numpy()
            
            # Calculate metrics
            metrics = {
                'accuracy': accuracy_score(y_true, y_pred),
                'precision': precision_score(y_true, y_pred),
                'recall': recall_score(y_true, y_pred),
                'f1_score': f1_score(y_true, y_pred),
                'roc_auc': roc_auc_score(y_true, y_prob)
            }
            
            # Print results
            print("\n" + "="*50)
            print("HETEROGENEOUS GNN EVALUATION RESULTS")
            print("="*50)
            print(f"Accuracy:  {metrics['accuracy']:.4f}")
            print(f"Precision: {metrics['precision']:.4f}")
            print(f"Recall:    {metrics['recall']:.4f}")
            print(f"F1-Score:  {metrics['f1_score']:.4f}")
            print(f"ROC-AUC:   {metrics['roc_auc']:.4f}")
            print("="*50)
            
            # Confusion matrix
            cm = confusion_matrix(y_true, y_pred)
            print(f"\nConfusion Matrix:")
            print(f"                Predicted")
            print(f"Actual   Normal  Fraud")
            print(f"Normal   {cm[0,0]:6d}  {cm[0,1]:5d}")
            print(f"Fraud    {cm[1,0]:6d}  {cm[1,1]:5d}")
            
            # Classification report
            print(f"\nDetailed Classification Report:")
            print(classification_report(y_true, y_pred, target_names=['Normal', 'Fraud']))
            
            return metrics, y_true, y_pred, y_prob

# Initialize and train the detector
detector = HeterogeneousFraudDetector()
detector.prepare_data(df.iloc[:50000])  # Use subset for faster training

# Train the model
training_history = detector.train_model(epochs=100, lr=0.01)

# Evaluate the model
metrics, y_true, y_pred, y_prob = detector.evaluate_model()

## Part 6: Advanced Analysis and Visualization

### Embedding Analysis

Let's analyze what the model learned by examining the node embeddings:

In [None]:
def analyze_embeddings(detector):
    """
    Analyze and visualize learned embeddings.
    """
    print("Analyzing learned embeddings...")
    
    # Get embeddings
    detector.model.eval()
    with torch.no_grad():
        embeddings = detector.model.get_embeddings(
            detector.graph_data.x_dict, 
            detector.graph_data.edge_index_dict
        )
    
    # Extract transaction embeddings
    transaction_embeddings = embeddings['transaction'].cpu().numpy()
    transaction_labels = detector.graph_data['transaction'].y.cpu().numpy()
    
    # Dimensionality reduction
    print("Performing dimensionality reduction...")
    pca = PCA(n_components=2)
    tsne = TSNE(n_components=2, random_state=42)
    
    # Use subset for t-SNE (computationally expensive)
    subset_indices = np.random.choice(len(transaction_embeddings), 2000, replace=False)
    subset_embeddings = transaction_embeddings[subset_indices]
    subset_labels = transaction_labels[subset_indices]
    
    embeddings_pca = pca.fit_transform(transaction_embeddings)
    embeddings_tsne = tsne.fit_transform(subset_embeddings)
    
    # Visualization
    fig, axes = plt.subplots(2, 2, figsize=(15, 12))
    
    # PCA visualization
    normal_mask = transaction_labels == 0
    fraud_mask = transaction_labels == 1
    
    axes[0, 0].scatter(embeddings_pca[normal_mask, 0], embeddings_pca[normal_mask, 1], 
                      alpha=0.6, s=10, label='Normal', color='blue')
    axes[0, 0].scatter(embeddings_pca[fraud_mask, 0], embeddings_pca[fraud_mask, 1], 
                      alpha=0.8, s=30, label='Fraud', color='red')
    axes[0, 0].set_xlabel(f'PC1 ({pca.explained_variance_ratio_[0]:.2f})')
    axes[0, 0].set_ylabel(f'PC2 ({pca.explained_variance_ratio_[1]:.2f})')
    axes[0, 0].set_title('Transaction Embeddings (PCA)')
    axes[0, 0].legend()
    axes[0, 0].grid(True, alpha=0.3)
    
    # t-SNE visualization
    subset_normal_mask = subset_labels == 0
    subset_fraud_mask = subset_labels == 1
    
    axes[0, 1].scatter(embeddings_tsne[subset_normal_mask, 0], embeddings_tsne[subset_normal_mask, 1], 
                      alpha=0.6, s=10, label='Normal', color='blue')
    axes[0, 1].scatter(embeddings_tsne[subset_fraud_mask, 0], embeddings_tsne[subset_fraud_mask, 1], 
                      alpha=0.8, s=30, label='Fraud', color='red')
    axes[0, 1].set_xlabel('t-SNE 1')
    axes[0, 1].set_ylabel('t-SNE 2')
    axes[0, 1].set_title('Transaction Embeddings (t-SNE)')
    axes[0, 1].legend()
    axes[0, 1].grid(True, alpha=0.3)
    
    # User embeddings
    user_embeddings = embeddings['user'].cpu().numpy()
    user_pca = PCA(n_components=2).fit_transform(user_embeddings)
    
    axes[1, 0].scatter(user_pca[:, 0], user_pca[:, 1], alpha=0.7, s=50, color='green')
    axes[1, 0].set_xlabel('PC1')
    axes[1, 0].set_ylabel('PC2')
    axes[1, 0].set_title('User Embeddings (PCA)')
    axes[1, 0].grid(True, alpha=0.3)
    
    # Merchant embeddings
    merchant_embeddings = embeddings['merchant'].cpu().numpy()
    merchant_pca = PCA(n_components=2).fit_transform(merchant_embeddings)
    
    axes[1, 1].scatter(merchant_pca[:, 0], merchant_pca[:, 1], alpha=0.7, s=50, color='orange')
    axes[1, 1].set_xlabel('PC1')
    axes[1, 1].set_ylabel('PC2')
    axes[1, 1].set_title('Merchant Embeddings (PCA)')
    axes[1, 1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    # Embedding statistics
    print("\nEmbedding Analysis:")
    print(f"Transaction embeddings shape: {transaction_embeddings.shape}")
    print(f"User embeddings shape: {user_embeddings.shape}")
    print(f"Merchant embeddings shape: {merchant_embeddings.shape}")
    
    # Silhouette analysis
    from sklearn.metrics import silhouette_score
    silhouette_pca = silhouette_score(embeddings_pca, transaction_labels)
    silhouette_tsne = silhouette_score(embeddings_tsne, subset_labels)
    
    print(f"\nSeparation Quality:")
    print(f"PCA silhouette score: {silhouette_pca:.4f}")
    print(f"t-SNE silhouette score: {silhouette_tsne:.4f}")
    
    return embeddings

# Analyze embeddings
learned_embeddings = analyze_embeddings(detector)

# Visualize training history
def plot_training_history(history):
    """
    Plot training history.
    """
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
    
    epochs = range(len(history['train_loss']))
    
    # Loss curves
    ax1.plot(epochs, history['train_loss'], 'b-', label='Training Loss', linewidth=2)
    ax1.plot(epochs, history['val_loss'], 'r-', label='Validation Loss', linewidth=2)
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.set_title('Training and Validation Loss')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    # Accuracy curves
    ax2.plot(epochs, history['train_acc'], 'b-', label='Training Accuracy', linewidth=2)
    ax2.plot(epochs, history['val_acc'], 'r-', label='Validation Accuracy', linewidth=2)
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Accuracy')
    ax2.set_title('Training and Validation Accuracy')
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

plot_training_history(training_history)

## Part 7: Performance Comparison and Analysis

Let's compare the heterogeneous GNN with simpler baselines:

In [None]:
def compare_with_baselines(detector, y_true, y_pred, y_prob):
    """
    Compare heterogeneous GNN with traditional ML baselines.
    """
    print("Comparing with traditional ML baselines...")
    
    # Prepare data for traditional ML
    X_test = detector.graph_data['transaction'].x[detector.test_mask].cpu().numpy()
    y_test = y_true
    
    # Get training data
    X_train = detector.graph_data['transaction'].x[detector.train_mask].cpu().numpy()
    y_train = detector.graph_data['transaction'].y[detector.train_mask].cpu().numpy()
    
    # Train baseline models
    from sklearn.ensemble import RandomForestClassifier
    from sklearn.linear_model import LogisticRegression
    from sklearn.svm import SVC
    
    # Scale features
    scaler = StandardScaler()
    X_train_scaled = scaler.fit_transform(X_train)
    X_test_scaled = scaler.transform(X_test)
    
    baselines = {
        'Random Forest': RandomForestClassifier(n_estimators=100, random_state=42, class_weight='balanced'),
        'Logistic Regression': LogisticRegression(random_state=42, class_weight='balanced', max_iter=1000),
        'SVM': SVC(random_state=42, class_weight='balanced', probability=True)
    }
    
    baseline_results = {}
    
    for name, model in baselines.items():
        print(f"Training {name}...")
        model.fit(X_train_scaled, y_train)
        
        # Predictions
        y_pred_baseline = model.predict(X_test_scaled)
        y_prob_baseline = model.predict_proba(X_test_scaled)[:, 1]
        
        # Metrics
        baseline_results[name] = {
            'accuracy': accuracy_score(y_test, y_pred_baseline),
            'precision': precision_score(y_test, y_pred_baseline),
            'recall': recall_score(y_test, y_pred_baseline),
            'f1_score': f1_score(y_test, y_pred_baseline),
            'roc_auc': roc_auc_score(y_test, y_prob_baseline)
        }
    
    # Add heterogeneous GNN results
    baseline_results['Heterogeneous GNN'] = {
        'accuracy': accuracy_score(y_true, y_pred),
        'precision': precision_score(y_true, y_pred),
        'recall': recall_score(y_true, y_pred),
        'f1_score': f1_score(y_true, y_pred),
        'roc_auc': roc_auc_score(y_true, y_prob)
    }
    
    # Create comparison table
    comparison_df = pd.DataFrame(baseline_results).T
    comparison_df = comparison_df.round(4)
    
    print("\n" + "="*80)
    print("MODEL COMPARISON RESULTS")
    print("="*80)
    print(comparison_df.to_string())
    print("="*80)
    
    # Visualization
    fig, axes = plt.subplots(2, 2, figsize=(15, 12))
    
    metrics = ['accuracy', 'precision', 'recall', 'f1_score']
    titles = ['Accuracy', 'Precision', 'Recall', 'F1-Score']
    
    for i, (metric, title) in enumerate(zip(metrics, titles)):
        ax = axes[i//2, i%2]
        
        models = list(baseline_results.keys())
        values = [baseline_results[model][metric] for model in models]
        
        colors = ['lightblue', 'lightgreen', 'lightcoral', 'gold']
        bars = ax.bar(models, values, color=colors)
        
        # Highlight best performer
        best_idx = np.argmax(values)
        bars[best_idx].set_color('darkgreen')
        
        ax.set_ylabel(title)
        ax.set_title(f'{title} Comparison')
        ax.set_ylim(0, 1)
        
        # Add value labels
        for j, (model, value) in enumerate(zip(models, values)):
            ax.text(j, value + 0.01, f'{value:.3f}', ha='center', va='bottom')
        
        ax.grid(True, alpha=0.3)
        
        # Rotate x-axis labels for better readability
        plt.setp(ax.get_xticklabels(), rotation=45, ha='right')
    
    plt.tight_layout()
    plt.show()
    
    return comparison_df

# Compare with baselines
comparison_results = compare_with_baselines(detector, y_true, y_pred, y_prob)

# ROC curve comparison
def plot_roc_comparison(detector, y_true, y_prob):
    """
    Plot ROC curves for different models.
    """
    from sklearn.metrics import roc_curve, auc
    from sklearn.ensemble import RandomForestClassifier
    
    plt.figure(figsize=(10, 8))
    
    # Heterogeneous GNN
    fpr_gnn, tpr_gnn, _ = roc_curve(y_true, y_prob)
    roc_auc_gnn = auc(fpr_gnn, tpr_gnn)
    
    plt.plot(fpr_gnn, tpr_gnn, color='red', lw=2, 
             label=f'Heterogeneous GNN (AUC = {roc_auc_gnn:.3f})')
    
    # Baseline comparison (Random Forest)
    X_test = detector.graph_data['transaction'].x[detector.test_mask].cpu().numpy()
    X_train = detector.graph_data['transaction'].x[detector.train_mask].cpu().numpy()
    y_train = detector.graph_data['transaction'].y[detector.train_mask].cpu().numpy()
    
    scaler = StandardScaler()
    X_train_scaled = scaler.fit_transform(X_train)
    X_test_scaled = scaler.transform(X_test)
    
    rf = RandomForestClassifier(n_estimators=100, random_state=42, class_weight='balanced')
    rf.fit(X_train_scaled, y_train)
    y_prob_rf = rf.predict_proba(X_test_scaled)[:, 1]
    
    fpr_rf, tpr_rf, _ = roc_curve(y_true, y_prob_rf)
    roc_auc_rf = auc(fpr_rf, tpr_rf)
    
    plt.plot(fpr_rf, tpr_rf, color='blue', lw=2, 
             label=f'Random Forest (AUC = {roc_auc_rf:.3f})')
    
    # Random classifier
    plt.plot([0, 1], [0, 1], color='gray', lw=2, linestyle='--', 
             label='Random Classifier (AUC = 0.500)')
    
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('ROC Curve Comparison')
    plt.legend(loc="lower right")
    plt.grid(True, alpha=0.3)
    plt.show()
    
    print(f"\nROC-AUC Improvement: {roc_auc_gnn - roc_auc_rf:.4f}")
    print(f"Relative improvement: {((roc_auc_gnn - roc_auc_rf) / roc_auc_rf * 100):.1f}%")

plot_roc_comparison(detector, y_true, y_prob)

## Practice Exercises

Now it's your turn! Try these exercises to deepen your understanding:

### Exercise 1: Graph Architecture Experiments
Modify the heterogeneous GNN architecture:
- Try different numbers of attention heads (2, 4, 8)
- Experiment with different hidden dimensions (32, 64, 128)
- Add more GNN layers (3, 4, 5)

How do these changes affect performance and training time?

In [None]:
# Your code here
# Hint: Create multiple models with different architectures
# Compare their performance on the same dataset

### Exercise 2: Edge Type Ablation Study
Study the importance of different edge types:
- Remove similarity edges (user-user, merchant-merchant)
- Remove temporal edges (transaction-transaction)
- Train with only user-transaction and merchant-transaction edges

Which edge types are most important for fraud detection?

In [None]:
# Your code here
# Hint: Modify the graph construction to exclude certain edge types
# Compare performance with and without each edge type

### Exercise 3: Attention Weight Analysis
Analyze the learned attention weights:
- Extract attention weights from the GAT layers
- Visualize which relationships the model focuses on
- Compare attention patterns for fraud vs normal transactions

What patterns does the model learn?

In [None]:
# Your code here
# Hint: Modify the forward pass to return attention weights
# Analyze the weights for different transaction types

## Key Takeaways

### 1. Graph-Based Fraud Detection
- **Relationships Matter**: Fraud detection benefits from understanding connections between entities
- **Heterogeneous Graphs**: Multiple node types provide richer representation
- **Message Passing**: Information flows through graph structure enhances predictions

### 2. Graph Construction Strategies
- **Synthetic Entities**: Use clustering to create realistic user/merchant entities
- **Multiple Edge Types**: Different relationships capture different aspects of fraud
- **Temporal Connections**: Time-based edges capture sequential patterns

### 3. Graph Attention Networks
- **Attention Mechanism**: Focuses on important connections dynamically
- **Multi-head Attention**: Captures different aspects of relationships
- **Heterogeneous Message Passing**: Different attention for different edge types

### 4. Training Considerations
- **Class Imbalance**: Weighted loss functions handle rare fraud cases
- **Graph Structure**: Complex graphs require careful regularization
- **Scalability**: Large graphs need efficient training strategies

### 5. Evaluation and Analysis
- **Embedding Visualization**: t-SNE and PCA show learned representations
- **Comparative Analysis**: Compare with traditional ML baselines
- **Attention Analysis**: Understand what the model learned

### 6. Practical Applications
- **Financial Networks**: Credit cards, bank transfers, insurance claims
- **Social Networks**: Fake accounts, spam detection, recommendation systems
- **E-commerce**: Fake reviews, seller verification, transaction monitoring

### 7. Production Considerations
- **Real-time Inference**: Optimize for low-latency predictions
- **Graph Updates**: Handle dynamic graphs with new entities
- **Scalability**: Use graph sampling for large-scale deployment
- **Interpretability**: Provide explanations for fraud predictions

## Next Steps

In the next tutorial, we'll explore:
- Online learning and streaming fraud detection
- Handling concept drift in fraud patterns
- Real-time model updates and deployment
- Scalable architectures for production systems

Remember: Graph neural networks are powerful tools for fraud detection, but they require careful design and understanding of the underlying relationships in your data!