<a href="https://colab.research.google.com/github/sanatan-dive/blockchain-anomaly-detection/blob/main/Model1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [23]:
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
!pip install torch-geometric
!pip install torch-scatter torch-sparse torch-cluster torch-spline-conv -f https://data.pyg.org/whl/torch-2.0.0+cu118.html
!pip install scikit-learn pandas numpy matplotlib seaborn

Looking in indexes: https://download.pytorch.org/whl/cu118
Looking in links: https://data.pyg.org/whl/torch-2.0.0+cu118.html


In [32]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torch_geometric.nn import GCNConv
from torch_geometric.data import Data
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    roc_auc_score, average_precision_score, confusion_matrix,
    roc_curve, precision_recall_curve
)
import matplotlib.pyplot as plt
import seaborn as sns
from collections import Counter
import warnings
import os
warnings.filterwarnings('ignore')
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Check GPU availability
print(f"🚀 Using device: {torch.device('cuda' if torch.cuda.is_available() else 'cpu')}")
if torch.cuda.is_available():
    print(f"   GPU: {torch.cuda.get_device_name(0)}")

class BlockchainGCN(nn.Module):
    """
    Graph Convolutional Network for blockchain anomaly detection
    """
    def __init__(self, num_features, hidden_dim=64, num_classes=2, dropout=0.5):
        super(BlockchainGCN, self).__init__()
        self.conv1 = GCNConv(num_features, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.conv3 = GCNConv(hidden_dim, num_classes)
        self.dropout = dropout
        self.batch_norm1 = nn.BatchNorm1d(hidden_dim)
        self.batch_norm2 = nn.BatchNorm1d(hidden_dim)

    def forward(self, x, edge_index):
        # First GCN layer
        x = self.conv1(x, edge_index)
        x = self.batch_norm1(x)
        x = F.relu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)

        # Second GCN layer
        x = self.conv2(x, edge_index)
        x = self.batch_norm2(x)
        x = F.relu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)

        # Output layer
        x = self.conv3(x, edge_index)
        return F.log_softmax(x, dim=1)

class BlockchainAnomalyDetector:
    """
    Complete pipeline for blockchain anomaly detection using GNN
    """
    def __init__(self, hidden_dim=64, dropout=0.5, lr=0.01, weight_decay=5e-4):
        self.hidden_dim = hidden_dim
        self.dropout = dropout
        self.lr = lr
        self.weight_decay = weight_decay
        self.model = None
        self.scaler = StandardScaler()
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        print(f"🔧 Detector initialized on {self.device}")

    def load_and_prepare_data(self, transactions_file='elliptic_txs_features.csv',
                            edges_file='elliptic_txs_edgelist.csv',
                            classes_file='elliptic_txs_classes.csv'):
        """
        Load and prepare blockchain transaction data from Elliptic dataset
        """
        print("🔄 Loading and preparing data...")

        # Load data
        if transactions_file and edges_file and classes_file and os.path.exists(transactions_file) and os.path.exists(edges_file) and os.path.exists(classes_file):
            try:
                # Load transactions with header=None and set the first column as 'txId'
                self.transactions = pd.read_csv(transactions_file, header=None)
                self.transactions.columns = ['txId'] + [f'feature_{i}' for i in range(self.transactions.shape[1] - 1)]
                self.edges = pd.read_csv(edges_file)
                self.classes = pd.read_csv(classes_file)
                print(f"✅ Loaded {len(self.transactions)} transactions, {len(self.edges)} edges, and {len(self.classes)} class labels")

                # Debug: Print column names and detailed class labels
                print("Transactions columns:", self.transactions.columns.tolist())
                print("Classes columns:", self.classes.columns.tolist())
                print("Sample class labels before merge (unique values):", self.classes['class'].unique().tolist())
                print("Sample class labels before merge (value counts):", self.classes['class'].value_counts().to_dict())

                # Identify the correct merge key
                tx_id_col_transactions = 'txId'
                tx_id_col_classes = 'txId' if 'txId' in self.classes.columns else self.classes.columns[0]

                print(f"Merging using: {tx_id_col_transactions} (transactions) and {tx_id_col_classes} (classes)")

                # Merge transactions with class labels using the identified columns
                self.transactions = self.transactions.merge(
                    self.classes,
                    left_on=tx_id_col_transactions,
                    right_on=tx_id_col_classes,
                    how='left'
                )

                # Debug: Check class distribution after merge
                print("Class distribution after merge (before filling):", self.transactions['class'].value_counts().to_dict())

                # Drop the duplicate ID column from classes if it exists and differs from 'txId'
                if tx_id_col_classes != 'txId' and tx_id_col_classes in self.transactions.columns:
                    self.transactions = self.transactions.drop(columns=[tx_id_col_classes])

                # Verify the resulting columns
                print("Merged transactions columns:", self.transactions.columns.tolist())

            except Exception as e:
                print(f"⚠️  Error loading files: {e}")
                raise
        else:
            print("⚠️  Files not provided or not found. Please provide valid transaction, edge, and class files.")
            raise ValueError("Missing or invalid input files")

        # Data cleaning and preparation
        self._clean_data()
        self._prepare_features()
        self._create_graph()

        return self

    def _clean_data(self):
        """
        Clean and validate the data
        """
        print("🧹 Cleaning data...")

        # Remove duplicates using the 'txId' column
        initial_count = len(self.transactions)
        self.transactions = self.transactions.drop_duplicates(subset=['txId'])
        self.edges = self.edges.drop_duplicates()

        # Handle missing values (fill with median for numeric columns)
        numeric_cols = self.transactions.select_dtypes(include=[np.number]).columns
        self.transactions[numeric_cols] = self.transactions[numeric_cols].fillna(
            self.transactions[numeric_cols].median()
        )

        # Only apply outlier removal if columns exist
        for col in ['amount', 'gas_price', 'gas_limit']:
            if col in self.transactions.columns:
                Q1 = self.transactions[col].quantile(0.01)
                Q3 = self.transactions[col].quantile(0.99)
                self.transactions = self.transactions[
                    (self.transactions[col] >= Q1) & (self.transactions[col] <= Q3)
                ]

        print(f"✅ Data cleaned: {initial_count} → {len(self.transactions)} transactions")

    def _prepare_features(self):
        """
        Prepare features for the GNN model
        """
        print("🔧 Preparing features...")

        # Handle NaN in class column
        if self.transactions['class'].isna().any():
            self.transactions['class'] = self.transactions['class'].fillna('unknown')
            print(f"⚠️ Filled NaN class labels with 'unknown'.")

        # Convert class column to string for consistent mapping
        self.transactions['class'] = self.transactions['class'].astype(str)

        # Enhanced class mapping to handle various formats
        class_mapping = {
            'unknown': -1,
            'licit': 0,
            'illicit': 1,
            '1': 0,  # Assuming '1' means licit in Elliptic dataset
            '2': 1,  # Assuming '2' means illicit in Elliptic dataset
            '-1': -1,
            '-1.0': -1,
            'nan': -1
        }

        # Map class labels
        self.transactions['class'] = self.transactions['class'].map(class_mapping)

        # Drop rows where mapping failed (if any)
        unmapped = self.transactions['class'].isna()
        if unmapped.any():
            print(f"⚠️ Found {unmapped.sum()} rows with unmapped class labels. Dropping them.")
            print(f"Unmapped values: {self.transactions.loc[unmapped, 'class'].unique()}")
            self.transactions = self.transactions[~unmapped]

        # Select features (exclude non-numeric columns)
        feature_cols = [col for col in self.transactions.columns
                       if col not in ['txId', 'class', 'time_step'] and
                       self.transactions[col].dtype in [np.number, 'float64', 'int64']]

        if len(feature_cols) == 0:
            raise ValueError("No numeric feature columns found!")

        self.X = self.transactions[feature_cols].values
        self.y = self.transactions['class'].values

        # Filter out unknown class (-1) and store known_mask
        self.known_mask = (self.y != -1)
        if not self.known_mask.any():
            raise ValueError("❌ No known labels (0 or 1) found in the dataset.")

        print(f"📊 Before filtering: {len(self.y)} samples")
        print(f"📊 After filtering unknown: {self.known_mask.sum()} samples")

        self.X = self.X[self.known_mask]
        self.y = self.y[self.known_mask]

        # Handle infinite values
        self.X = np.nan_to_num(self.X, nan=0.0, posinf=0.0, neginf=0.0)

        # Normalize features
        self.X = self.scaler.fit_transform(self.X)

        # Convert to tensors
        self.X = torch.FloatTensor(self.X)
        self.y = torch.LongTensor(self.y)

        # Compute class weights for handling imbalanced data
        class_counts = Counter(self.y.numpy())
        total = len(self.y)
        self.class_weights = torch.FloatTensor([
            total / (2 * class_counts.get(0, 1)),
            total / (2 * class_counts.get(1, 1))
        ]).to(self.device)

        print(f"📊 Final class distribution: {dict(class_counts)}")
        print(f"⚖️ Class weights: {self.class_weights.cpu().numpy()}")
        print(f"📐 Feature matrix shape: {self.X.shape}")

    def _create_graph(self):
        """
        Create PyTorch Geometric graph data
        """
        print("🕸️  Creating graph structure...")

        # Get known txIds from the original txId column
        known_tx_ids = self.transactions.loc[self.known_mask, 'txId'].values
        tx_id_to_idx = {tx_id: idx for idx, tx_id in enumerate(known_tx_ids)}

        # Debug: Check the number of known txIds and their presence in edges
        print(f"Debug: Number of known txIds: {len(known_tx_ids)}")

        # Check if edges file has the expected columns
        if 'txId1' not in self.edges.columns or 'txId2' not in self.edges.columns:
            # Try alternative column names
            edge_cols = self.edges.columns.tolist()
            if len(edge_cols) >= 2:
                self.edges = self.edges.rename(columns={edge_cols[0]: 'txId1', edge_cols[1]: 'txId2'})
            else:
                raise ValueError("Edge file must have at least 2 columns for source and target nodes")

        edge_tx_ids = set(self.edges['txId1']).union(set(self.edges['txId2']))
        known_edge_overlap = len(set(known_tx_ids).intersection(edge_tx_ids))
        print(f"Debug: Number of known txIds overlapping with edges: {known_edge_overlap}")

        # Filter edges to include only those with known txIds
        valid_edges = []
        for tx1, tx2 in zip(self.edges['txId1'], self.edges['txId2']):
            if tx1 in tx_id_to_idx and tx2 in tx_id_to_idx:
                valid_edges.append((tx_id_to_idx[tx1], tx_id_to_idx[tx2]))

        if not valid_edges:
            print("⚠️ No valid edges found. Creating a simple sequential graph structure.")
            # Create a simple chain graph if no edges are available
            valid_edges = [(i, i+1) for i in range(len(known_tx_ids)-1)]

        # Convert to undirected graph by adding reverse edges
        undirected_edges = []
        for edge in valid_edges:
            undirected_edges.append(edge)
            undirected_edges.append((edge[1], edge[0]))  # Add reverse edge

        edge_index = torch.LongTensor(undirected_edges).t().contiguous()

        # Move tensors to device
        x = self.X.to(self.device)
        y = self.y.to(self.device)
        edge_index = edge_index.to(self.device)

        # Create train/validation/test masks with stratified split
        known_indices = np.arange(len(self.y))

        # Ensure we have enough samples for splitting
        if len(known_indices) < 10:
            raise ValueError("Not enough samples for train/validation/test split")

        train_idx, temp_idx = train_test_split(
            known_indices,
            test_size=0.4,
            stratify=self.y,
            random_state=42
        )
        val_idx, test_idx = train_test_split(
            temp_idx,
            test_size=0.5,
            stratify=self.y[temp_idx],
            random_state=42
        )

        # Create boolean masks
        train_mask = torch.zeros(len(self.y), dtype=torch.bool, device=self.device)
        val_mask = torch.zeros(len(self.y), dtype=torch.bool, device=self.device)
        test_mask = torch.zeros(len(self.y), dtype=torch.bool, device=self.device)

        train_mask[train_idx] = True
        val_mask[val_idx] = True
        test_mask[test_idx] = True

        # Create PyG Data object
        self.data = Data(
            x=x, y=y, edge_index=edge_index,
            train_mask=train_mask, val_mask=val_mask, test_mask=test_mask
        )

        print(f"✅ Graph created: {self.data.num_nodes} nodes, {self.data.num_edges} edges")
        print(f"   Train: {train_mask.sum()}, Val: {val_mask.sum()}, Test: {test_mask.sum()}")

    def build_model(self):
        """
        Build the GCN model
        """
        print("🏗️  Building GCN model...")

        self.model = BlockchainGCN(
            num_features=self.data.num_node_features,
            hidden_dim=self.hidden_dim,
            dropout=self.dropout
        ).to(self.device)

        self.optimizer = Adam(
            self.model.parameters(),
            lr=self.lr,
            weight_decay=self.weight_decay
        )

        total_params = sum(p.numel() for p in self.model.parameters())
        trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)

        print(f"✅ Model built:")
        print(f"   Total parameters: {total_params:,}")
        print(f"   Trainable parameters: {trainable_params:,}")

        return self

    def train(self, epochs=300, verbose=True):
        """
        Train the GCN model
        """
        print("🏋️‍♂️ Starting training...")

        self.model.train()
        criterion = torch.nn.CrossEntropyLoss(weight=self.class_weights)
        train_losses = []
        val_losses = []

        for epoch in range(epochs):
            # Training step
            self.model.train()
            self.optimizer.zero_grad()

            out = self.model(self.data.x, self.data.edge_index)
            loss = criterion(out[self.data.train_mask], self.data.y[self.data.train_mask])

            loss.backward()
            self.optimizer.step()
            train_losses.append(loss.item())

            # Validation step
            self.model.eval()
            with torch.no_grad():
                val_out = self.model(self.data.x, self.data.edge_index)
                val_loss = criterion(val_out[self.data.val_mask], self.data.y[self.data.val_mask])
                val_losses.append(val_loss.item())

            if verbose and (epoch + 1) % 50 == 0:
                print(f"Epoch {epoch + 1}/{epochs}, Train Loss: {train_losses[-1]:.4f}, Val Loss: {val_losses[-1]:.4f}")

        print("✅ Training completed!")

        # Save model
        torch.save(self.model.state_dict(), '/content/best_gnn_model.pth')
        print("💾 Model saved to '/content/best_gnn_model.pth'")

        return train_losses, val_losses

    def evaluate(self):
        """
        Comprehensive evaluation of the model
        """
        print("📊 Evaluating model...")

        self.model.eval()
        with torch.no_grad():
            # Get predictions
            out = self.model(self.data.x, self.data.edge_index)
            pred_probs = torch.exp(out)  # Convert log probabilities to probabilities
            pred_labels = out.argmax(dim=1)

            # Test set evaluation
            test_mask = self.data.test_mask
            y_true = self.data.y[test_mask].cpu().numpy()
            y_pred = pred_labels[test_mask].cpu().numpy()
            y_prob = pred_probs[test_mask, 1].cpu().numpy()  # Probability of illicit class

        # Calculate metrics
        metrics = {
            'accuracy': accuracy_score(y_true, y_pred),
            'precision': precision_score(y_true, y_pred, zero_division=0),
            'recall': recall_score(y_true, y_pred, zero_division=0),
            'f1': f1_score(y_true, y_pred, zero_division=0),
            'roc_auc': roc_auc_score(y_true, y_prob) if len(np.unique(y_true)) > 1 else np.nan,
            'pr_auc': average_precision_score(y_true, y_prob) if len(np.unique(y_true)) > 1 else np.nan
        }

        # Print results
        print("\n🎯 Model Performance:")
        print("=" * 40)
        for metric, value in metrics.items():
            if np.isnan(value):
                print(f"{metric.upper():>12}: NaN")
            else:
                print(f"{metric.upper():>12}: {value:.4f}")

        # Confusion matrix
        cm = confusion_matrix(y_true, y_pred)
        print(f"\n🔍 Confusion Matrix:")
        print("=" * 25)
        if cm.shape == (2, 2):
            print(f"True Negative:  {cm[0,0]:4d} | False Positive: {cm[0,1]:4d}")
            print(f"False Negative: {cm[1,0]:4d} | True Positive:  {cm[1,1]:4d}")
        else:
            print(f"Confusion Matrix:\n{cm}")

        # Save predictions
        self._save_predictions(y_true, y_pred, y_prob)

        return metrics, y_true, y_pred, y_prob

    def _save_predictions(self, y_true, y_pred, y_prob):
        """
        Save predictions to CSV
        """
        test_indices = torch.where(self.data.test_mask)[0].cpu().numpy()

        results_df = pd.DataFrame({
            'transaction_id': test_indices,
            'true_label': y_true,
            'predicted_label': y_pred,
            'illicit_probability': y_prob,
            'is_flagged': (y_prob > 0.5).astype(int),
            'confidence': np.maximum(y_prob, 1 - y_prob)
        })

        # Sort by illicit probability (most suspicious first)
        results_df = results_df.sort_values('illicit_probability', ascending=False)

        results_df.to_csv('/content/gnn_flagged_transactions.csv', index=False)
        print("💾 Predictions saved to '/content/gnn_flagged_transactions.csv'")

        # Show top 10 most suspicious transactions
        print("\n🚨 Top 10 Most Suspicious Transactions:")
        print(results_df.head(10)[['transaction_id', 'true_label', 'illicit_probability', 'is_flagged']])

    def visualize_results(self, y_true, y_pred, y_prob):
        """
        Create individual visualizations for model evaluation
        """
        print("📈 Creating visualizations...")

        # Check if we have valid data for visualization
        if len(np.unique(y_true)) < 2:
            print("⚠️ Cannot create ROC/PR curves: only one class present in test set")
            # Still create other plots that don't require both classes

        # 1. ROC Curve
        if len(np.unique(y_true)) > 1:
            plt.figure(figsize=(8, 6))
            try:
                fpr, tpr, _ = roc_curve(y_true, y_prob)
                roc_auc = roc_auc_score(y_true, y_prob)
                plt.plot(fpr, tpr, linewidth=2, label=f'ROC Curve (AUC = {roc_auc:.3f})')
                plt.plot([0, 1], [0, 1], 'k--', alpha=0.6, label='Random Classifier')
                plt.fill_between(fpr, tpr, alpha=0.2)
                plt.xlabel('False Positive Rate')
                plt.ylabel('True Positive Rate')
                plt.title('🎯 ROC Curve')
                plt.legend()
                plt.grid(True, alpha=0.3)
                plt.savefig('/content/roc_curve.png', dpi=300, bbox_inches='tight')
                plt.close()
            except Exception as e:
                print(f"Error plotting ROC Curve: {e}")
        else:
            print("Skipping ROC Curve due to single class in test set")

        # 2. Precision-Recall Curve
        if len(np.unique(y_true)) > 1:
            plt.figure(figsize=(8, 6))
            try:
                precision, recall, _ = precision_recall_curve(y_true, y_prob)
                pr_auc = average_precision_score(y_true, y_prob)
                plt.plot(recall, precision, linewidth=2, label=f'PR Curve (AUC = {pr_auc:.3f})')
                plt.fill_between(recall, precision, alpha=0.2)
                plt.xlabel('Recall')
                plt.ylabel('Precision')
                plt.title('📊 Precision-Recall Curve')
                plt.legend()
                plt.grid(True, alpha=0.3)
                plt.savefig('/content/precision_recall_curve.png', dpi=300, bbox_inches='tight')
                plt.close()
            except Exception as e:
                print(f"Error plotting Precision-Recall Curve: {e}")
        else:
            print("Skipping Precision-Recall Curve due to single class in test set")

        # 3. Confusion Matrix
        plt.figure(figsize=(8, 6))
        cm = confusion_matrix(y_true, y_pred)
        if cm.size > 0:
            cm_normalized = cm.astype('float') / (cm.sum(axis=1)[:, np.newaxis] + 1e-8)
            sns.heatmap(cm_normalized, annot=True, fmt='.2%', cmap='Blues',
                        xticklabels=['Licit', 'Illicit'], yticklabels=['Licit', 'Illicit'])
            plt.title('🔍 Confusion Matrix (Normalized)')
            plt.xlabel('Predicted')
            plt.ylabel('Actual')
            plt.savefig('/content/confusion_matrix.png', dpi=300, bbox_inches='tight')
            plt.close()

        # 4. Prediction Distribution
        plt.figure(figsize=(8, 6))
        if len(np.unique(y_true)) > 1:
            plt.hist(y_prob[y_true == 0], bins=30, alpha=0.7, label='Licit', density=True, color='skyblue')
            plt.hist(y_prob[y_true == 1], bins=30, alpha=0.7, label='Illicit', density=True, color='salmon')
        else:
            unique_class = np.unique(y_true)[0]
            plt.hist(y_prob, bins=30, alpha=0.7, label=f'Class {unique_class}', density=True, color='gray')
        plt.axvline(x=0.5, color='red', linestyle='--', alpha=0.8, label='Decision Threshold')
        plt.xlabel('Predicted Probability (Illicit)')
        plt.ylabel('Density')
        plt.title('📈 Prediction Probability Distribution')
        plt.legend()
        plt.grid(True, alpha=0.3)
        plt.savefig('/content/prediction_distribution.png', dpi=300, bbox_inches='tight')
        plt.close()

        # 5. Feature Importance
        plt.figure(figsize=(10, 6))
        try:
            # Extract weights from the first GCN layer as a proxy for feature importance
            self.model.eval()
            with torch.no_grad():
                weights = self.model.conv1.lin.weight.data.cpu().numpy()
                # Compute absolute mean weights across output features
                feature_importance = np.mean(np.abs(weights), axis=0)
                # Normalize to sum to 1 for relative importance
                feature_importance = feature_importance / np.sum(feature_importance)

            # Get feature names from the original dataset
            feature_cols = [col for col in self.transactions.columns
                          if col not in ['txId', 'class', 'time_step'] and
                          self.transactions[col].dtype in [np.number, 'float64', 'int64']]
            feature_names = feature_cols[:len(feature_importance)]  # Match length of importance scores

            # Create bar plot
            plt.bar(range(len(feature_importance)), feature_importance, color='teal')
            plt.xlabel('Feature Index')
            plt.ylabel('Relative Importance')
            plt.title('🎯 Feature Importance')
            plt.xticks(range(len(feature_importance)), feature_names, rotation=90, fontsize=8)
            plt.tight_layout()
            plt.savefig('/content/feature_importance.png', dpi=300, bbox_inches='tight')
            plt.close()
        except Exception as e:
            print(f"Error plotting Feature Importance: {e}")
            plt.text(0.5, 0.5, 'Feature Importance\n(Computation Failed)',
                    ha='center', va='center', fontsize=12)
            plt.title('🎯 Feature Importance')
            plt.axis('off')
            plt.savefig('/content/feature_importance.png', dpi=300, bbox_inches='tight')
            plt.close()

        # 6. Performance Metrics Bar Chart
        plt.figure(figsize=(10, 6))
        metrics_names = ['Accuracy', 'Precision', 'Recall', 'F1-Score']
        metrics_values = [
            accuracy_score(y_true, y_pred),
            precision_score(y_true, y_pred, zero_division=0),
            recall_score(y_true, y_pred, zero_division=0),
            f1_score(y_true, y_pred, zero_division=0)
        ]
        if len(np.unique(y_true)) > 1:
            metrics_names.extend(['ROC-AUC', 'PR-AUC'])
            metrics_values.extend([
                roc_auc_score(y_true, y_prob),
                average_precision_score(y_true, y_prob)
            ])
        colors = plt.cm.Set3(np.linspace(0, 1, len(metrics_names)))
        bars = plt.bar(metrics_names, metrics_values, color=colors)
        plt.ylim(0, 1)
        plt.title('📋 Performance Metrics Summary')
        plt.ylabel('Score')
        plt.xticks(rotation=45)
        plt.grid(True, alpha=0.3)
        for bar, value in zip(bars, metrics_values):
            if not np.isnan(value):
                plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
                        f'{value:.3f}', ha='center', va='bottom', fontweight='bold')
        plt.tight_layout()
        plt.savefig('/content/performance_metrics.png', dpi=300, bbox_inches='tight')
        plt.close()

        print("💾 Individual visualizations saved to '/content/':")
        print("   - roc_curve.png (if applicable)")
        print("   - precision_recall_curve.png (if applicable)")
        print("   - confusion_matrix.png")
        print("   - prediction_distribution.png")
        print("   - feature_importance.png")
        print("   - performance_metrics.png")

def main():
    """
    Main execution function
    """
    print("🚀 Starting Blockchain Anomaly Detection Pipeline")
    print("=" * 60)

    try:
        # Initialize detector
        detector = BlockchainAnomalyDetector(
            hidden_dim=128,
            dropout=0.3,
            lr=0.005,
            weight_decay=5e-4
        )

        # Load and prepare data
        detector.load_and_prepare_data()

        # Build model
        detector.build_model()

        # Train model
        train_losses, val_losses = detector.train(epochs=200)

        # Evaluate model
        metrics, y_true, y_pred, y_prob = detector.evaluate()

        # Create visualizations
        detector.visualize_results(y_true, y_pred, y_prob)

        print("\n🎉 Pipeline completed successfully!")
        print("\nFiles created:")
        print("- /content/best_gnn_model.pth")
        print("- /content/gnn_flagged_transactions.csv")
        print("- /content/gnn_evaluation_results.png")

    except Exception as e:
        print(f"❌ Error in pipeline: {str(e)}")
        import traceback
        traceback.print_exc()

if __name__ == "__main__":
    main()
    # Download files to local machine
    print("\n💾 Files available for download:")
    print("   - /content/best_gnn_model.pth (trained model)")
    print("   - /content/gnn_flagged_transactions.csv (predictions)")
    print("   - /content/roc_curve.png (ROC Curve, if applicable)")
    print("   - /content/precision_recall_curve.png (Precision-Recall Curve, if applicable)")
    print("   - /content/confusion_matrix.png (Confusion Matrix)")
    print("   - /content/prediction_distribution.png (Prediction Distribution)")
    print("   - /content/feature_importance.png (Feature Importance)")
    print("   - /content/performance_metrics.png (Performance Metrics)")

    # Uncomment to download files
    # files.download('/content/best_gnn_model.pth')
    # files.download('/content/gnn_flagged_transactions.csv')
    # files.download('/content/roc_curve.png')
    # files.download('/content/precision_recall_curve.png')
    # files.download('/content/confusion_matrix.png')
    # files.download('/content/prediction_distribution.png')
    # files.download('/content/feature_importance.png')
    # files.download('/content/performance_metrics.png')

🚀 Using device: cuda
   GPU: Tesla T4
🚀 Starting Blockchain Anomaly Detection Pipeline
🔧 Detector initialized on cuda
🔄 Loading and preparing data...
✅ Loaded 203769 transactions, 234355 edges, and 203769 class labels
Transactions columns: ['txId', 'feature_0', 'feature_1', 'feature_2', 'feature_3', 'feature_4', 'feature_5', 'feature_6', 'feature_7', 'feature_8', 'feature_9', 'feature_10', 'feature_11', 'feature_12', 'feature_13', 'feature_14', 'feature_15', 'feature_16', 'feature_17', 'feature_18', 'feature_19', 'feature_20', 'feature_21', 'feature_22', 'feature_23', 'feature_24', 'feature_25', 'feature_26', 'feature_27', 'feature_28', 'feature_29', 'feature_30', 'feature_31', 'feature_32', 'feature_33', 'feature_34', 'feature_35', 'feature_36', 'feature_37', 'feature_38', 'feature_39', 'feature_40', 'feature_41', 'feature_42', 'feature_43', 'feature_44', 'feature_45', 'feature_46', 'feature_47', 'feature_48', 'feature_49', 'feature_50', 'feature_51', 'feature_52', 'feature_53', 'feat