# GNN Cascade Failure Prediction - Production Scale

## Engineering: PyTorch Geometric GCN for Grid Resilience

This notebook implements a production-grade Graph Convolutional Network (GCN) for cascade failure prediction,
designed to run on Snowpark Container Services with GPU acceleration.

**Key Features:**
- 3-layer GCN architecture (10 → 64 → 64 → 32 → 1)
- BFS-based cascade ordering with wave depth labels
- Graph centrality features integration
- Snowflake ML model registry integration

**Based on:** Original GNN demo at `/Documents/gnn_resilient_energy_digital_twin/notebooks/grid_cascade_analysis.ipynb`

In [None]:
# Install dependencies (run once)
# !pip install torch torch-geometric snowflake-ml-python pandas numpy scikit-learn

In [None]:
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.data import Data
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score, precision_recall_curve, average_precision_score
from snowflake.snowpark import Session
from snowflake.ml.registry import Registry
import warnings
warnings.filterwarnings('ignore')

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## 1. Connect to Snowflake and Load Grid Data

In [None]:
# Snowflake connection
connection_name = os.getenv("SNOWFLAKE_CONNECTION_NAME", "cpe_demo_CLI")
session = Session.builder.config("connection_name", connection_name).create()

print(f"Connected to: {session.get_current_account()}")
print(f"Database: {session.get_current_database()}")
print(f"Warehouse: {session.get_current_warehouse()}")

In [None]:
# Load grid nodes with centrality features
nodes_query = """
SELECT 
    n.NODE_ID,
    n.NODE_TYPE,
    n.LAT,
    n.LON,
    n.CAPACITY_KW,
    n.VOLTAGE_KV,
    n.CRITICALITY_SCORE,
    n.DOWNSTREAM_TRANSFORMERS,
    n.DOWNSTREAM_CAPACITY_KVA,
    COALESCE(c.DEGREE_CENTRALITY, 0) AS DEGREE_CENTRALITY,
    COALESCE(c.NORMALIZED_DEGREE, 0) AS NORMALIZED_DEGREE,
    COALESCE(c.REACH_EXPANSION_RATIO, 0) AS REACH_EXPANSION_RATIO,
    COALESCE(c.NEIGHBORS_1HOP, 0) AS NEIGHBORS_1HOP,
    COALESCE(c.NEIGHBORS_2HOP, 0) AS NEIGHBORS_2HOP
FROM SI_DEMOS.ML_DEMO.GRID_NODES n
LEFT JOIN SI_DEMOS.CASCADE_ANALYSIS.NODE_CENTRALITY_FEATURES c ON n.NODE_ID = c.NODE_ID
WHERE n.LAT IS NOT NULL AND n.LON IS NOT NULL
"""

nodes_df = session.sql(nodes_query).to_pandas()
print(f"Loaded {len(nodes_df)} nodes")
nodes_df.head()

In [None]:
# Load grid edges
edges_query = """
SELECT 
    EDGE_ID,
    FROM_NODE,
    TO_NODE,
    EDGE_TYPE,
    DISTANCE_KM,
    IMPEDANCE_PU
FROM SI_DEMOS.ML_DEMO.GRID_EDGES
"""

edges_df = session.sql(edges_query).to_pandas()
print(f"Loaded {len(edges_df)} edges")
edges_df.head()

## 2. Build PyTorch Geometric Graph

In [None]:
# Create node ID to index mapping
node_id_to_idx = {node_id: idx for idx, node_id in enumerate(nodes_df['NODE_ID'])}
print(f"Node ID mapping created: {len(node_id_to_idx)} nodes")

# Filter edges to only include nodes in our dataset
valid_edges = edges_df[
    edges_df['FROM_NODE'].isin(node_id_to_idx) & 
    edges_df['TO_NODE'].isin(node_id_to_idx)
]
print(f"Valid edges: {len(valid_edges)} / {len(edges_df)}")

In [None]:
# Build edge index tensor (COO format)
source_nodes = [node_id_to_idx[n] for n in valid_edges['FROM_NODE']]
target_nodes = [node_id_to_idx[n] for n in valid_edges['TO_NODE']]

# Add reverse edges for undirected graph
edge_index = torch.tensor(
    [source_nodes + target_nodes, target_nodes + source_nodes],
    dtype=torch.long
)
print(f"Edge index shape: {edge_index.shape}")

In [None]:
# Build node feature matrix (10 features)
feature_columns = [
    'CAPACITY_KW', 'VOLTAGE_KV', 'CRITICALITY_SCORE',
    'DOWNSTREAM_TRANSFORMERS', 'DOWNSTREAM_CAPACITY_KVA',
    'DEGREE_CENTRALITY', 'NORMALIZED_DEGREE', 'REACH_EXPANSION_RATIO',
    'NEIGHBORS_1HOP', 'NEIGHBORS_2HOP'
]

# Fill NaN with 0 and normalize features
features = nodes_df[feature_columns].fillna(0).values
features_normalized = (features - features.mean(axis=0)) / (features.std(axis=0) + 1e-8)

x = torch.tensor(features_normalized, dtype=torch.float32)
print(f"Node features shape: {x.shape}")

In [None]:
# Create cascade failure labels using BFS from high-criticality nodes
from collections import deque

def simulate_cascade(start_idx, adjacency, max_depth=5):
    """BFS cascade simulation returning wave depths"""
    wave_depths = {start_idx: 0}
    queue = deque([start_idx])
    
    while queue:
        current = queue.popleft()
        current_depth = wave_depths[current]
        
        if current_depth >= max_depth:
            continue
            
        if current in adjacency:
            for neighbor in adjacency[current]:
                if neighbor not in wave_depths:
                    wave_depths[neighbor] = current_depth + 1
                    queue.append(neighbor)
    
    return wave_depths

# Build adjacency list
adjacency = {}
for i in range(edge_index.shape[1]):
    src, dst = edge_index[0, i].item(), edge_index[1, i].item()
    if src not in adjacency:
        adjacency[src] = []
    adjacency[src].append(dst)

print(f"Adjacency list built for {len(adjacency)} nodes")

In [None]:
# Generate cascade labels from top 10 highest criticality nodes
high_crit_indices = nodes_df['CRITICALITY_SCORE'].fillna(0).nlargest(10).index.tolist()
print(f"High criticality Patient Zero candidates: {high_crit_indices[:5]}...")

# Aggregate cascade labels from multiple starting points
cascade_labels = np.zeros(len(nodes_df))
for start_idx in high_crit_indices:
    wave_depths = simulate_cascade(start_idx, adjacency, max_depth=3)
    for node_idx, depth in wave_depths.items():
        # Label as 1 if affected in any cascade simulation
        cascade_labels[node_idx] = 1

y = torch.tensor(cascade_labels, dtype=torch.float32)
print(f"Cascade labels: {y.sum().item():.0f} affected / {len(y)} total ({100*y.mean().item():.1f}%)")

In [None]:
# Create PyTorch Geometric Data object
data = Data(x=x, edge_index=edge_index, y=y)
print(f"Graph Data:")
print(f"  - Nodes: {data.num_nodes}")
print(f"  - Edges: {data.num_edges}")
print(f"  - Node features: {data.num_node_features}")
print(f"  - Is directed: {data.is_directed()}")

## 3. Define GCN Model Architecture

In [None]:
class CascadeGCN(nn.Module):
    """
    3-layer Graph Convolutional Network for Cascade Failure Prediction
    
    Architecture: 10 → 64 → 64 → 32 → 1
    - Based on original GNN demo
    - BatchNorm for training stability
    - Dropout for regularization
    """
    
    def __init__(self, in_features=10, hidden_dim=64, dropout=0.3):
        super(CascadeGCN, self).__init__()
        
        # GCN layers
        self.conv1 = GCNConv(in_features, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.conv3 = GCNConv(hidden_dim, 32)
        
        # Batch normalization
        self.bn1 = nn.BatchNorm1d(hidden_dim)
        self.bn2 = nn.BatchNorm1d(hidden_dim)
        self.bn3 = nn.BatchNorm1d(32)
        
        # Output layer
        self.fc = nn.Linear(32, 1)
        
        # Dropout
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, edge_index):
        # Layer 1: 10 → 64
        x = self.conv1(x, edge_index)
        x = self.bn1(x)
        x = F.relu(x)
        x = self.dropout(x)
        
        # Layer 2: 64 → 64
        x = self.conv2(x, edge_index)
        x = self.bn2(x)
        x = F.relu(x)
        x = self.dropout(x)
        
        # Layer 3: 64 → 32
        x = self.conv3(x, edge_index)
        x = self.bn3(x)
        x = F.relu(x)
        
        # Output: 32 → 1 (cascade probability)
        x = self.fc(x)
        return torch.sigmoid(x).squeeze(-1)
    
    def predict_cascade_risk(self, x, edge_index):
        """Get cascade risk scores for all nodes"""
        self.eval()
        with torch.no_grad():
            return self.forward(x, edge_index)

# Initialize model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = CascadeGCN(in_features=10, hidden_dim=64, dropout=0.3).to(device)
print(f"Model on device: {device}")
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")

## 4. Train the GCN Model

In [None]:
# Train/val/test split (node-level)
num_nodes = data.num_nodes
indices = np.arange(num_nodes)

train_idx, test_idx = train_test_split(indices, test_size=0.2, random_state=42)
train_idx, val_idx = train_test_split(train_idx, test_size=0.15, random_state=42)

train_mask = torch.zeros(num_nodes, dtype=torch.bool)
val_mask = torch.zeros(num_nodes, dtype=torch.bool)
test_mask = torch.zeros(num_nodes, dtype=torch.bool)

train_mask[train_idx] = True
val_mask[val_idx] = True
test_mask[test_idx] = True

print(f"Train: {train_mask.sum().item()} | Val: {val_mask.sum().item()} | Test: {test_mask.sum().item()}")

In [None]:
# Training configuration
data = data.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=10, factor=0.5)

# Class imbalance handling
pos_weight = (1 - y.mean()) / y.mean()
criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([pos_weight]).to(device))

print(f"Positive class weight: {pos_weight:.2f}")

In [None]:
# Training loop
epochs = 200
best_val_loss = float('inf')
patience = 30
patience_counter = 0

train_losses = []
val_losses = []

for epoch in range(epochs):
    # Training
    model.train()
    optimizer.zero_grad()
    
    out = model(data.x, data.edge_index)
    train_loss = F.binary_cross_entropy(out[train_mask.to(device)], data.y[train_mask.to(device)])
    
    train_loss.backward()
    optimizer.step()
    
    # Validation
    model.eval()
    with torch.no_grad():
        out = model(data.x, data.edge_index)
        val_loss = F.binary_cross_entropy(out[val_mask.to(device)], data.y[val_mask.to(device)])
    
    train_losses.append(train_loss.item())
    val_losses.append(val_loss.item())
    
    scheduler.step(val_loss)
    
    # Early stopping
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        patience_counter = 0
        # Save best model
        torch.save(model.state_dict(), 'best_cascade_gcn.pt')
    else:
        patience_counter += 1
    
    if patience_counter >= patience:
        print(f"Early stopping at epoch {epoch}")
        break
    
    if epoch % 20 == 0:
        print(f"Epoch {epoch:3d} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")

print(f"\nTraining complete. Best validation loss: {best_val_loss:.4f}")

## 5. Evaluate Model Performance

In [None]:
# Load best model and evaluate on test set
model.load_state_dict(torch.load('best_cascade_gcn.pt'))
model.eval()

with torch.no_grad():
    predictions = model(data.x, data.edge_index)
    
# Test metrics
y_test = data.y[test_mask.to(device)].cpu().numpy()
y_pred = predictions[test_mask.to(device)].cpu().numpy()

auc_score = roc_auc_score(y_test, y_pred)
ap_score = average_precision_score(y_test, y_pred)

print(f"Test Set Performance:")
print(f"  - AUC-ROC: {auc_score:.4f}")
print(f"  - Average Precision: {ap_score:.4f}")

In [None]:
# Identify top Patient Zero candidates
cascade_risk_scores = predictions.cpu().numpy()
nodes_df['GNN_CASCADE_RISK'] = cascade_risk_scores

top_risk_nodes = nodes_df.nlargest(20, 'GNN_CASCADE_RISK')[[
    'NODE_ID', 'NODE_TYPE', 'CRITICALITY_SCORE', 'GNN_CASCADE_RISK',
    'DEGREE_CENTRALITY', 'DOWNSTREAM_TRANSFORMERS'
]]

print("\nTop 20 Patient Zero Candidates (GNN-based):")
top_risk_nodes

## 6. Register Model in Snowflake ML Registry

In [None]:
# Create a wrapper class for Snowflake ML Registry
class CascadeGCNPredictor:
    """Wrapper for Snowflake ML Registry deployment"""
    
    def __init__(self, model_state_dict, node_id_mapping, feature_columns):
        self.model = CascadeGCN(in_features=10, hidden_dim=64)
        self.model.load_state_dict(model_state_dict)
        self.model.eval()
        self.node_id_mapping = node_id_mapping
        self.feature_columns = feature_columns
        
    def predict(self, df):
        """Predict cascade risk for nodes in dataframe"""
        # This would be adapted for batch inference in SPCS
        pass

# Save model artifacts
model_artifacts = {
    'state_dict': model.state_dict(),
    'node_id_mapping': node_id_to_idx,
    'feature_columns': feature_columns,
    'architecture': {
        'in_features': 10,
        'hidden_dim': 64,
        'dropout': 0.3
    },
    'metrics': {
        'auc_roc': auc_score,
        'average_precision': ap_score
    }
}

torch.save(model_artifacts, 'cascade_gcn_full.pt')
print("Model artifacts saved to cascade_gcn_full.pt")

In [None]:
# Register in Snowflake ML Registry
try:
    registry = Registry(session=session, database_name="SI_DEMOS", schema_name="ML_DEMO")
    
    # Log model to registry
    mv = registry.log_model(
        model_name="CASCADE_GCN_MODEL",
        version_name="v1_gcn_3layer",
        model=model,
        conda_dependencies=["pytorch", "torch_geometric"],
        comment="3-layer GCN for cascade failure prediction. AUC-ROC: {:.4f}".format(auc_score),
        metrics={
            "auc_roc": auc_score,
            "average_precision": ap_score,
            "num_nodes": len(nodes_df),
            "num_edges": len(valid_edges),
            "num_features": 10
        }
    )
    print(f"Model registered: {mv.model_name}.{mv.version_name}")
except Exception as e:
    print(f"Note: ML Registry registration skipped: {e}")
    print("Model saved locally as cascade_gcn_full.pt")

## 7. Export Predictions to Snowflake

In [None]:
# Create predictions table
predictions_df = nodes_df[['NODE_ID', 'NODE_TYPE', 'GNN_CASCADE_RISK', 'CRITICALITY_SCORE']].copy()
predictions_df['PREDICTION_TIMESTAMP'] = pd.Timestamp.now()

# Write to Snowflake
try:
    snowpark_df = session.create_dataframe(predictions_df)
    snowpark_df.write.mode("overwrite").save_as_table("SI_DEMOS.CASCADE_ANALYSIS.GNN_PREDICTIONS")
    print(f"Predictions written to SI_DEMOS.CASCADE_ANALYSIS.GNN_PREDICTIONS")
except Exception as e:
    print(f"Note: Could not write to Snowflake: {e}")
    predictions_df.to_csv('gnn_predictions.csv', index=False)
    print("Predictions saved to gnn_predictions.csv")

In [None]:
# Summary
print("="*60)
print("GNN CASCADE FAILURE PREDICTION - TRAINING COMPLETE")
print("="*60)
print(f"\nModel Architecture: 10 → 64 → 64 → 32 → 1 (GCN)")
print(f"Training Nodes: {train_mask.sum().item():,}")
print(f"Test AUC-ROC: {auc_score:.4f}")
print(f"Test Average Precision: {ap_score:.4f}")
print(f"\nTop Patient Zero Candidates:")
for i, row in top_risk_nodes.head(5).iterrows():
    print(f"  - {row['NODE_ID']}: Risk={row['GNN_CASCADE_RISK']:.3f}, Type={row['NODE_TYPE']}")
print("\nArtifacts:")
print("  - Model: cascade_gcn_full.pt")
print("  - Predictions: SI_DEMOS.CASCADE_ANALYSIS.GNN_PREDICTIONS")

In [None]:
# Cleanup
session.close()