In [1]:
import json
import statistics
from pathlib import Path
from typing import Dict, List, Optional, Sequence, Union

import networkx as nx

Numeric = Union[int, float]

In [2]:
def load_json(path: Path) -> dict:
    """Load a JSON file into a dictionary."""
    with path.open("r", encoding="utf-8") as handle:
        return json.load(handle)


def _flatten_curve_to_scalars(
    x_values: Sequence[Numeric],
    y_values: Sequence[Numeric],
    prefix: str,
) -> Dict[str, float]:
    """Convert curve data into scalar features for ML models."""
    if not x_values or not y_values or len(x_values) != len(y_values):
        return {}
    
    x_vals = [float(v) for v in x_values if isinstance(v, (int, float))]
    y_vals = [float(v) for v in y_values if isinstance(v, (int, float))]
    
    if not x_vals or not y_vals or len(x_vals) != len(y_vals):
        return {}
    
    features: Dict[str, float] = {
        f"{prefix}_min_x": min(x_vals),
        f"{prefix}_max_x": max(x_vals),
        f"{prefix}_min_y": min(y_vals),
        f"{prefix}_max_y": max(y_vals),
    }
    
    # Compute average slope
    slopes: List[float] = []
    for i in range(1, len(x_vals)):
        dx = x_vals[i] - x_vals[i - 1]
        if abs(dx) > 1e-9:
            slopes.append((y_vals[i] - y_vals[i - 1]) / dx)
    if slopes:
        features[f"{prefix}_avg_slope"] = statistics.fmean(slopes)
    
    return features


def _make_ml_ready_attributes(attrs: dict, time_index: Optional[int] = None) -> dict:
    """Convert raw attributes to numeric features.
    
    If time_index is provided, extracts values from time series at that index.
    """
    ml_attrs: dict = {}
    
    for key, value in attrs.items():
        # Skip non-numeric/non-convertible attributes
        if key in {"Bus", "id", "node_type", "parent_bus", "line_id", "edge_type"}:
            ml_attrs[key] = value
            continue
        
        # Handle numeric values
        if isinstance(value, (int, float)):
            ml_attrs[key] = float(value)
        
        # Handle time series - extract value at time_index
        elif isinstance(value, Sequence) and not isinstance(value, (str, bytes)):
            if value and all(isinstance(v, (int, float)) for v in value):
                if time_index is not None and 0 <= time_index < len(value):
                    ml_attrs[key] = float(value[time_index])
        
        # Handle reserve eligibility lists - convert to count
        elif key == "Reserve eligibility" and isinstance(value, list):
            ml_attrs["reserve_eligibility_count"] = float(len(value))
        
        # Keep everything else for visualization
        else:
            ml_attrs[key] = value
    
    return ml_attrs


def build_graph(case_data: dict, include_generator_nodes: bool = False, time_index: int = 0) -> nx.Graph:
    """Create a NetworkX graph from case data.
    
    All node and edge attributes are converted to numeric features.
    Curve data (production cost, startup cost) is flattened into scalar statistics.
    Time series data (e.g., Load) is extracted at the specified time_index.
    
    When include_generator_nodes=False, generator features are added
    directly to bus nodes (assumes one generator per bus).
    """
    graph = nx.Graph()

    # Add bus nodes
    for bus_id, attrs in case_data.get("Buses", {}).items():
        node_attrs = _make_ml_ready_attributes(attrs, time_index=time_index)
        node_attrs["generators"] = []
        node_attrs["node_type"] = "bus"
        graph.add_node(bus_id, **node_attrs)

    # Add generators
    for gen_id, attrs in case_data.get("Generators", {}).items():
        bus = attrs.get("Bus")
        
        # Separate curve data for special handling
        prod_mw = attrs.get("Production cost curve (MW)", [])
        prod_cost = attrs.get("Production cost curve ($)", [])
        startup_costs = attrs.get("Startup costs ($)", [])
        startup_delays = attrs.get("Startup delays (h)", [])
        
        # Convert base attributes to ML-ready format
        payload = {k: v for k, v in attrs.items() if k != "Bus"}
        payload = _make_ml_ready_attributes(payload, time_index=time_index)
        payload["id"] = gen_id
        
        # Add flattened curve features
        prod_features = _flatten_curve_to_scalars(prod_mw, prod_cost, "prod_cost")
        startup_features = _flatten_curve_to_scalars(startup_delays, startup_costs, "startup_cost")
        payload.update(prod_features)
        payload.update(startup_features)
        
        # Keep original curves for visualization (not ML features)
        payload["Production cost curve (MW)"] = prod_mw
        payload["Production cost curve ($)"] = prod_cost
        payload["Startup costs ($)"] = startup_costs
        payload["Startup delays (h)"] = startup_delays
        
        if bus not in graph.nodes:
            graph.add_node(bus, node_type="bus", generators=[])
        
        if include_generator_nodes and bus is not None:
            # Generator as separate node - don't add to bus's generators list
            gen_attrs = dict(payload)
            gen_attrs["node_type"] = "generator"
            gen_attrs["parent_bus"] = bus
            graph.add_node(gen_id, **gen_attrs)
            graph.add_edge(bus, gen_id, edge_type="generator_link")
        else:
            # Generator combined with bus - add to bus's generators list
            graph.nodes[bus].setdefault("generators", []).append(dict(payload))

    # Add transmission lines
    for line_id, attrs in case_data.get("Transmission lines", {}).items():
        source = attrs.get("Source bus")
        target = attrs.get("Target bus")
        if not source or not target:
            continue
        
        edge_attrs = {
            k: v
            for k, v in attrs.items()
            if k not in {"Source bus", "Target bus"}
        }
        edge_attrs = _make_ml_ready_attributes(edge_attrs, time_index=time_index)
        edge_attrs["line_id"] = line_id
        edge_attrs["edge_type"] = "transmission"
        graph.add_edge(source, target, **edge_attrs)

    # Mark contingency candidates
    contingencies = case_data.get("Contingencies", {})
    affected_lines = {
        line
        for scenario in contingencies.values()
        for line in scenario.get("Affected lines", [])
    }
    for _, _, data in graph.edges(data=True):
        if data.get("edge_type") == "transmission":
            is_contingency = data.get("line_id") in affected_lines
            data["is_contingency_candidate"] = 1.0 if is_contingency else 0.0

    # If generators are not separate nodes, add their features to bus nodes
    if not include_generator_nodes:
        _aggregate_generator_features_to_buses(graph)

    return graph


def _aggregate_generator_features_to_buses(graph: nx.Graph) -> None:
    """Add generator features directly to bus nodes.
    
    Assumes one generator per bus. Raises ValueError if more than one generator
    is attached to any bus.
    """
    for node_id, attrs in graph.nodes(data=True):
        if attrs.get("node_type") != "bus":
            continue
        
        generators = attrs.get("generators", [])
        
        if not generators:
            # No generators at this bus
            attrs["has_generator"] = 0.0
            continue
        
        if len(generators) > 1:
            raise ValueError(
                f"Bus '{node_id}' has {len(generators)} generators. "
                f"Expected at most one generator per bus."
            )
        
        # Single generator - add its features directly to the bus
        attrs["has_generator"] = 1.0
        gen = generators[0]
        
        # Keys to skip (not numeric features)
        skip_keys = {"id", "Production cost curve (MW)", "Production cost curve ($)",
                     "Startup costs ($)", "Startup delays (h)", "Reserve eligibility"}
        
        # Add generator features with "gen_" prefix
        for key, value in gen.items():
            if key in skip_keys or not isinstance(value, (int, float)):
                continue
            attrs[f"gen_{key}"] = float(value)


def extract_generator_metric(
    solution_data: Optional[dict],
    metric: Optional[str],
    *,
    time_index: int,
 ) -> Dict[str, float]:
    """Pull a per-generator value from the solution payload at the specified time index."""
    if not solution_data or not metric or metric not in solution_data:
        return {}
    
    series_by_generator = solution_data.get(metric, {})
    results: Dict[str, float] = {}
    
    for gen_id, series in series_by_generator.items():
        if not isinstance(series, Sequence) or isinstance(series, (str, bytes)):
            continue
        if not series or time_index >= len(series):
            continue
        
        value = series[time_index]
        if isinstance(value, (int, float)):
            results[gen_id] = float(value)
    
    return results


def annotate_generators_from_solution(
    graph: nx.Graph,
    solution_data: Optional[dict],
    *,
    metric: Optional[str],
    time_index: int,
 ) -> None:
    """Attach solution data to generator nodes/attributes.
    
    Extracts values at the specified time index from the solution time series.
    """
    if not metric:
        return
    
    # Extract metric values at time index
    outputs: Dict[str, float] = {}
    if solution_data and metric in solution_data:
        series_by_generator = solution_data.get(metric, {})
        for gen_id, series in series_by_generator.items():
            if not isinstance(series, Sequence) or isinstance(series, (str, bytes)):
                continue
            if series and time_index < len(series):
                value = series[time_index]
                if isinstance(value, (int, float)):
                    outputs[gen_id] = float(value)
    
    # Check if values are binary (0 or 1) - if so, use as status instead of output
    is_binary = True
    for value in outputs.values():
        rounded = round(value)
        if rounded not in (0, 1) or abs(value - rounded) > 1e-6:
            is_binary = False
            break
    
    # Update graph nodes
    for node, attrs in graph.nodes(data=True):
        node_type = attrs.get("node_type")
        
        # Update generator nodes
        if node_type == "generator":
            gen_id = attrs.get("id", node)
            if gen_id in outputs:
                if is_binary:
                    attrs["is_on"] = outputs[gen_id]
                else:
                    attrs["current_output_mw"] = outputs[gen_id]
        
        # Update generators in bus nodes
        generator_list = attrs.get("generators") or []
        for generator in generator_list:
            gen_id = generator.get("id")
            if gen_id and gen_id in outputs:
                if is_binary:
                    generator["is_on"] = outputs[gen_id]
                else:
                    generator["current_output_mw"] = outputs[gen_id]
    
    # Re-aggregate generator features if they're combined with buses
    has_separate_generator_nodes = any(
        attrs.get("node_type") == "generator" 
        for _, attrs in graph.nodes(data=True)
    )
    if not has_separate_generator_nodes:
        _aggregate_generator_features_to_buses(graph)


In [3]:
# === Data Sources ===
CASE_FILE = Path("data/case14_data/2017-01-01.json")
SOLUTION_FILE = Path("data/case14_solutions/2017-01-01_solution.json")  # Set to None if unavailable

# === Graph Generation Parameters ===
TIME_INDEX = 0  # Time step to extract from solution time series
GENERATORS_AS_NODES = True  # Set True to add dedicated generator nodes
SOLUTION_METRIC = "Is on"  # e.g., "Thermal production (MW)" or "Is on"


In [4]:
case_data = load_json(CASE_FILE)
solution_data = load_json(SOLUTION_FILE) if SOLUTION_FILE else None

In [5]:
graph = build_graph(case_data, include_generator_nodes=GENERATORS_AS_NODES, time_index=TIME_INDEX)
annotate_generators_from_solution(
    graph,
    solution_data,
    metric=SOLUTION_METRIC,
    time_index=TIME_INDEX,
 )

In [6]:
def get_ml_node_features(graph: nx.Graph, node_id: str) -> Dict[str, float]:
    """Extract numeric features for a node.
    
    For bus nodes with an attached generator (when GENERATORS_AS_NODES=False),
    this includes generator features with "gen_" prefix:
    - has_generator: 1.0 if generator present, 0.0 otherwise
    - gen_*: individual generator feature values (e.g., gen_Pmax, gen_Pmin)
    """
    attrs = graph.nodes[node_id]
    features = {}
    
    # Exclude non-numeric metadata
    skip_keys = {"id", "node_type", "parent_bus", "generators", 
                 "Production cost curve (MW)", "Production cost curve ($)", 
                 "Startup costs ($)", "Startup delays (h)", "Reserve eligibility"}
    
    for key, value in attrs.items():
        if key in skip_keys:
            continue
        if isinstance(value, (int, float)):
            features[key] = float(value)
    
    return features


def get_ml_edge_features(graph: nx.Graph, u: str, v: str) -> Dict[str, float]:
    """Extract numeric features for an edge."""
    attrs = graph.edges[u, v]
    features = {}
    
    # Exclude non-numeric metadata
    skip_keys = {"line_id", "edge_type"}
    
    for key, value in attrs.items():
        if key in skip_keys:
            continue
        if isinstance(value, (int, float)):
            features[key] = float(value)
    
    return features


def print_sample_features(graph: nx.Graph, max_nodes: int = 2, max_edges: int = 2) -> None:
    """Print sample features to verify numeric conversion."""
    
    # Separate buses and generators
    bus_nodes = [(n, d) for n, d in graph.nodes(data=True) if d.get('node_type') == 'bus']
    gen_nodes = [(n, d) for n, d in graph.nodes(data=True) if d.get('node_type') == 'generator']
    
    print("=== Sample Node Features ===")
    
    # Print bus nodes
    if bus_nodes:
        print("\n--- Bus Nodes ---")
        for i, (node_id, attrs) in enumerate(bus_nodes):
            if i >= max_nodes:
                break
            features = get_ml_node_features(graph, node_id)
            num_gens = attrs.get('generators', [])
            
            print(f"\nNode '{node_id}' (type=bus, generators={len(num_gens) if isinstance(num_gens, list) else 0}):")
            
            # Group features by category for better readability
            bus_features = {k: v for k, v in features.items() if not k.startswith('gen_')}
            gen_features = {k: v for k, v in features.items() if k.startswith('gen_')}
            
            if bus_features:
                print("  Bus features:")
                for key, value in sorted(bus_features.items()):
                    print(f"    {key}: {value}")
            
            if gen_features:
                print("  Generator features (aggregated to bus):")
                for key, value in sorted(gen_features.items()):
                    print(f"    {key}: {value}")
    
    # Print generator nodes
    if gen_nodes:
        print("\n--- Generator Nodes ---")
        for i, (node_id, attrs) in enumerate(gen_nodes):
            if i >= max_nodes:
                break
            features = get_ml_node_features(graph, node_id)
            parent_bus = attrs.get('parent_bus', 'unknown')
            
            print(f"\nNode '{node_id}' (type=generator, parent_bus={parent_bus}):")
            print("  Generator features:")
            for key, value in sorted(features.items()):
                print(f"    {key}: {value}")
    
    print("\n=== Sample Edge Features ===")
    for i, (u, v, attrs) in enumerate(graph.edges(data=True)):
        if i >= max_edges:
            break
        edge_type = attrs.get("edge_type")
        features = get_ml_edge_features(graph, u, v)
        
        if edge_type == "generator_link":
            print(f"\nEdge '{u}' -> '{v}' (generator link):")
        else:
            print(f"\nEdge '{u}' -> '{v}' (transmission):")
        
        if features:
            for key, value in sorted(features.items()):
                print(f"  {key}: {value}")
        else:
            print("  (no numeric features)")

In [7]:
# Optionally print sample features to verify ML-ready attributes
print_sample_features(graph, max_nodes=3, max_edges=5)


=== Sample Node Features ===

--- Bus Nodes ---

Node 'b1' (type=bus, generators=0):
  Bus features:
    Load (MW): 0.0

Node 'b2' (type=bus, generators=0):
  Bus features:
    Load (MW): 19.33301

Node 'b3' (type=bus, generators=0):
  Bus features:
    Load (MW): 83.92488

--- Generator Nodes ---

Node 'g1' (type=generator, parent_bus=b1):
  Generator features:
    Initial power (MW): 230.74887999999999
    Initial status (h): 24.0
    Minimum downtime (h): 1.0
    Minimum uptime (h): 1.0
    Ramp down limit (MW): 231.12
    Ramp up limit (MW): 231.12
    Shutdown limit (MW): 231.12
    Startup limit (MW): 231.12
    is_on: 1.0
    prod_cost_avg_slope: 44.761038740610545
    prod_cost_max_x: 330.1716826503023
    prod_cost_max_y: 14293.941140664052
    prod_cost_min_x: 36.75123365792352
    prod_cost_min_y: 1160.1370560288465
    startup_cost_avg_slope: 3296.1450000000004
    startup_cost_max_x: 4.0
    startup_cost_max_y: 34238.45
    startup_cost_min_x: 1.0
    startup_cost_min_y: 2