# Final Notebook for BTCGraphGuard

**Authors: Xuhui Zhan, Tianhao Qu, Siyu Yang**


## Import Libraries

In [1]:
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)
import pandas as pd
import networkx as nx
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
import random
from sklearn.preprocessing import LabelEncoder
from torch_geometric.data import Data
import torch.nn.functional as F
from torch_geometric.nn import GATConv
from torch_geometric.nn import SAGEConv
from torch_geometric.utils import from_networkx
import time

In [2]:
# Prepare Data
data_root = 'data/elliptic_bitcoin_dataset'
elliptic_txs_features = pd.read_csv(os.path.join(data_root, 'elliptic_txs_features.csv'), header=None)
elliptic_txs_edgelist = pd.read_csv(os.path.join(data_root, 'elliptic_txs_edgelist.csv'))
elliptic_txs_classes = pd.read_csv(os.path.join(data_root, 'elliptic_txs_classes.csv'))

elliptic_txs_features.columns = ['txId'] + [f'V{i}' for i in range(1, 167)]


In [3]:
print(elliptic_txs_features.shape)
print(elliptic_txs_edgelist.shape)
print(elliptic_txs_classes.shape)


(203769, 167)
(234355, 2)
(203769, 2)


In [4]:
elliptic_txs_classes['class_mapped'] = elliptic_txs_classes['class'].replace({'1': 'illicit', '2': 'licit'})

In [5]:
# Create Graph
G = nx.from_pandas_edgelist(elliptic_txs_edgelist, 'txId1', 'txId2')

## Random seed settings

In [6]:
RANDOM_STATE = 42
NUM_EPOCHS = 30

In [7]:
def set_seed_for_torch(seed):
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)      # For single-GPU.
        torch.cuda.manual_seed_all(seed)  # For multi-GPU.
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
def set_seed_for_numpy(seed):
    np.random.seed(seed) 
    
def set_seed_for_random(seed):
    random.seed(seed)  

In [8]:
set_seed_for_torch(RANDOM_STATE)
set_seed_for_numpy(RANDOM_STATE)
set_seed_for_random(RANDOM_STATE)

## EDA

In [9]:
# Spaceholders for EDA

## Preprocess Data

In [10]:
tx_id_mapping = {tx_id: idx for idx, tx_id in enumerate(elliptic_txs_features['txId'])}

# Create an explicit copy of the filtered DataFrame
edges_with_features = elliptic_txs_edgelist[elliptic_txs_edgelist['txId1'].isin(list(tx_id_mapping.keys())) & 
                                           elliptic_txs_edgelist['txId2'].isin(list(tx_id_mapping.keys()))].copy()

# Now use loc to set values (though with copy() above, direct assignment would also work)
edges_with_features.loc[:, 'Id1'] = edges_with_features['txId1'].map(tx_id_mapping)
edges_with_features.loc[:, 'Id2'] = edges_with_features['txId2'].map(tx_id_mapping)

In [11]:
edge_index = torch.tensor(edges_with_features[['Id1', 'Id2']].values.T, dtype=torch.long)
node_features = torch.tensor(elliptic_txs_features.drop(columns=['txId']).values, 
                             dtype=torch.float)

In [12]:
le = LabelEncoder()
class_labels = le.fit_transform(elliptic_txs_classes['class'])
node_labels = torch.tensor(class_labels, dtype=torch.long)
original_labels = le.inverse_transform(class_labels)

In [13]:
data = Data(x=node_features, 
            edge_index=edge_index, 
            y=node_labels)

In [14]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [15]:
known_mask   = (data.y == 0) | (data.y == 1)  # Only nodes with known labels licit or illicit
unknown_mask = data.y == 2                    # Nodes with unknown labels

In [16]:
num_known_nodes = known_mask.sum().item()
permutations = torch.randperm(num_known_nodes)
train_size = int(0.8 * num_known_nodes)
val_size = int(0.1 * num_known_nodes)
test_size = num_known_nodes - train_size - val_size

total = np.sum([train_size, val_size, test_size])

print(f"""Number of observations per split
    Training   : {train_size:10,} ({100*train_size/total:0.2f} %)
    Validation : {val_size:10,} ({100*val_size/total:0.2f} %)
    Testing    : {test_size:10,} ({100*test_size/total:0.2f} %)
""")

Number of observations per split
    Training   :     37,251 (80.00 %)
    Validation :      4,656 (10.00 %)
    Testing    :      4,657 (10.00 %)



In [17]:
data.train_mask = torch.zeros(data.num_nodes, dtype=torch.bool)
data.val_mask = torch.zeros(data.num_nodes, dtype=torch.bool)
data.test_mask = torch.zeros(data.num_nodes, dtype=torch.bool)

train_indices = known_mask.nonzero(as_tuple=True)[0][permutations[:train_size]]
val_indices = known_mask.nonzero(as_tuple=True)[0][permutations[train_size:train_size + val_size]]
test_indices = known_mask.nonzero(as_tuple=True)[0][permutations[train_size + val_size:]]

data.train_mask[train_indices] = True
data.val_mask[val_indices] = True
data.test_mask[test_indices] = True

print(len(data.train_mask))

203769


In [18]:
print(data.keys())

['train_mask', 'x', 'test_mask', 'edge_index', 'y', 'val_mask']


## Graph attention network (GAT)

In [19]:
class GAT(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, heads=8):
        super(GAT, self).__init__()
        self.conv1 = GATConv(input_dim, hidden_dim, heads=heads, dropout=0.6)
        self.conv2 = GATConv(hidden_dim * heads, output_dim, heads=1, concat=False, dropout=0.6)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = F.dropout(x, p=0.6, training=self.training)
        x = self.conv1(x, edge_index)
        x = F.elu(x)
        x = F.dropout(x, p=0.6, training=self.training)
        x = self.conv2(x, edge_index)
        return x

## GraphSAGE

In [20]:
class GraphSAGE(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(GraphSAGE, self).__init__()
        self.conv1 = SAGEConv(input_dim, hidden_dim)
        self.conv2 = SAGEConv(hidden_dim, output_dim)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.conv2(x, edge_index)
        return x

In [21]:
def train_gnn_model(graph_data, checkpoint_path, model_args=None, num_epochs=200, lr=0.005, weight_decay=5e-4, verbose=True):
    """
    Train a Graph Neural Network model and save checkpoints.
    
    Args:
        graph_data (torch_geometric.data.Data): The prepared graph data
        checkpoint_path (str): Path to save model checkpoints
        model_args (dict, optional): Dictionary containing model parameters:
            - model_name: Type of GNN model ('GAT' or 'GraphSAGE')
            - input_dim: Input feature dimension
            - hidden_dim: Hidden layer dimension
            - output_dim: Output dimension (number of classes)
            - heads: Number of attention heads (for GAT, default: 8)
        num_epochs (int): Maximum number of training epochs
        lr (float): Learning rate for Adam optimizer
        weight_decay (float): Weight decay for regularization
        verbose (bool): Whether to print training progress
        
    Returns:
        dict: Dictionary containing trained model and training metrics
    """
    # Set up device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Create output directory if it doesn't exist
    os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)
    
    # Move data to device
    graph_data = graph_data.to(device)
    
    # Set default model args if not provided
    if model_args is None:
        model_args = {
            'model_name': 'GAT',  # Default to GAT if not specified
            'input_dim': graph_data.x.shape[1],
            'hidden_dim': 64,
            'output_dim': len(torch.unique(graph_data.y[graph_data.y != 2])),  # Exclude unknown class
            'heads': 8
        }
    
    # Initialize model based on model_name
    model_name = model_args.get('model_name', 'GAT')
    
    if model_name.upper() == 'GAT':
        model = GAT(
            input_dim=model_args['input_dim'],
            hidden_dim=model_args['hidden_dim'],
            output_dim=model_args['output_dim'],
            heads=model_args.get('heads', 8)
        ).to(device)
        model_type = 'GAT'
    elif model_name.upper() == 'GRAPHSAGE' or model_name.upper() == 'SAGE':
        model = GraphSAGE(
            input_dim=model_args['input_dim'],
            hidden_dim=model_args['hidden_dim'],
            output_dim=model_args['output_dim']
        ).to(device)
        model_type = 'GraphSAGE'
    else:
        raise ValueError(f"Unsupported model type: {model_name}. Use 'GAT' or 'GraphSAGE'.")
    
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    criterion = torch.nn.CrossEntropyLoss()
    
    # Training function
    def train_step():
        model.train()
        optimizer.zero_grad()
        out = model(graph_data)
        loss = criterion(out[graph_data.train_mask], graph_data.y[graph_data.train_mask])
        loss.backward()
        optimizer.step()
        return loss.item()
    
    # Evaluation function
    def evaluate(mask):
        model.eval()
        with torch.no_grad():
            out = model(graph_data)
            pred = out.argmax(dim=1)
            correct = pred[mask] == graph_data.y[mask]
            acc = int(correct.sum()) / int(mask.sum())
        return acc
    
    # Training loop
    best_val_acc = 0
    best_model_state = None
    train_history = {
        'losses': [],
        'train_acc': [],
        'val_acc': []
    }
    
    start_time = time.time()
    for epoch in range(1, num_epochs + 1):
        loss = train_step()
        train_acc = evaluate(graph_data.train_mask)
        val_acc = evaluate(graph_data.val_mask)
        
        train_history['losses'].append(loss)
        train_history['train_acc'].append(train_acc)
        train_history['val_acc'].append(val_acc)
        
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_model_state = model.state_dict()
            torch.save(best_model_state, checkpoint_path)
            
        if verbose and epoch % 10 == 0:
            print(f'{model_type} Epoch: {epoch:03d}, Loss: {loss:.4f}, Train Acc: {train_acc:.4f}, Val Acc: {val_acc:.4f}')
    
    training_time = time.time() - start_time
    
    # Load best model and evaluate
    model.load_state_dict(best_model_state)
    test_acc = evaluate(graph_data.test_mask)
    
    if verbose:
        print(f'{model_type} Test Accuracy: {test_acc:.4f}')
        print(f'{model_type} Training Time: {training_time:.2f} seconds')
    
    return {
        'model': model,
        'best_model_state': best_model_state,
        'test_accuracy': test_acc,
        'val_accuracy': best_val_acc,
        'training_time': training_time,
        'training_history': train_history
    }

In [22]:
def inspect_model_results(results, save_dir=None, model_name="model"):
    """
    Inspects and visualizes the results from the train_gnn_model function.
    
    Args:
        results (dict): Results dictionary returned by train_gnn_model
        save_dir (str, optional): Directory to save visualizations and metrics
                                 If None, extracts directory from checkpoint path
        model_name (str): Name of the model for labeling plots and files
    """
    # Extract training history and metrics
    history = results['training_history']
    test_acc = results['test_accuracy']
    val_acc = results['val_accuracy']
    training_time = results['training_time']
    
    # Determine save directory
    if save_dir is None:
        if 'best_model_state' in results and isinstance(results['best_model_state'], str):
            save_dir = os.path.dirname(results['best_model_state'])
        else:
            save_dir = 'output'
    
    os.makedirs(save_dir, exist_ok=True)
    
    # Create figure with 2 subplots
    plt.figure(figsize=(15, 6))
    
    # Plot training loss
    plt.subplot(1, 2, 1)
    epochs = range(1, len(history['losses']) + 1)
    plt.plot(epochs, history['losses'], 'bo-', label='Training Loss')
    plt.title(f'{model_name} - Training Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True, linestyle='--', alpha=0.7)
    
    # Plot training and validation accuracy
    plt.subplot(1, 2, 2)
    plt.plot(epochs, history['train_acc'], 'go-', label='Training Accuracy')
    plt.plot(epochs, history['val_acc'], 'ro-', label='Validation Accuracy')
    plt.title(f'{model_name} - Accuracy')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.grid(True, linestyle='--', alpha=0.7)
    
    # Save figure
    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, f'{model_name}_training_history.png'), dpi=300)
    plt.close()
    
    # Print and save metrics
    metrics = {
        'Model': model_name,
        'Test Accuracy': f'{test_acc:.4f}',
        'Validation Accuracy': f'{val_acc:.4f}',
        'Training Time (s)': f'{training_time:.2f}',
        'Final Training Loss': f'{history["losses"][-1]:.4f}',
        'Number of Epochs': len(history['losses'])
    }
    
    print(f"\n{'='*20} {model_name} Results {'='*20}")
    for key, value in metrics.items():
        print(f"{key}: {value}")
    print(f"{'='*50}")
    
    # Save metrics to CSV
    pd.DataFrame([metrics]).to_csv(os.path.join(save_dir, f'{model_name}_metrics.csv'), index=False)
    
    # Create additional visualization: accuracy vs loss
    plt.figure(figsize=(10, 6))
    plt.scatter(history['losses'], history['val_acc'], c=range(len(history['losses'])), cmap='viridis', 
                s=100, alpha=0.7, edgecolors='black', linewidth=1)
    plt.colorbar(label='Epoch')
    plt.title(f'{model_name} - Validation Accuracy vs Training Loss')
    plt.xlabel('Training Loss')
    plt.ylabel('Validation Accuracy')
    plt.grid(True, linestyle='--', alpha=0.3)
    plt.savefig(os.path.join(save_dir, f'{model_name}_acc_vs_loss.png'), dpi=300)
    plt.close()
    
    return metrics

In [23]:
# input_dim = data.x.shape[1]
# hidden_dim = 64
# output_dim = len(torch.unique(data.y[data.y != 2]))

# gat_output_dir = 'output/GAT'

# # For GAT model
# gat_args = {
#     'model_name': 'GAT',
#     'input_dim': input_dim,
#     'hidden_dim': hidden_dim,
#     'output_dim': output_dim,
#     'heads': 8
# }

# gat_results = train_gnn_model(
#     graph_data=data,
#     checkpoint_path=os.path.join(gat_output_dir, "gat_best_model.pt"),
#     model_args=gat_args,
#     num_epochs=200,
#     verbose=True
# )



In [24]:
# sage_output_dir = 'output/GraphSAGE'

# # For GraphSAGE model
# sage_args = {
#     'model_name': 'GraphSAGE',
#     'input_dim': input_dim,
#     'hidden_dim': hidden_dim,
#     'output_dim': output_dim
# }

# sage_results = train_gnn_model(
#     graph_data=data,
#     checkpoint_path=os.path.join(sage_output_dir, "sage_best_model.pt"),
#     model_args=sage_args,
#     num_epochs=200,
#     verbose=True
# )

# metrics = inspect_model_results(sage_results)

## Augmented Graph labels by self-supervised-learning

In [25]:
def augment_labels(data, percentage=0.3, model=None):
    """
    Augments the training data by adding predicted labels for previously unknown nodes.
    
    Args:
        data (torch_geometric.data.Data): The input graph data
        percentage (float, optional): Percentage of unknown nodes to add to training. Defaults to 0.3.
        model (torch.nn.Module, optional): Pre-trained model to use for predictions. If None,
                                          uses label propagation for self-supervised learning.
    
    Returns:
        torch_geometric.data.Data: Augmented data with expanded training set
    """
    # Create a copy of the data to avoid modifying the original
    augmented_data = data.clone()
    device = augmented_data.x.device
    
    # Get masks for known and unknown labels
    known_mask = (augmented_data.y == 0) | (augmented_data.y == 1)
    unknown_mask = augmented_data.y == 2
    
    # Count unknown nodes and calculate how many to add
    num_unknown = unknown_mask.sum().item()
    num_to_add = int(num_unknown * percentage)
    
    if model is None:
        # Self-supervised learning approach using simple propagation
        print(f"Using self-supervised learning for label augmentation")
        
        # Create graph from data
        G = nx.Graph()
        edge_index = augmented_data.edge_index.cpu().numpy()
        
        # Add nodes and edges to the graph
        for i in range(augmented_data.num_nodes):
            G.add_node(i)
            
        for i in range(edge_index.shape[1]):
            G.add_edge(edge_index[0, i], edge_index[1, i])
        
        # Simple label propagation implementation
        predicted_labels = torch.zeros_like(augmented_data.y)
        confidence_scores = torch.zeros(len(augmented_data.y))
        
        # For each unknown node, check neighbors' labels
        for node in range(len(augmented_data.y)):
            if unknown_mask[node]:
                # Get all neighbors
                neighbors = list(G.neighbors(node))
                if not neighbors:
                    continue
                
                # Count labels of neighbors
                label_count = {0: 0, 1: 0}
                neighbor_features = []
                
                for neighbor in neighbors:
                    if known_mask[neighbor]:
                        neighbor_label = augmented_data.y[neighbor].item()
                        label_count[neighbor_label] = label_count.get(neighbor_label, 0) + 1
                        neighbor_features.append(augmented_data.x[neighbor])
                
                # If we have neighbors with known labels
                if sum(label_count.values()) > 0:
                    # Assign most frequent label
                    if label_count[0] > label_count[1]:
                        predicted_labels[node] = 0
                        confidence_scores[node] = label_count[0] / (label_count[0] + label_count[1])
                    elif label_count[1] > 0:
                        predicted_labels[node] = 1
                        confidence_scores[node] = label_count[1] / (label_count[0] + label_count[1])
                else:
                    # If no neighbors have known labels, look at second-degree neighbors
                    second_degree_neighbors = []
                    for neighbor in neighbors:
                        second_degree_neighbors.extend(list(G.neighbors(neighbor)))
                    
                    # Remove duplicates and the original node
                    second_degree_neighbors = list(set(second_degree_neighbors))
                    if node in second_degree_neighbors:
                        second_degree_neighbors.remove(node)
                    
                    # Count labels of second-degree neighbors
                    second_label_count = {0: 0, 1: 0}
                    for neighbor in second_degree_neighbors:
                        if known_mask[neighbor]:
                            neighbor_label = augmented_data.y[neighbor].item()
                            second_label_count[neighbor_label] = second_label_count.get(neighbor_label, 0) + 1
                    
                    # Assign most frequent label from second-degree neighbors
                    if sum(second_label_count.values()) > 0:
                        if second_label_count[0] > second_label_count[1]:
                            predicted_labels[node] = 0
                            confidence_scores[node] = second_label_count[0] / (second_label_count[0] + second_label_count[1])
                        elif second_label_count[1] > 0:
                            predicted_labels[node] = 1
                            confidence_scores[node] = second_label_count[1] / (second_label_count[0] + second_label_count[1])
    else:
        # Use the provided model for predictions
        print(f"Using provided model for label augmentation")
        model.eval()
        with torch.no_grad():
            out = model(augmented_data)
            probabilities = F.softmax(out, dim=1)
            
            # Get the highest probability and corresponding class
            confidence_scores, predicted_labels = probabilities.max(dim=1)
    
    # Select the top confident predictions among unknown nodes
    unknown_indices = unknown_mask.nonzero(as_tuple=True)[0]
    unknown_confidence = confidence_scores[unknown_indices]
    
    # Sort by confidence
    sorted_indices = unknown_confidence.argsort(descending=True)
    top_indices = unknown_indices[sorted_indices[:num_to_add]]
    
    # Update labels and training mask for selected nodes
    augmented_data.y[top_indices] = predicted_labels[top_indices]
    augmented_data.train_mask[top_indices] = True
    
    # Print statistics
    added_illicit = (predicted_labels[top_indices] == 1).sum().item()
    added_licit = (predicted_labels[top_indices] == 0).sum().item()
    
    print(f"Added {num_to_add} previously unknown nodes to training set:")
    print(f"  - Predicted illicit: {added_illicit} ({100*added_illicit/num_to_add:.2f}%)")
    print(f"  - Predicted licit: {added_licit} ({100*added_licit/num_to_add:.2f}%)")
    
    return augmented_data

In [26]:
# augment_data = augment_labels(data)

In [27]:
# sage_model = GraphSAGE(input_dim, hidden_dim, output_dim)
# sage_model.load_state_dict(torch.load(os.path.join(sage_output_dir, "sage_best_model.pt")))
# augment_data_sage = augment_labels(data, model=sage_model)

## Subgraph Property Analysis

In [28]:
def analyze_graph_properties(data, model=None, save_dir='output/graph_analysis', verbose=True):
    """
    Analyzes subgraph properties for different node types (train/val/test/unknown) and labels.
    
    Args:
        data (torch_geometric.data.Data): The graph data object
        model (torch.nn.Module, optional): Model to predict unknown node labels
        save_dir (str): Directory to save results and plots
        verbose (bool): Whether to print detailed information
        
    Returns:
        dict: Dictionary containing all computed metrics
    """
    import os
    import pandas as pd
    import numpy as np
    import networkx as nx
    import matplotlib.pyplot as plt
    import seaborn as sns
    from torch_geometric.utils import to_networkx
    import torch.nn.functional as F
    
    # Create output directory
    os.makedirs(save_dir, exist_ok=True)
    
    # Convert to NetworkX graph for analysis
    G = to_networkx(data, to_undirected=True)
    
    # Define mask names for clear labeling
    mask_names = {
        'train': data.train_mask,
        'val': data.val_mask,
        'test': data.test_mask,
        'unknown': data.y == 2
    }
    
    # Define label names
    label_names = {
        0: 'licit',
        1: 'illicit',
        2: 'unknown'
    }
    
    # Process unknown nodes if model is provided
    if model is not None:
        model.eval()
        with torch.no_grad():
            out = model(data)
            pred_probs = F.softmax(out, dim=1)
            pred_labels = pred_probs.argmax(dim=1)
            
            # Create predicted label mask
            pred_unknown_licit = (data.y == 2) & (pred_labels == 0)
            pred_unknown_illicit = (data.y == 2) & (pred_labels == 1)
            
            # Add to masks dictionary
            mask_names['pred_unknown_licit'] = pred_unknown_licit
            mask_names['pred_unknown_illicit'] = pred_unknown_illicit
            
            if verbose:
                print(f"Unknown nodes predicted as licit: {pred_unknown_licit.sum().item()}")
                print(f"Unknown nodes predicted as illicit: {pred_unknown_illicit.sum().item()}")
    
    # Calculate graph-level metrics for each mask and label combination
    metrics = {}
    node_metrics = []
    
    # Calculate degree centrality for all nodes first (to avoid recalculation)
    degree_centrality = nx.degree_centrality(G)
    betweenness_centrality = nx.betweenness_centrality(G, k=min(100, len(G.nodes())), seed=42)
    
    try:
        closeness_centrality = nx.closeness_centrality(G)
    except:
        if verbose:
            print("Warning: Could not compute closeness centrality for the full graph (likely disconnected)")
        closeness_centrality = {node: 0 for node in G.nodes()}
    
    try:
        eigenvector_centrality = nx.eigenvector_centrality(G, max_iter=1000)
    except:
        if verbose:
            print("Warning: Eigenvector centrality did not converge, using approximate method")
        try:
            eigenvector_centrality = nx.eigenvector_centrality_numpy(G)
        except:
            if verbose:
                print("Warning: Could not compute eigenvector centrality, using zeros")
            eigenvector_centrality = {node: 0 for node in G.nodes()}
            
    clustering_coefficients = nx.clustering(G)
    
    # Start processing each mask type
    for mask_name, mask in mask_names.items():
        # Skip if no nodes in this mask
        if mask.sum() == 0:
            continue
            
        # Get node indices for this mask
        node_indices = mask.nonzero(as_tuple=True)[0].cpu().numpy()
        
        # For train, val, test - further split by label
        if mask_name in ['train', 'val', 'test']:
            for label in [0, 1]:  # licit and illicit
                label_mask = (data.y == label) & mask
                label_indices = label_mask.nonzero(as_tuple=True)[0].cpu().numpy()
                
                if len(label_indices) == 0:
                    continue
                
                # Create label name for this group
                group_name = f"{mask_name}_{label_names[label]}"
                
                # Process this subgroup
                process_node_group(G, label_indices, group_name, metrics, node_metrics,
                                 degree_centrality, betweenness_centrality, 
                                 closeness_centrality, eigenvector_centrality,
                                 clustering_coefficients, verbose)
        else:
            # For unknown and predicted groups, process directly
            process_node_group(G, node_indices, mask_name, metrics, node_metrics,
                             degree_centrality, betweenness_centrality, 
                             closeness_centrality, eigenvector_centrality,
                             clustering_coefficients, verbose)
    
    # Create a DataFrame with all node-level metrics
    node_df = pd.DataFrame(node_metrics)
    
    # Save node-level metrics
    node_df.to_csv(os.path.join(save_dir, 'node_metrics.csv'), index=False)
    
    # Save graph-level metrics
    graph_metrics_df = pd.DataFrame.from_dict(metrics, orient='index')
    graph_metrics_df.to_csv(os.path.join(save_dir, 'graph_metrics.csv'))
    
    # Create visualizations
    create_visualizations(metrics, node_df, save_dir)
    
    if verbose:
        print(f"Analysis complete. Results saved to {save_dir}")
    
    return metrics

def process_node_group(G, node_indices, group_name, metrics, node_metrics,
                     degree_centrality, betweenness_centrality, 
                     closeness_centrality, eigenvector_centrality,
                     clustering_coefficients, verbose):
    """Helper function to process a group of nodes and calculate metrics"""
    
    # Extract the subgraph
    subgraph = G.subgraph(node_indices)
    
    # Calculate graph-level metrics
    group_metrics = {
        'num_nodes': len(subgraph),
        'num_edges': subgraph.number_of_edges(),
        'density': nx.density(subgraph),
        'avg_degree': np.mean([d for _, d in subgraph.degree()]) if len(subgraph) > 0 else 0,
        'avg_clustering': np.mean([clustering_coefficients.get(node, 0) for node in subgraph.nodes()]),
    }
    
    # Calculate connected components
    connected_components = list(nx.connected_components(subgraph))
    group_metrics['num_components'] = len(connected_components)
    
    if len(connected_components) > 0:
        largest_cc_size = max([len(cc) for cc in connected_components])
        group_metrics['largest_component_size'] = largest_cc_size
        group_metrics['largest_component_ratio'] = largest_cc_size / len(subgraph) if len(subgraph) > 0 else 0
    else:
        group_metrics['largest_component_size'] = 0
        group_metrics['largest_component_ratio'] = 0
    
    # Calculate centrality averages
    group_metrics['avg_degree_centrality'] = np.mean([degree_centrality.get(node, 0) for node in subgraph.nodes()])
    group_metrics['avg_betweenness_centrality'] = np.mean([betweenness_centrality.get(node, 0) for node in subgraph.nodes()])
    group_metrics['avg_closeness_centrality'] = np.mean([closeness_centrality.get(node, 0) for node in subgraph.nodes()])
    group_metrics['avg_eigenvector_centrality'] = np.mean([eigenvector_centrality.get(node, 0) for node in subgraph.nodes()])
    
    # Add homophily measure - how many edges connect to same-group nodes
    internal_edges = 0
    external_edges = 0
    
    node_set = set(node_indices)
    for node in subgraph:
        for neighbor in G.neighbors(node):
            if neighbor in node_set:
                internal_edges += 1
            else:
                external_edges += 1
    
    # Each internal edge is counted twice (once from each end)
    internal_edges = internal_edges / 2
    group_metrics['internal_edges'] = internal_edges
    group_metrics['external_edges'] = external_edges
    group_metrics['homophily'] = internal_edges / (internal_edges + external_edges) if (internal_edges + external_edges) > 0 else 0
    
    # Store the metrics for this group
    metrics[group_name] = group_metrics
    
    # Also store node-level metrics for later visualization
    for node in subgraph.nodes():
        node_metrics.append({
            'node_id': node,
            'group': group_name,
            'degree': G.degree(node),
            'degree_centrality': degree_centrality.get(node, 0),
            'betweenness_centrality': betweenness_centrality.get(node, 0),
            'closeness_centrality': closeness_centrality.get(node, 0),
            'eigenvector_centrality': eigenvector_centrality.get(node, 0),
            'clustering_coefficient': clustering_coefficients.get(node, 0)
        })
    
    if verbose:
        print(f"Processed {group_name}: {len(subgraph)} nodes, {subgraph.number_of_edges()} edges")

def create_visualizations(metrics, node_df, save_dir):
    """Create and save visualizations of the graph metrics"""
    import matplotlib.pyplot as plt
    import seaborn as sns
    import pandas as pd
    import os
    
    # Set the style
    plt.style.use('seaborn-v0_8-whitegrid')
    
    # Convert metrics to DataFrame for easier plotting
    metrics_df = pd.DataFrame.from_dict(metrics, orient='index')
    
    # 1. Centrality comparison between groups
    centrality_metrics = ['avg_degree_centrality', 'avg_betweenness_centrality', 
                          'avg_closeness_centrality', 'avg_eigenvector_centrality']
    
    plt.figure(figsize=(12, 8))
    metrics_df[centrality_metrics].plot(kind='bar', figsize=(15, 6))
    plt.title('Average Centrality Metrics by Group')
    plt.ylabel('Centrality Value')
    plt.xticks(rotation=45, ha='right')
    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, 'centrality_by_group.png'), dpi=300)
    plt.close()
    
    # 2. Homophily and connectivity
    connectivity_metrics = ['density', 'homophily', 'avg_clustering', 'largest_component_ratio']
    
    plt.figure(figsize=(12, 8))
    metrics_df[connectivity_metrics].plot(kind='bar', figsize=(15, 6))
    plt.title('Connectivity Metrics by Group')
    plt.ylabel('Value')
    plt.xticks(rotation=45, ha='right')
    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, 'connectivity_by_group.png'), dpi=300)
    plt.close()
    
    # 3. Node and edge count
    plt.figure(figsize=(12, 8))
    ax = metrics_df[['num_nodes', 'num_edges']].plot(kind='bar', figsize=(15, 6))
    plt.title('Node and Edge Count by Group')
    plt.ylabel('Count')
    plt.xticks(rotation=45, ha='right')
    
    # Add value labels on bars
    for container in ax.containers:
        ax.bar_label(container, fmt='%.0f')
    
    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, 'node_edge_count.png'), dpi=300)
    plt.close()
    
    # 4. Scatter plot of centrality metrics for nodes
    centrality_pairs = [
        ('degree_centrality', 'betweenness_centrality'),
        ('degree_centrality', 'eigenvector_centrality'),
        ('closeness_centrality', 'eigenvector_centrality')
    ]
    
    # Define a color map for consistent group coloring
    group_order = node_df['group'].unique()
    colors = plt.cm.tab10(np.linspace(0, 1, len(group_order)))
    group_colors = {group: colors[i] for i, group in enumerate(group_order)}
    
    for x_metric, y_metric in centrality_pairs:
        plt.figure(figsize=(10, 8))
        
        for group in group_order:
            group_data = node_df[node_df['group'] == group]
            plt.scatter(
                group_data[x_metric], 
                group_data[y_metric],
                alpha=0.6, 
                label=group,
                color=group_colors[group],
                s=50
            )
        
        plt.title(f'{y_metric.replace("_", " ").title()} vs {x_metric.replace("_", " ").title()}')
        plt.xlabel(x_metric.replace("_", " ").title())
        plt.ylabel(y_metric.replace("_", " ").title())
        plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
        plt.grid(True, linestyle='--', alpha=0.7)
        plt.tight_layout()
        plt.savefig(os.path.join(save_dir, f'scatter_{x_metric}_vs_{y_metric}.png'), dpi=300)
        plt.close()
    
    # 5. Distribution of degree centrality by group
    plt.figure(figsize=(12, 8))
    for group in group_order:
        group_data = node_df[node_df['group'] == group]
        sns.kdeplot(group_data['degree_centrality'], label=group)
    
    plt.title('Distribution of Degree Centrality by Group')
    plt.xlabel('Degree Centrality')
    plt.ylabel('Density')
    plt.legend()
    plt.grid(True, linestyle='--', alpha=0.7)
    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, 'degree_centrality_distribution.png'), dpi=300)
    plt.close()
    
    # 6. Heatmap of metrics correlation
    plt.figure(figsize=(10, 8))
    correlation = metrics_df.corr()
    mask = np.triu(np.ones_like(correlation, dtype=bool))
    sns.heatmap(correlation, mask=mask, annot=True, cmap='coolwarm', vmin=-1, vmax=1, center=0)
    plt.title('Correlation Between Graph Metrics')
    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, 'metrics_correlation.png'), dpi=300)
    plt.close()

In [29]:
# # Analysis with original data
# original_metrics = analyze_graph_properties(
#     data, 
#     save_dir='output/graph_analysis/original',
#     verbose=True
# )

# # Analysis with model predictions
# sage_metrics = analyze_graph_properties(
#     data,
#     model=sage_model,  # Using SAGE model for predictions
#     save_dir='output/graph_analysis/model_predictions',
#     verbose=True
# )

# Compare metrics between original and predicted
def compare_metrics(original_metrics, predicted_metrics, save_path='output/graph_analysis/metrics_comparison.csv'):
    """Compare metrics between original and predicted data"""
    import pandas as pd
    
    # Create DataFrames
    original_df = pd.DataFrame.from_dict(original_metrics, orient='index')
    predicted_df = pd.DataFrame.from_dict(predicted_metrics, orient='index')
    
    # Common groups (like train_licit, train_illicit, etc.)
    common_groups = set(original_df.index).intersection(set(predicted_df.index))
    
    # Compare metrics for common groups
    comparison = {}
    for group in common_groups:
        group_comparison = {}
        for metric in original_df.columns:
            original_value = original_df.loc[group, metric]
            predicted_value = predicted_df.loc[group, metric]
            difference = predicted_value - original_value
            percent_change = (difference / original_value) * 100 if original_value != 0 else float('inf')
            
            group_comparison[f"{metric}_original"] = original_value
            group_comparison[f"{metric}_predicted"] = predicted_value
            group_comparison[f"{metric}_diff"] = difference
            group_comparison[f"{metric}_pct_change"] = percent_change
            
        comparison[group] = group_comparison
    
    # Convert to DataFrame and save
    comparison_df = pd.DataFrame.from_dict(comparison, orient='index')
    comparison_df.to_csv(save_path)
    
    print(f"Metrics comparison saved to {save_path}")
    return comparison_df

# # Execute comparison
# comparison = compare_metrics(original_metrics, sage_metrics)

# # Print key findings
# print("\n==== KEY FINDINGS ====")
# print("Top metrics differences between original and predicted data:")

# # Find top 5 metrics with largest percent changes
# for group in comparison.index:
#     pct_change_cols = [col for col in comparison.columns if col.endswith('_pct_change')]
#     largest_changes = comparison.loc[group, pct_change_cols].abs().nlargest(5)
    
#     print(f"\nGroup: {group}")
#     for metric in largest_changes.index:
#         base_metric = metric.replace('_pct_change', '')
#         original = comparison.loc[group, f"{base_metric}_original"]
#         predicted = comparison.loc[group, f"{base_metric}_predicted"]
#         pct_change = comparison.loc[group, metric]
        
#         change_direction = "increased" if pct_change > 0 else "decreased"
#         print(f"  {base_metric}: {original:.4f} → {predicted:.4f} ({change_direction} by {abs(pct_change):.2f}%)")

## Overall comparison

In [30]:
def comprehensive_graph_analysis(
    data, 
    output_dir='output/comprehensive_analysis',
    model_configs=None,
    augment_percentage=0.3,
    epochs=100,
    random_seed=42,
    verbose=True
):
    """
    Performs comprehensive analysis of graph augmentation methods.
    
    Args:
        data (torch_geometric.data.Data): The input graph data
        output_dir (str): Root directory for saving results
        model_configs (dict, optional): Model configurations
        augment_percentage (float): Percentage of unknown nodes to add (default: 0.3)
        epochs (int): Number of training epochs
        random_seed (int): Random seed for reproducibility
        verbose (bool): Whether to print detailed information
    
    Returns:
        dict: Dictionary containing all results and comparisons
    """
    import os
    import time
    import pandas as pd
    import numpy as np
    import torch
    import matplotlib.pyplot as plt
    import seaborn as sns
    from datetime import datetime
    
    # Set random seeds
    set_seed_for_torch(random_seed)
    set_seed_for_numpy(random_seed)
    set_seed_for_random(random_seed)
    
    # Create output directory
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    output_dir = f"{output_dir}_{timestamp}"
    os.makedirs(output_dir, exist_ok=True)
    
    # Setup default model configurations if not provided
    if model_configs is None:
        model_configs = {
            'input_dim': data.x.shape[1],
            'hidden_dim': 64,
            'output_dim': len(torch.unique(data.y[data.y != 2])),
            'gat_heads': 8
        }
    
    # Create subdirectories
    original_dir = os.path.join(output_dir, "original")
    models_dir = os.path.join(output_dir, "models")
    augmented_dir = os.path.join(output_dir, "augmented")
    comparison_dir = os.path.join(output_dir, "comparison")
    
    for directory in [original_dir, models_dir, augmented_dir, comparison_dir]:
        os.makedirs(directory, exist_ok=True)
    
    # Dictionary to store all results
    results = {
        'original': {},
        'base_models': {},
        'augmented': {},
        'augmented_models': {},
        'comparison': {}
    }
    
    # Step 1: Analyze original graph properties
    if verbose:
        print("\n==== Step 1: Analyzing Original Graph Properties ====")
    
    original_metrics = analyze_graph_properties(
        data, 
        save_dir=os.path.join(original_dir, "graph_properties"),
        verbose=verbose
    )
    results['original']['metrics'] = original_metrics
    
    # Step 2: Train GAT and GraphSAGE on known graph
    if verbose:
        print("\n==== Step 2: Training Base Models ====")
    
    # For GAT model
    gat_args = {
        'model_name': 'GAT',
        'input_dim': model_configs['input_dim'],
        'hidden_dim': model_configs['hidden_dim'],
        'output_dim': model_configs['output_dim'],
        'heads': model_configs['gat_heads']
    }
    
    gat_path = os.path.join(models_dir, "base_gat")
    os.makedirs(gat_path, exist_ok=True)
    
    base_gat_results = train_gnn_model(
        graph_data=data,
        checkpoint_path=os.path.join(gat_path, "base_gat_model.pt"),
        model_args=gat_args,
        num_epochs=epochs,
        verbose=verbose
    )
    
    # For GraphSAGE model
    sage_args = {
        'model_name': 'GraphSAGE',
        'input_dim': model_configs['input_dim'],
        'hidden_dim': model_configs['hidden_dim'],
        'output_dim': model_configs['output_dim']
    }
    
    sage_path = os.path.join(models_dir, "base_sage")
    os.makedirs(sage_path, exist_ok=True)
    
    base_sage_results = train_gnn_model(
        graph_data=data,
        checkpoint_path=os.path.join(sage_path, "base_sage_model.pt"),
        model_args=sage_args,
        num_epochs=epochs,
        verbose=verbose
    )
    
    # Step 3: Inspect results and analyze model properties
    if verbose:
        print("\n==== Step 3: Inspecting Base Model Results ====")
        
    gat_metrics = inspect_model_results(base_gat_results, save_dir=gat_path, model_name="base_gat")
    sage_metrics = inspect_model_results(base_sage_results, save_dir=sage_path, model_name="base_sage")
    
    results['base_models']['gat'] = {
        'model': base_gat_results['model'],
        'metrics': gat_metrics,
        'training_history': base_gat_results['training_history']
    }
    
    results['base_models']['sage'] = {
        'model': base_sage_results['model'],
        'metrics': sage_metrics,
        'training_history': base_sage_results['training_history']
    }
    
    # Analyze graph properties with base models
    gat_model = base_gat_results['model']
    sage_model = base_sage_results['model']
    
    gat_graph_metrics = analyze_graph_properties(
        data, 
        model=gat_model,
        save_dir=os.path.join(gat_path, "graph_properties"),
        verbose=verbose
    )
    
    sage_graph_metrics = analyze_graph_properties(
        data, 
        model=sage_model,
        save_dir=os.path.join(sage_path, "graph_properties"),
        verbose=verbose
    )
    
    results['base_models']['gat']['graph_metrics'] = gat_graph_metrics
    results['base_models']['sage']['graph_metrics'] = sage_graph_metrics
    
    # Step 4: Augment data using different methods
    if verbose:
        print("\n==== Step 4: Augmenting Data with Different Methods ====")
    
    # Method 1: Label propagation (default method)
    lp_augmented_data = augment_labels(data, percentage=augment_percentage)
    
    # Method 2: GAT-based augmentation
    gat_augmented_data = augment_labels(data, percentage=augment_percentage, model=gat_model)
    
    # Method 3: GraphSAGE-based augmentation
    sage_augmented_data = augment_labels(data, percentage=augment_percentage, model=sage_model)
    
    # Store augmented data
    augmented_data = {
        'label_propagation': lp_augmented_data,
        'gat': gat_augmented_data,
        'sage': sage_augmented_data
    }
    results['augmented']['data'] = augmented_data
    
    # Step 5: Analyze graph properties for each augmented method
    if verbose:
        print("\n==== Step 5: Analyzing Augmented Graph Properties ====")
    
    for method_name, aug_data in augmented_data.items():
        method_dir = os.path.join(augmented_dir, method_name)
        os.makedirs(method_dir, exist_ok=True)
        
        # Analyze without model predictions
        aug_metrics = analyze_graph_properties(
            aug_data, 
            save_dir=os.path.join(method_dir, "graph_properties"),
            verbose=verbose
        )
        
        results['augmented'][f'{method_name}_metrics'] = aug_metrics
    
    # Step 6: Train models on augmented graphs
    if verbose:
        print("\n==== Step 6: Training Models on Augmented Graphs ====")
    
    augmented_models = {}
    
    for method_name, aug_data in augmented_data.items():
        method_dir = os.path.join(augmented_dir, method_name)
        
        # Train GAT on augmented data
        gat_aug_path = os.path.join(method_dir, "gat")
        os.makedirs(gat_aug_path, exist_ok=True)
        
        gat_aug_results = train_gnn_model(
            graph_data=aug_data,
            checkpoint_path=os.path.join(gat_aug_path, f"{method_name}_gat_model.pt"),
            model_args=gat_args,
            num_epochs=epochs,
            verbose=verbose
        )
        
        # Train GraphSAGE on augmented data
        sage_aug_path = os.path.join(method_dir, "sage")
        os.makedirs(sage_aug_path, exist_ok=True)
        
        sage_aug_results = train_gnn_model(
            graph_data=aug_data,
            checkpoint_path=os.path.join(sage_aug_path, f"{method_name}_sage_model.pt"),
            model_args=sage_args,
            num_epochs=epochs,
            verbose=verbose
        )
        
        # Inspect model results
        gat_aug_metrics = inspect_model_results(
            gat_aug_results, 
            save_dir=gat_aug_path, 
            model_name=f"{method_name}_gat"
        )
        
        sage_aug_metrics = inspect_model_results(
            sage_aug_results, 
            save_dir=sage_aug_path, 
            model_name=f"{method_name}_sage"
        )
        
        # Store model results
        augmented_models[f'{method_name}_gat'] = {
            'model': gat_aug_results['model'],
            'metrics': gat_aug_metrics,
            'training_history': gat_aug_results['training_history']
        }
        
        augmented_models[f'{method_name}_sage'] = {
            'model': sage_aug_results['model'],
            'metrics': sage_aug_metrics,
            'training_history': sage_aug_results['training_history']
        }
    
    results['augmented_models'] = augmented_models
    
    # Step 7: Analyze graph properties using augmented models
    if verbose:
        print("\n==== Step 7: Analyzing Graph Properties with Augmented Models ====")
    
    for model_name, model_info in augmented_models.items():
        model = model_info['model']
        method_name = model_name.split('_')[0]
        model_type = model_name.split('_')[1]
        
        save_dir = os.path.join(augmented_dir, method_name, model_type, "graph_properties_with_model")
        os.makedirs(save_dir, exist_ok=True)
        
        # Analyze graph properties using this model
        graph_metrics = analyze_graph_properties(
            data,  # Use original data to see how model predicts
            model=model,
            save_dir=save_dir,
            verbose=verbose
        )
        
        results['augmented_models'][model_name]['graph_metrics'] = graph_metrics
    
    # Step 8: Save important metrics
    if verbose:
        print("\n==== Step 8: Saving Important Metrics ====")
    
    # Compile accuracy metrics for all models
    accuracy_metrics = {
        'base_gat': results['base_models']['gat']['metrics'],
        'base_sage': results['base_models']['sage']['metrics']
    }
    
    for model_name in augmented_models.keys():
        accuracy_metrics[model_name] = augmented_models[model_name]['metrics']
    
    # Convert to DataFrame and save
    accuracy_df = pd.DataFrame.from_dict(accuracy_metrics, orient='index')
    accuracy_df.to_csv(os.path.join(comparison_dir, "model_accuracy_comparison.csv"))
    
    # Step 9: Create comprehensive comparison results
    if verbose:
        print("\n==== Step 9: Creating Comprehensive Comparison ====")
    
    # 1. Compare test accuracy across all models
    test_accuracies = {}
    for model_name, metrics in accuracy_metrics.items():
        if isinstance(metrics, dict) and 'Test Accuracy' in metrics:
            test_accuracies[model_name] = float(metrics['Test Accuracy'])
    
    # Create and save accuracy comparison plot
    plt.figure(figsize=(12, 6))
    models = list(test_accuracies.keys())
    accuracies = [test_accuracies[model] for model in models]
    
    # Use different colors for base models vs augmented models
    colors = ['blue' if 'base' in model else 'green' for model in models]
    
    bars = plt.bar(models, accuracies, color=colors)
    plt.title('Test Accuracy Comparison Across Models')
    plt.xlabel('Model')
    plt.ylabel('Test Accuracy')
    plt.xticks(rotation=45, ha='right')
    
    # Add value labels
    for bar in bars:
        height = bar.get_height()
        plt.text(bar.get_x() + bar.get_width()/2., height + 0.005,
                f'{height:.4f}', ha='center', va='bottom', rotation=0)
    
    plt.tight_layout()
    plt.savefig(os.path.join(comparison_dir, 'test_accuracy_comparison.png'), dpi=300)
    plt.close()
    
    # 2. Compare graph property consistency
    
    # Collect all graph metrics
    all_graph_metrics = {
        'original': results['original']['metrics'],
        'base_gat': results['base_models']['gat']['graph_metrics'],
        'base_sage': results['base_models']['sage']['graph_metrics']
    }
    
    for model_name, model_info in augmented_models.items():
        if 'graph_metrics' in model_info:
            all_graph_metrics[model_name] = model_info['graph_metrics']
    
    # Calculate similarity scores between original and predicted metrics
    similarity_scores = {}
    
    for model_name, metrics in all_graph_metrics.items():
        if model_name == 'original':
            continue
            
        # Compare with original metrics
        similarity = calculate_metrics_similarity(
            all_graph_metrics['original'], 
            metrics,
            save_path=os.path.join(comparison_dir, f"{model_name}_vs_original_similarity.csv")
        )
        
        similarity_scores[model_name] = similarity
    
    # Create and save similarity comparison plot
    plt.figure(figsize=(12, 6))
    models = list(similarity_scores.keys())
    scores = [similarity_scores[model]['overall_similarity'] for model in models]
    
    # Use different colors for base models vs augmented models
    colors = ['blue' if 'base' in model else 'green' for model in models]
    
    bars = plt.bar(models, scores, color=colors)
    plt.title('Graph Property Consistency Scores')
    plt.xlabel('Model')
    plt.ylabel('Similarity Score (higher is better)')
    plt.xticks(rotation=45, ha='right')
    
    # Add value labels
    for bar in bars:
        height = bar.get_height()
        plt.text(bar.get_x() + bar.get_width()/2., height + 0.01,
                f'{height:.4f}', ha='center', va='bottom', rotation=0)
    
    plt.tight_layout()
    plt.savefig(os.path.join(comparison_dir, 'graph_property_similarity.png'), dpi=300)
    plt.close()
    
    # 3. Create summary table with key metrics
    summary_metrics = []
    
    # Base models
    summary_metrics.append({
        'model': 'base_gat',
        'test_accuracy': float(results['base_models']['gat']['metrics']['Test Accuracy']),
        'training_time': float(results['base_models']['gat']['metrics']['Training Time (s)']),
        'consistency_score': similarity_scores['base_gat']['overall_similarity'] if 'base_gat' in similarity_scores else 0
    })
    
    summary_metrics.append({
        'model': 'base_sage',
        'test_accuracy': float(results['base_models']['sage']['metrics']['Test Accuracy']),
        'training_time': float(results['base_models']['sage']['metrics']['Training Time (s)']),
        'consistency_score': similarity_scores['base_sage']['overall_similarity'] if 'base_sage' in similarity_scores else 0
    })
    
    # Augmented models
    for model_name, model_info in augmented_models.items():
        summary_metrics.append({
            'model': model_name,
            'test_accuracy': float(model_info['metrics']['Test Accuracy']),
            'training_time': float(model_info['metrics']['Training Time (s)']),
            'consistency_score': similarity_scores[model_name]['overall_similarity'] if model_name in similarity_scores else 0
        })
    
    # Create DataFrame and save
    summary_df = pd.DataFrame(summary_metrics)
    summary_df.to_csv(os.path.join(comparison_dir, 'summary_metrics.csv'), index=False)
    
    # 4. Find the best model based on combined metrics
    summary_df['combined_score'] = (
        summary_df['test_accuracy'] * 0.6 +  # Weight accuracy more
        summary_df['consistency_score'] * 0.4  # Weight consistency less
    )
    
    best_model = summary_df.sort_values('combined_score', ascending=False).iloc[0]
    
    # Print and save conclusion
    conclusion = f"""
    ===== ANALYSIS CONCLUSION =====
    
    Best Overall Model: {best_model['model']}
    Test Accuracy: {best_model['test_accuracy']:.4f}
    Graph Consistency Score: {best_model['consistency_score']:.4f}
    Combined Score: {best_model['combined_score']:.4f}
    Training Time: {best_model['training_time']:.2f} seconds
    
    Key Findings:
    - The best model for test accuracy was: {summary_df.loc[summary_df['test_accuracy'].idxmax(), 'model']} ({summary_df['test_accuracy'].max():.4f})
    - The best model for graph consistency was: {summary_df.loc[summary_df['consistency_score'].idxmax(), 'model']} ({summary_df['consistency_score'].max():.4f})
    - The fastest model was: {summary_df.loc[summary_df['training_time'].idxmin(), 'model']} ({summary_df['training_time'].min():.2f} seconds)
    
    Recommendation:
    The {best_model['model']} model provides the best balance between prediction accuracy and graph structure preservation.
    """
    
    if verbose:
        print(conclusion)
    
    # Save conclusion
    with open(os.path.join(comparison_dir, 'conclusion.txt'), 'w') as f:
        f.write(conclusion)
    
    # Save overall results dictionary (without model objects which aren't serializable)
    serializable_results = {}
    for key, value in results.items():
        if key in ['original', 'comparison']:
            serializable_results[key] = value
        else:
            serializable_results[key] = {}
            for subkey, subvalue in value.items():
                if isinstance(subvalue, dict) and 'model' in subvalue:
                    serializable_results[key][subkey] = {k: v for k, v in subvalue.items() if k != 'model'}
                else:
                    serializable_results[key][subkey] = subvalue
    
    # Save as pickle if comprehensive results are needed
    import pickle
    with open(os.path.join(output_dir, 'serializable_results.pkl'), 'wb') as f:
        pickle.dump(serializable_results, f)
    
    return results


def calculate_metrics_similarity(original_metrics, predicted_metrics, 
                               save_path=None, key_metrics=None):
    """
    Calculate similarity/consistency between original and predicted graph metrics.
    
    Args:
        original_metrics (dict): Original graph metrics
        predicted_metrics (dict): Predicted graph metrics
        save_path (str, optional): Path to save the comparison
        key_metrics (list, optional): List of key metrics to prioritize
        
    Returns:
        dict: Similarity scores
    """
    import pandas as pd
    import numpy as np
    
    # Define key metrics if not provided
    if key_metrics is None:
        key_metrics = [
            'homophily', 'density', 'avg_clustering', 'avg_degree_centrality',
            'avg_betweenness_centrality', 'largest_component_ratio'
        ]
    
    # Common groups
    common_groups = set(original_metrics.keys()).intersection(set(predicted_metrics.keys()))
    
    # Calculate similarity for each metric
    similarity_scores = {}
    all_scores = []
    
    for group in common_groups:
        group_similarity = {}
        group_scores = []
        
        # Check if group exists in both metrics
        if group in original_metrics and group in predicted_metrics:
            orig_group = original_metrics[group]
            pred_group = predicted_metrics[group]
            
            # Find common metrics
            common_metrics = set(orig_group.keys()).intersection(set(pred_group.keys()))
            
            for metric in common_metrics:
                if metric in ['num_nodes', 'num_edges', 'num_components']:
                    continue  # Skip count metrics
                
                orig_val = orig_group.get(metric, 0)
                pred_val = pred_group.get(metric, 0)
                
                # Calculate similarity (1 - relative difference)
                if orig_val != 0:
                    rel_diff = abs(pred_val - orig_val) / abs(orig_val)
                    similarity = max(0, 1 - min(rel_diff, 1))  # Cap at 0-1 range
                else:
                    # If original is 0, check if prediction is also close to 0
                    similarity = 1 if abs(pred_val) < 0.01 else 0
                
                group_similarity[metric] = similarity
                group_scores.append(similarity)
                
                # If this is a key metric, add it to the overall scores with higher weight
                if metric in key_metrics:
                    all_scores.append(similarity)
                    all_scores.append(similarity)  # Add twice for more weight
                else:
                    all_scores.append(similarity)
        
        # Calculate group average
        if group_scores:
            group_similarity['average'] = np.mean(group_scores)
            similarity_scores[group] = group_similarity
    
    # Calculate overall similarity
    overall_similarity = np.mean(all_scores) if all_scores else 0
    similarity_scores['overall_similarity'] = overall_similarity
    
    # Save comparison if path provided
    if save_path:
        comparison_data = []
        
        for group in common_groups:
            if group in similarity_scores:
                for metric, score in similarity_scores[group].items():
                    if metric != 'average':
                        comparison_data.append({
                            'group': group,
                            'metric': metric,
                            'original_value': original_metrics[group].get(metric, np.nan),
                            'predicted_value': predicted_metrics[group].get(metric, np.nan),
                            'similarity_score': score
                        })
        
        # Convert to DataFrame and save
        df = pd.DataFrame(comparison_data)
        df.to_csv(save_path, index=False)
    
    return similarity_scores

In [31]:
results = comprehensive_graph_analysis(data, epochs=200)


==== Step 1: Analyzing Original Graph Properties ====
Processed train_licit: 3602 nodes, 651 edges
Processed train_illicit: 33649 nodes, 21207 edges
Processed val_licit: 508 nodes, 18 edges
Processed val_illicit: 4148 nodes, 379 edges
Processed test_licit: 435 nodes, 8 edges
Processed test_illicit: 4222 nodes, 394 edges
Processed unknown: 157205 nodes, 131778 edges
Analysis complete. Results saved to output/comprehensive_analysis_20250410_171856/original/graph_properties

==== Step 2: Training Base Models ====
GAT Epoch: 010, Loss: 0.8470, Train Acc: 0.7809, Val Acc: 0.7788
GAT Epoch: 020, Loss: 0.6171, Train Acc: 0.7802, Val Acc: 0.7801
GAT Epoch: 030, Loss: 0.4966, Train Acc: 0.7180, Val Acc: 0.7118
GAT Epoch: 040, Loss: 0.4165, Train Acc: 0.8239, Val Acc: 0.8232
GAT Epoch: 050, Loss: 0.3771, Train Acc: 0.8344, Val Acc: 0.8344
GAT Epoch: 060, Loss: 0.3567, Train Acc: 0.8602, Val Acc: 0.8582
GAT Epoch: 070, Loss: 0.3427, Train Acc: 0.8573, Val Acc: 0.8567
GAT Epoch: 080, Loss: 0.3331

  fig = self.plt.figure(figsize=self.figsize)
  plt.figure(figsize=(12, 8))
  sns.kdeplot(group_data['degree_centrality'], label=group)


Analysis complete. Results saved to output/comprehensive_analysis_20250410_171856/augmented/label/propagation/graph_properties_with_model
Unknown nodes predicted as licit: 4989
Unknown nodes predicted as illicit: 152216
Processed train_licit: 3602 nodes, 651 edges
Processed train_illicit: 33649 nodes, 21207 edges
Processed val_licit: 508 nodes, 18 edges
Processed val_illicit: 4148 nodes, 379 edges
Processed test_licit: 435 nodes, 8 edges
Processed test_illicit: 4222 nodes, 394 edges
Processed unknown: 157205 nodes, 131778 edges
Processed pred_unknown_licit: 4989 nodes, 1299 edges
Processed pred_unknown_illicit: 152216 nodes, 124705 edges
Analysis complete. Results saved to output/comprehensive_analysis_20250410_171856/augmented/label/propagation/graph_properties_with_model
Unknown nodes predicted as licit: 5862
Unknown nodes predicted as illicit: 151343
Processed train_licit: 3602 nodes, 651 edges
Processed train_illicit: 33649 nodes, 21207 edges
Processed val_licit: 508 nodes, 18 edge

<Figure size 1200x800 with 0 Axes>

<Figure size 1200x800 with 0 Axes>

<Figure size 1200x800 with 0 Axes>

<Figure size 1200x800 with 0 Axes>

<Figure size 1200x800 with 0 Axes>

<Figure size 1200x800 with 0 Axes>

<Figure size 1200x800 with 0 Axes>

<Figure size 1200x800 with 0 Axes>

<Figure size 1200x800 with 0 Axes>

<Figure size 1200x800 with 0 Axes>

<Figure size 1200x800 with 0 Axes>

<Figure size 1200x800 with 0 Axes>

<Figure size 1200x800 with 0 Axes>

<Figure size 1200x800 with 0 Axes>

<Figure size 1200x800 with 0 Axes>

<Figure size 1200x800 with 0 Axes>

<Figure size 1200x800 with 0 Axes>

<Figure size 1200x800 with 0 Axes>

<Figure size 1200x800 with 0 Axes>

<Figure size 1200x800 with 0 Axes>

<Figure size 1200x800 with 0 Axes>

<Figure size 1200x800 with 0 Axes>

<Figure size 1200x800 with 0 Axes>

<Figure size 1200x800 with 0 Axes>

<Figure size 1200x800 with 0 Axes>

<Figure size 1200x800 with 0 Axes>

<Figure size 1200x800 with 0 Axes>

<Figure size 1200x800 with 0 Axes>

<Figure size 1200x800 with 0 Axes>

<Figure size 1200x800 with 0 Axes>

<Figure size 1200x800 with 0 Axes>

<Figure size 1200x800 with 0 Axes>

<Figure size 1200x800 with 0 Axes>

<Figure size 1200x800 with 0 Axes>

<Figure size 1200x800 with 0 Axes>

<Figure size 1200x800 with 0 Axes>