# Heterogeneous Graph Neural Network - Network Criticality

Implements a Heterogeneous GNN using PyTorch Geometric on Snowpark Container Services.
This notebook builds a dynamic graph with flights, aircraft, crew, passengers, and airports
as nodes, then uses attention-based message passing to compute network criticality scores.

**Target:** `network_disruption_label` (downstream cascade impact)
**Algorithm:** HGTConv (Heterogeneous Graph Transformer)
**Output:** `IROP_GNN_RISK.ML_PROCESSING.GNN_FLIGHT_EMBEDDINGS`

**Runtime:** SPCS GPU Compute Pool (IROP_GNN_RISK_GPU_POOL)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import HeteroData
from torch_geometric.nn import HGTConv, Linear
import numpy as np
import pandas as pd
from snowflake.snowpark import Session
from snowflake.snowpark.functions import col
from snowflake.ml.registry import Registry
import uuid

In [None]:
session = get_active_session()
session.use_database('IROP_GNN_RISK')
session.use_schema('ATOMIC')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Connected: {session.get_current_database()}.{session.get_current_schema()}")
print(f"Device: {device}")

In [None]:
flights_pd = session.table('FLIGHT_INSTANCE').to_pandas()
rotations_pd = session.table('AIRCRAFT_ROTATION').to_pandas()
crew_pd = session.table('CREW_DUTY_PERIOD').to_pandas()
assignments_pd = session.table('CREW_ASSIGNMENT').to_pandas()
airports_pd = session.table('AIRPORT_CAPABILITY').to_pandas()

print(f"Loaded: {len(flights_pd)} flights, {len(rotations_pd)} rotations")
print(f"        {len(crew_pd)} crew duties, {len(assignments_pd)} assignments")
print(f"        {len(airports_pd)} airports")

In [None]:
flight_idx = {fk: i for i, fk in enumerate(flights_pd['FLIGHT_KEY'].unique())}
tail_idx = {tn: i for i, tn in enumerate(rotations_pd['TAIL_NUMBER'].unique())}
duty_idx = {di: i for i, di in enumerate(crew_pd['DUTY_ID'].unique())}
airport_idx = {ac: i for i, ac in enumerate(airports_pd['STATION_CODE'].unique())}

print(f"Node counts: Flights={len(flight_idx)}, Tails={len(tail_idx)}")
print(f"             Duties={len(duty_idx)}, Airports={len(airport_idx)}")

In [None]:
def build_hetero_graph():
    data = HeteroData()
    
    flight_features = flights_pd[['CURRENT_DELAY_DEPARTURE', 'TURN_BUFFER_MINUTES', 
                                  'PAX_COUNT', 'CONNECTING_PAX_PCT', 'REVENUE_AT_RISK_USD',
                                  'DELAY_RISK_SCORE', 'TURN_SUCCESS_PROB', 'MISCONNECT_PROB']].fillna(0).values
    data['flight'].x = torch.tensor(flight_features, dtype=torch.float)
    
    tail_features = []
    for tail in tail_idx.keys():
        tail_data = rotations_pd[rotations_pd['TAIL_NUMBER'] == tail].iloc[0] if len(rotations_pd[rotations_pd['TAIL_NUMBER'] == tail]) > 0 else None
        if tail_data is not None:
            tail_features.append([
                float(tail_data.get('AIRCRAFT_AGE_YEARS', 0) or 0),
                float(tail_data.get('UTILIZATION_HOURS_24H', 0) or 0),
                float(1 if tail_data.get('MEL_APU_FLAG', False) else 0),
                float(tail_data.get('AOG_RISK_SCORE', 0) or 0)
            ])
        else:
            tail_features.append([0, 0, 0, 0])
    data['aircraft'].x = torch.tensor(tail_features, dtype=torch.float)
    
    duty_features = crew_pd[['FDP_LIMIT_MINUTES', 'FDP_TIME_USED_MINUTES', 
                             'FDP_REMAINING_MINUTES', 'NUM_SEGMENTS',
                             'CREW_TIMEOUT_RISK_SCORE']].fillna(0).values
    data['crew'].x = torch.tensor(duty_features, dtype=torch.float)
    
    airport_features = airports_pd[['GATE_COUNT', 'ATC_CONGESTION_INDEX',
                                    'AIRPORT_DISRUPTION_INDEX']].fillna(0).values
    data['airport'].x = torch.tensor(airport_features, dtype=torch.float)
    
    operated_by_src, operated_by_dst = [], []
    for _, row in rotations_pd.iterrows():
        if row['FLIGHT_KEY'] in flight_idx and row['TAIL_NUMBER'] in tail_idx:
            operated_by_src.append(flight_idx[row['FLIGHT_KEY']])
            operated_by_dst.append(tail_idx[row['TAIL_NUMBER']])
    data['flight', 'operated_by', 'aircraft'].edge_index = torch.tensor([operated_by_src, operated_by_dst], dtype=torch.long)
    
    next_leg_src, next_leg_dst = [], []
    for _, row in rotations_pd.iterrows():
        if row['FLIGHT_KEY'] in flight_idx and row['NEXT_FLIGHT_KEY'] in flight_idx:
            next_leg_src.append(flight_idx[row['FLIGHT_KEY']])
            next_leg_dst.append(flight_idx[row['NEXT_FLIGHT_KEY']])
    data['flight', 'next_leg', 'flight'].edge_index = torch.tensor([next_leg_src, next_leg_dst], dtype=torch.long)
    
    assigned_src, assigned_dst = [], []
    for _, row in assignments_pd.iterrows():
        if row['FLIGHT_KEY'] in flight_idx and row['DUTY_ID'] in duty_idx:
            assigned_src.append(flight_idx[row['FLIGHT_KEY']])
            assigned_dst.append(duty_idx[row['DUTY_ID']])
    data['flight', 'assigned_to', 'crew'].edge_index = torch.tensor([assigned_src, assigned_dst], dtype=torch.long)
    
    departs_src, departs_dst, arrives_src, arrives_dst = [], [], [], []
    for _, row in flights_pd.iterrows():
        if row['FLIGHT_KEY'] in flight_idx:
            if row['DEPARTURE_STATION'] in airport_idx:
                departs_src.append(flight_idx[row['FLIGHT_KEY']])
                departs_dst.append(airport_idx[row['DEPARTURE_STATION']])
            if row['ARRIVAL_STATION'] in airport_idx:
                arrives_src.append(flight_idx[row['FLIGHT_KEY']])
                arrives_dst.append(airport_idx[row['ARRIVAL_STATION']])
    data['flight', 'departs_from', 'airport'].edge_index = torch.tensor([departs_src, departs_dst], dtype=torch.long)
    data['flight', 'arrives_at', 'airport'].edge_index = torch.tensor([arrives_src, arrives_dst], dtype=torch.long)
    
    return data

hetero_data = build_hetero_graph()
print(f"Graph built: {hetero_data}")
print(f"Node types: {hetero_data.node_types}")
print(f"Edge types: {hetero_data.edge_types}")

In [None]:
class HGNNNetworkCriticality(nn.Module):
    def __init__(self, hidden_dim=64, num_heads=4, num_layers=2):
        super().__init__()
        
        self.node_encoders = nn.ModuleDict({
            'flight': Linear(8, hidden_dim),
            'aircraft': Linear(4, hidden_dim),
            'crew': Linear(5, hidden_dim),
            'airport': Linear(3, hidden_dim)
        })
        
        self.convs = nn.ModuleList()
        for _ in range(num_layers):
            conv = HGTConv(
                in_channels=hidden_dim,
                out_channels=hidden_dim,
                metadata=(
                    ['flight', 'aircraft', 'crew', 'airport'],
                    [('flight', 'operated_by', 'aircraft'),
                     ('flight', 'next_leg', 'flight'),
                     ('flight', 'assigned_to', 'crew'),
                     ('flight', 'departs_from', 'airport'),
                     ('flight', 'arrives_at', 'airport')]
                ),
                heads=num_heads
            )
            self.convs.append(conv)
        
        self.criticality_head = nn.Sequential(
            Linear(hidden_dim, 32),
            nn.ReLU(),
            Linear(32, 1),
            nn.Sigmoid()
        )
        
    def forward(self, x_dict, edge_index_dict):
        h_dict = {ntype: encoder(x_dict[ntype]) for ntype, encoder in self.node_encoders.items()}
        
        for conv in self.convs:
            h_dict = conv(h_dict, edge_index_dict)
            h_dict = {k: F.relu(v) for k, v in h_dict.items()}
        
        criticality = self.criticality_head(h_dict['flight']) * 100
        
        return h_dict, criticality

model = HGNNNetworkCriticality(hidden_dim=64, num_heads=4, num_layers=2).to(device)
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

In [None]:
target_criticality = torch.tensor(
    flights_pd['NETWORK_CRITICALITY_SCORE'].fillna(50).values,
    dtype=torch.float
).unsqueeze(1).to(device)

hetero_data = hetero_data.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = nn.MSELoss()

model.train()
for epoch in range(100):
    optimizer.zero_grad()
    h_dict, pred_criticality = model(hetero_data.x_dict, hetero_data.edge_index_dict)
    loss = criterion(pred_criticality, target_criticality)
    loss.backward()
    optimizer.step()
    
    if (epoch + 1) % 20 == 0:
        print(f"Epoch {epoch+1}: Loss = {loss.item():.4f}")

In [None]:
model.eval()
with torch.no_grad():
    embeddings_dict, criticality_scores = model(hetero_data.x_dict, hetero_data.edge_index_dict)
    
flight_embeddings = embeddings_dict['flight'].cpu().numpy()
criticality_scores = criticality_scores.cpu().numpy().flatten()

print(f"Flight embeddings shape: {flight_embeddings.shape}")
print(f"Criticality scores range: [{criticality_scores.min():.1f}, {criticality_scores.max():.1f}]")

In [None]:
flight_keys = list(flight_idx.keys())
tail_numbers = [rotations_pd[rotations_pd['FLIGHT_KEY'] == fk]['TAIL_NUMBER'].iloc[0] 
                if len(rotations_pd[rotations_pd['FLIGHT_KEY'] == fk]) > 0 else None 
                for fk in flight_keys]

next_leg_counts = rotations_pd.groupby('FLIGHT_KEY')['NEXT_FLIGHT_KEY'].apply(
    lambda x: x.notna().sum()
).to_dict()

output_data = []
for i, fk in enumerate(flight_keys):
    output_data.append({
        'EMBEDDING_ID': str(uuid.uuid4())[:8].upper(),
        'FLIGHT_KEY': fk,
        'TAIL_NUMBER': tail_numbers[i],
        'SNAPSHOT_TS': pd.Timestamp.now(),
        'GNN_EMBEDDING': flight_embeddings[i].tolist(),
        'GNN_NETWORK_CRITICALITY': float(criticality_scores[i]),
        'ATTENTION_WEIGHTS': None,
        'DOWNLINE_LEGS_AFFECTED_COUNT': next_leg_counts.get(fk, 0),
        'MODEL_VERSION': 'v1.0'
    })

output_df = session.create_dataframe(output_data)
session.use_schema('ML_PROCESSING')
output_df.write.mode('overwrite').save_as_table('GNN_FLIGHT_EMBEDDINGS')
print(f"Saved {len(output_data)} GNN embeddings to ML_PROCESSING.GNN_FLIGHT_EMBEDDINGS")

In [None]:
torch.save(model.state_dict(), '/tmp/hgnn_model.pt')

print("Model saved. To register in Snowflake ML Registry:")
print("  1. Upload model artifacts to stage")
print("  2. Use Registry.log_model() with custom model class")
print("  3. Deploy for inference via Model Registry")

print(f"\nTop 10 flights by network criticality:")
top_flights = sorted(zip(flight_keys, criticality_scores), key=lambda x: x[1], reverse=True)[:10]
for fk, score in top_flights:
    print(f"  {fk}: {score:.1f}")