# GNN Supply Chain Risk Analysis

This notebook implements a **Graph Neural Network (GNN)** for supply chain risk analysis using PyTorch Geometric.

## Why Graph Neural Networks for Supply Chain?

Traditional risk models treat suppliers independently, missing critical **network effects**:
- A Tier-2 supplier failure can cascade through multiple Tier-1 suppliers
- Geographic concentration creates correlated risks
- Hidden dependencies exist in trade data that aren't in master data

**GNNs solve this** by learning from the graph structure itself. Each node (supplier, material, region) learns a representation that incorporates information from its neighbors through **message passing**.

## Key Concepts

| Term | Definition |
|------|------------|
| **Heterogeneous Graph** | A graph with multiple node types (Vendor, Material, Region) and edge types (SUPPLIES, LOCATED_IN) |
| **GraphSAGE** | "Graph Sample and Aggregate" - a GNN that learns by sampling neighbors and aggregating their features |
| **Link Prediction** | Predicting whether an edge should exist between two nodes (used to discover hidden Tier-2 relationships) |
| **Node Embedding** | A learned vector representation of a node that captures its position and role in the network |
| **Message Passing** | The mechanism by which nodes exchange information with neighbors across edges |

## Objectives

1. **Build a heterogeneous graph** from supply chain data (Vendors, Materials, Regions, External Suppliers)
2. **Train a GraphSAGE model** for link prediction to discover hidden Tier-2 dependencies
3. **Propagate risk scores** through the network using learned embeddings
4. **Identify bottlenecks** (single points of failure) where many vendors depend on one external supplier
5. **Write results** back to Snowflake tables for downstream analysis


## 1. Environment Setup

Install PyTorch Geometric and dependencies. Uses `os.system` for compatibility with headless execution on SPCS.


In [None]:
import sys
import os

# =============================================================================
# PACKAGE INSTALLATION
# =============================================================================
# PyTorch Geometric (PyG) is the leading library for deep learning on graphs.
# It provides efficient implementations of GNN layers like GraphSAGE, GAT, GCN.
#
# Key packages:
# - torch: PyTorch deep learning framework (tensors, autograd, neural networks)
# - torch-geometric: GNN layers, graph data structures, and utilities
# - torch-scatter: Efficient scatter operations for aggregating neighbor messages
# - torch-sparse: Sparse matrix operations for large graphs
#
# We use os.system() instead of !pip because Snowflake notebook headless 
# execution doesn't support shell commands via the ! prefix.
# =============================================================================

packages = [
    "torch",
    "torch-geometric",
    "torch-scatter",
    "torch-sparse",
]

for pkg in packages:
    print(f"Installing {pkg}...")
    os.system(f"{sys.executable} -m pip install {pkg} -q")

print("[OK] Package installation complete")


In [None]:
# =============================================================================
# LIBRARY IMPORTS
# =============================================================================

# Standard data science stack
import numpy as np
import pandas as pd
from datetime import datetime
from collections import defaultdict
import json
import warnings
warnings.filterwarnings('ignore')

# -----------------------------------------------------------------------------
# PyTorch: The deep learning framework
# - torch.nn: Neural network modules (layers, loss functions)
# - torch.nn.functional (F): Stateless functions (ReLU, dropout, softmax)
# - torch.optim: Optimizers (Adam, SGD) for gradient descent
# -----------------------------------------------------------------------------
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam

# -----------------------------------------------------------------------------
# PyTorch Geometric: Graph Neural Network library
# - HeteroData: Data structure for heterogeneous graphs (multiple node/edge types)
# - SAGEConv: GraphSAGE convolution layer - aggregates neighbor features
# - HeteroConv: Wrapper to apply different convolutions per edge type
# - Linear: Linear transformation layer for projecting features
# - ToUndirected: Adds reverse edges (A->B becomes A<->B) for bidirectional 
#                 message passing - critical for learning from relationships
# -----------------------------------------------------------------------------
from torch_geometric.data import HeteroData
from torch_geometric.nn import SAGEConv, HeteroConv, Linear
from torch_geometric.transforms import ToUndirected
import torch_geometric.transforms as T

# Visualization libraries for model diagnostics
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE

# Snowflake session for data access
from snowflake.snowpark.context import get_active_session

# Report compute environment
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")


In [None]:
# Get Snowflake session
session = get_active_session()

# Configuration
DATABASE = "GNN_SUPPLY_CHAIN_RISK"
SCHEMA = "GNN_SUPPLY_CHAIN_RISK"

# Version tracking enables A/B testing different model configs and rolling back if a deployment fails
MODEL_VERSION = "v1.0.0"

# Set context
session.sql(f"USE DATABASE {DATABASE}").collect()
session.sql(f"USE SCHEMA {SCHEMA}").collect()

print(f"[OK] Connected to {DATABASE}.{SCHEMA}")


## 2. Load Data from Snowflake

Load the supply chain data from Snowflake tables into pandas DataFrames.


In [None]:
# Load data from Snowflake tables
print("Loading data from Snowflake...")

vendors_df = session.table("VENDORS").to_pandas()
materials_df = session.table("MATERIALS").to_pandas()
purchase_orders_df = session.table("PURCHASE_ORDERS").to_pandas()
bom_df = session.table("BILL_OF_MATERIALS").to_pandas()

# Trade data reveals hidden Tier-2 suppliers not in master data via shipping records
trade_data_df = session.table("TRADE_DATA").to_pandas()
regions_df = session.table("REGIONS").to_pandas()

print(f"[OK] Loaded data:")
print(f"  - Vendors: {len(vendors_df)}")
print(f"  - Materials: {len(materials_df)}")
print(f"  - Purchase Orders: {len(purchase_orders_df)}")
print(f"  - BOM: {len(bom_df)}")
print(f"  - Trade Data: {len(trade_data_df)}")
print(f"  - Regions: {len(regions_df)}")


## 3. Build Heterogeneous Graph

A **heterogeneous graph** has multiple types of nodes and edges. This is essential for supply chain modeling where relationships are diverse.

### Graph Schema

```
‚îå‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê     SUPPLIES      ‚îå‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê
‚îÇ    VENDOR    ‚îÇ ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚ñ∂‚îÇ   MATERIAL   ‚îÇ
‚îÇ  (Tier 1-3)  ‚îÇ                   ‚îÇ  (Parts/Raw) ‚îÇ
‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò                   ‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò
       ‚îÇ                                  ‚îÇ
       ‚îÇ LOCATED_IN                       ‚îÇ COMPONENT_OF
       ‚ñº                                  ‚ñº
‚îå‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê                   ‚îå‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê
‚îÇ    REGION    ‚îÇ                   ‚îÇ   MATERIAL   ‚îÇ
‚îÇ  (Country)   ‚îÇ                   ‚îÇ   (Parent)   ‚îÇ
‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò                   ‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò
       ‚ñ≤
       ‚îÇ (inferred via trade)
‚îå‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê     SHIPS_TO      ‚îå‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê
‚îÇ   EXTERNAL   ‚îÇ ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚ñ∂‚îÇ    VENDOR    ‚îÇ
‚îÇ  (Shipper)   ‚îÇ                   ‚îÇ              ‚îÇ
‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò                   ‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò
```

### Why This Matters

- **SUPPLIES** edges: Which vendors supply which materials (from purchase orders)
- **COMPONENT_OF** edges: Bill of materials hierarchy (child ‚Üí parent)
- **LOCATED_IN** edges: Geographic risk exposure (vendor ‚Üí country)
- **SHIPS_TO** edges: Trade data reveals hidden Tier-2 relationships

The GNN will learn to propagate information across ALL these edge types simultaneously.


In [None]:
# =============================================================================
# NODE MAPPINGS: Converting IDs to Graph Indices
# =============================================================================
# PyTorch Geometric represents nodes as contiguous integer indices [0, 1, 2, ...]
# We need to map business IDs (like "VENDOR_001") to these indices.
#
# This is analogous to creating a vocabulary in NLP - each unique entity gets
# a unique integer ID that we'll use throughout the model.
# =============================================================================

def create_node_mappings(vendors_df, materials_df, regions_df, trade_data_df):
    """
    Create bidirectional mappings from business IDs to graph indices.
    
    Returns a dict of dicts:
        mappings['vendor']['VENDOR_001'] = 0  # ID to index
        
    We'll create reverse mappings later when outputting results.
    """
    # Each unique vendor gets an index 0, 1, 2, ...
    vendor_to_idx = {v: i for i, v in enumerate(vendors_df['VENDOR_ID'].unique())}
    
    # Same for materials (parts, raw materials, finished goods)
    material_to_idx = {m: i for i, m in enumerate(materials_df['MATERIAL_ID'].unique())}
    
    # Regions (countries) for geographic risk
    region_to_idx = {r: i for i, r in enumerate(regions_df['REGION_CODE'].unique())}
    
    # External suppliers are discovered from trade/shipping data
    # These are potential Tier-2+ suppliers not in our master vendor list
    # This is the key innovation - discovering unknown suppliers from shipping patterns
    external_suppliers = trade_data_df['SHIPPER_NAME'].unique()
    external_to_idx = {e: i for i, e in enumerate(external_suppliers)}
    
    return {
        'vendor': vendor_to_idx,
        'material': material_to_idx,
        'region': region_to_idx,
        'external': external_to_idx
    }

mappings = create_node_mappings(vendors_df, materials_df, regions_df, trade_data_df)

print("Node mappings created:")
print("=" * 40)
for k, v in mappings.items():
    print(f"  {k:12s}: {len(v):5d} nodes")
print("=" * 40)
print(f"  TOTAL      : {sum(len(v) for v in mappings.values()):5d} nodes")


### Understanding Node Features in Graph Neural Networks

Unlike traditional ML where each sample is independent, in GNNs the **features are just the starting point**. The model will transform these initial features by combining them with neighbor information.

**What makes good node features?**

1. **Intrinsic properties**: Attributes of the node itself (financial health, criticality score)
2. **Categorical indicators**: One-hot encodings help the model learn type-specific patterns
3. **Normalized scales**: All features should be in similar ranges (typically [0,1]) so no single feature dominates

**A key insight**: The GNN will learn to combine features from connected nodes. A vendor's final embedding will incorporate:
- Its own features (financial health, tier)
- Features of materials it supplies (criticality)
- Features of its region (geopolitical risk)
- Features of external suppliers shipping to it (volume)

This is why even simple initial features can produce powerful representations - the graph structure provides the context.


In [None]:
# =============================================================================
# NODE FEATURES: Initial Representations for Each Node
# =============================================================================
# Each node needs an initial feature vector. The GNN will transform these
# through message passing, but we need to start with meaningful attributes.
#
# Feature Engineering Choices:
# - Continuous features (scores) are normalized to [0, 1]
# - Categorical features use one-hot encoding
# - Different node types can have different feature dimensions
#   (the model will project them to a common hidden dimension)
#
# Normalization to [0,1] prevents features with large magnitudes from dominating gradient updates
# =============================================================================

def create_node_features(vendors_df, materials_df, regions_df, trade_data_df, mappings):
    """
    Create feature tensors for each node type.
    
    Feature matrices have shape [num_nodes, num_features].
    Each row is one node's feature vector.
    """
    num_vendors = len(mappings['vendor'])
    num_regions = len(mappings['region'])
    
    # -------------------------------------------------------------------------
    # VENDOR FEATURES: [financial_health, tier_normalized, one-hot_country...]
    # -------------------------------------------------------------------------
    # - Financial health (0-1): Probability of default or financial distress
    # - Tier normalized (0-1): Tier 1 = 0.33, Tier 2 = 0.67, Tier 3 = 1.0
    # - Country one-hot: Binary indicator for each region (captures geography)
    vendor_features = torch.zeros(num_vendors, 2 + num_regions)
    for _, row in vendors_df.iterrows():
        idx = mappings['vendor'][row['VENDOR_ID']]
        vendor_features[idx, 0] = row['FINANCIAL_HEALTH_SCORE']
        vendor_features[idx, 1] = row['TIER'] / 3.0  # Normalize tier to [0, 1]
        if row['COUNTRY_CODE'] in mappings['region']:
            region_idx = mappings['region'][row['COUNTRY_CODE']]
            vendor_features[idx, 2 + region_idx] = 1.0  # One-hot encoding
    
    # -------------------------------------------------------------------------
    # MATERIAL FEATURES: [criticality, material_group_one_hot]
    # -------------------------------------------------------------------------
    # - Criticality score: How essential is this material? (0 = commodity, 1 = critical)
    # - Material group: RAW (raw materials), SEMI (semi-finished), FIN (finished goods)
    num_materials = len(mappings['material'])
    material_features = torch.zeros(num_materials, 4)
    group_map = {'RAW': 0, 'SEMI': 1, 'FIN': 2}
    for _, row in materials_df.iterrows():
        idx = mappings['material'][row['MATERIAL_ID']]
        material_features[idx, 0] = row['CRITICALITY_SCORE']
        material_features[idx, 1 + group_map.get(row['MATERIAL_GROUP'], 0)] = 1.0
    
    # -------------------------------------------------------------------------
    # REGION FEATURES: [base_risk, geopolitical, natural_disaster, infrastructure]
    # -------------------------------------------------------------------------
    # Multi-dimensional risk profile for each country/region
    region_features = torch.zeros(num_regions, 4)
    for _, row in regions_df.iterrows():
        if row['REGION_CODE'] in mappings['region']:
            idx = mappings['region'][row['REGION_CODE']]
            region_features[idx, 0] = row['BASE_RISK_SCORE']       # Overall risk
            region_features[idx, 1] = row['GEOPOLITICAL_RISK']     # Political stability
            region_features[idx, 2] = row['NATURAL_DISASTER_RISK'] # Climate/earthquake
            region_features[idx, 3] = row['INFRASTRUCTURE_SCORE']  # Logistics quality
    
    # -------------------------------------------------------------------------
    # EXTERNAL SUPPLIER FEATURES: [volume_share, shipment_frequency]
    # -------------------------------------------------------------------------
    # These are discovered from trade data, so we have less info than internal vendors
    num_external = len(mappings['external'])
    external_features = torch.zeros(num_external, 2)
    shipper_stats = trade_data_df.groupby('SHIPPER_NAME').agg({
        'WEIGHT_KG': 'sum',    # Total volume shipped
        'BOL_ID': 'count'      # Number of shipments
    }).reset_index()
    max_volume = shipper_stats['WEIGHT_KG'].max() if len(shipper_stats) > 0 else 1
    
    for _, row in shipper_stats.iterrows():
        if row['SHIPPER_NAME'] in mappings['external']:
            idx = mappings['external'][row['SHIPPER_NAME']]
            # Normalized volume share (0-1)
            external_features[idx, 0] = row['WEIGHT_KG'] / max_volume if max_volume > 0 else 0
            # Shipment frequency relative to total
            external_features[idx, 1] = row['BOL_ID'] / len(trade_data_df)
    
    return {
        'vendor': vendor_features, 
        'material': material_features, 
        'region': region_features, 
        'external': external_features
    }

node_features = create_node_features(vendors_df, materials_df, regions_df, trade_data_df, mappings)

print("Node features created:")
print("=" * 50)
print(f"  {'Node Type':<12s}  {'Nodes':>6s}  {'Features':>8s}  Description")
print("-" * 50)
for node_type, features in node_features.items():
    desc = {
        'vendor': 'financial + tier + region one-hot',
        'material': 'criticality + group one-hot',
        'region': 'multi-risk profile',
        'external': 'volume + frequency'
    }
    print(f"  {node_type:<12s}  {features.shape[0]:>6d}  {features.shape[1]:>8d}  {desc[node_type]}")
print("=" * 50)


In [None]:
# =============================================================================
# EDGE INDICES: Defining Graph Connectivity
# =============================================================================
# Edges are stored in COO (Coordinate) format as a 2xN tensor:
#   edge_index[0] = source node indices
#   edge_index[1] = target node indices
#
# For heterogeneous graphs, edges are keyed by (src_type, relation, dst_type)
# tuples, e.g., ('vendor', 'supplies', 'material')
# =============================================================================

def create_edge_indices(purchase_orders_df, bom_df, trade_data_df, vendors_df, mappings):
    """
    Build edge index tensors for each relationship type.
    
    Edge indices are 2xN tensors where:
        edge_index[0, i] = source node index for edge i
        edge_index[1, i] = destination node index for edge i
    """
    edges = {}
    
    # -------------------------------------------------------------------------
    # SUPPLIES edges: Vendor --> Material
    # -------------------------------------------------------------------------
    # Created from purchase orders: if vendor V supplied material M, add edge V->M
    # This captures the direct Tier-1 supply relationships
    supplies_src, supplies_dst = [], []
    for _, row in purchase_orders_df.iterrows():
        if row['VENDOR_ID'] in mappings['vendor'] and row['MATERIAL_ID'] in mappings['material']:
            supplies_src.append(mappings['vendor'][row['VENDOR_ID']])
            supplies_dst.append(mappings['material'][row['MATERIAL_ID']])
    edges[('vendor', 'supplies', 'material')] = torch.tensor([supplies_src, supplies_dst], dtype=torch.long)
    
    # -------------------------------------------------------------------------
    # COMPONENT_OF edges: Material --> Material (child -> parent)
    # -------------------------------------------------------------------------
    # Bill of Materials (BOM) hierarchy: raw materials -> semi-finished -> finished
    # This allows risk to propagate up the BOM tree
    component_src, component_dst = [], []
    for _, row in bom_df.iterrows():
        if row['CHILD_MATERIAL_ID'] in mappings['material'] and row['PARENT_MATERIAL_ID'] in mappings['material']:
            component_src.append(mappings['material'][row['CHILD_MATERIAL_ID']])
            component_dst.append(mappings['material'][row['PARENT_MATERIAL_ID']])
    edges[('material', 'component_of', 'material')] = torch.tensor([component_src, component_dst], dtype=torch.long)
    
    # -------------------------------------------------------------------------
    # LOCATED_IN edges: Vendor --> Region
    # -------------------------------------------------------------------------
    # Geographic relationship: links vendors to their country
    # Enables propagation of regional risk (geopolitical, natural disaster, etc.)
    located_src, located_dst = [], []
    for _, row in vendors_df.iterrows():
        if row['VENDOR_ID'] in mappings['vendor'] and row['COUNTRY_CODE'] in mappings['region']:
            located_src.append(mappings['vendor'][row['VENDOR_ID']])
            located_dst.append(mappings['region'][row['COUNTRY_CODE']])
    edges[('vendor', 'located_in', 'region')] = torch.tensor([located_src, located_dst], dtype=torch.long)
    
    # -------------------------------------------------------------------------
    # SHIPS_TO edges: External --> Vendor
    # -------------------------------------------------------------------------
    # Discovered from trade/shipping data (e.g., import records, BOL data)
    # This reveals hidden Tier-2+ suppliers that aren't in master vendor data
    # Key insight: if external supplier E ships to vendor V, E is likely V's supplier
    # These edges become our training signal - we learn to predict them, then find missing ones
    ships_src, ships_dst = [], []
    vendor_name_to_id = {row['NAME'].upper(): row['VENDOR_ID'] for _, row in vendors_df.iterrows()}
    for _, row in trade_data_df.iterrows():
        shipper = row['SHIPPER_NAME']
        consignee = row['CONSIGNEE_NAME'].upper()
        if shipper in mappings['external'] and consignee in vendor_name_to_id:
            vendor_id = vendor_name_to_id[consignee]
            if vendor_id in mappings['vendor']:
                ships_src.append(mappings['external'][shipper])
                ships_dst.append(mappings['vendor'][vendor_id])
    edges[('external', 'ships_to', 'vendor')] = torch.tensor([ships_src, ships_dst], dtype=torch.long)
    
    return edges

edge_indices = create_edge_indices(purchase_orders_df, bom_df, trade_data_df, vendors_df, mappings)

print("Edge indices created:")
print("=" * 60)
print(f"  {'Edge Type':<45s}  {'Edges':>8s}")
print("-" * 60)
for edge_type, edge_index in edge_indices.items():
    edge_str = f"{edge_type[0]} --[{edge_type[1]}]--> {edge_type[2]}"
    print(f"  {edge_str:<45s}  {edge_index.shape[1]:>8d}")
print("=" * 60)
print(f"  {'TOTAL EDGES':<45s}  {sum(e.shape[1] for e in edge_indices.values()):>8d}")


In [None]:
# =============================================================================
# BUILD HETERODATA OBJECT
# =============================================================================
# HeteroData is PyG's container for heterogeneous graphs. It stores:
#   - data['node_type'].x = node feature matrix
#   - data[('src', 'relation', 'dst')].edge_index = edge connectivity
#
# ToUndirected() adds reverse edges for every edge type:
#   e.g., ('vendor', 'supplies', 'material') ‚Üí ('material', 'rev_supplies', 'vendor')
#
# Why undirected? In GNNs, message passing happens along edges. If we only have
# A‚ÜíB edges, information can only flow from A to B. By adding B‚ÜíA, we enable
# bidirectional information flow, which is critical for learning representations
# that capture the full graph context.
# =============================================================================

def build_hetero_graph(node_features, edge_indices):
    """
    Assemble the HeteroData object from features and edges.
    
    This is the core data structure that gets fed to the GNN.
    """
    data = HeteroData()
    
    # Add node features (each node type has its own feature matrix)
    data['vendor'].x = node_features['vendor']
    data['material'].x = node_features['material']
    data['region'].x = node_features['region']
    data['external'].x = node_features['external']
    
    # Add edge indices (only if edges exist)
    for edge_type, edge_index in edge_indices.items():
        if edge_index.shape[1] > 0:
            data[edge_type].edge_index = edge_index
    
    return data

data = build_hetero_graph(node_features, edge_indices)

# Add reverse edges for bidirectional message passing
# This transforms directed edges A‚ÜíB into undirected A‚ÜîB
data = ToUndirected()(data)

print("\n" + "=" * 60)
print("HETEROGENEOUS GRAPH CONSTRUCTED")
print("=" * 60)
print(data)
print("\nNote: 'rev_*' edge types are reverse edges added by ToUndirected()")


In [None]:
# =============================================================================
# GRAPH STRUCTURE VISUALIZATION
# =============================================================================
# Visualize the graph statistics to understand the network structure
# Imbalanced node/edge counts can cause learning issues - catch problems early before training

fig, axes = plt.subplots(1, 3, figsize=(14, 4))

# Plot 1: Node counts by type
node_counts = {k: len(v) for k, v in mappings.items()}
colors = ['#2ecc71', '#3498db', '#e74c3c', '#9b59b6']
axes[0].bar(node_counts.keys(), node_counts.values(), color=colors)
axes[0].set_title('Node Counts by Type', fontweight='bold')
axes[0].set_ylabel('Count')
for i, (k, v) in enumerate(node_counts.items()):
    axes[0].text(i, v + max(node_counts.values())*0.02, str(v), ha='center', fontweight='bold')

# Plot 2: Edge counts by type
edge_counts = {f"{e[0][:3]}‚Üí{e[2][:3]}\n({e[1][:6]})": idx.shape[1] 
               for e, idx in edge_indices.items()}
axes[1].bar(range(len(edge_counts)), edge_counts.values(), color='#34495e')
axes[1].set_xticks(range(len(edge_counts)))
axes[1].set_xticklabels(edge_counts.keys(), fontsize=8)
axes[1].set_title('Edge Counts by Relationship', fontweight='bold')
axes[1].set_ylabel('Count')

# Plot 3: Feature dimensions by node type
feat_dims = {k: v.shape[1] for k, v in node_features.items()}
axes[2].barh(list(feat_dims.keys()), list(feat_dims.values()), color=colors)
axes[2].set_title('Feature Dimensions by Node Type', fontweight='bold')
axes[2].set_xlabel('Number of Features')
for i, (k, v) in enumerate(feat_dims.items()):
    axes[2].text(v + 0.5, i, str(v), va='center', fontweight='bold')

plt.tight_layout()
plt.savefig('/tmp/graph_structure.png', dpi=150, bbox_inches='tight')
plt.show()

print("\nüìä Graph Density Analysis:")
total_nodes = sum(node_counts.values())
total_edges = sum(e.shape[1] for e in edge_indices.values())
print(f"   Total nodes: {total_nodes}")
print(f"   Total edges: {total_edges}")
print(f"   Average degree: {2 * total_edges / total_nodes:.2f}")


## 4. Define GNN Model

### GraphSAGE: How It Works

**GraphSAGE** (Graph SAmple and aggreGatE) learns node representations by aggregating features from neighboring nodes. Unlike traditional methods that require the full graph at inference time, GraphSAGE learns an **aggregation function** that can generalize to unseen nodes.

**Message Passing (one layer):**
```
h_v^(l+1) = œÉ( W ¬∑ AGGREGATE({h_u^(l) : u ‚àà N(v)}) + B ¬∑ h_v^(l) )
```

Where:
- `h_v^(l)` = embedding of node v at layer l
- `N(v)` = neighbors of node v  
- `AGGREGATE` = mean/sum/max of neighbor embeddings
- `W, B` = learnable weight matrices
- `œÉ` = activation function (ReLU)

**Why 2 Layers?**
- Layer 1: Each node sees its immediate neighbors (1-hop)
- Layer 2: Each node sees neighbors of neighbors (2-hop)
- This captures Tier-1 AND Tier-2 relationships in one forward pass

**Heterogeneous Extension:**
For each edge type, we have separate weight matrices. This allows the model to learn that "SUPPLIES" relationships matter differently than "LOCATED_IN" relationships.


### Deep Dive: How Message Passing Actually Works

This is the core of what makes GNNs powerful. Let's trace through what happens to a single vendor node:

**Layer 1 (1-hop):**
```
Vendor_A receives messages from:
  ‚îú‚îÄ‚îÄ Material_X (via SUPPLIES edge) ‚Üí "I'm a critical component"
  ‚îú‚îÄ‚îÄ Material_Y (via SUPPLIES edge) ‚Üí "I'm a commodity"  
  ‚îú‚îÄ‚îÄ Region_CN (via LOCATED_IN edge) ‚Üí "I have high geopolitical risk"
  ‚îî‚îÄ‚îÄ External_Z (via SHIPS_TO edge) ‚Üí "I ship large volumes"

Vendor_A aggregates these: new_embedding = MEAN(all neighbor embeddings)
Then transforms: final = ReLU(W √ó new_embedding + b)
```

**Layer 2 (2-hop):**
```
Now Vendor_A's neighbors have ALREADY incorporated THEIR neighbors:
  ‚îú‚îÄ‚îÄ Material_X now knows about OTHER vendors that supply it
  ‚îú‚îÄ‚îÄ Region_CN now knows about OTHER vendors in the same region
  ‚îî‚îÄ‚îÄ External_Z now knows about OTHER vendors it ships to

When Vendor_A aggregates again, it indirectly learns about:
  - Competitors (other vendors supplying same materials)
  - Geographic peers (other vendors in same region)  
  - Shared suppliers (other vendors using same external supplier)
```

**This is how hidden Tier-2 dependencies emerge**: Two vendors that share an external supplier will have similar embeddings even if they have no direct connection.


In [None]:
# =============================================================================
# HETEROGENEOUS GRAPHSAGE MODEL
# =============================================================================
# This model has three main components:
#   1. Input projection: Map each node type's features to a common dimension
#   2. Message passing layers: Aggregate information from neighbors
#   3. Task heads: Predict risk scores or link probabilities
# =============================================================================

class HeteroGraphSAGE(nn.Module):
    """
    Heterogeneous GraphSAGE model for supply chain risk analysis.
    
    Architecture:
        Input Features (varying dim) 
            ‚Üí Linear projection (hidden_channels)
            ‚Üí ReLU
            ‚Üí GraphSAGE Conv Layer 1 (hidden_channels)
            ‚Üí ReLU + Dropout
            ‚Üí GraphSAGE Conv Layer 2 (out_channels)
            ‚Üí Task-specific heads
    
    The model learns:
        - Node embeddings that capture graph structure
        - How risk propagates across different relationship types
        - Which external suppliers are likely connected to which vendors
    """
    
    def __init__(self, in_channels_dict, hidden_channels=64, out_channels=32):
        super().__init__()
        self.hidden_channels = hidden_channels
        self.out_channels = out_channels
        
        # ---------------------------------------------------------------------
        # INPUT PROJECTION LAYERS
        # ---------------------------------------------------------------------
        # Each node type has different input feature dimensions.
        # We project all of them to a common hidden dimension.
        # This is like creating a shared "embedding space" for all node types.
        self.lin_dict = nn.ModuleDict()
        for node_type, in_channels in in_channels_dict.items():
            self.lin_dict[node_type] = Linear(in_channels, hidden_channels)
        
        # ---------------------------------------------------------------------
        # FIRST GRAPH CONVOLUTION LAYER (1-hop neighborhood)
        # ---------------------------------------------------------------------
        # HeteroConv wraps separate convolutions for each edge type.
        # Each edge type has its own weights, allowing the model to learn
        # that "supplies" relationships differ from "located_in" relationships.
        #
        # aggr='mean': If a node receives messages from multiple edge types,
        # we average them. Alternatives: 'sum', 'max'.
        self.conv1 = HeteroConv({
            # Forward edges (original relationships)
            ('vendor', 'supplies', 'material'): SAGEConv(hidden_channels, hidden_channels),
            ('material', 'component_of', 'material'): SAGEConv(hidden_channels, hidden_channels),
            ('vendor', 'located_in', 'region'): SAGEConv(hidden_channels, hidden_channels),
            ('external', 'ships_to', 'vendor'): SAGEConv(hidden_channels, hidden_channels),
            # Reverse edges (added by ToUndirected)
            ('material', 'rev_supplies', 'vendor'): SAGEConv(hidden_channels, hidden_channels),
            ('material', 'rev_component_of', 'material'): SAGEConv(hidden_channels, hidden_channels),
            ('region', 'rev_located_in', 'vendor'): SAGEConv(hidden_channels, hidden_channels),
            ('vendor', 'rev_ships_to', 'external'): SAGEConv(hidden_channels, hidden_channels),
        }, aggr='mean')
        
        # ---------------------------------------------------------------------
        # SECOND GRAPH CONVOLUTION LAYER (2-hop neighborhood)
        # ---------------------------------------------------------------------
        # After this layer, each node's embedding contains information from
        # nodes up to 2 hops away. This captures Tier-2 relationships.
        self.conv2 = HeteroConv({
            ('vendor', 'supplies', 'material'): SAGEConv(hidden_channels, out_channels),
            ('material', 'component_of', 'material'): SAGEConv(hidden_channels, out_channels),
            ('vendor', 'located_in', 'region'): SAGEConv(hidden_channels, out_channels),
            ('external', 'ships_to', 'vendor'): SAGEConv(hidden_channels, out_channels),
            ('material', 'rev_supplies', 'vendor'): SAGEConv(hidden_channels, out_channels),
            ('material', 'rev_component_of', 'material'): SAGEConv(hidden_channels, out_channels),
            ('region', 'rev_located_in', 'vendor'): SAGEConv(hidden_channels, out_channels),
            ('vendor', 'rev_ships_to', 'external'): SAGEConv(hidden_channels, out_channels),
        }, aggr='mean')
        
        # ---------------------------------------------------------------------
        # RISK PREDICTION HEADS
        # ---------------------------------------------------------------------
        # Simple linear layers that map embeddings to a single risk score.
        # Sigmoid activation bounds output to [0, 1].
        # Note: 'region' doesn't have a risk head - we use raw region features.
        self.risk_head = nn.ModuleDict({
            'vendor': nn.Linear(out_channels, 1),
            'material': nn.Linear(out_channels, 1),
            'external': nn.Linear(out_channels, 1),
        })
    
    def forward(self, x_dict, edge_index_dict):
        """
        Forward pass: compute node embeddings.
        
        Args:
            x_dict: Dict mapping node_type -> feature tensor
            edge_index_dict: Dict mapping edge_type -> edge indices
            
        Returns:
            Dict mapping node_type -> embedding tensor
        """
        # Step 1: Project all node types to hidden dimension
        x_dict = {key: self.lin_dict[key](x) for key, x in x_dict.items()}
        x_dict = {key: F.relu(x) for key, x in x_dict.items()}
        
        # Step 2: First round of message passing (1-hop)
        x_dict = self.conv1(x_dict, edge_index_dict)
        x_dict = {key: F.relu(x) for key, x in x_dict.items()}
        
        # Dropout for regularization (only during training)
        x_dict = {key: F.dropout(x, p=0.3, training=self.training) for key, x in x_dict.items()}
        
        # Step 3: Second round of message passing (2-hop)
        x_dict = self.conv2(x_dict, edge_index_dict)
        
        return x_dict
    
    def predict_risk(self, x_dict):
        """
        Predict risk scores for each node.
        
        Returns:
            Dict mapping node_type -> risk scores in [0, 1]
        """
        risk_scores = {}
        for node_type, head in self.risk_head.items():
            if node_type in x_dict:
                # Sigmoid squashes output to [0, 1] for probability interpretation
                risk_scores[node_type] = torch.sigmoid(head(x_dict[node_type]))
        return risk_scores
    
    def predict_link(self, z_src, z_dst):
        """
        Predict probability of a link between source and destination nodes.
        
        This uses a simple dot-product decoder:
            P(edge exists) = sigmoid(z_src ¬∑ z_dst)
        
        Intuition: If two nodes have similar embeddings (high dot product),
        they're likely connected.
        """
        return torch.sigmoid((z_src * z_dst).sum(dim=-1))


# =============================================================================
# MODEL INITIALIZATION
# =============================================================================
# Hyperparameters:
#   hidden_channels=64: Dimension of intermediate representations
#   out_channels=32: Dimension of final embeddings (used for link prediction)
# These dimensions balance expressiveness vs overfitting for typical supply chain graphs (<10K nodes)
# =============================================================================

in_channels_dict = {key: feat.shape[1] for key, feat in node_features.items()}
model = HeteroGraphSAGE(in_channels_dict, hidden_channels=64, out_channels=32)

# Move to GPU if available (significant speedup for large graphs)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
data = data.to(device)

# Model summary
print("=" * 60)
print("MODEL ARCHITECTURE")
print("=" * 60)
print(f"Device: {device}")
print(f"Hidden dimension: 64")
print(f"Output embedding dimension: 32")
print(f"Number of GNN layers: 2 (captures 2-hop neighborhoods)")
print(f"Dropout rate: 0.3")
print("-" * 60)
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print("=" * 60)


## 5. Model Training

### Training Objective: Link Prediction

We train the model using **link prediction** as a self-supervised task:
- **Positive samples**: Edges that exist (External ‚Üí Vendor shipments)
- **Negative samples**: Random pairs that don't have edges

**Loss Function**: Binary Cross-Entropy
```
L = -[y¬∑log(p) + (1-y)¬∑log(1-p)]
```
Where y=1 for real edges, y=0 for fake edges.

**Why Link Prediction?**
1. It's self-supervised (no manual labels needed)
2. Forces the model to learn meaningful embeddings
3. The trained model can predict NEW links (hidden Tier-2 relationships)

**Regularization**:
- L2 regularization on embeddings prevents overfitting
- Dropout (30%) during training


### The Self-Supervised Training Trick

A common question: **"How can we train this model without labeled risk data?"**

The answer is **link prediction as a proxy task**. Here's the insight:

1. We have some SHIPS_TO edges from trade data (external supplier ‚Üí vendor)
2. We train the model to predict: "Given embeddings of External_A and Vendor_B, are they connected?"
3. The model learns embeddings where **connected nodes are similar** (high dot product)

**But here's the magic**: To predict links well, the model MUST learn meaningful representations. It has to understand:
- What makes a vendor likely to receive shipments from a particular external supplier?
- What patterns in the graph indicate a supply relationship?

Once trained, these embeddings capture the **latent structure of the supply chain**. We can then:
- Use embeddings to predict risk (via the risk head)
- Find missing links (hidden dependencies)
- Identify similar nodes (potential alternative suppliers)

**The negative sampling is critical**: If we only showed positive examples (real edges), the model would learn to predict 1.0 for everything. By showing random non-edges as negatives, we force it to learn discriminative features.


In [None]:
# =============================================================================
# MODEL TRAINING
# =============================================================================
# We use link prediction as the training objective. The model learns to
# distinguish real edges (positive samples) from fake edges (negative samples).
#
# Training Loop:
#   1. Forward pass: Compute embeddings for all nodes
#   2. Sample positive edges (real SHIPS_TO relationships)
#   3. Sample negative edges (random external-vendor pairs)
#   4. Compute loss: BCE for positive + BCE for negative + L2 regularization
#   5. Backward pass: Compute gradients
#   6. Update weights with Adam optimizer
# =============================================================================

def train_model(model, data, epochs=100, lr=0.01):
    """
    Train the GNN model using link prediction.
    
    Args:
        model: HeteroGraphSAGE model
        data: HeteroData graph object
        epochs: Number of training iterations
        lr: Learning rate for Adam optimizer
        
    Returns:
        List of loss values per epoch (for plotting)
    """
    optimizer = Adam(model.parameters(), lr=lr)
    model.train()  # Enable dropout and training mode
    
    # We train on the SHIPS_TO edges (external supplier ‚Üí vendor)
    # These represent the relationships we want to predict for hidden suppliers
    edge_type = ('external', 'ships_to', 'vendor')
    pos_edge_index = data[edge_type].edge_index if edge_type in data.edge_index_dict else None
    
    losses = []
    pos_losses = []
    neg_losses = []
    
    for epoch in range(epochs):
        optimizer.zero_grad()
        
        # Forward pass: compute node embeddings
        z_dict = model(data.x_dict, data.edge_index_dict)
        
        loss = 0
        pos_loss_val = 0
        neg_loss_val = 0
        
        if pos_edge_index is not None and pos_edge_index.shape[1] > 0:
            # -----------------------------------------------------------------
            # POSITIVE SAMPLES: Real edges that exist
            # -----------------------------------------------------------------
            # For each real edge (external[i] ‚Üí vendor[j]), we want
            # the model to predict high probability (close to 1)
            z_src = z_dict['external'][pos_edge_index[0]]
            z_dst = z_dict['vendor'][pos_edge_index[1]]
            pos_pred = model.predict_link(z_src, z_dst)
            pos_loss = F.binary_cross_entropy(pos_pred, torch.ones_like(pos_pred))
            pos_loss_val = pos_loss.item()
            
            # -----------------------------------------------------------------
            # NEGATIVE SAMPLES: Random pairs (likely non-edges)
            # -----------------------------------------------------------------
            # For random pairs, we want the model to predict low probability
            # We sample the same number of negatives as positives (balanced)
            # Without negatives, the model would predict 1.0 for everything - negatives teach it to discriminate
            num_neg = pos_edge_index.shape[1]
            neg_src = torch.randint(0, z_dict['external'].shape[0], (num_neg,), device=device)
            neg_dst = torch.randint(0, z_dict['vendor'].shape[0], (num_neg,), device=device)
            neg_pred = model.predict_link(z_dict['external'][neg_src], z_dict['vendor'][neg_dst])
            neg_loss = F.binary_cross_entropy(neg_pred, torch.zeros_like(neg_pred))
            neg_loss_val = neg_loss.item()
            
            loss += pos_loss + neg_loss
        
        # -----------------------------------------------------------------
        # L2 REGULARIZATION: Prevent embeddings from exploding
        # -----------------------------------------------------------------
        # Small penalty on the magnitude of all embeddings
        # Coefficient 0.001 is typical (not too strong)
        for node_type, z in z_dict.items():
            loss += 0.001 * torch.norm(z, p=2)
        
        # Backward pass and optimization step
        loss.backward()
        optimizer.step()
        
        # Track losses for visualization
        losses.append(loss.item())
        pos_losses.append(pos_loss_val)
        neg_losses.append(neg_loss_val)
        
        if (epoch + 1) % 20 == 0:
            print(f"Epoch {epoch+1:3d}/{epochs} | Total Loss: {loss.item():.4f} | "
                  f"Pos: {pos_loss_val:.4f} | Neg: {neg_loss_val:.4f}")
    
    return losses, pos_losses, neg_losses

print("=" * 60)
print("TRAINING GNN MODEL")
print("=" * 60)
print(f"Epochs: 100 | Learning Rate: 0.01 | Optimizer: Adam")
print("-" * 60)

losses, pos_losses, neg_losses = train_model(model, data, epochs=100, lr=0.01)

print("-" * 60)
print(f"‚úÖ Training complete!")
print(f"   Final loss: {losses[-1]:.4f}")
print(f"   Loss reduction: {(losses[0] - losses[-1]) / losses[0] * 100:.1f}%")
print("=" * 60)


In [None]:
# =============================================================================
# TRAINING DIAGNOSTICS VISUALIZATION
# =============================================================================
# Visualize training progress to check for convergence and potential issues

fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Plot 1: Total Loss Curve
axes[0].plot(losses, color='#2c3e50', linewidth=2)
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Total Loss')
axes[0].set_title('Training Loss Curve', fontweight='bold')
axes[0].axhline(y=losses[-1], color='#e74c3c', linestyle='--', alpha=0.7, label=f'Final: {losses[-1]:.4f}')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Plot 2: Positive vs Negative Loss
axes[1].plot(pos_losses, label='Positive (real edges)', color='#27ae60', linewidth=2)
axes[1].plot(neg_losses, label='Negative (fake edges)', color='#e74c3c', linewidth=2)
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('BCE Loss')
axes[1].set_title('Link Prediction Loss Components', fontweight='bold')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

# Plot 3: Loss ratio (should converge to ~1 for balanced learning)
# Ratio >> 1 means model ignores negatives (overpredicts); << 1 means ignores positives (underpredicts)
loss_ratio = [p / (n + 1e-8) for p, n in zip(pos_losses, neg_losses)]
axes[2].plot(loss_ratio, color='#9b59b6', linewidth=2)
axes[2].axhline(y=1.0, color='#34495e', linestyle='--', alpha=0.7, label='Balanced (1.0)')
axes[2].set_xlabel('Epoch')
axes[2].set_ylabel('Pos/Neg Loss Ratio')
axes[2].set_title('Loss Balance', fontweight='bold')
axes[2].legend()
axes[2].grid(True, alpha=0.3)
axes[2].set_ylim(0, 3)

plt.tight_layout()
plt.savefig('/tmp/training_diagnostics.png', dpi=150, bbox_inches='tight')
plt.show()

# Training quality assessment
print("\nüìä Training Quality Assessment:")
print("-" * 40)

# Check convergence
if losses[-1] < losses[0] * 0.5:
    print("‚úÖ Good convergence: Loss reduced by >50%")
else:
    print("‚ö†Ô∏è  Limited convergence: Consider more epochs or lower LR")

# Check balance
final_ratio = pos_losses[-1] / (neg_losses[-1] + 1e-8)
if 0.5 < final_ratio < 2.0:
    print("‚úÖ Balanced learning: Pos/Neg ratio near 1.0")
else:
    print(f"‚ö†Ô∏è  Imbalanced: Pos/Neg ratio = {final_ratio:.2f}")

# Check for underfitting/overfitting
if losses[-1] < 0.5:
    print("‚úÖ Low final loss: Model learned the task")
else:
    print("‚ö†Ô∏è  High final loss: May need more capacity or data")


## 6. Compute Risk Scores

### Risk Scoring Approach

The GNN learns embeddings that capture each node's position in the supply chain network. We convert these embeddings to risk scores using:

1. **Learned Risk Head**: A linear layer trained to predict risk from embeddings
2. **Regional Risk Blending**: Vendor risk is adjusted based on their region's risk profile

**Risk Score Formula for Vendors:**
```
final_risk = 0.6 √ó learned_risk + 0.4 √ó regional_risk
```

**Regional Risk Calculation:**
```
regional_risk = 0.3√óbase + 0.4√ógeopolitical + 0.2√ónatural_disaster + 0.1√ó(1-infrastructure)
```

The weights reflect that geopolitical risk (tariffs, sanctions, instability) tends to be more impactful than infrastructure quality.


In [None]:
# =============================================================================
# RISK SCORE COMPUTATION
# =============================================================================
# We combine learned embeddings with domain knowledge (regional risk factors)
# to produce interpretable risk scores for each node.
# =============================================================================

def compute_risk_scores(model, data, regions_df, mappings):
    """
    Compute risk scores for all nodes with region risk propagation.
    
    The model's risk head produces a raw learned risk score.
    We then blend vendor risk with their region's risk profile.
    
    Returns:
        risk_scores: Dict mapping node_type -> risk tensor [0, 1]
        z_dict: Dict mapping node_type -> embedding tensor
    """
    model.eval()  # Disable dropout for inference
    
    with torch.no_grad():
        # Forward pass to get embeddings
        z_dict = model(data.x_dict, data.edge_index_dict)
        # Apply risk prediction heads
        risk_scores = model.predict_risk(z_dict)
    
    # -------------------------------------------------------------------------
    # REGIONAL RISK CALCULATION
    # -------------------------------------------------------------------------
    # Weighted combination of region risk factors
    # Weights based on typical supply chain risk impact:
    #   - Geopolitical (40%): Tariffs, sanctions, political instability
    #   - Base risk (30%): Overall country risk assessment
    #   - Natural disasters (20%): Earthquakes, typhoons, flooding
    #   - Infrastructure (10%): Ports, roads, logistics quality (inverted)
    region_risks = {}
    for _, row in regions_df.iterrows():
        if row['REGION_CODE'] in mappings['region']:
            region_risk = (
                row['BASE_RISK_SCORE'] * 0.3 + 
                row['GEOPOLITICAL_RISK'] * 0.4 +
                row['NATURAL_DISASTER_RISK'] * 0.2 + 
                (1 - row['INFRASTRUCTURE_SCORE']) * 0.1  # Invert: low infrastructure = high risk
            )
            region_risks[mappings['region'][row['REGION_CODE']]] = region_risk
    
    # -------------------------------------------------------------------------
    # BLEND VENDOR RISK WITH REGIONAL RISK
    # -------------------------------------------------------------------------
    # Final vendor risk = 60% learned + 40% regional
    # This ensures geographic factors are incorporated
    # Pure ML risk misses domain knowledge; blending adds interpretability and captures known risk factors
    if ('vendor', 'located_in', 'region') in data.edge_index_dict:
        edge_index = data[('vendor', 'located_in', 'region')].edge_index
        vendor_risk = risk_scores['vendor'].squeeze().cpu().numpy()
        
        for i in range(edge_index.shape[1]):
            vendor_idx = edge_index[0, i].item()
            region_idx = edge_index[1, i].item()
            if region_idx in region_risks:
                # Weighted blend of learned risk and regional risk
                vendor_risk[vendor_idx] = (
                    vendor_risk[vendor_idx] * 0.6 + 
                    region_risks[region_idx] * 0.4
                )
        
        risk_scores['vendor'] = torch.tensor(vendor_risk).unsqueeze(1)
    
    return risk_scores, z_dict

risk_scores, embeddings = compute_risk_scores(model, data, regions_df, mappings)

print("=" * 60)
print("RISK SCORES COMPUTED")
print("=" * 60)
print(f"{'Node Type':<15s} {'Min':>8s} {'Max':>8s} {'Mean':>8s} {'Std':>8s}")
print("-" * 60)
for node_type, scores in risk_scores.items():
    scores_np = scores.squeeze().cpu().numpy()
    print(f"{node_type:<15s} {scores_np.min():>8.3f} {scores_np.max():>8.3f} "
          f"{scores_np.mean():>8.3f} {scores_np.std():>8.3f}")
print("=" * 60)


### Inference Mode: From Embeddings to Risk Scores

Training and inference use the model differently:

| Aspect | Training | Inference |
|--------|----------|-----------|
| Mode | `model.train()` | `model.eval()` |
| Dropout | Active (30% of neurons randomly zeroed) | Disabled (all neurons active) |
| Gradients | Computed for backprop | Disabled (`torch.no_grad()`) |
| Output | Loss value | Embeddings + predictions |

**Why disable dropout at inference?** During training, dropout provides regularization by forcing the model to not rely on any single neuron. At inference, we want the full model capacity for the best predictions. The dropout rate (0.3) means we effectively trained an ensemble - at inference we use the averaged model.

**The risk blending formula** deserves attention:
```
final_risk = 0.6 √ó learned_risk + 0.4 √ó regional_risk
```

This isn't arbitrary. We weight learned risk higher (0.6) because it captures network effects the regional risk can't see. But we keep 40% regional risk because:
1. It provides interpretability ("high risk due to geopolitical factors")
2. It acts as a prior when the model has limited data for a vendor
3. It captures domain knowledge that may not be in the training signal


In [None]:
# =============================================================================
# RISK SCORE DISTRIBUTION VISUALIZATION
# =============================================================================

fig, axes = plt.subplots(2, 2, figsize=(12, 10))

# Color scheme for risk levels
risk_colors = {'LOW': '#27ae60', 'MEDIUM': '#f39c12', 'HIGH': '#e67e22', 'CRITICAL': '#e74c3c'}

# Plot 1: Risk score distributions by node type
ax = axes[0, 0]
for i, (node_type, scores) in enumerate(risk_scores.items()):
    scores_np = scores.squeeze().cpu().numpy()
    ax.hist(scores_np, bins=20, alpha=0.6, label=node_type, density=True)
ax.set_xlabel('Risk Score')
ax.set_ylabel('Density')
ax.set_title('Risk Score Distribution by Node Type', fontweight='bold')
ax.legend()
ax.axvline(x=0.25, color='green', linestyle='--', alpha=0.5, label='LOW/MEDIUM')
ax.axvline(x=0.5, color='orange', linestyle='--', alpha=0.5, label='MEDIUM/HIGH')
ax.axvline(x=0.75, color='red', linestyle='--', alpha=0.5, label='HIGH/CRITICAL')

# Plot 2: Risk category counts
ax = axes[0, 1]
def categorize_risk(score):
    if score < 0.25: return 'LOW'
    elif score < 0.5: return 'MEDIUM'
    elif score < 0.75: return 'HIGH'
    else: return 'CRITICAL'

category_counts = defaultdict(lambda: defaultdict(int))
for node_type, scores in risk_scores.items():
    for score in scores.squeeze().cpu().numpy():
        cat = categorize_risk(score)
        category_counts[node_type][cat] += 1

x = np.arange(len(risk_scores))
width = 0.2
categories = ['LOW', 'MEDIUM', 'HIGH', 'CRITICAL']
for i, cat in enumerate(categories):
    values = [category_counts[nt][cat] for nt in risk_scores.keys()]
    ax.bar(x + i*width, values, width, label=cat, color=risk_colors[cat])
ax.set_xticks(x + width * 1.5)
ax.set_xticklabels(risk_scores.keys())
ax.set_ylabel('Count')
ax.set_title('Risk Categories by Node Type', fontweight='bold')
ax.legend()

# Plot 3: Embedding visualization (PCA)
# Well-separated clusters in embedding space indicate the model learned meaningful node distinctions
ax = axes[1, 0]
# Combine all embeddings for visualization
all_emb = []
all_labels = []
all_risks = []
for node_type, emb in embeddings.items():
    emb_np = emb.cpu().numpy()
    all_emb.append(emb_np)
    all_labels.extend([node_type] * len(emb_np))
    if node_type in risk_scores:
        all_risks.extend(risk_scores[node_type].squeeze().cpu().numpy().tolist())
    else:
        all_risks.extend([0.5] * len(emb_np))

all_emb = np.vstack(all_emb)
pca = PCA(n_components=2)
emb_2d = pca.fit_transform(all_emb)

# Plot by node type with risk as color intensity
node_type_colors = {'vendor': '#2ecc71', 'material': '#3498db', 'region': '#e74c3c', 'external': '#9b59b6'}
for node_type in set(all_labels):
    mask = [l == node_type for l in all_labels]
    ax.scatter(emb_2d[mask, 0], emb_2d[mask, 1], c=node_type_colors.get(node_type, 'gray'), 
               alpha=0.6, label=node_type, s=30)
ax.set_xlabel(f'PC1 ({pca.explained_variance_ratio_[0]*100:.1f}% var)')
ax.set_ylabel(f'PC2 ({pca.explained_variance_ratio_[1]*100:.1f}% var)')
ax.set_title('Node Embeddings (PCA)', fontweight='bold')
ax.legend()

# Plot 4: High-risk nodes
ax = axes[1, 1]
high_risk_data = []
for node_type, scores in risk_scores.items():
    scores_np = scores.squeeze().cpu().numpy()
    high_risk_count = np.sum(scores_np >= 0.5)  # HIGH or CRITICAL
    high_risk_data.append({'type': node_type, 'high_risk': high_risk_count, 
                           'total': len(scores_np), 'pct': high_risk_count/len(scores_np)*100})

df_hr = pd.DataFrame(high_risk_data)
bars = ax.barh(df_hr['type'], df_hr['pct'], color=['#e74c3c' if p > 30 else '#f39c12' if p > 15 else '#27ae60' 
                                                    for p in df_hr['pct']])
ax.set_xlabel('% High Risk (score ‚â• 0.5)')
ax.set_title('High-Risk Node Proportion', fontweight='bold')
ax.set_xlim(0, 100)
for i, (p, t) in enumerate(zip(df_hr['pct'], df_hr['total'])):
    ax.text(p + 2, i, f'{p:.1f}% ({int(p*t/100)}/{t})', va='center')

plt.tight_layout()
plt.savefig('/tmp/risk_analysis.png', dpi=150, bbox_inches='tight')
plt.show()

print("\nüéØ Key Findings:")
for row in high_risk_data:
    if row['pct'] > 20:
        print(f"   ‚ö†Ô∏è  {row['type']}: {row['pct']:.1f}% are HIGH/CRITICAL risk")


## 7. Identify Hidden Dependencies and Bottlenecks

### Link Prediction for Hidden Dependencies

The trained model can predict **missing edges** - supplier relationships that exist but aren't in our master data:

```
P(edge exists) = sigmoid(z_external ¬∑ z_vendor)
```

If two nodes have similar learned embeddings (high dot product), they're likely connected even if we don't have explicit records.

### Bottleneck Detection

A **bottleneck** (single point of failure) is an external supplier that many vendors depend on. If this supplier fails:
- All dependent vendors are impacted
- Risk cascades through the network
- Alternative suppliers may not exist

**Impact Score**: `min(1.0, dependent_count / 10)`
- 10+ dependents = maximum impact (1.0)
- Scales linearly below that


In [None]:
# =============================================================================
# HIDDEN LINK PREDICTION
# =============================================================================
# Use the trained model to discover supplier relationships that aren't in our
# master data. This reveals hidden Tier-2+ dependencies.
#
# Method: For every (external, vendor) pair, compute the link probability.
# If probability > threshold, we predict a hidden relationship exists.
# =============================================================================

def predict_hidden_links(model, embeddings, mappings, threshold=0.3):
    """
    Predict hidden links between external suppliers and vendors.
    
    Args:
        model: Trained GNN model with predict_link method
        embeddings: Dict of node embeddings from forward pass
        mappings: Node ID to index mappings
        threshold: Minimum probability to consider a link (0.3 = 30% confidence)
        
    Returns:
        List of predicted links with probabilities
    """
    model.eval()
    predicted_links = []
    z_external = embeddings['external']
    z_vendor = embeddings['vendor']
    
    with torch.no_grad():
        # Iterate over all possible (external, vendor) pairs
        # For large graphs, this would need batching or sampling
        for ext_idx in range(z_external.shape[0]):
            for ven_idx in range(z_vendor.shape[0]):
                # Dot-product similarity ‚Üí probability via sigmoid
                prob = model.predict_link(
                    z_external[ext_idx:ext_idx+1], 
                    z_vendor[ven_idx:ven_idx+1]
                ).item()
                
                if prob >= threshold:
                    predicted_links.append({
                        'external_idx': ext_idx, 
                        'vendor_idx': ven_idx, 
                        'probability': prob
                    })
    
    return predicted_links


def identify_bottlenecks(predicted_links, mappings, threshold_dependents=2):
    """
    Identify bottlenecks - external suppliers that serve multiple vendors.
    
    A bottleneck is a single point of failure. If an external supplier
    serves many vendors, its failure would have outsized impact.
    
    Args:
        predicted_links: List of predicted (external ‚Üí vendor) links
        mappings: Node ID mappings
        threshold_dependents: Minimum vendors to be considered a bottleneck
        
    Returns:
        List of bottlenecks sorted by dependent count (descending)
    """
    # Count how many vendors each external supplier serves
    external_dependents = defaultdict(list)
    for link in predicted_links:
        external_dependents[link['external_idx']].append(link['vendor_idx'])
    
    # Flag suppliers with multiple dependents
    bottlenecks = []
    for ext_idx, dependents in external_dependents.items():
        if len(dependents) >= threshold_dependents:
            bottlenecks.append({
                'external_idx': ext_idx, 
                'dependent_count': len(dependents), 
                'dependent_vendors': dependents
            })
    
    # Sort by impact (most dependents first)
    bottlenecks.sort(key=lambda x: x['dependent_count'], reverse=True)
    return bottlenecks


# Execute link prediction
print("=" * 60)
print("HIDDEN LINK PREDICTION")
print("=" * 60)
print("Scanning all (external, vendor) pairs for hidden relationships...")

# Using 0.3 threshold favors recall over precision - missing a real dependency is costlier than a false positive
predicted_links = predict_hidden_links(model, embeddings, mappings, threshold=0.3)

# Analyze predictions
total_possible = len(mappings['external']) * len(mappings['vendor'])
print(f"Total possible pairs: {total_possible:,}")
print(f"Predicted links (‚â•30% confidence): {len(predicted_links):,}")
print(f"Link density: {len(predicted_links)/total_possible*100:.2f}%")

# Identify bottlenecks
bottlenecks = identify_bottlenecks(predicted_links, mappings, threshold_dependents=2)

print("-" * 60)
print(f"üö® BOTTLENECKS IDENTIFIED: {len(bottlenecks)}")
print("-" * 60)

idx_to_external = {v: k for k, v in mappings['external'].items()}
idx_to_vendor = {v: k for k, v in mappings['vendor'].items()}

print("\nTop 5 Single Points of Failure:")
for i, bn in enumerate(bottlenecks[:5]):
    ext_name = idx_to_external.get(bn['external_idx'], f"External-{bn['external_idx']}")
    print(f"  {i+1}. {ext_name[:40]:<40s} ‚Üí {bn['dependent_count']:>3d} vendors depend on this supplier")

print("=" * 60)


In [None]:
# =============================================================================
# BOTTLENECK AND LINK PREDICTION VISUALIZATION
# =============================================================================

fig, axes = plt.subplots(1, 3, figsize=(15, 5))

# Plot 1: Link probability distribution
ax = axes[0]
link_probs = [l['probability'] for l in predicted_links]
ax.hist(link_probs, bins=20, color='#3498db', edgecolor='white', alpha=0.8)
ax.axvline(x=0.5, color='#e74c3c', linestyle='--', linewidth=2, label='Medium confidence')
ax.axvline(x=0.7, color='#27ae60', linestyle='--', linewidth=2, label='High confidence')
ax.set_xlabel('Link Probability')
ax.set_ylabel('Count')
ax.set_title('Predicted Link Confidence Distribution', fontweight='bold')
ax.legend()

# Plot 2: Bottleneck impact distribution
ax = axes[1]
if bottlenecks:
    dependent_counts = [bn['dependent_count'] for bn in bottlenecks]
    ax.hist(dependent_counts, bins=range(2, max(dependent_counts)+2), 
            color='#e74c3c', edgecolor='white', alpha=0.8)
    ax.set_xlabel('Number of Dependent Vendors')
    ax.set_ylabel('Number of External Suppliers')
    ax.set_title('Bottleneck Size Distribution', fontweight='bold')
else:
    ax.text(0.5, 0.5, 'No bottlenecks identified', ha='center', va='center', fontsize=12)
    ax.set_title('Bottleneck Size Distribution', fontweight='bold')

# Plot 3: Top bottlenecks bar chart
ax = axes[2]
if bottlenecks:
    top_bn = bottlenecks[:10]
    names = [idx_to_external.get(bn['external_idx'], f"Ext-{bn['external_idx']}")[:20] for bn in top_bn]
    counts = [bn['dependent_count'] for bn in top_bn]
    colors = ['#e74c3c' if c >= 5 else '#f39c12' if c >= 3 else '#27ae60' for c in counts]
    ax.barh(range(len(names)), counts, color=colors)
    ax.set_yticks(range(len(names)))
    ax.set_yticklabels(names, fontsize=8)
    ax.set_xlabel('Dependent Vendor Count')
    ax.set_title('Top 10 Bottlenecks', fontweight='bold')
    ax.invert_yaxis()

plt.tight_layout()
plt.savefig('/tmp/bottleneck_analysis.png', dpi=150, bbox_inches='tight')
plt.show()

# Summary statistics
print("\nüìä Link Prediction Summary:")
print(f"   Low confidence (30-50%): {sum(1 for l in predicted_links if 0.3 <= l['probability'] < 0.5)}")
print(f"   Medium confidence (50-70%): {sum(1 for l in predicted_links if 0.5 <= l['probability'] < 0.7)}")
print(f"   High confidence (‚â•70%): {sum(1 for l in predicted_links if l['probability'] >= 0.7)}")


### Interpreting Link Predictions: What the Probabilities Mean

The link prediction scores require careful interpretation:

**What the model is actually computing:**
```python
P(link exists) = sigmoid(embedding_external ¬∑ embedding_vendor)
```

This dot product measures **similarity in embedding space**. High similarity ‚Üí high probability. But what does "similar" mean here?

**Two nodes are similar if they have similar graph neighborhoods.** An external supplier and vendor are predicted to be connected if:
- The external supplier ships to vendors that look like this vendor
- The vendor receives from external suppliers that look like this one
- They share common patterns in their 2-hop neighborhoods

**Thresholds and confidence:**

| Probability | Evidence | Interpretation |
|-------------|----------|----------------|
| 0.30 - 0.50 | WEAK | Model sees some similarity; could be noise |
| 0.50 - 0.70 | MODERATE | Likely real relationship; worth investigating |
| 0.70 - 1.00 | STRONG | High confidence hidden dependency |

**Important caveat**: A high probability doesn't prove a relationship exists - it means the model found strong structural evidence. Always validate high-confidence predictions against business knowledge or additional data sources.


## 8. Write Results to Snowflake

### Output Tables

| Table | Description | Key Columns |
|-------|-------------|-------------|
| `RISK_SCORES` | Risk scores for all nodes | NODE_ID, NODE_TYPE, RISK_SCORE, RISK_CATEGORY, EMBEDDING |
| `PREDICTED_LINKS` | Inferred supplier relationships | SOURCE_NODE_ID, TARGET_NODE_ID, PROBABILITY, EVIDENCE_STRENGTH |
| `BOTTLENECKS` | Single points of failure | NODE_ID, DEPENDENT_COUNT, IMPACT_SCORE, MITIGATION_STATUS |

The embeddings are stored as JSON arrays for potential downstream use (e.g., similarity search, clustering).


In [None]:
# =============================================================================
# HELPER FUNCTIONS FOR DATA OUTPUT
# =============================================================================

def categorize_risk(score):
    """
    Categorize risk score into discrete buckets for business interpretation.
    
    Thresholds:
        < 0.25: LOW - acceptable risk, standard monitoring
        < 0.50: MEDIUM - elevated risk, enhanced monitoring recommended
        < 0.75: HIGH - significant risk, mitigation planning required
        >= 0.75: CRITICAL - immediate attention required
    """
    if score < 0.25: return 'LOW'
    elif score < 0.5: return 'MEDIUM'
    elif score < 0.75: return 'HIGH'
    else: return 'CRITICAL'

def safe_float(val, default=0.0):
    """Convert value to float, handling inf/nan edge cases from model output."""
    if val is None or np.isnan(val) or np.isinf(val): 
        return default
    return float(val)


# =============================================================================
# WRITE RISK SCORES TO SNOWFLAKE
# =============================================================================

def write_risk_scores(session, risk_scores, embeddings, mappings, model_version):
    """
    Write risk scores to RISK_SCORES table.
    
    Each record contains:
        - NODE_ID: Business identifier
        - NODE_TYPE: SUPPLIER, PART, or EXTERNAL_SUPPLIER
        - RISK_SCORE: Continuous score [0, 1]
        - RISK_CATEGORY: Discrete category (LOW/MEDIUM/HIGH/CRITICAL)
        - CONFIDENCE: Model confidence in the prediction
        - EMBEDDING: 32-dimensional vector as JSON array
        - MODEL_VERSION: For tracking which model produced the results
    """
    # Delete previous results for this model version (idempotent updates)
    session.sql(f"DELETE FROM RISK_SCORES WHERE MODEL_VERSION = '{model_version}'").collect()
    
    records = []
    
    # Reverse mappings: index ‚Üí business ID
    idx_to_vendor = {v: k for k, v in mappings['vendor'].items()}
    idx_to_material = {v: k for k, v in mappings['material'].items()}
    idx_to_external = {v: k for k, v in mappings['external'].items()}
    
    # Process vendor risk scores (highest confidence - most data available)
    # Confidence reflects data quality: internal vendors have rich data (0.85), externals are inferred (0.75)
    if 'vendor' in risk_scores:
        scores = risk_scores['vendor'].squeeze().cpu().numpy()
        emb = embeddings['vendor'].cpu().numpy()
        for idx, score in enumerate(scores):
            records.append({
                'NODE_ID': idx_to_vendor.get(idx, f'V-{idx}'), 
                'NODE_TYPE': 'SUPPLIER',
                'RISK_SCORE': safe_float(score), 
                'RISK_CATEGORY': categorize_risk(score),
                'CONFIDENCE': 0.85,  # High confidence for internal vendors
                'EMBEDDING': str(emb[idx].tolist()), 
                'MODEL_VERSION': model_version
            })
    
    # Process material risk scores (slightly lower confidence)
    if 'material' in risk_scores:
        scores = risk_scores['material'].squeeze().cpu().numpy()
        emb = embeddings['material'].cpu().numpy()
        for idx, score in enumerate(scores):
            records.append({
                'NODE_ID': idx_to_material.get(idx, f'M-{idx}'), 
                'NODE_TYPE': 'PART',
                'RISK_SCORE': safe_float(score), 
                'RISK_CATEGORY': categorize_risk(score),
                'CONFIDENCE': 0.80,  # Less direct risk data for materials
                'EMBEDDING': str(emb[idx].tolist()), 
                'MODEL_VERSION': model_version
            })
    
    # Process external supplier risk scores (lowest confidence - inferred from trade data)
    if 'external' in risk_scores:
        scores = risk_scores['external'].squeeze().cpu().numpy()
        emb = embeddings['external'].cpu().numpy()
        for idx, score in enumerate(scores):
            records.append({
                'NODE_ID': idx_to_external.get(idx, f'E-{idx}'), 
                'NODE_TYPE': 'EXTERNAL_SUPPLIER',
                'RISK_SCORE': safe_float(score), 
                'RISK_CATEGORY': categorize_risk(score),
                'CONFIDENCE': 0.75,  # Lower confidence for discovered suppliers
                'EMBEDDING': str(emb[idx].tolist()), 
                'MODEL_VERSION': model_version
            })
    
    # Write to Snowflake
    if records:
        df = pd.DataFrame(records)
        session.write_pandas(df, 'RISK_SCORES', auto_create_table=False, overwrite=False)
    
    return len(records)


print("=" * 60)
print("WRITING RESULTS TO SNOWFLAKE")
print("=" * 60)

print("\nüìù Writing risk scores...")
num_scores = write_risk_scores(session, risk_scores, embeddings, mappings, MODEL_VERSION)
print(f"   ‚úÖ Wrote {num_scores:,} risk score records")


In [None]:
# =============================================================================
# WRITE PREDICTED LINKS AND BOTTLENECKS
# =============================================================================

def write_predicted_links(session, predicted_links, mappings, model_version):
    """
    Write predicted links to PREDICTED_LINKS table.
    
    Evidence strength categories:
        WEAK: 30-50% probability - might be noise
        MODERATE: 50-70% probability - likely real relationship
        STRONG: 70%+ probability - high confidence hidden dependency
    """
    session.sql(f"DELETE FROM PREDICTED_LINKS WHERE MODEL_VERSION = '{model_version}'").collect()
    
    idx_to_external = {v: k for k, v in mappings['external'].items()}
    idx_to_vendor = {v: k for k, v in mappings['vendor'].items()}
    
    records = []
    for link in predicted_links:
        # Categorize evidence strength based on probability
        if link['probability'] < 0.5:
            evidence = 'WEAK'
        elif link['probability'] < 0.7:
            evidence = 'MODERATE'
        else:
            evidence = 'STRONG'
            
        records.append({
            'SOURCE_NODE_ID': idx_to_external.get(link['external_idx'], f"E-{link['external_idx']}"),
            'SOURCE_NODE_TYPE': 'EXTERNAL_SUPPLIER',
            'TARGET_NODE_ID': idx_to_vendor.get(link['vendor_idx'], f"V-{link['vendor_idx']}"),
            'TARGET_NODE_TYPE': 'SUPPLIER', 
            'LINK_TYPE': 'INFERRED_SUPPLIES',
            'PROBABILITY': safe_float(link['probability']), 
            'EVIDENCE_STRENGTH': evidence, 
            'MODEL_VERSION': model_version
        })
    
    if records:
        df = pd.DataFrame(records)
        session.write_pandas(df, 'PREDICTED_LINKS', auto_create_table=False, overwrite=False)
    return len(records)


def write_bottlenecks(session, bottlenecks, mappings):
    """
    Write identified bottlenecks to BOTTLENECKS table.
    
    Impact scoring:
        - Scales linearly with dependent count
        - Capped at 1.0 for 10+ dependents
        - Used for prioritizing mitigation efforts
    """
    # Only delete unmitigated bottlenecks (preserve mitigation history)
    # Preserving mitigated records maintains audit trail and prevents re-alerting on resolved issues
    session.sql("DELETE FROM BOTTLENECKS WHERE MITIGATION_STATUS = 'UNMITIGATED'").collect()
    
    idx_to_external = {v: k for k, v in mappings['external'].items()}
    idx_to_vendor = {v: k for k, v in mappings['vendor'].items()}
    
    records = []
    for bn in bottlenecks:
        ext_name = idx_to_external.get(bn['external_idx'], f"External-{bn['external_idx']}")
        dependent_names = [idx_to_vendor.get(v, f"V-{v}") for v in bn['dependent_vendors']]
        
        # Impact score: linear scale capped at 1.0
        impact = min(1.0, bn['dependent_count'] / 10.0)
        
        records.append({
            'NODE_ID': ext_name, 
            'NODE_TYPE': 'EXTERNAL_SUPPLIER', 
            'DEPENDENT_COUNT': bn['dependent_count'],
            'DEPENDENT_NODES': str(dependent_names),  # JSON array of vendor IDs
            'IMPACT_SCORE': safe_float(impact),
            'DESCRIPTION': f"External supplier '{ext_name}' is a single point of failure for {bn['dependent_count']} vendors",
            'MITIGATION_STATUS': 'UNMITIGATED'  # Default status for new bottlenecks
        })
    
    if records:
        df = pd.DataFrame(records)
        session.write_pandas(df, 'BOTTLENECKS', auto_create_table=False, overwrite=False)
    return len(records)


print("\nüìù Writing predicted links...")
num_links = write_predicted_links(session, predicted_links, mappings, MODEL_VERSION)
print(f"   ‚úÖ Wrote {num_links:,} predicted link records")

print("\nüìù Writing bottlenecks...")
num_bn = write_bottlenecks(session, bottlenecks, mappings)
print(f"   ‚úÖ Wrote {num_bn:,} bottleneck records")

print("\n" + "=" * 60)
print("‚úÖ ALL RESULTS WRITTEN TO SNOWFLAKE")
print("=" * 60)


## 9. Summary & Verification

Query the output tables to verify results and provide a final summary of the analysis.


In [None]:
# =============================================================================
# FINAL SUMMARY & VERIFICATION
# =============================================================================
# Querying Snowflake tables validates the write succeeded and provides a sanity check on results

print("‚ïî" + "‚ïê"*58 + "‚ïó")
print("‚ïë" + " GNN SUPPLY CHAIN RISK ANALYSIS - FINAL SUMMARY ".center(58) + "‚ïë")
print("‚ïö" + "‚ïê"*58 + "‚ïù")

# Query and display risk score summary from Snowflake
risk_summary = session.sql("""
    SELECT 
        NODE_TYPE, 
        COUNT(*) as COUNT, 
        ROUND(AVG(RISK_SCORE), 3) as AVG_RISK,
        SUM(CASE WHEN RISK_CATEGORY = 'CRITICAL' THEN 1 ELSE 0 END) as CRITICAL,
        SUM(CASE WHEN RISK_CATEGORY = 'HIGH' THEN 1 ELSE 0 END) as HIGH,
        SUM(CASE WHEN RISK_CATEGORY = 'MEDIUM' THEN 1 ELSE 0 END) as MEDIUM,
        SUM(CASE WHEN RISK_CATEGORY = 'LOW' THEN 1 ELSE 0 END) as LOW
    FROM RISK_SCORES 
    WHERE MODEL_VERSION = '{}'
    GROUP BY NODE_TYPE
    ORDER BY AVG_RISK DESC
""".format(MODEL_VERSION)).to_pandas()

print("\nüìä RISK SCORE SUMMARY BY NODE TYPE")
print("-" * 60)
print(risk_summary.to_string(index=False))

# Top bottlenecks from Snowflake
top_bottlenecks = session.sql("""
    SELECT 
        NODE_ID, 
        DEPENDENT_COUNT, 
        ROUND(IMPACT_SCORE, 2) as IMPACT_SCORE 
    FROM BOTTLENECKS
    WHERE MITIGATION_STATUS = 'UNMITIGATED'
    ORDER BY IMPACT_SCORE DESC 
    LIMIT 5
""").to_pandas()

print("\nüö® TOP 5 BOTTLENECKS (Single Points of Failure)")
print("-" * 60)
if len(top_bottlenecks) > 0:
    print(top_bottlenecks.to_string(index=False))
else:
    print("No bottlenecks identified above threshold")

# High confidence predicted links
high_conf_links = session.sql("""
    SELECT 
        SOURCE_NODE_ID, 
        TARGET_NODE_ID, 
        ROUND(PROBABILITY, 3) as PROBABILITY,
        EVIDENCE_STRENGTH
    FROM PREDICTED_LINKS
    WHERE PROBABILITY >= 0.5 
    AND MODEL_VERSION = '{}'
    ORDER BY PROBABILITY DESC 
    LIMIT 10
""".format(MODEL_VERSION)).to_pandas()

print("\nüîó TOP 10 HIGH-CONFIDENCE PREDICTED LINKS")
print("-" * 60)
if len(high_conf_links) > 0:
    print(high_conf_links.to_string(index=False))
else:
    print("No high-confidence links above 50% threshold")

# Final statistics
print("\n" + "‚ïê" * 60)
print("üìà ANALYSIS STATISTICS")
print("-" * 60)
print(f"   Model Version: {MODEL_VERSION}")
print(f"   Total nodes analyzed: {sum(len(m) for m in mappings.values()):,}")
print(f"   Total edges in graph: {sum(e.shape[1] for e in edge_indices.values()):,}")
print(f"   Risk scores generated: {num_scores:,}")
print(f"   Hidden links predicted: {num_links:,}")
print(f"   Bottlenecks identified: {num_bn:,}")
print("‚ïê" * 60)

print("\n‚úÖ Analysis complete! Results are available in Snowflake tables:")
print("   ‚Ä¢ RISK_SCORES")
print("   ‚Ä¢ PREDICTED_LINKS")  
print("   ‚Ä¢ BOTTLENECKS")
