# GridGuard: Cascade Failure Analysis with Graph Neural Networks

This notebook trains a Graph Convolutional Network (GCN) to identify cascade failure patterns in power grid topology and runs batch inference across all simulation scenarios.

**Objectives:**
1. Load grid topology and telemetry data from Snowflake
2. Build PyTorch Geometric graph structures
3. Train a GCN model on historical cascade patterns
4. Run batch inference for all scenarios (BASE_CASE, HIGH_LOAD, WINTER_STORM_2021)
5. Identify "Patient Zero" - the cascade origin node
6. Write results to SIMULATION_RESULTS table


In [None]:
# install_packages: Install PyTorch Geometric and dependencies
import sys
import os

# Use os.system for pip installs (subprocess doesn't work in SPCS headless mode)
# torch: Deep learning framework (provides tensors, autograd, neural network modules)
# torch-geometric: PyTorch extension for graph neural networks
# networkx: Graph analysis library (for BFS, shortest paths, etc.)
# scikit-learn: For model evaluation metrics (confusion matrix, precision, recall)
# matplotlib: For visualization of training curves and results
os.system(f"{sys.executable} -m pip install torch torch-geometric networkx scikit-learn matplotlib -q")

print("Package installation complete")


In [None]:
# import_libraries: Import all required libraries
import numpy as np
import pandas as pd
import networkx as nx
from datetime import datetime
import uuid
import warnings
warnings.filterwarnings('ignore')

# PyTorch and PyG
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv, global_mean_pool

# Snowflake session
from snowflake.snowpark.context import get_active_session

# Set random seeds for reproducibility
RANDOM_SEED = 42
torch.manual_seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")


In [None]:
# snowflake_session_setup: Get and verify Snowflake session
session = get_active_session()

# Verify connection
result = session.sql("SELECT CURRENT_DATABASE(), CURRENT_SCHEMA(), CURRENT_USER()").collect()
print(f"Database: {result[0][0]}")
print(f"Schema: {result[0][1]}")
print(f"User: {result[0][2]}")


In [None]:
# load_grid_topology: Load grid nodes and edges from Snowflake

# Load grid nodes
nodes_df = session.sql("""
    SELECT NODE_ID, NODE_NAME, NODE_TYPE, LAT, LON, REGION, 
           CAPACITY_MW, VOLTAGE_KV, CRITICALITY_SCORE
    FROM GRID_NODES
""").to_pandas()

print(f"Loaded {len(nodes_df)} grid nodes")
print(f"Node types: {nodes_df['NODE_TYPE'].value_counts().to_dict()}")
print(f"Regions: {nodes_df['REGION'].unique().tolist()}")

# Load grid edges
edges_df = session.sql("""
    SELECT EDGE_ID, SRC_NODE, DST_NODE, EDGE_TYPE, 
           CAPACITY_MW, LENGTH_MILES, VOLTAGE_KV, REDUNDANCY_LEVEL
    FROM GRID_EDGES
""").to_pandas()

print(f"Loaded {len(edges_df)} grid edges")
print(f"Edge types: {edges_df['EDGE_TYPE'].value_counts().to_dict()}")


In [None]:
# load_telemetry_data: Load historical telemetry for all scenarios
telemetry_df = session.sql("""
    SELECT TELEMETRY_ID, TIMESTAMP, NODE_ID, SCENARIO_NAME,
           VOLTAGE_KV, LOAD_MW, FREQUENCY_HZ, TEMPERATURE_F, 
           STATUS, ALERT_CODE
    FROM HISTORICAL_TELEMETRY
    ORDER BY SCENARIO_NAME, TIMESTAMP, NODE_ID
""").to_pandas()

print(f"Loaded {len(telemetry_df)} telemetry records")
print(f"Scenarios: {telemetry_df['SCENARIO_NAME'].unique().tolist()}")
print(f"Status distribution:")
print(telemetry_df.groupby(['SCENARIO_NAME', 'STATUS']).size().unstack(fill_value=0))


In [None]:
# create_node_mappings: Create node ID to index mapping for PyG
node_ids = nodes_df['NODE_ID'].tolist()
node_to_idx = {node_id: idx for idx, node_id in enumerate(node_ids)}
idx_to_node = {idx: node_id for node_id, idx in node_to_idx.items()}

num_nodes = len(node_ids)
print(f"Number of nodes: {num_nodes}")
print(f"Sample mapping: {list(node_to_idx.items())[:5]}")


## Building the Graph Structure: Edge Index

PyTorch Geometric (PyG) represents graphs using a **COO (Coordinate) sparse format** for edges, stored in a tensor called `edge_index`.

### Edge Index Format
- Shape: `[2, num_edges]` where:
  - Row 0 contains **source node indices**
  - Row 1 contains **destination node indices**
- Example: `[[0, 1, 2], [1, 2, 0]]` means edges: 0→1, 1→2, 2→0

### Why Duplicate Edges for Undirected Graphs?
Power grids are **undirected** (electricity flows both ways on transmission lines). PyG's GCNConv expects directed edges, so we add both directions:
- Original edge: Substation A → Substation B
- We add: A→B **and** B→A

This ensures message passing flows symmetrically during graph convolution.

### The `.t().contiguous()` Pattern
```python
edge_index = torch.tensor(edge_list).t().contiguous()
```
- `.t()` — Transpose from `[num_edges, 2]` to `[2, num_edges]` (PyG's expected format)
- `.contiguous()` — Ensures memory is laid out sequentially (required for efficient GPU operations)


In [None]:
# build_edge_index: Build edge index for PyTorch Geometric
edge_list = []
for _, row in edges_df.iterrows():
    src_idx = node_to_idx.get(row['SRC_NODE'])
    dst_idx = node_to_idx.get(row['DST_NODE'])
    if src_idx is not None and dst_idx is not None:
        # Add both directions for undirected graph
        edge_list.append([src_idx, dst_idx])
        edge_list.append([dst_idx, src_idx])

edge_index = torch.tensor(edge_list, dtype=torch.long).t().contiguous()
print(f"Edge index shape: {edge_index.shape}")
print(f"Number of edges (directed): {edge_index.shape[1]}")


In [None]:
# define_feature_function: Define function to create node feature matrix
def create_node_features(nodes_df, telemetry_snapshot):
    """
    Create node feature matrix from topology and telemetry data.
    
    Returns a tensor of shape [num_nodes, 10] where each row contains:
    
    CONTINUOUS FEATURES (indices 0-5):
    - [0] Normalized capacity: max generation/load capacity, indicates node importance
    - [1] Normalized voltage: operating voltage level, higher = more critical infrastructure
    - [2] Criticality score: pre-computed importance metric (0-1)
    - [3] Load ratio: current_load / capacity — values > 1.0 indicate overload stress
    - [4] Temperature: ambient conditions affect equipment failure rates
    - [5] Status encoding: ordinal encoding of operational state
    
    CATEGORICAL FEATURES (indices 6-9):
    - [6-9] Node type one-hot: SUBSTATION, GENERATOR, LOAD_CENTER, TRANSMISSION_HUB
    
    DESIGN DECISIONS:
    - Features are normalized to roughly [0, 1] range to help gradient-based optimization
    - Status uses ordinal encoding (ACTIVE=0 < WARNING=0.5 < OFFLINE=0.8 < FAILED=1.0)
      instead of one-hot because there's a natural severity ordering
    - Load ratio is capped at 2.0 to prevent extreme outliers from dominating
    """
    features = []
    
    # Status mapping: ordinal encoding preserves severity ordering
    # This is preferable to one-hot when categories have natural order
    status_map = {'ACTIVE': 0.0, 'WARNING': 0.5, 'FAILED': 1.0, 'OFFLINE': 0.8}
    
    # Node type one-hot: no natural ordering, so use one-hot encoding
    # This prevents the model from learning spurious ordinal relationships
    type_map = {'SUBSTATION': [1,0,0,0], 'GENERATOR': [0,1,0,0], 
                'LOAD_CENTER': [0,0,1,0], 'TRANSMISSION_HUB': [0,0,0,1]}
    
    for _, node in nodes_df.iterrows():
        node_id = node['NODE_ID']
        
        # Get latest telemetry for this node (temporal snapshot)
        node_telemetry = telemetry_snapshot[telemetry_snapshot['NODE_ID'] == node_id]
        
        if len(node_telemetry) > 0:
            latest = node_telemetry.iloc[-1]
            load_mw = latest['LOAD_MW'] if pd.notna(latest['LOAD_MW']) else 0
            temp_f = latest['TEMPERATURE_F'] if pd.notna(latest['TEMPERATURE_F']) else 70
            status = status_map.get(latest['STATUS'], 0.0)
        else:
            # Default values for nodes without telemetry (use domain-reasonable defaults)
            load_mw = 0
            temp_f = 70  # Room temperature baseline
            status = 0.0  # Assume active if no data
        
        # Static topology features (from GRID_NODES table)
        capacity = node['CAPACITY_MW'] if pd.notna(node['CAPACITY_MW']) else 500
        voltage = node['VOLTAGE_KV'] if pd.notna(node['VOLTAGE_KV']) else 138
        criticality = node['CRITICALITY_SCORE'] if pd.notna(node['CRITICALITY_SCORE']) else 0.5
        
        # Assemble feature vector with normalization
        base_features = [
            capacity / 2000,  # Normalize by typical max capacity (2000 MW)
            voltage / 500,    # Normalize by max transmission voltage (500 kV)
            criticality,      # Already in [0, 1] range
            min(load_mw / max(capacity, 1), 2.0),  # Load ratio; cap at 2x to limit outliers
            (temp_f - 32) / 100,  # Convert to ~[0, 1] range (32°F=0, 132°F=1)
            status            # Already in [0, 1] from ordinal mapping
        ]
        
        # Append one-hot node type vector
        type_vec = type_map.get(node['NODE_TYPE'], [0,0,0,0])
        
        features.append(base_features + type_vec)
    
    return torch.tensor(features, dtype=torch.float)

print("Feature creation function defined")
print("Features: capacity, voltage, criticality, load_ratio, temperature, status, type_onehot(4)")
print("Total features per node: 10")


## Graph Convolutional Networks (GCNs): Core Concepts

### Why Use GCNs for Power Grid Analysis?

Traditional machine learning treats each node independently, ignoring **network topology**. This is problematic for cascade failure prediction because:
- A node's failure risk depends heavily on its **neighbors' states**
- Failures propagate through **physical connections** (transmission lines)
- Network structure determines which nodes are "upstream" or "downstream" of stress points

GCNs solve this by learning representations that incorporate both **node features** and **graph structure**.

### How GCNs Work: Message Passing

Each GCN layer performs **neighborhood aggregation**:

1. **Gather**: Each node collects feature vectors from its neighbors
2. **Aggregate**: Combine neighbor features (typically via mean/sum)
3. **Transform**: Apply a learned linear transformation + nonlinearity

Mathematically, the GCN layer update rule is:

$$H^{(l+1)} = \sigma\left( \tilde{D}^{-1/2} \tilde{A} \tilde{D}^{-1/2} H^{(l)} W^{(l)} \right)$$

Where:
- $\tilde{A} = A + I$ — Adjacency matrix with self-loops (so nodes also consider their own features)
- $\tilde{D}$ — Degree matrix of $\tilde{A}$ (diagonal matrix of node degrees)
- $H^{(l)}$ — Node features at layer $l$
- $W^{(l)}$ — Learnable weight matrix
- $\sigma$ — Nonlinear activation (ReLU)

The $\tilde{D}^{-1/2} \tilde{A} \tilde{D}^{-1/2}$ term is called **symmetric normalization** — it prevents nodes with many neighbors from dominating.

### Receptive Field and Layer Depth

Each GCN layer lets a node "see" one hop further:
- **1 layer**: Node sees immediate neighbors
- **2 layers**: Node sees 2-hop neighborhood
- **3 layers**: Node sees 3-hop neighborhood (used here)

**Warning**: Too many layers cause **over-smoothing** — all node representations converge to the same vector. 3 layers is a common sweet spot for medium-sized graphs.


In [None]:
# define_gcn_model: Define the Graph Convolutional Network model
class CascadeGCN(nn.Module):
    """
    Graph Convolutional Network for cascade failure prediction.
    
    ARCHITECTURE OVERVIEW:
    ┌─────────────┐     ┌─────────────┐     ┌─────────────┐     ┌─────────┐
    │ GCNConv(64) │ --> │ GCNConv(64) │ --> │ GCNConv(32) │ --> │ FC(1)   │
    │ + ReLU      │     │ + ReLU      │     │ + ReLU      │     │ +Sigmoid│
    │ + Dropout   │     │ + Dropout   │     │             │     │         │
    └─────────────┘     └─────────────┘     └─────────────┘     └─────────┘
    
    Input: [num_nodes, 10] node features
    Output: [num_nodes] failure probabilities in (0, 1)
    
    WHY 3 LAYERS?
    - Each GCN layer expands the receptive field by 1 hop
    - 3 layers = each node aggregates info from 3-hop neighborhood
    - In power grids, cascade effects typically propagate through nearby substations
    - More layers risk "over-smoothing" where all nodes converge to similar embeddings
    
    WHY THESE DIMENSIONS?
    - 10 → 64: Expand to learn richer representations
    - 64 → 64: Maintain capacity for complex patterns
    - 64 → 32: Compress before final prediction (bottleneck forces abstraction)
    - 32 → 1: Binary classification output
    
    REGULARIZATION:
    - Dropout (30%): Randomly zeros neurons during training to prevent overfitting
    - Applied after layers 1 and 2, but NOT after layer 3 (preserve final features)
    - Weight decay in optimizer provides L2 regularization on weights
    
    OUTPUT ACTIVATION:
    - Sigmoid squashes output to (0, 1) for probability interpretation
    - Enables direct interpretation as P(failure | node features, graph structure)
    """
    
    def __init__(self, num_features, hidden_dim=64, dropout=0.3):
        super(CascadeGCN, self).__init__()
        
        # Layer 1: Expand from input features to hidden dimension
        self.conv1 = GCNConv(num_features, hidden_dim)
        # Layer 2: Maintain hidden dimension for further message passing
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        # Layer 3: Compress to bottleneck dimension
        self.conv3 = GCNConv(hidden_dim, hidden_dim // 2)
        
        # Dropout for regularization (prevents co-adaptation of neurons)
        self.dropout = nn.Dropout(dropout)
        
        # Output layer: project from embedding to scalar prediction
        self.fc = nn.Linear(hidden_dim // 2, 1)
        
    def forward(self, x, edge_index):
        # Layer 1: Transform + aggregate 1-hop neighbors
        x = self.conv1(x, edge_index)
        x = F.relu(x)  # Non-linearity enables learning complex patterns
        x = self.dropout(x)  # Regularization during training
        
        # Layer 2: Aggregate 2-hop neighborhood
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = self.dropout(x)
        
        # Layer 3: Final aggregation (3-hop), no dropout to preserve learned features
        x = self.conv3(x, edge_index)
        x = F.relu(x)
        
        # Project to failure probability
        # sigmoid(x) maps any real number to (0, 1)
        out = torch.sigmoid(self.fc(x))
        return out.squeeze(-1)  # Remove last dimension: [N, 1] -> [N]

# Initialize model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = CascadeGCN(num_features=10, hidden_dim=64, dropout=0.3).to(device)

print(f"Model initialized on device: {device}")
print(model)


## Training Data Preparation: Temporal Split Strategy

### Why Temporal Splitting Matters

For time-series prediction problems, **random train/test splits are invalid** because they allow the model to "see the future." Instead, we use a **temporal split**:

```
Timeline:  ──────────────────────────────────────────────────────────────▶
           │  TRAINING WINDOW (hours 0-5)  │  PREDICTION TARGET (hours 6+)  │
           │  "What conditions existed?"   │  "Which nodes failed?"          │
```

### Our Approach

1. **Features (X)**: Extract node conditions from the **early period** (hours 0-5)
   - Load ratios, temperatures, status flags from WINTER_STORM_2021 telemetry
   
2. **Labels (y)**: Identify nodes that **eventually failed** in the late period (hours 6+)
   - Binary: 1 = node failed at some point, 0 = node never failed

This simulates the real-world task: "Given current grid conditions, predict which nodes will fail in the coming hours."

### Why WINTER_STORM_2021 for Training?

- This scenario contains the most **cascade failure events** (stress conditions)
- The model learns patterns that precede failures under extreme conditions
- We then apply the trained model to other scenarios (BASE_CASE, HIGH_LOAD) during inference


In [None]:
# prepare_training_data: Prepare training data from WINTER_STORM_2021 scenario
train_scenario = 'WINTER_STORM_2021'
train_telemetry = telemetry_df[telemetry_df['SCENARIO_NAME'] == train_scenario].copy()

# Get unique timestamps
timestamps = train_telemetry['TIMESTAMP'].unique()
print(f"Training scenario: {train_scenario}")
print(f"Number of time steps: {len(timestamps)}")

# Create training samples (features from early time, labels from later time)
training_data = []

# Use data from hours 0-5 to predict failures at hours 6+
early_timestamps = sorted(timestamps)[:24]  # First 6 hours (4 per hour)
late_timestamps = sorted(timestamps)[24:]   # After hour 6

# Features from early period
early_telemetry = train_telemetry[train_telemetry['TIMESTAMP'].isin(early_timestamps)]
X = create_node_features(nodes_df, early_telemetry)

# Labels from late period (did node fail?)
late_telemetry = train_telemetry[train_telemetry['TIMESTAMP'].isin(late_timestamps)]
failure_nodes = late_telemetry[late_telemetry['STATUS'] == 'FAILED']['NODE_ID'].unique()

y = torch.zeros(num_nodes, dtype=torch.float)
for node_id in failure_nodes:
    if node_id in node_to_idx:
        y[node_to_idx[node_id]] = 1.0

print(f"Feature matrix shape: {X.shape}")
print(f"Number of failure nodes in labels: {int(y.sum())}")
print(f"Failure nodes: {list(failure_nodes)[:5]}...")


In [None]:
# train_model: Train the GCN model
# ============================================================================
# HYPERPARAMETERS
# ============================================================================
num_epochs = 100      # Number of full passes through the training data
learning_rate = 0.01  # Step size for gradient descent (0.01 is typical for Adam)
weight_decay = 5e-4   # L2 regularization coefficient (prevents large weights)

# Move data to device (GPU if available, otherwise CPU)
X = X.to(device)
y = y.to(device)
edge_index = edge_index.to(device)

# ============================================================================
# OPTIMIZER SETUP
# ============================================================================
# Adam optimizer combines:
# - Momentum: accelerates convergence by accumulating gradient direction
# - RMSprop: adapts learning rate per-parameter based on gradient magnitude
# weight_decay adds L2 penalty: loss += weight_decay * sum(params^2)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

# ============================================================================
# HANDLING CLASS IMBALANCE
# ============================================================================
# In cascade failures, most nodes survive (y=0) and few fail (y=1).
# This creates class imbalance: model could achieve high accuracy by always
# predicting "no failure" while missing all actual failures.
#
# Solution: Weight the positive class more heavily in the loss function.
# pos_weight = (# negative samples) / (# positive samples)
# Example: 95 surviving nodes, 5 failing → pos_weight = 19
#          Each failure contributes 19x more to the loss gradient
pos_weight = torch.tensor([(num_nodes - y.sum()) / max(y.sum(), 1)]).to(device)
# NOTE: BCEWithLogitsLoss expects raw logits (pre-sigmoid), but our model
# outputs probabilities (post-sigmoid). We use F.binary_cross_entropy instead.
# The pos_weight computed here is informational; consider using it in a
# weighted BCE if the model struggles with recall on failure cases.
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

# ============================================================================
# TRAINING LOOP
# ============================================================================
model.train()  # Enable dropout and batch norm (if present)
losses = []
accuracies = []

print(f"Training for {num_epochs} epochs...")
print(f"Class imbalance ratio (pos_weight): {pos_weight.item():.2f}")
print(f"Failure nodes: {int(y.sum().item())} / {num_nodes} ({100*y.sum().item()/num_nodes:.1f}%)")

for epoch in range(num_epochs):
    optimizer.zero_grad()  # Clear gradients from previous iteration
    
    # Forward pass: compute predictions
    out = model(X, edge_index)
    
    # Compute loss: Binary Cross-Entropy between predictions and labels
    # BCE = -[y*log(p) + (1-y)*log(1-p)] averaged over all nodes
    loss = F.binary_cross_entropy(out, y)
    
    # Backward pass: compute gradients via backpropagation
    loss.backward()
    
    # Update weights: params -= learning_rate * gradients
    optimizer.step()
    
    losses.append(loss.item())
    
    if (epoch + 1) % 20 == 0:
        # Calculate accuracy (what fraction of predictions are correct)
        # Note: accuracy can be misleading with class imbalance
        pred = (out > 0.5).float()
        correct = (pred == y).sum().item()
        acc = correct / num_nodes
        accuracies.append(acc)
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss.item():.4f}, Accuracy: {acc:.4f}")

print(f"\nTraining complete!")
print(f"Final loss: {losses[-1]:.4f}")


## Model Diagnostics: Training Convergence

Before running inference, we should verify the model trained properly by examining:
1. **Loss curve**: Should decrease and plateau (not oscillate or diverge)
2. **Confusion matrix**: Reveals if the model is biased toward predicting one class
3. **Prediction distribution**: Probabilities should spread across [0, 1], not cluster at extremes


In [None]:
# visualize_training: Plot training loss curve and model diagnostics
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# --- Loss Curve ---
ax1 = axes[0]
ax1.plot(losses, color='#2E86AB', linewidth=2, label='Training Loss')
ax1.axhline(y=losses[-1], color='#A23B72', linestyle='--', alpha=0.7, 
            label=f'Final Loss: {losses[-1]:.4f}')
ax1.set_xlabel('Epoch', fontsize=12)
ax1.set_ylabel('Binary Cross-Entropy Loss', fontsize=12)
ax1.set_title('Training Convergence', fontsize=14, fontweight='bold')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Add annotation about convergence
if losses[-1] < losses[0] * 0.5:
    convergence_status = "Good convergence"
    status_color = 'green'
else:
    convergence_status = "May need more epochs"
    status_color = 'orange'
ax1.text(0.95, 0.95, convergence_status, transform=ax1.transAxes, 
         fontsize=10, verticalalignment='top', horizontalalignment='right',
         bbox=dict(boxstyle='round', facecolor=status_color, alpha=0.3))

# --- Prediction Distribution on Training Data ---
ax2 = axes[1]
model.eval()
with torch.no_grad():
    train_preds = model(X, edge_index).cpu().numpy()

# Separate predictions by actual label
preds_positive = train_preds[y.cpu().numpy() == 1]
preds_negative = train_preds[y.cpu().numpy() == 0]

ax2.hist(preds_negative, bins=20, alpha=0.7, label=f'Actual Survived (n={len(preds_negative)})', 
         color='#28A745', edgecolor='white')
ax2.hist(preds_positive, bins=20, alpha=0.7, label=f'Actual Failed (n={len(preds_positive)})', 
         color='#DC3545', edgecolor='white')
ax2.axvline(x=0.5, color='black', linestyle='--', linewidth=2, label='Decision Threshold (0.5)')
ax2.set_xlabel('Predicted Failure Probability', fontsize=12)
ax2.set_ylabel('Number of Nodes', fontsize=12)
ax2.set_title('Prediction Distribution by True Label', fontsize=14, fontweight='bold')
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Print interpretation guidance
print("\n" + "="*60)
print("INTERPRETATION GUIDE:")
print("="*60)
print("Loss Curve:")
print("  • Decreasing loss indicates the model is learning")
print("  • Plateauing loss suggests convergence (good)")
print("  • Oscillating/increasing loss suggests learning rate issues")
print("")
print("Prediction Distribution:")
print("  • Ideal: Failed nodes (red) cluster near 1.0, survived (green) near 0.0")
print("  • Good separation = model discriminates well")
print("  • Overlapping distributions = harder classification task")
print("="*60)


In [None]:
# confusion_matrix: Evaluate model performance with confusion matrix and metrics
from sklearn.metrics import confusion_matrix, precision_score, recall_score, f1_score

# Get predictions at 0.5 threshold
y_true = y.cpu().numpy()
y_pred = (train_preds > 0.5).astype(int)

# Compute confusion matrix
cm = confusion_matrix(y_true, y_pred)
tn, fp, fn, tp = cm.ravel()

# Create visualization
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# --- Confusion Matrix Heatmap ---
ax1 = axes[0]
im = ax1.imshow(cm, interpolation='nearest', cmap='Blues')
ax1.set_title('Confusion Matrix (Training Set)', fontsize=14, fontweight='bold')

# Add text annotations
thresh = cm.max() / 2.
for i in range(2):
    for j in range(2):
        ax1.text(j, i, format(cm[i, j], 'd'),
                ha="center", va="center",
                color="white" if cm[i, j] > thresh else "black",
                fontsize=16, fontweight='bold')

ax1.set_ylabel('Actual Label', fontsize=12)
ax1.set_xlabel('Predicted Label', fontsize=12)
ax1.set_xticks([0, 1])
ax1.set_yticks([0, 1])
ax1.set_xticklabels(['Survived (0)', 'Failed (1)'])
ax1.set_yticklabels(['Survived (0)', 'Failed (1)'])

# Add colorbar
plt.colorbar(im, ax=ax1, shrink=0.8)

# --- Metrics Bar Chart ---
ax2 = axes[1]
precision = precision_score(y_true, y_pred, zero_division=0)
recall = recall_score(y_true, y_pred, zero_division=0)
f1 = f1_score(y_true, y_pred, zero_division=0)
accuracy = (tp + tn) / (tp + tn + fp + fn)

metrics = ['Accuracy', 'Precision', 'Recall', 'F1 Score']
values = [accuracy, precision, recall, f1]
colors = ['#3498db', '#2ecc71', '#e74c3c', '#9b59b6']

bars = ax2.bar(metrics, values, color=colors, edgecolor='white', linewidth=2)
ax2.set_ylim(0, 1.1)
ax2.set_ylabel('Score', fontsize=12)
ax2.set_title('Classification Metrics', fontsize=14, fontweight='bold')
ax2.axhline(y=0.5, color='gray', linestyle='--', alpha=0.5, label='Random baseline')

# Add value labels on bars
for bar, val in zip(bars, values):
    ax2.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.02, 
             f'{val:.3f}', ha='center', va='bottom', fontsize=11, fontweight='bold')

ax2.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.show()

# Print detailed interpretation
print("\n" + "="*60)
print("CONFUSION MATRIX INTERPRETATION:")
print("="*60)
print(f"  True Negatives  (TN): {tn:4d} — Correctly predicted 'survived'")
print(f"  False Positives (FP): {fp:4d} — Incorrectly predicted 'failed' (false alarm)")
print(f"  False Negatives (FN): {fn:4d} — Missed actual failures (DANGEROUS)")
print(f"  True Positives  (TP): {tp:4d} — Correctly predicted 'failed'")
print("")
print("KEY METRICS:")
print(f"  • Precision = TP/(TP+FP) = {precision:.3f}")
print(f"    'Of nodes we predicted would fail, what fraction actually did?'")
print(f"  • Recall = TP/(TP+FN) = {recall:.3f}")
print(f"    'Of nodes that actually failed, what fraction did we catch?'")
print(f"  • F1 Score = 2*P*R/(P+R) = {f1:.3f}")
print(f"    'Harmonic mean of precision and recall'")
print("")
print("FOR SAFETY-CRITICAL APPLICATIONS:")
print("  High RECALL is crucial — we must catch failures even at cost of false alarms")
print("="*60)


## Batch Inference and Patient Zero Identification

### What is "Patient Zero"?

In epidemiology, Patient Zero is the first case that starts an outbreak. In power grid cascades, **Patient Zero is the first node to fail** — the domino that triggers the chain reaction.

Identifying Patient Zero is critical for:
- **Root cause analysis**: Understanding what initiated the cascade
- **Prevention**: Reinforcing vulnerable nodes before they trigger failures
- **Simulation**: Running "what-if" scenarios with different starting points

### Our Heuristic for Finding Patient Zero

Since we have the GCN's failure probability predictions, we use this approach:

1. **Filter**: Consider only nodes that actually failed in the scenario
2. **Rank**: Sort these by the model's predicted failure probability (descending)
3. **Select**: The node with the highest probability is likely the origin

**Intuition**: The model learned to assign high probabilities to nodes under stress. The failed node with the highest predicted probability was likely showing warning signs earliest → most likely to be the cascade origin.

### Cascade Ordering via BFS

Once we identify Patient Zero, we use **Breadth-First Search (BFS)** from that node to establish cascade order:

```
         Patient Zero
             │
    ┌────────┼────────┐
    ▼        ▼        ▼
  Node A   Node B   Node C   ← Cascade Order 2 (1-hop neighbors)
    │        │
    ▼        ▼
  Node D   Node E            ← Cascade Order 3 (2-hop neighbors)
```

**Cascade Depth** = shortest path distance from Patient Zero. This indicates how quickly the failure propagated to reach each node.


In [None]:
# batch_inference_all_scenarios: Run inference for ALL scenarios
model.eval()
all_results = []

scenarios = telemetry_df['SCENARIO_NAME'].unique()
simulation_id = str(uuid.uuid4())[:8]
run_timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S')

print(f"Running batch inference for {len(scenarios)} scenarios...")
print(f"Simulation ID: {simulation_id}")

# Build NetworkX graph for cascade analysis
G = nx.Graph()
G.add_nodes_from(node_ids)
for _, row in edges_df.iterrows():
    G.add_edge(row['SRC_NODE'], row['DST_NODE'])

for scenario in scenarios:
    print(f"\nProcessing scenario: {scenario}")
    
    # Get telemetry for this scenario
    scenario_telemetry = telemetry_df[telemetry_df['SCENARIO_NAME'] == scenario]
    
    # Create features
    X_scenario = create_node_features(nodes_df, scenario_telemetry).to(device)
    
    # Run inference
    with torch.no_grad():
        predictions = model(X_scenario, edge_index).cpu().numpy()
    
    # Identify actual failures from telemetry
    failure_statuses = scenario_telemetry.groupby('NODE_ID')['STATUS'].apply(
        lambda x: 'FAILED' in x.values
    ).to_dict()
    
    # Find Patient Zero (highest probability node that actually failed)
    failure_probs = []
    for idx, node_id in idx_to_node.items():
        if failure_statuses.get(node_id, False):
            failure_probs.append((node_id, predictions[idx]))
    
    failure_probs.sort(key=lambda x: x[1], reverse=True)
    patient_zero_id = failure_probs[0][0] if failure_probs else None
    
    # Calculate cascade order based on BFS from Patient Zero
    cascade_order = {}
    if patient_zero_id:
        order = 1
        for node in nx.bfs_tree(G, patient_zero_id):
            if failure_statuses.get(node, False):
                cascade_order[node] = order
                order += 1
    
    # Generate results for each node
    for idx, node_id in idx_to_node.items():
        node_info = nodes_df[nodes_df['NODE_ID'] == node_id].iloc[0]
        is_failed = failure_statuses.get(node_id, False)
        is_patient_zero = (node_id == patient_zero_id)
        
        # Calculate impact metrics
        if is_failed:
            np.random.seed(hash(node_id) % 2**32)  # Deterministic per node
            load_shed = node_info['CAPACITY_MW'] * np.random.uniform(0.3, 0.8)
            customers = int(node_info['CAPACITY_MW'] * np.random.uniform(500, 1500))
            repair_cost = node_info['CAPACITY_MW'] * np.random.uniform(5000, 15000)
        else:
            load_shed = 0
            customers = 0
            repair_cost = 0
        
        # Calculate cascade depth
        c_order = cascade_order.get(node_id)
        c_depth = None
        if c_order is not None and patient_zero_id:
            try:
                c_depth = nx.shortest_path_length(G, patient_zero_id, node_id)
            except:
                c_depth = None
        
        # Generate AI explanation for high-risk nodes
        ai_explanation = None
        if predictions[idx] > 0.7 or is_patient_zero:
            if is_patient_zero:
                ai_explanation = f"Identified as cascade origin. High criticality score ({node_info['CRITICALITY_SCORE']:.2f}) combined with network topology position makes this node a critical failure point."
            else:
                ai_explanation = f"High failure risk due to proximity to cascade origin and load stress. Recommend preemptive load balancing."
        
        result = {
            'SIMULATION_ID': f"{simulation_id}_{scenario[:3]}",
            'SCENARIO_NAME': scenario,
            'NODE_ID': node_id,
            'RUN_TIMESTAMP': run_timestamp,
            'FAILURE_PROBABILITY': float(predictions[idx]),
            'IS_PATIENT_ZERO': is_patient_zero,
            'CASCADE_ORDER': c_order,
            'CASCADE_DEPTH': c_depth,
            'LOAD_SHED_MW': round(load_shed, 2),
            'CUSTOMERS_IMPACTED': customers,
            'REPAIR_COST': round(repair_cost, 2),
            'RISK_SCORE': round(predictions[idx] * node_info['CRITICALITY_SCORE'], 4),
            'AI_EXPLANATION': ai_explanation
        }
        all_results.append(result)
    
    # Summary for this scenario
    scenario_results = [r for r in all_results if r['SCENARIO_NAME'] == scenario]
    failed_count = sum(1 for r in scenario_results if r['CASCADE_ORDER'] is not None)
    patient_zero = next((r for r in scenario_results if r['IS_PATIENT_ZERO']), None)
    
    print(f"  - Nodes analyzed: {len(scenario_results)}")
    print(f"  - Nodes in cascade: {failed_count}")
    if patient_zero:
        print(f"  - Patient Zero: {patient_zero['NODE_ID']} (prob: {patient_zero['FAILURE_PROBABILITY']:.4f})")

print(f"\nBatch inference complete!")
print(f"Total results: {len(all_results)}")


In [None]:
# visualize_inference_results: Analyze and visualize batch inference results
# Convert results to DataFrame for analysis
results_df_viz = pd.DataFrame(all_results)

fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# --- 1. Risk Score Distribution by Scenario ---
ax1 = axes[0, 0]
scenarios_colors = {'BASE_CASE': '#3498db', 'HIGH_LOAD': '#f39c12', 'WINTER_STORM_2021': '#e74c3c'}
for scenario in results_df_viz['SCENARIO_NAME'].unique():
    scenario_data = results_df_viz[results_df_viz['SCENARIO_NAME'] == scenario]['RISK_SCORE']
    ax1.hist(scenario_data, bins=20, alpha=0.6, label=scenario, 
             color=scenarios_colors.get(scenario, 'gray'), edgecolor='white')
ax1.set_xlabel('Risk Score (Probability × Criticality)', fontsize=11)
ax1.set_ylabel('Number of Nodes', fontsize=11)
ax1.set_title('Risk Score Distribution by Scenario', fontsize=13, fontweight='bold')
ax1.legend()
ax1.grid(True, alpha=0.3)

# --- 2. Failure Probability Comparison ---
ax2 = axes[0, 1]
# Box plot of failure probability by scenario
scenario_probs = [results_df_viz[results_df_viz['SCENARIO_NAME'] == s]['FAILURE_PROBABILITY'].values 
                  for s in results_df_viz['SCENARIO_NAME'].unique()]
bp = ax2.boxplot(scenario_probs, labels=results_df_viz['SCENARIO_NAME'].unique(), patch_artist=True)
colors = ['#3498db', '#f39c12', '#e74c3c']
for patch, color in zip(bp['boxes'], colors):
    patch.set_facecolor(color)
    patch.set_alpha(0.7)
ax2.set_xlabel('Scenario', fontsize=11)
ax2.set_ylabel('Failure Probability', fontsize=11)
ax2.set_title('Failure Probability Distribution by Scenario', fontsize=13, fontweight='bold')
ax2.grid(True, alpha=0.3, axis='y')

# --- 3. Cascade Depth Analysis (for scenarios with failures) ---
ax3 = axes[1, 0]
winter_storm = results_df_viz[
    (results_df_viz['SCENARIO_NAME'] == 'WINTER_STORM_2021') & 
    (results_df_viz['CASCADE_DEPTH'].notna())
]
if len(winter_storm) > 0:
    depth_counts = winter_storm['CASCADE_DEPTH'].value_counts().sort_index()
    ax3.bar(depth_counts.index, depth_counts.values, color='#e74c3c', edgecolor='white', alpha=0.8)
    ax3.set_xlabel('Cascade Depth (hops from Patient Zero)', fontsize=11)
    ax3.set_ylabel('Number of Failed Nodes', fontsize=11)
    ax3.set_title('Cascade Propagation Depth (WINTER_STORM_2021)', fontsize=13, fontweight='bold')
    ax3.grid(True, alpha=0.3, axis='y')
else:
    ax3.text(0.5, 0.5, 'No cascade data available', ha='center', va='center', fontsize=12)
    ax3.set_title('Cascade Propagation Depth', fontsize=13, fontweight='bold')

# --- 4. Top Risk Nodes Table ---
ax4 = axes[1, 1]
ax4.axis('off')
top_risk = results_df_viz.nlargest(10, 'RISK_SCORE')[['NODE_ID', 'SCENARIO_NAME', 'FAILURE_PROBABILITY', 'RISK_SCORE', 'IS_PATIENT_ZERO']]
top_risk['FAILURE_PROBABILITY'] = top_risk['FAILURE_PROBABILITY'].round(4)
top_risk['RISK_SCORE'] = top_risk['RISK_SCORE'].round(4)
top_risk['IS_PATIENT_ZERO'] = top_risk['IS_PATIENT_ZERO'].apply(lambda x: '★' if x else '')

# Create table
table = ax4.table(
    cellText=top_risk.values,
    colLabels=['Node ID', 'Scenario', 'Failure Prob', 'Risk Score', 'P0'],
    loc='center',
    cellLoc='center',
    colWidths=[0.25, 0.3, 0.2, 0.15, 0.1]
)
table.auto_set_font_size(False)
table.set_fontsize(9)
table.scale(1.2, 1.5)
ax4.set_title('Top 10 Highest Risk Nodes', fontsize=13, fontweight='bold', pad=20)

plt.tight_layout()
plt.show()

# Summary statistics
print("\n" + "="*60)
print("INFERENCE RESULTS SUMMARY:")
print("="*60)
for scenario in results_df_viz['SCENARIO_NAME'].unique():
    s_data = results_df_viz[results_df_viz['SCENARIO_NAME'] == scenario]
    high_risk = s_data[s_data['FAILURE_PROBABILITY'] > 0.7]
    patient_zero = s_data[s_data['IS_PATIENT_ZERO'] == True]
    
    print(f"\n{scenario}:")
    print(f"  • Total nodes: {len(s_data)}")
    print(f"  • High-risk nodes (prob > 0.7): {len(high_risk)}")
    print(f"  • Mean failure probability: {s_data['FAILURE_PROBABILITY'].mean():.4f}")
    print(f"  • Max risk score: {s_data['RISK_SCORE'].max():.4f}")
    if len(patient_zero) > 0:
        print(f"  • Patient Zero: {patient_zero.iloc[0]['NODE_ID']}")
print("="*60)


In [None]:
# write_simulation_results: Write results to Snowflake
results_df = pd.DataFrame(all_results)

# Handle None values for Snowflake
results_df['CASCADE_ORDER'] = results_df['CASCADE_ORDER'].apply(lambda x: int(x) if pd.notna(x) else None)
results_df['CASCADE_DEPTH'] = results_df['CASCADE_DEPTH'].apply(lambda x: int(x) if pd.notna(x) else None)

print(f"Writing {len(results_df)} results to SIMULATION_RESULTS table...")

# Clear existing results and insert new ones
session.sql("TRUNCATE TABLE IF EXISTS SIMULATION_RESULTS").collect()

# Convert to Snowpark DataFrame and write
snowpark_df = session.create_dataframe(results_df)
snowpark_df.write.mode('append').save_as_table('SIMULATION_RESULTS')

# Verify write
count = session.sql("SELECT COUNT(*) FROM SIMULATION_RESULTS").collect()[0][0]
print(f"Successfully wrote {count} rows to SIMULATION_RESULTS")

# Show sample results
print("\nSample results:")
sample = session.sql("""
    SELECT SCENARIO_NAME, NODE_ID, FAILURE_PROBABILITY, IS_PATIENT_ZERO, CASCADE_ORDER
    FROM SIMULATION_RESULTS
    WHERE FAILURE_PROBABILITY > 0.5
    ORDER BY SCENARIO_NAME, FAILURE_PROBABILITY DESC
    LIMIT 10
""").to_pandas()
print(sample)


In [None]:
# save_model_artifacts: Save model metadata to Snowflake
import json

# Training metrics
metrics = {
    'final_loss': round(losses[-1], 6),
    'num_epochs': num_epochs,
    'learning_rate': learning_rate,
    'hidden_dim': 64,
    'num_nodes': num_nodes,
    'num_edges': edge_index.shape[1] // 2,
    'random_seed': RANDOM_SEED
}

artifact_id = f"gcn_cascade_{simulation_id}"
model_version = '1.0.0'
metrics_json = json.dumps(metrics)

# Insert model artifact record using SELECT with PARSE_JSON
session.sql(f"""
    INSERT INTO MODEL_ARTIFACTS (ARTIFACT_ID, MODEL_NAME, VERSION, TRAINING_SCENARIOS, METRICS, STATUS)
    SELECT 
        '{artifact_id}',
        'CascadeGCN',
        '{model_version}',
        'WINTER_STORM_2021',
        PARSE_JSON('{metrics_json}'),
        'ACTIVE'
""").collect()

print(f"Model artifact saved: {artifact_id}")
print(f"Metrics: {metrics}")


In [None]:
# display_training_summary: Final summary of training and inference
print("="*60)
print("GRIDGUARD - TRAINING & INFERENCE SUMMARY")
print("="*60)
print("")
print("Model Configuration:")
print(f"  - Architecture: GCN (3 layers)")
print(f"  - Hidden dimension: 64")
print(f"  - Input features: 10")
print(f"  - Training epochs: {num_epochs}")
print(f"  - Final loss: {losses[-1]:.6f}")
print("")
print("Data Summary:")
print(f"  - Grid nodes: {num_nodes}")
print(f"  - Grid edges: {edge_index.shape[1] // 2}")
print(f"  - Scenarios processed: {len(scenarios)}")
print("")
print("Results by Scenario:")

summary = session.sql("""
    SELECT 
        SCENARIO_NAME,
        COUNT(*) AS NODES,
        SUM(CASE WHEN CASCADE_ORDER IS NOT NULL THEN 1 ELSE 0 END) AS CASCADE_NODES,
        ROUND(SUM(LOAD_SHED_MW), 0) AS TOTAL_LOAD_SHED_MW,
        SUM(CUSTOMERS_IMPACTED) AS TOTAL_CUSTOMERS,
        ROUND(SUM(REPAIR_COST), 0) AS TOTAL_REPAIR_COST
    FROM SIMULATION_RESULTS
    GROUP BY SCENARIO_NAME
    ORDER BY SCENARIO_NAME
""").to_pandas()

print(summary.to_string(index=False))

print("")
print("Patient Zero Identification (WINTER_STORM_2021):")
patient_zero = session.sql("""
    SELECT NODE_ID, FAILURE_PROBABILITY, AI_EXPLANATION
    FROM SIMULATION_RESULTS
    WHERE SCENARIO_NAME = 'WINTER_STORM_2021' AND IS_PATIENT_ZERO = TRUE
""").to_pandas()

if len(patient_zero) > 0:
    pz = patient_zero.iloc[0]
    print(f"  - Node ID: {pz['NODE_ID']}")
    print(f"  - Failure Probability: {pz['FAILURE_PROBABILITY']:.4f}")
    print(f"  - Explanation: {pz['AI_EXPLANATION']}")

print("")
print("="*60)
print("Training and inference complete!")
print("Results available in SIMULATION_RESULTS table.")
print("="*60)
