# Preference-Based EBM Training for Power System Decisions

This notebook trains a conditional Energy-Based Model (EBM) using preference learning with an LP economic oracle.

**Methodology (from preference_learning.tex):**
1. **HTE Embeddings**: Encode scenarios into context vectors `h = HTE(x)`
2. **EBM with GRU**: Learn energy function `E_Œ∏(u | h)` over decisions with temporal understanding
3. **Langevin Sampling**: Generate K candidate decisions `{u^(k)} ~ S_Œ∏(h)`
4. **LP Worker**: Evaluate candidates with physics-aware economic oracle
5. **Preference Learning**: Margin ranking loss to shape energy landscape

**Target Hardware:** Colab A100 80GB VRAM

**Contents:**
1. Setup & Installation
2. Configuration
3. Data Loading
4. Model Initialization
5. Training Loop
6. Evaluation & Visualization

## 1. Setup & Installation

In [1]:
# ============================================================================
# 1. SETUP & INSTALLATION
# ============================================================================

import sys
import os

# Check if running on Colab
IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
    # Mount Google Drive
    from google.colab import drive
    drive.mount('/content/drive', force_remount=True)

    # Set repository path
    REPO_PATH = '/content/drive/MyDrive/benchmark'

    if not os.path.exists(REPO_PATH):
        print("Please upload the benchmark folder to Google Drive at: MyDrive/benchmark/")
        print("Or modify REPO_PATH to point to your repository location")
    else:
        sys.path.insert(0, REPO_PATH)
        os.chdir(REPO_PATH)
        print(f"‚úì Working directory: {os.getcwd()}")

    # Install dependencies using %pip (Jupyter magic)
    get_ipython().system('pip install -q torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121')
    get_ipython().system('pip install -q torch-geometric pyomo highspy tqdm matplotlib seaborn pandas numpy')

else:
    # Local setup
    REPO_PATH = r'C:\Users\Dell\projects\multilayer_milp_gnn\benchmark'
    sys.path.insert(0, REPO_PATH)
    os.chdir(REPO_PATH)
    print(f"‚úì Working directory: {os.getcwd()}")

# Verify GPU
import torch
print(f"‚úì PyTorch version: {torch.__version__}")
print(f"‚úì CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"‚úì GPU: {torch.cuda.get_device_name(0)}")
    print(f"‚úì GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

Mounted at /content/drive
‚úì Working directory: /content/drive/MyDrive/benchmark
‚úì PyTorch version: 2.9.0+cu126
‚úì CUDA available: True
‚úì GPU: NVIDIA A100-SXM4-80GB
‚úì GPU Memory: 85.2 GB


## 2. Configuration

In [None]:
# ============================================================================
# 2. CONFIGURATION (TWO-PHASE TRAINING)
# ============================================================================

from dataclasses import dataclass, field
from typing import Optional, Dict, Any
from pathlib import Path

@dataclass
class Config:
    """Training configuration for two-phase preference-based EBM."""

    # Paths
    scenarios_dir: str = "outputs/scenarios_v1_filtered"
    milp_reports_dir: str = "outputs/scenarios_v1_filtered/reports"
    embeddings_path: Optional[str] = "outputs/encoders/hierachical_temporal/embeddings_multiscale_normalized.pt"
    prebuilt_lp_dir: str = "outputs/lp_models_v1"  # Pre-built LP models
    output_dir: str = "outputs/preference_training"

    # Data
    train_split: float = 0.8

    # Model architecture
    architecture: str = "gru"  # "mlp", "gru", "transformer"
    h_dim: int = 128
    hidden_dim: int = 128
    gru_layers: int = 2
    num_heads: int = 4
    dropout: float = 0.1

    # Sampling (defaults for model init - phases override these)
    langevin_steps: int = 20
    langevin_step_size: float = 0.01
    langevin_noise: float = 0.01
    num_candidates: int = 3

    # ============= PHASE 1: PRETRAIN (fast, no LP) =============
    pretrain_epochs: int = 15
    pretrain_lr: float = 1e-3
    pretrain_langevin_steps: int = 10
    pretrain_num_candidates: int = 3
    
    # ============= PHASE 2: FINETUNE (with LP) =============
    finetune_epochs: int = 10
    finetune_lr: float = 1e-4
    finetune_langevin_steps: int = 20
    finetune_lp_ratio: float = 0.2           # LP on 20% of batches
    finetune_uncertainty_top: float = 0.3     # Only top 30% uncertain scenarios
    finetune_num_candidates_min: int = 2
    finetune_num_candidates_max: int = 8      # Adaptive K
    finetune_cost_proxy_top_k: int = 2        # Filter to top-2 before LP

    # Loss
    margin: float = 1.0
    alpha: float = 1.0
    w_max: float = 5.0
    energy_reg: float = 0.01

    # Optimization
    learning_rate: float = 1e-4  # Default for model init
    weight_decay: float = 1e-5
    gradient_clip: float = 1.0
    batch_size: int = 4

    # Training
    epochs: int = 25  # Total epochs (pretrain + finetune)
    eval_every: int = 5
    save_every: int = 5
    log_every: int = 1

    # LP Oracle
    use_lp_worker: bool = True
    lp_solver: str = "appsi_highs"
    lp_time_limit: float = 0.5
    lp_cache_size: int = 100

    # Device
    device: str = "cuda" if torch.cuda.is_available() else "cpu"

    # Scenario dimensions (set from data)
    n_zones: int = 100
    n_timesteps: int = 24
    n_features: int = 8

# Create config
config = Config()

# Create output directory
Path(config.output_dir).mkdir(parents=True, exist_ok=True)

print("=" * 70)
print("TWO-PHASE TRAINING CONFIGURATION")
print("=" * 70)
print(f"üìç Device: {config.device}")
print(f"üì¶ Batch size: {config.batch_size}")
print()
print("üöÄ PHASE 1: PRETRAIN (fast, no LP)")
print(f"   Epochs: {config.pretrain_epochs}")
print(f"   Learning rate: {config.pretrain_lr}")
print(f"   Langevin steps: {config.pretrain_langevin_steps}")
print(f"   Candidates (K): {config.pretrain_num_candidates}")
print()
print("üéØ PHASE 2: FINETUNE (with LP Oracle)")
print(f"   Epochs: {config.finetune_epochs}")
print(f"   Learning rate: {config.finetune_lr}")
print(f"   Langevin steps: {config.finetune_langevin_steps}")
print(f"   LP batch ratio: {config.finetune_lp_ratio:.0%}")
print(f"   Uncertainty top: {config.finetune_uncertainty_top:.0%}")
print(f"   Candidates (K): {config.finetune_num_candidates_min} ‚Üí {config.finetune_num_candidates_max}")
print(f"   Cost proxy filter: top-{config.finetune_cost_proxy_top_k}")
print()
print(f"üìÅ Pre-built LP models: {config.prebuilt_lp_dir}")
print("=" * 70)

CONFIGURATION
Architecture: gru
Device: cuda
Epochs: 50
Batch size: 2
Candidates per scenario: 3
Langevin steps: 30
Use LP Worker: True


## 3. Data Loading

In [None]:
# ============================================================================
# 3. DATA LOADING
# ============================================================================

import json
import numpy as np
from pathlib import Path
from torch.utils.data import Dataset, DataLoader
from tqdm.auto import tqdm

# Import preference learning modules
from src.preference.data_models import ScenarioData, DecisionVector

# Check if data exists
scenarios_path = Path(config.scenarios_dir)
reports_path = Path(config.milp_reports_dir)

print(f"Scenarios directory: {scenarios_path} (exists: {scenarios_path.exists()})")
print(f"Reports directory: {reports_path} (exists: {reports_path.exists()})")

# Count available data
scenario_files = []
if scenarios_path.exists():
    scenario_files = sorted(list(scenarios_path.glob("scenario_*.json")))
    print(f"Found {len(scenario_files)} scenario files")

report_files = []
if reports_path.exists():
    report_files = list(reports_path.glob("scenario_*.json"))
    print(f"Found {len(report_files)} MILP reports")

# Analyze zone counts across scenarios (they are DYNAMIC)
def get_n_zones(scenario_path):
    """Extract number of zones from scenario file."""
    with open(scenario_path, 'r') as f:
        data = json.load(f)
    # zones_per_region contains list of zone counts per region
    zones_per_region = data.get("graph", {}).get("zones_per_region", [])
    return sum(zones_per_region)

# Sample zone counts from first N scenarios
if scenario_files:
    sample_size = min(100, len(scenario_files))
    zone_counts = [get_n_zones(f) for f in scenario_files[:sample_size]]

    print(f"\nüìä Zone count analysis (from {sample_size} scenarios):")
    print(f"   Min zones: {min(zone_counts)}")
    print(f"   Max zones: {max(zone_counts)}")
    print(f"   Mean zones: {np.mean(zone_counts):.1f}")
    print(f"   Median zones: {np.median(zone_counts):.0f}")

    # Set max_zones for padding (use max + buffer)
    config.max_zones = max(zone_counts) + 10
    config.n_zones = None  # Dynamic - will be determined per batch

    # Load first scenario for timesteps
    with open(scenario_files[0], 'r') as f:
        sample_scenario = json.load(f)
    config.n_timesteps = sample_scenario.get("horizon_hours", 24)

    print(f"\n‚úì Config updated:")
    print(f"   max_zones: {config.max_zones} (for padding)")
    print(f"   n_zones: DYNAMIC (varies per scenario)")
    print(f"   n_timesteps: {config.n_timesteps}")
    print(f"   n_features: {config.n_features}")
else:
    print("‚ö†Ô∏è No scenario files found!")

Scenarios directory: outputs/scenarios_v1_filtered (exists: True)
Reports directory: outputs/scenarios_v1_filtered/reports (exists: True)
Found 2960 scenario files
Found 2960 MILP reports

üìä Zone count analysis (from 100 scenarios):
   Min zones: 8
   Max zones: 119
   Mean zones: 58.4
   Median zones: 56

‚úì Config updated:
   max_zones: 129 (for padding)
   n_zones: DYNAMIC (varies per scenario)
   n_timesteps: 24
   n_features: 8


In [4]:
# Create dataset with zone-level embeddings support
from src.preference.embedding_processor import (
    ZonalEmbeddingProcessor,
    TemporalZonalDataset,
    temporal_collate_fn,
    load_zone_embeddings,
)

# Check for pre-processed zone embeddings
ZONE_EMBEDDING_PATH = "outputs/encoders/hierarchical_temporal/embeddings_zone.pt"
MULTISCALE_EMBEDDING_PATH = "outputs/encoders/hierarchical_temporal/embeddings_multiscale_normalized.pt"
DATASET_INDEX_PATH = "outputs/graphs/hetero_temporal_v1/dataset_index.json"

zone_embeddings = None

# Try to load zone embeddings
if Path(ZONE_EMBEDDING_PATH).exists():
    print(f"‚úì Found zone embeddings at {ZONE_EMBEDDING_PATH}")
    zone_embeddings = load_zone_embeddings(ZONE_EMBEDDING_PATH)
    config.embeddings_path = ZONE_EMBEDDING_PATH

elif Path(MULTISCALE_EMBEDDING_PATH).exists():
    print(f"üîß Processing multiscale embeddings to zone level...")
    processor = ZonalEmbeddingProcessor()
    zone_embeddings = processor.load_and_process(
        embedding_path=MULTISCALE_EMBEDDING_PATH,
        dataset_index_path=DATASET_INDEX_PATH if Path(DATASET_INDEX_PATH).exists() else None,
        output_path=ZONE_EMBEDDING_PATH,
        level="zones",
    )
    config.embeddings_path = ZONE_EMBEDDING_PATH

else:
    print("‚ö†Ô∏è No embeddings found - will use dummy embeddings")
    print(f"   Expected: {ZONE_EMBEDDING_PATH}")
    print(f"   Or: {MULTISCALE_EMBEDDING_PATH}")

# Show embedding info
if zone_embeddings:
    sample_key = next(iter(zone_embeddings.keys()))
    sample_emb = zone_embeddings[sample_key]
    print(f"\n‚úì Zone embeddings loaded: {len(zone_embeddings)} scenarios")
    print(f"   Sample shape: {sample_emb.shape} (Z, T, D)")
    print(f"   Embed dim: {sample_emb.shape[-1]}")

    # Update config
    config.h_dim = sample_emb.shape[-1]
    if sample_emb.dim() >= 2:
        config.n_timesteps = sample_emb.shape[1] if sample_emb.dim() == 3 else 24

‚úì Found zone embeddings at outputs/encoders/hierarchical_temporal/embeddings_zone.pt
üìÇ Loading zone embeddings from outputs/encoders/hierarchical_temporal/embeddings_zone.pt...
‚úì Loaded 2960 scenario embeddings

‚úì Zone embeddings loaded: 2960 scenarios
   Sample shape: torch.Size([52, 24, 128]) (Z, T, D)
   Embed dim: 128


In [None]:
# Create TemporalZonalDataset if zone embeddings are available
# Otherwise create a SimpleDataset that loads from scenarios directly

if zone_embeddings:
    print("üîß Creating TemporalZonalDataset with zone embeddings...")

    temporal_dataset = TemporalZonalDataset(
        scenarios_dir=config.scenarios_dir,
        zone_embeddings=zone_embeddings,
        milp_reports_dir=config.milp_reports_dir,
        n_features=config.n_features,
        n_timesteps=config.n_timesteps,
    )

    if len(temporal_dataset) > 0:
        sample0 = temporal_dataset[0]
        print(f"\n‚úì Dataset created: {len(temporal_dataset)} samples")
        print(f"   Sample u_zt shape: {sample0['u_zt'].shape}")
        print(f"   Sample h_zt shape: {sample0['h_zt'].shape}")
        print(f"   n_zones: {sample0['n_zones']} (DYNAMIC)")

        config.h_dim = sample0['embed_dim']
else:
    print("üîß Creating SimplePreferenceDataset from scenarios...")

    # Simple dataset that loads directly from scenario files
    class SimplePreferenceDataset(Dataset):
        """Dataset loading scenarios and MILP reports with dynamic zone counts."""

        def __init__(self, scenario_files, reports_dir, n_features=8, n_timesteps=24, h_dim=128):
            self.scenario_files = scenario_files
            self.reports_dir = Path(reports_dir)
            self.n_features = n_features
            self.n_timesteps = n_timesteps
            self.h_dim = h_dim

        def __len__(self):
            return len(self.scenario_files)

        def __getitem__(self, idx):
            scenario_path = self.scenario_files[idx]
            scenario_id = scenario_path.stem

            # Load scenario
            with open(scenario_path, 'r') as f:
                scenario = json.load(f)

            # Get zone count
            zones_per_region = scenario.get("graph", {}).get("zones_per_region", [])
            n_zones = sum(zones_per_region)

            # Load MILP report
            report_path = self.reports_dir / f"{scenario_id}.json"
            milp_objective = float("inf")
            milp_decisions = torch.zeros(n_zones, self.n_timesteps, self.n_features)

            if report_path.exists():
                with open(report_path, 'r') as f:
                    report = json.load(f)
                milp_objective = report.get("mip", {}).get("objective", float("inf"))

                # Extract binary decisions from dispatch
                dispatch = report.get("dispatch", {})
                if dispatch:
                    milp_decisions = self._extract_decisions(dispatch, n_zones)

            # Dummy embedding (will be replaced with real HTE embeddings)
            h = torch.randn(n_zones, self.n_timesteps, self.h_dim)

            return {
                'scenario_id': scenario_id,
                'u_zt': milp_decisions.float(),
                'h_zt': h.float(),
                'n_zones': n_zones,
                'n_timesteps': self.n_timesteps,
                'n_features': self.n_features,
                'embed_dim': self.h_dim,
                'milp_objective': milp_objective,
            }

        def _extract_decisions(self, dispatch, n_zones):
            """Extract binary decisions from MILP dispatch."""
            T = self.n_timesteps
            decisions = torch.zeros(n_zones, T, self.n_features)

            # Feature mapping
            feature_keys = [
                'battery_charge', 'battery_discharge',
                'pumped_charge', 'pumped_discharge',
                'dr_active', 'nuclear', 'thermal', 'import_mode'
            ]

            for f_idx, key in enumerate(feature_keys):
                if key in dispatch:
                    data = dispatch[key]
                    if isinstance(data, list):
                        arr = np.array(data)
                        if arr.ndim == 2:
                            Z_data, T_data = arr.shape
                            Z_use = min(Z_data, n_zones)
                            T_use = min(T_data, T)
                            decisions[:Z_use, :T_use, f_idx] = torch.from_numpy(arr[:Z_use, :T_use])

            return decisions

    # Create dataset using all available scenarios
    temporal_dataset = SimplePreferenceDataset(
        scenario_files=scenario_files,
        reports_dir=config.milp_reports_dir,
        n_features=config.n_features,
        n_timesteps=config.n_timesteps,
        h_dim=config.h_dim,
    )

    if len(temporal_dataset) > 0:
        sample0 = temporal_dataset[0]
        print(f"\n‚úì Dataset created: {len(temporal_dataset)} samples")
        print(f"   Sample u_zt shape: {sample0['u_zt'].shape}")
        print(f"   Sample h_zt shape: {sample0['h_zt'].shape}")
        print(f"   n_zones: {sample0['n_zones']} (DYNAMIC)")

# Verify zone count distribution
if len(temporal_dataset) > 0:
    print("\nüìä Zone count distribution in dataset:")
    zone_counts = []
    for i in range(min(50, len(temporal_dataset))):
        zone_counts.append(temporal_dataset[i]['n_zones'])
    print(f"   First 50: min={min(zone_counts)}, max={max(zone_counts)}, mean={np.mean(zone_counts):.1f}")

üîß Creating TemporalZonalDataset with zone embeddings...
‚úì TemporalZonalDataset: 2960 scenarios

‚úì Dataset created: 2960 samples
   Sample u_zt shape: torch.Size([52, 24, 8])
   Sample h_zt shape: torch.Size([52, 24, 128])
   n_zones: 52 (DYNAMIC)

üìä Zone count distribution in dataset:
   First 50: min=8, max=118, mean=61.4


In [6]:
# ============================================================================
# 4. MODEL INITIALIZATION (with dynamic zone support)
# ============================================================================

from src.preference.ebm import ConditionalEBMWithGRU, ConditionalEBM, build_ebm
from src.preference.sampler import LangevinSampler, SamplerConfig, binarize_decisions
from src.preference.losses import CombinedPreferenceLoss

# Get max_zones from dataset analysis for model sizing
# The model will use flattened decisions, so we need a max size
max_zones_in_data = config.max_zones if hasattr(config, 'max_zones') else 100

# Build EBM model - use max_zones for model capacity
ebm_config = {
    "architecture": config.architecture,
    "hidden_dim": config.hidden_dim,
    "gru_layers": config.gru_layers,
    "num_heads": config.num_heads,
    "dropout": config.dropout,
    "bidirectional": True,
    "use_zone_attention": True,
}

# Decision dim based on max possible size (will be padded in batch)
max_decision_dim = max_zones_in_data * config.n_timesteps * config.n_features

print(f"üìê Model dimensions:")
print(f"   max_zones: {max_zones_in_data}")
print(f"   n_timesteps: {config.n_timesteps}")
print(f"   n_features: {config.n_features}")
print(f"   max_decision_dim: {max_decision_dim}")
print(f"   h_dim: {config.h_dim}")

ebm = build_ebm(
    config=ebm_config,
    h_dim=config.h_dim,
    decision_dim=max_decision_dim,
    n_zones=max_zones_in_data,
    n_timesteps=config.n_timesteps,
    n_features=config.n_features,
).to(config.device)

# Count parameters
n_params = sum(p.numel() for p in ebm.parameters())
n_trainable = sum(p.numel() for p in ebm.parameters() if p.requires_grad)
print(f"\n‚úì EBM Model: {type(ebm).__name__}")
print(f"  Parameters: {n_params:,} ({n_trainable:,} trainable)")

# Create Langevin sampler
sampler_config = SamplerConfig(
    num_steps=config.langevin_steps,
    step_size=config.langevin_step_size,
    noise_scale=config.langevin_noise,
    use_normalized=True,
    anneal_schedule="cosine",
)
sampler = LangevinSampler(ebm, sampler_config)
print(f"‚úì Langevin Sampler: {config.langevin_steps} steps")

# Create loss function
loss_fn = CombinedPreferenceLoss(
    margin=config.margin,
    alpha=config.alpha,
    w_max=config.w_max,
    energy_reg_weight=config.energy_reg,
)
print(f"‚úì Loss: CombinedPreferenceLoss (margin={config.margin}, Œ±={config.alpha})")

# Create optimizer and scheduler
optimizer = torch.optim.AdamW(
    ebm.parameters(),
    lr=config.learning_rate,
    weight_decay=config.weight_decay,
)

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer,
    T_max=config.epochs,
    eta_min=config.learning_rate * 0.01,
)
print(f"‚úì Optimizer: AdamW (lr={config.learning_rate})")
print(f"‚úì Scheduler: CosineAnnealingLR")

üìê Model dimensions:
   max_zones: 129
   n_timesteps: 24
   n_features: 8
   max_decision_dim: 24768
   h_dim: 128

‚úì EBM Model: ConditionalEBMWithGRU
  Parameters: 669,825 (669,825 trainable)
‚úì Langevin Sampler: 30 steps
‚úì Loss: CombinedPreferenceLoss (margin=1.0, Œ±=1.0)
‚úì Optimizer: AdamW (lr=0.0001)
‚úì Scheduler: CosineAnnealingLR


In [7]:
# ============================================================================
# 4b. LP ORACLE INITIALIZATION
# ============================================================================

from src.preference.lp_oracle import (
    LPOracle,
    PreferenceLPOracle,
    LPOracleConfig,
    create_lp_oracle,
)

# Initialize LP Oracle
if config.use_lp_worker:
    print("üîß Initializing LP Oracle with full LP Worker...")
    lp_oracle = create_lp_oracle(
    scenarios_dir=config.scenarios_dir,
    solver_name="appsi_highs",
    use_full_lp=True,
    use_cache=True,
    max_cache_size=100,
    prebuilt_dir="outputs/lp_models_v1",  # ‚Üê Mod√®les pr√©-construits
    time_limit_hard_fix=1,
    time_limit_repair=0.5,
    time_limit_full_soft=1,
    verbose=False,
    )
else:
    print("‚ö†Ô∏è LP Worker disabled - using dummy costs")
    print("   Set config.use_lp_worker = True for real cost evaluation")
    lp_oracle = create_lp_oracle(
        scenarios_dir=config.scenarios_dir,
        use_full_lp=False,  # Will use dummy evaluation
    )

print(f"‚úì LP Oracle ready")
print(f"   Mode: {'Full LP' if config.use_lp_worker else 'Dummy costs'}")

üîß Initializing LP Oracle with full LP Worker...
‚úì LP Oracle ready
   Mode: Full LP


In [8]:
# ============================================================================
# 5. DATA LOADERS WITH DYNAMIC ZONE PADDING
# ============================================================================

import torch.nn.functional as F

def dynamic_collate_fn(batch):
    """
    Collate function for variable-size scenarios with dynamic zone counts.
    Pads all tensors to max zones in batch.
    """
    # Find max zones in this batch
    max_zones = max(item["n_zones"] for item in batch)
    n_timesteps = batch[0]["n_timesteps"]
    n_features = batch[0]["n_features"]
    embed_dim = batch[0]["embed_dim"]

    u_batch = []
    h_batch = []
    zone_masks = []
    n_zones_list = []
    objectives = []
    scenario_ids = []

    for item in batch:
        Z = item["n_zones"]
        pad_z = max_zones - Z

        # Pad u_zt: [Z, T, F] -> [Z_max, T, F]
        u = item["u_zt"]
        if pad_z > 0:
            u = F.pad(u, (0, 0, 0, 0, 0, pad_z))
        u_batch.append(u)

        # Pad h_zt: [Z, T, D] -> [Z_max, T, D]
        h = item["h_zt"]
        if pad_z > 0:
            h = F.pad(h, (0, 0, 0, 0, 0, pad_z))
        h_batch.append(h)

        # Zone mask: [Z_max] - 1 for real zones, 0 for padding
        mask = torch.cat([torch.ones(Z), torch.zeros(pad_z)])
        zone_masks.append(mask)

        n_zones_list.append(Z)
        objectives.append(item["milp_objective"])
        scenario_ids.append(item["scenario_id"])

    return {
        "scenario_ids": scenario_ids,
        "milp_decisions": torch.stack(u_batch),      # [B, Z_max, T, F]
        "embeddings": torch.stack(h_batch),          # [B, Z_max, T, D]
        "zone_mask": torch.stack(zone_masks),        # [B, Z_max]
        "milp_objectives": torch.tensor(objectives), # [B]
        "n_zones": n_zones_list,                     # List[int] - original zone counts
        "n_timesteps": n_timesteps,
        "n_features": n_features,
        "embed_dim": embed_dim,
        "max_zones": max_zones,
    }

# Split dataset into train/eval
n_total = len(temporal_dataset)
n_train = int(n_total * config.train_split)
n_eval = n_total - n_train

train_indices = list(range(n_train))
eval_indices = list(range(n_train, n_total))

train_dataset = torch.utils.data.Subset(temporal_dataset, train_indices)
eval_dataset = torch.utils.data.Subset(temporal_dataset, eval_indices)

print(f"üìä Dataset split:")
print(f"   Train: {len(train_dataset)} scenarios")
print(f"   Eval: {len(eval_dataset)} scenarios")

# Create DataLoaders
train_loader = DataLoader(
    train_dataset,
    batch_size=config.batch_size,
    shuffle=True,
    collate_fn=dynamic_collate_fn,
    num_workers=0,
    pin_memory=True if config.device == "cuda" else False,
)

eval_loader = DataLoader(
    eval_dataset,
    batch_size=config.batch_size,
    shuffle=False,
    collate_fn=dynamic_collate_fn,
    num_workers=0,
)

print(f"‚úì Train loader: {len(train_loader)} batches")
print(f"‚úì Eval loader: {len(eval_loader)} batches")

# Test batch
test_batch = next(iter(train_loader))
print(f"\nüìã Test batch info:")
print(f"   milp_decisions shape: {test_batch['milp_decisions'].shape}")
print(f"   embeddings shape: {test_batch['embeddings'].shape}")
print(f"   zone_mask shape: {test_batch['zone_mask'].shape}")
print(f"   n_zones: {test_batch['n_zones']}")
print(f"   max_zones in batch: {test_batch['max_zones']}")

üìä Dataset split:
   Train: 2368 scenarios
   Eval: 592 scenarios
‚úì Train loader: 1184 batches
‚úì Eval loader: 296 batches

üìã Test batch info:
   milp_decisions shape: torch.Size([2, 73, 24, 8])
   embeddings shape: torch.Size([2, 73, 24, 128])
   zone_mask shape: torch.Size([2, 73])
   n_zones: [70, 73]
   max_zones in batch: 73


In [9]:
# Training step function with dynamic zone support and LP Oracle
def train_step_with_oracle(batch, ebm, sampler, loss_fn, optimizer, config, lp_oracle=None):
    """
    Single training step with dynamic zone counts and optional LP Oracle.
    """
    optimizer.zero_grad()

    B = len(batch["scenario_ids"])
    device = config.device
    max_zones = batch["max_zones"]  # Max zones in this batch
    n_timesteps = batch["n_timesteps"]

    # Move data to device
    h = batch["embeddings"].to(device)           # [B, Z_max, T, D]
    u_positive = batch["milp_decisions"].to(device)  # [B, Z_max, T, F]
    costs_positive = batch["milp_objectives"].to(device)  # [B]
    zone_mask = batch["zone_mask"].to(device)    # [B, Z_max]

    K = config.num_candidates

    # Aggregate embeddings for conditioning (mean over zones and time)
    # Apply zone mask before averaging
    zone_mask_expanded = zone_mask.unsqueeze(-1).unsqueeze(-1)  # [B, Z_max, 1, 1]
    h_masked = h * zone_mask_expanded
    h_sum = h_masked.sum(dim=(1, 2))  # [B, D]
    h_count = zone_mask.sum(dim=1, keepdim=True) * n_timesteps  # [B, 1]
    h_flat = h_sum / h_count.clamp(min=1)  # [B, D]

    # Sample candidates using Langevin
    u_candidates, _ = sampler.sample(
        h=h_flat,
        n_samples=K,
        n_zones=max_zones,
        n_timesteps=n_timesteps,
        n_features=config.n_features,
    )  # [B, K, Z_max, T, 8]

    # Binarize
    u_binary = binarize_decisions(u_candidates, method="threshold")

    # Evaluate candidates with LP Oracle
    if lp_oracle is not None and config.use_lp_worker:
        costs_negative = torch.zeros(B, K, device=device)
        for b in range(B):
            scenario_id = batch["scenario_ids"][b]
            n_zones_b = batch["n_zones"][b]  # Original zone count
            for k in range(K):
                # Use only valid zones for evaluation
                u_valid = u_binary[b, k, :n_zones_b].cpu()
                result = lp_oracle.evaluate(scenario_id, u_valid)
                costs_negative[b, k] = result.objective_value
    else:
        # Dummy costs based on heuristics
        thermal_usage = u_binary[..., 6].sum(dim=(-1, -2, -3))  # [B, K]
        costs_negative = 1e6 + thermal_usage * 1e4 + torch.randn(B, K, device=device) * 1e4

    # Flatten decisions for EBM
    u_pos_flat = u_positive.view(B, -1)  # [B, Z_max * T * F]
    energy_positive = ebm(u_pos_flat, h_flat)  # [B]

    u_neg_flat = u_binary.view(B * K, -1)  # [B*K, Z_max * T * F]
    h_expanded = h_flat.unsqueeze(1).expand(-1, K, -1).reshape(B * K, -1)
    energy_negatives = ebm(u_neg_flat, h_expanded).view(B, K)  # [B, K]

    # Compute loss
    loss, loss_components = loss_fn(
        energy_positive=energy_positive,
        energy_negatives=energy_negatives,
        cost_positive=costs_positive,
        costs_negative=costs_negative,
    )

    # Backward
    loss.backward()

    # Gradient clipping
    if config.gradient_clip > 0:
        torch.nn.utils.clip_grad_norm_(ebm.parameters(), config.gradient_clip)

    # Optimizer step
    optimizer.step()

    # Compute metrics
    with torch.no_grad():
        cost_gap = (costs_negative - costs_positive.unsqueeze(1)).mean()
        energy_gap = (energy_negatives.mean() - energy_positive.mean())

    metrics = {
        **loss_components,
        "loss": loss.item(),
        "mean_energy_pos": energy_positive.mean().item(),
        "mean_energy_neg": energy_negatives.mean().item(),
        "energy_gap": energy_gap.item(),
        "mean_cost_gap": cost_gap.item(),
    }

    return loss.item(), metrics


@torch.no_grad()
def evaluate_with_oracle(eval_loader, ebm, sampler, loss_fn, config, lp_oracle=None):
    """Evaluate on validation set with dynamic zone support."""
    ebm.eval()

    total_loss = 0
    total_energy_pos = 0
    total_energy_neg = 0
    n_batches = 0

    for batch in eval_loader:
        B = len(batch["scenario_ids"])
        device = config.device
        K = config.num_candidates
        max_zones = batch["max_zones"]
        n_timesteps = batch["n_timesteps"]

        h = batch["embeddings"].to(device)
        u_positive = batch["milp_decisions"].to(device)
        costs_positive = batch["milp_objectives"].to(device)
        zone_mask = batch["zone_mask"].to(device)

        # Aggregate embeddings
        zone_mask_expanded = zone_mask.unsqueeze(-1).unsqueeze(-1)
        h_masked = h * zone_mask_expanded
        h_sum = h_masked.sum(dim=(1, 2))
        h_count = zone_mask.sum(dim=1, keepdim=True) * n_timesteps
        h_flat = h_sum / h_count.clamp(min=1)

        # Sample candidates
        u_candidates, _ = sampler.sample(
            h=h_flat,
            n_samples=K,
            n_zones=max_zones,
            n_timesteps=n_timesteps,
            n_features=config.n_features,
        )
        u_binary = binarize_decisions(u_candidates)

        # Get costs
        if lp_oracle is not None and config.use_lp_worker:
            costs_negative = torch.zeros(B, K, device=device)
            for b in range(B):
                scenario_id = batch["scenario_ids"][b]
                n_zones_b = batch["n_zones"][b]
                for k in range(K):
                    u_valid = u_binary[b, k, :n_zones_b].cpu()
                    result = lp_oracle.evaluate(scenario_id, u_valid)
                    costs_negative[b, k] = result.objective_value
        else:
            thermal_usage = u_binary[..., 6].sum(dim=(-1, -2, -3))
            costs_negative = 1e6 + thermal_usage * 1e4 + torch.randn(B, K, device=device) * 1e4

        # Compute energies
        u_pos_flat = u_positive.view(B, -1)
        energy_positive = ebm(u_pos_flat, h_flat)

        u_neg_flat = u_binary.view(B * K, -1)
        h_exp = h_flat.unsqueeze(1).expand(-1, K, -1).reshape(B * K, -1)
        energy_negatives = ebm(u_neg_flat, h_exp).view(B, K)

        loss, _ = loss_fn(
            energy_positive, energy_negatives,
            costs_positive, costs_negative
        )

        total_loss += loss.item()
        total_energy_pos += energy_positive.mean().item()
        total_energy_neg += energy_negatives.mean().item()
        n_batches += 1

    ebm.train()

    return {
        "val_loss": total_loss / max(1, n_batches),
        "val_energy_pos": total_energy_pos / max(1, n_batches),
        "val_energy_neg": total_energy_neg / max(1, n_batches),
    }

In [None]:
# ============================================================================
# OPTIONAL: LOAD CHECKPOINT TO RESUME TRAINING
# ============================================================================
# Set RESUME_TRAINING = True to continue from a saved checkpoint
# Set CHECKPOINT_PATH to the path of the checkpoint file
# ============================================================================

RESUME_TRAINING = False  # ‚Üê Change to True to resume
CHECKPOINT_PATH = None   # ‚Üê Set path, e.g., "outputs/preference_training/best_model.pt"
                         # or "outputs/preference_training/checkpoint_phase2_epoch_5.pt"
SKIP_PHASE1 = False      # ‚Üê Set True to skip Phase 1 entirely (if already done)
START_PHASE2_EPOCH = 1   # ‚Üê Set starting epoch for Phase 2 (if resuming mid-phase)

if RESUME_TRAINING and CHECKPOINT_PATH is not None:
    checkpoint_file = Path(CHECKPOINT_PATH)
    
    if checkpoint_file.exists():
        print(f"üìÇ Loading checkpoint from: {checkpoint_file}")
        checkpoint = torch.load(checkpoint_file, map_location=config.device)
        
        # Load model weights
        if 'model_state_dict' in checkpoint:
            ebm.load_state_dict(checkpoint['model_state_dict'])
            print(f"   ‚úì EBM model weights loaded")
        
        # Load conditioner weights if available
        if 'conditioner_state_dict' in checkpoint and checkpoint['conditioner_state_dict'] is not None:
            if trainer.conditioner is not None:
                trainer.conditioner.load_state_dict(checkpoint['conditioner_state_dict'])
                print(f"   ‚úì HConditioner weights loaded")
        
        # Load cost proxy weights if available
        if 'cost_proxy_state_dict' in checkpoint and checkpoint['cost_proxy_state_dict'] is not None:
            if trainer.cost_proxy is not None:
                trainer.cost_proxy.load_state_dict(checkpoint['cost_proxy_state_dict'])
                print(f"   ‚úì CostProxy weights loaded")
        
        # Load history if available
        if 'history' in checkpoint:
            for k, v in checkpoint['history'].items():
                history[k] = v
            print(f"   ‚úì Training history loaded ({len(history)} keys)")
        
        # Get checkpoint info
        ckpt_epoch = checkpoint.get('epoch', 0)
        ckpt_phase = checkpoint.get('phase', 'unknown')
        ckpt_loss = checkpoint.get('loss', float('inf'))
        
        print(f"\nüìä Checkpoint info:")
        print(f"   Phase: {ckpt_phase}")
        print(f"   Epoch: {ckpt_epoch}")
        print(f"   Loss: {ckpt_loss:.4f}")
        
        # Auto-configure resume settings
        if ckpt_phase == 'finetune':
            SKIP_PHASE1 = True
            START_PHASE2_EPOCH = ckpt_epoch + 1
            best_loss = min(history.get('phase1_loss', [float('inf')]))
            print(f"\n‚öôÔ∏è Auto-configured:")
            print(f"   SKIP_PHASE1 = True (Phase 1 already done)")
            print(f"   START_PHASE2_EPOCH = {START_PHASE2_EPOCH}")
        elif ckpt_phase == 'pretrain':
            SKIP_PHASE1 = True  # Phase 1 was completed
            best_loss = ckpt_loss
            print(f"\n‚öôÔ∏è Auto-configured:")
            print(f"   SKIP_PHASE1 = True (will start Phase 2)")
        
        print(f"\n‚úÖ Checkpoint loaded successfully!")
    else:
        print(f"‚ö†Ô∏è Checkpoint file not found: {checkpoint_file}")
        print(f"   Training will start from scratch.")
        RESUME_TRAINING = False
else:
    if RESUME_TRAINING:
        print("‚ö†Ô∏è RESUME_TRAINING=True but CHECKPOINT_PATH is None")
        print("   Please set CHECKPOINT_PATH to resume training.")
    print("üÜï Starting fresh training (no checkpoint loaded)")

In [None]:
# ============================================================================
# TWO-PHASE TRAINING LOOP (Production-Grade v2)
# ============================================================================
# Phase 1: Contrastive (InfoNCE) without LP - stable energy geometry
# Phase 2: Ranked (WeightedMarginRankingLoss) with LP - cost-aware
# ============================================================================

import time
from collections import defaultdict
from src.preference.training_strategy import (
    TwoPhaseConfig,
    TwoPhaseTrainer,
    TrainingPhase,
    create_two_phase_trainer,
)
from src.preference.conditioning import (
    HConditioner,
    DecisionFeatureExtractor,
    FeatureBasedCostProxy,
)

# ============================================================================
# EARLY STOPPING HELPER
# ============================================================================
class EarlyStopping:
    def __init__(self, patience: int = 3, min_delta: float = 5e-4):
        self.patience = patience
        self.min_delta = min_delta
        self.best_loss = float('inf')
        self.counter = 0
        self.should_stop = False
    
    def step(self, loss: float) -> bool:
        if loss < self.best_loss - self.min_delta:
            self.best_loss = loss
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.should_stop = True
        return self.should_stop

# ============================================================================
# PHASE 1 SCHEDULE: Adaptive K and steps
# ============================================================================
def get_phase1_schedule(epoch: int):
    """
    Returns (K, langevin_steps) for Phase 1 based on epoch.
    epochs 1-3:  K=3, steps=10
    epochs 4-8:  K=6, steps=15
    epochs 9-12: K=8, steps=20
    """
    if epoch <= 3:
        return 3, 10
    elif epoch <= 8:
        return 6, 15
    else:
        return 8, 20

# Create TwoPhaseConfig from our config
two_phase_config = TwoPhaseConfig(
    # Phase 1: Pretrain (contrastive) - will be overridden by schedule
    pretrain_epochs=12,  # max epochs (early stopping may end sooner)
    pretrain_lr=config.pretrain_lr,
    pretrain_langevin_steps=10,  # base, will be updated per epoch
    pretrain_num_candidates=3,   # base, will be updated per epoch
    
    # Phase 2: Finetune (ranked with LP)
    finetune_epochs=config.finetune_epochs,
    finetune_lr=config.finetune_lr,
    finetune_langevin_steps=config.finetune_langevin_steps,
    finetune_lp_ratio=config.finetune_lp_ratio,
    finetune_num_candidates_min=config.finetune_num_candidates_min,
    finetune_num_candidates_max=config.finetune_num_candidates_max,
    finetune_uncertainty_top=config.finetune_uncertainty_top,
    finetune_cost_proxy_top_k=config.finetune_cost_proxy_top_k,
    
    # General
    batch_size=config.batch_size,
    margin=config.margin,
    device=config.device,
)

# Create two-phase trainer with new architecture
trainer = create_two_phase_trainer(
    ebm=ebm,
    sampler=sampler,
    config=two_phase_config,
    lp_oracle=lp_oracle if config.use_lp_worker else None,
    embedding_dim=config.h_dim,
    use_cost_proxy=True,
    use_conditioner=True,
    use_decoder=True,  # FeasibilityDecoder for mutual exclusion
)

print("‚úì TwoPhaseTrainer initialized (v2)")
print(f"  HConditioner: {'Enabled' if trainer.conditioner else 'Disabled'}")
print(f"  FeatureBasedCostProxy: {'Enabled' if trainer.cost_proxy else 'Disabled'}")
print(f"  FeasibilityDecoder: {'Enabled' if trainer.decoder else 'Disabled'}")
print(f"  LP Oracle: {'Enabled' if config.use_lp_worker else 'Disabled'}")

# Training history (may be pre-loaded from checkpoint)
if 'history' not in dir() or not history:
    history = defaultdict(list)
if 'best_loss' not in dir():
    best_loss = float("inf")

print("\n" + "=" * 70)
print("üöÄ TWO-PHASE PREFERENCE-BASED EBM TRAINING (v2)")
print("=" * 70)
print("   Phase 1: InfoNCE contrastive + hard negatives (early stopping)")
print("   Phase 2: WeightedMarginRankingLoss + FeasibilityDecoder (cost-aware)")
if RESUME_TRAINING:
    print(f"   üìÇ RESUMING FROM CHECKPOINT")
    print(f"      Skip Phase 1: {SKIP_PHASE1}")
    print(f"      Start Phase 2 epoch: {START_PHASE2_EPOCH}")
print("=" * 70)

start_time = time.time()

# ============================================================================
# PHASE 1: PRETRAIN - Contrastive with Early Stopping & Adaptive Schedule
# ============================================================================
if not SKIP_PHASE1:
    phase1_config = two_phase_config.get_phase_config(TrainingPhase.PRETRAIN)
    phase1_max_epochs = 12

    optimizer_phase1 = torch.optim.AdamW(
        list(ebm.parameters()) + (list(trainer.conditioner.parameters()) if trainer.conditioner else []),
        lr=phase1_config.learning_rate,
        weight_decay=config.weight_decay,
    )
    scheduler_phase1 = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer_phase1, T_max=phase1_max_epochs, eta_min=phase1_config.learning_rate * 0.01
    )

    early_stopper = EarlyStopping(patience=3, min_delta=5e-4)

    print(f"\n{'='*70}")
    print(f"üì¶ PHASE 1: PRETRAIN (Contrastive + Hard Negatives)")
    print(f"{'='*70}")
    print(f"   Loss: InfoNCE + top-2 hard negatives")
    print(f"   Max epochs: {phase1_max_epochs} (early stopping: patience=3)")
    print(f"   LP Oracle: DISABLED")
    print(f"   FeasibilityDecoder: DISABLED (Phase 1)")
    print(f"   Schedule:")
    print(f"     epochs 1-3:  K=3, steps=10")
    print(f"     epochs 4-8:  K=6, steps=15")
    print(f"     epochs 9-12: K=8, steps=20")
    print(f"{'='*70}")

    ebm.train()
    if trainer.conditioner:
        trainer.conditioner.train()

    for epoch in range(1, phase1_max_epochs + 1):
        epoch_start = time.time()
        epoch_losses = []
        epoch_energy_pos = []
        epoch_energy_neg = []
        epoch_energy_gap = []
        
        # Get adaptive K and steps for this epoch
        K_epoch, steps_epoch = get_phase1_schedule(epoch)
        phase1_config.num_candidates_min = K_epoch
        phase1_config.num_candidates_max = K_epoch
        phase1_config.langevin_steps = steps_epoch
        
        pbar = tqdm(train_loader, desc=f"[P1] Epoch {epoch}/{phase1_max_epochs} (K={K_epoch})", leave=False)
        for batch_idx, batch in enumerate(pbar):
            metrics = trainer.train_step(
                batch, optimizer_phase1, phase1_config, epoch, batch_idx
            )
            epoch_losses.append(metrics['loss'])
            epoch_energy_pos.append(metrics.get('energy_pos', 0))
            epoch_energy_neg.append(metrics.get('energy_neg', 0))
            epoch_energy_gap.append(metrics.get('energy_gap', 0))
            
            pbar.set_postfix({
                "loss": f"{metrics['loss']:.4f}",
                "E+": f"{metrics.get('energy_pos', 0):.2f}",
                "E-": f"{metrics.get('energy_neg', 0):.2f}",
                "gap": f"{metrics.get('energy_gap', 0):.2f}",
            })
        
        scheduler_phase1.step()
        
        avg_loss = np.mean(epoch_losses)
        avg_e_pos = np.mean(epoch_energy_pos)
        avg_e_neg = np.mean(epoch_energy_neg)
        avg_gap = np.mean(epoch_energy_gap)
        
        history["phase1_loss"].append(avg_loss)
        history["phase1_energy_pos"].append(avg_e_pos)
        history["phase1_energy_neg"].append(avg_e_neg)
        history["phase1_energy_gap"].append(avg_gap)
        
        epoch_time = time.time() - epoch_start
        
        # Early stopping check
        should_stop = early_stopper.step(avg_loss)
        stop_indicator = " ‚ö†Ô∏è EARLY STOP" if should_stop else ""
        
        print(f"[P1] Epoch {epoch}/{phase1_max_epochs} ({epoch_time:.1f}s) K={K_epoch} | "
              f"Loss: {avg_loss:.4f} | E+: {avg_e_pos:.3f} | E-: {avg_e_neg:.3f} | "
              f"Gap: {avg_gap:.3f} | LR: {scheduler_phase1.get_last_lr()[0]:.2e}{stop_indicator}")
        
        # Save best
        if avg_loss < best_loss:
            best_loss = avg_loss
            torch.save({
                "epoch": epoch, "phase": "pretrain",
                "model_state_dict": ebm.state_dict(),
                "conditioner_state_dict": trainer.conditioner.state_dict() if trainer.conditioner else None,
                "loss": best_loss,
            }, Path(config.output_dir) / "best_model_phase1.pt")
        
        if should_stop:
            print(f"   Early stopping triggered after {epoch} epochs (no improvement for {early_stopper.patience} epochs)")
            break

    phase1_time = time.time() - start_time
    print(f"\n‚úì Phase 1 complete: {phase1_time/60:.1f} min, Best loss: {best_loss:.4f}")
else:
    print(f"\n‚è≠Ô∏è SKIPPING PHASE 1 (already completed, loading from checkpoint)")
    phase1_time = 0

# ============================================================================
# PHASE 2: FINETUNE - Ranked with LP + FeasibilityDecoder (detailed logging)
# ============================================================================
phase2_config = two_phase_config.get_phase_config(TrainingPhase.FINETUNE)
optimizer_phase2 = torch.optim.AdamW(
    list(ebm.parameters()) + 
    (list(trainer.conditioner.parameters()) if trainer.conditioner else []) +
    (list(trainer.cost_proxy.parameters()) if trainer.cost_proxy else []),
    lr=phase2_config.learning_rate,
    weight_decay=config.weight_decay,
)
scheduler_phase2 = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer_phase2, T_max=phase2_config.epochs, eta_min=phase2_config.learning_rate * 0.01
)

# If resuming, step scheduler to correct position
if RESUME_TRAINING and START_PHASE2_EPOCH > 1:
    for _ in range(START_PHASE2_EPOCH - 1):
        scheduler_phase2.step()
    print(f"   Scheduler stepped to epoch {START_PHASE2_EPOCH}")

print(f"\n{'='*70}")
print(f"üéØ PHASE 2: FINETUNE (Ranked with LP + FeasibilityDecoder)")
print(f"{'='*70}")
print(f"   Loss: WeightedMarginRankingLoss (cost-aware)")
print(f"   Epochs: {START_PHASE2_EPOCH} ‚Üí {phase2_config.epochs}")
print(f"   LP Oracle: {phase2_config.lp_batch_ratio:.0%} of batches")
print(f"   FeasibilityDecoder: ENABLED (mutual exclusion)")
print(f"   Uncertainty filter: top {phase2_config.uncertainty_threshold:.0%}")
print(f"   Candidates: K={phase2_config.num_candidates_min}‚Üí{phase2_config.num_candidates_max}")
print(f"   Cost proxy filter: top-{phase2_config.cost_proxy_top_k}")
print(f"   Langevin steps: {phase2_config.langevin_steps}")
print(f"{'='*70}")

phase2_start = time.time()
best_loss_phase2 = float("inf")

# Load best_loss_phase2 from history if resuming
if RESUME_TRAINING and history.get('phase2_loss'):
    best_loss_phase2 = min(history['phase2_loss'])
    print(f"   Best Phase 2 loss from history: {best_loss_phase2:.4f}")

for epoch in range(START_PHASE2_EPOCH, phase2_config.epochs + 1):
    epoch_start = time.time()
    epoch_losses = []
    epoch_energy_pos = []
    epoch_energy_neg = []
    epoch_energy_gap = []
    epoch_lp_ratios = []
    epoch_loss_rank = []
    epoch_mean_weight = []
    epoch_cost_gap = []
    epoch_cost_gap_p90 = []
    
    pbar = tqdm(train_loader, desc=f"[P2] Epoch {epoch}/{phase2_config.epochs}", leave=False)
    for batch_idx, batch in enumerate(pbar):
        metrics = trainer.train_step(
            batch, optimizer_phase2, phase2_config, epoch, batch_idx
        )
        epoch_losses.append(metrics['loss'])
        epoch_energy_pos.append(metrics.get('energy_pos', 0))
        epoch_energy_neg.append(metrics.get('energy_neg', 0))
        epoch_energy_gap.append(metrics.get('energy_gap', 0))
        epoch_lp_ratios.append(metrics.get('lp_ratio', 0))
        epoch_loss_rank.append(metrics.get('loss_rank', 0))
        epoch_mean_weight.append(metrics.get('mean_weight', 0))
        epoch_cost_gap.append(metrics.get('mean_cost_gap', 0))
        epoch_cost_gap_p90.append(metrics.get('cost_gap_p90', 0))
        
        pbar.set_postfix({
            "loss": f"{metrics['loss']:.4f}",
            "L_rank": f"{metrics.get('loss_rank', 0):.3f}",
            "LP": f"{metrics.get('lp_ratio', 0):.0%}",
            "wÃÑ": f"{metrics.get('mean_weight', 0):.2f}",
        })
    
    scheduler_phase2.step()
    
    avg_loss = np.mean(epoch_losses)
    avg_e_pos = np.mean(epoch_energy_pos)
    avg_e_neg = np.mean(epoch_energy_neg)
    avg_gap = np.mean(epoch_energy_gap)
    avg_lp_ratio = np.mean(epoch_lp_ratios)
    avg_loss_rank = np.mean(epoch_loss_rank)
    avg_mean_weight = np.mean(epoch_mean_weight)
    avg_cost_gap = np.mean(epoch_cost_gap)
    avg_cost_gap_p90 = np.mean(epoch_cost_gap_p90)
    
    history["phase2_loss"].append(avg_loss)
    history["phase2_energy_pos"].append(avg_e_pos)
    history["phase2_energy_neg"].append(avg_e_neg)
    history["phase2_energy_gap"].append(avg_gap)
    history["phase2_lp_ratio"].append(avg_lp_ratio)
    history["phase2_loss_rank"].append(avg_loss_rank)
    history["phase2_mean_weight"].append(avg_mean_weight)
    history["phase2_cost_gap"].append(avg_cost_gap)
    
    epoch_time = time.time() - epoch_start
    
    if epoch % config.log_every == 0:
        print(f"[P2] Epoch {epoch}/{phase2_config.epochs} ({epoch_time:.1f}s)")
        print(f"     Loss: {avg_loss:.4f} | L_rank: {avg_loss_rank:.4f}")
        print(f"     E+: {avg_e_pos:.3f} | E-: {avg_e_neg:.3f} | Gap: {avg_gap:.3f}")
        print(f"     LP%: {avg_lp_ratio:.1%} | wÃÑ: {avg_mean_weight:.2f}")
        print(f"     Cost gap: {avg_cost_gap:.1f} | p90: {avg_cost_gap_p90:.1f}")
    
    # Evaluation
    if epoch % config.eval_every == 0 and len(eval_loader) > 0:
        val_metrics = evaluate_with_oracle(eval_loader, ebm, sampler, loss_fn, config, lp_oracle)
        history["val_loss"].append(val_metrics["val_loss"])
        print(f"     Val Loss: {val_metrics['val_loss']:.4f}")
    
    # Save best
    if avg_loss < best_loss_phase2:
        best_loss_phase2 = avg_loss
        torch.save({
            "epoch": epoch, "phase": "finetune",
            "model_state_dict": ebm.state_dict(),
            "conditioner_state_dict": trainer.conditioner.state_dict() if trainer.conditioner else None,
            "cost_proxy_state_dict": trainer.cost_proxy.state_dict() if trainer.cost_proxy else None,
            "optimizer_state_dict": optimizer_phase2.state_dict(),
            "loss": best_loss_phase2,
            "config": config.__dict__,
        }, Path(config.output_dir) / "best_model.pt")
    
    # Checkpoint every save_every epochs
    if epoch % config.save_every == 0:
        torch.save({
            "epoch": epoch, "phase": "finetune",
            "model_state_dict": ebm.state_dict(),
            "conditioner_state_dict": trainer.conditioner.state_dict() if trainer.conditioner else None,
            "cost_proxy_state_dict": trainer.cost_proxy.state_dict() if trainer.cost_proxy else None,
            "history": dict(history),
            "trainer_stats": trainer.get_statistics(),
        }, Path(config.output_dir) / f"checkpoint_phase2_epoch_{epoch}.pt")
        print(f"     üíæ Checkpoint saved: checkpoint_phase2_epoch_{epoch}.pt")

phase2_time = time.time() - phase2_start
total_time = time.time() - start_time

# Final statistics
stats = trainer.get_statistics()

print("\n" + "=" * 70)
print("‚úÖ TWO-PHASE TRAINING COMPLETE (v2)")
print("=" * 70)
print(f"üìä Total time: {total_time/60:.1f} minutes")
print(f"   Phase 1 (contrastive): {phase1_time/60:.1f} min ({len(history.get('phase1_loss', []))} epochs)")
print(f"   Phase 2 (ranked): {phase2_time/60:.1f} min")
print(f"\nüìà Best losses:")
print(f"   Phase 1 (InfoNCE): {best_loss:.4f}")
print(f"   Phase 2 (Ranked): {best_loss_phase2:.4f}")
print(f"\n‚ö° Final energies:")
if history.get('phase2_energy_pos'):
    print(f"   E+ (positive): {history['phase2_energy_pos'][-1]:.3f}")
    print(f"   E- (negative): {history['phase2_energy_neg'][-1]:.3f}")
    print(f"   Gap (E- - E+): {history['phase2_energy_gap'][-1]:.3f}")
print(f"\nüî¨ LP Oracle statistics:")
print(f"   Total LP calls: {stats['lp_calls']:,}")
print(f"   LP skipped: {stats['lp_skipped']:,}")
print(f"   Effective LP ratio: {stats['lp_ratio']:.1%}")
print(f"   Proxy filtered: {stats['proxy_filtered']:,}")
print(f"\nüíæ Model saved to: {config.output_dir}")
print("=" * 70)

PREFERENCE-BASED EBM TRAINING
Device: cuda
Epochs: 50
Architecture: gru
Candidates per scenario: 3
LP Oracle: Enabled


Epoch 1/50:   0%|          | 0/1184 [00:00<?, ?it/s]

‚úì LPWorkerTwoStage initialized
  Solver: appsi_highs
  Slack tolerance: 1.0 MWh (dt=1.0h)
  Deviation penalty (Œª): 10000.0
  Flip budgets: K=20‚Üí100, K=100‚Üí1000, full_soft‚ÜíNone
  Time limits: TL1=0.3s, TL2=0.2s, TL3=0.2s, TL4=0.5s
‚úì LPOracle initialized with LPWorkerTwoStage


KeyboardInterrupt: 

In [None]:
# ============================================================================
# 6. EVALUATION & VISUALIZATION
# ============================================================================

import matplotlib.pyplot as plt
import seaborn as sns

# Set style
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette("husl")

fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# 1. Training Loss
ax1 = axes[0, 0]
ax1.plot(history["train_loss"], label="Train Loss", linewidth=2)
if "val_loss" in history and len(history["val_loss"]) > 0:
    eval_epochs = list(range(config.eval_every, len(history["train_loss"])+1, config.eval_every))
    ax1.plot(eval_epochs, history["val_loss"], label="Val Loss", linewidth=2, marker='o')
ax1.set_xlabel("Epoch")
ax1.set_ylabel("Loss")
ax1.set_title("Training & Validation Loss")
ax1.legend()
ax1.grid(True, alpha=0.3)

# 2. Energy Gap (E- - E+)
ax2 = axes[0, 1]
if "train_energy_gap" in history:
    ax2.plot(history["train_energy_gap"], label="Energy Gap (E- - E+)", linewidth=2, color="green")
ax2.axhline(y=0, color='red', linestyle='--', alpha=0.5, label="Zero gap")
ax2.set_xlabel("Epoch")
ax2.set_ylabel("Energy Gap")
ax2.set_title("Energy Gap Evolution (should be positive)")
ax2.legend()
ax2.grid(True, alpha=0.3)

# 3. Mean Energies
ax3 = axes[1, 0]
if "train_mean_energy_pos" in history:
    ax3.plot(history["train_mean_energy_pos"], label="E(u+ | h) Positive", linewidth=2)
if "train_mean_energy_neg" in history:
    ax3.plot(history["train_mean_energy_neg"], label="E(u- | h) Negative", linewidth=2)
ax3.set_xlabel("Epoch")
ax3.set_ylabel("Mean Energy")
ax3.set_title("Mean Energy by Decision Type")
ax3.legend()
ax3.grid(True, alpha=0.3)

# 4. Learning Rate
ax4 = axes[1, 1]
if "train_loss" in history:
    lrs = [config.learning_rate * (0.01 + 0.99 * (1 + np.cos(np.pi * i / config.epochs)) / 2)
           for i in range(len(history["train_loss"]))]
    ax4.plot(lrs, label="Learning Rate", linewidth=2, color="purple")
ax4.set_xlabel("Epoch")
ax4.set_ylabel("Learning Rate")
ax4.set_title("Learning Rate Schedule")
ax4.set_yscale('log')
ax4.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(Path(config.output_dir) / "training_curves.png", dpi=150, bbox_inches='tight')
plt.show()

print(f"\n‚úì Training curves saved to: {config.output_dir}/training_curves.png")

In [None]:
# ============================================================================
# 7. MODEL ANALYSIS & INFERENCE
# ============================================================================

@torch.no_grad()
def sample_and_rank(ebm, sampler, h, n_candidates, n_zones, n_timesteps, n_features=8):
    """
    Sample candidates and rank by energy.
    Returns candidates sorted by energy (lowest first = best).
    """
    ebm.eval()
    device = h.device
    B = h.shape[0]

    # Sample candidates
    u_candidates, _ = sampler.sample(
        h=h,
        n_samples=n_candidates,
        n_zones=n_zones,
        n_timesteps=n_timesteps,
        n_features=n_features,
    )  # [B, K, Z, T, 8]

    # Compute energies for all candidates
    K = n_candidates
    u_flat = u_candidates.view(B * K, -1)
    h_exp = h.unsqueeze(1).expand(-1, K, -1).reshape(B * K, -1)
    energies = ebm(u_flat, h_exp).view(B, K)  # [B, K]

    # Sort by energy (ascending)
    sorted_indices = energies.argsort(dim=1)

    # Reorder candidates
    sorted_candidates = torch.gather(
        u_candidates,
        dim=1,
        index=sorted_indices.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).expand(-1, -1, n_zones, n_timesteps, n_features)
    )
    sorted_energies = torch.gather(energies, dim=1, index=sorted_indices)

    return sorted_candidates, sorted_energies

# Test inference on a sample
if len(eval_dataset) > 0:
    print("=" * 60)
    print("INFERENCE TEST")
    print("=" * 60)

    # Get a sample
    sample = eval_dataset[0]
    h_test = sample["embedding"].unsqueeze(0).to(config.device)

    # Sample and rank candidates
    candidates, energies = sample_and_rank(
        ebm, sampler, h_test,
        n_candidates=10,
        n_zones=config.n_zones,
        n_timesteps=config.n_timesteps,
    )

    print(f"Scenario: {sample['scenario_id']}")
    print(f"MILP Objective: {sample['milp_objective']:.0f}")
    print(f"\nGenerated {candidates.shape[1]} candidates, ranked by energy:")
    for k in range(min(5, candidates.shape[1])):
        u_k = candidates[0, k]  # [Z, T, 8]
        thermal_on = u_k[..., 6].sum().item()
        nuclear_on = u_k[..., 5].sum().item()
        print(f"  Rank {k+1}: E={energies[0, k].item():.3f}, "
              f"Thermal={thermal_on:.0f}, Nuclear={nuclear_on:.0f}")

    print("=" * 60)

In [None]:
# ============================================================================
# 8. SAVE RESULTS & NEXT STEPS
# ============================================================================

import json

# Save training history
history_path = Path(config.output_dir) / "training_history.json"
with open(history_path, 'w') as f:
    json.dump({k: [float(v) for v in vals] for k, vals in history.items()}, f, indent=2)
print(f"‚úì Training history saved to: {history_path}")

# Save final model
final_path = Path(config.output_dir) / "final_model.pt"
torch.save({
    "epoch": config.epochs,
    "model_state_dict": ebm.state_dict(),
    "optimizer_state_dict": optimizer.state_dict(),
    "scheduler_state_dict": scheduler.state_dict(),
    "config": config.__dict__,
    "history": dict(history),
}, final_path)
print(f"‚úì Final model saved to: {final_path}")

# Summary
print("\n" + "=" * 60)
print("TRAINING SUMMARY")
print("=" * 60)
print(f"Architecture: {config.architecture}")
print(f"Total epochs: {config.epochs}")
print(f"Best loss: {best_loss:.4f}")
print(f"Final train loss: {history['train_loss'][-1]:.4f}")
if "val_loss" in history and len(history["val_loss"]) > 0:
    print(f"Final val loss: {history['val_loss'][-1]:.4f}")
print(f"\nOutput directory: {config.output_dir}")
print("=" * 60)

print("\nüìã NEXT STEPS:")
print("1. Enable LP Worker (set config.use_lp_worker = True) for real cost evaluation")
print("2. Increase training scenarios and epochs")
print("3. Load HTE embeddings (set config.embeddings_path)")
print("4. Run inference on new scenarios and evaluate vs MILP")

In [None]:
# ============================================================================
# 9. VERIFY EBM LEARNED CORRECTLY: MILP should have LOWER energy
# ============================================================================

print("=" * 70)
print("VERIFICATION: Energy of MILP vs Generated Candidates")
print("=" * 70)

ebm.eval()

# Collect energies on eval set
milp_energies = []
candidate_energies = []
energy_gaps = []

with torch.no_grad():
    for batch in tqdm(eval_loader, desc="Evaluating"):
        B = len(batch["scenario_ids"])

        # Get embeddings and MILP decisions
        h = batch["embeddings"].to(config.device)
        u_positive = batch["milp_decisions"].to(config.device)

        n_zones = batch["n_zones"]
        n_timesteps = batch["n_timesteps"]

        # Handle embedding format for EBM
        if h.dim() == 3:
            h_flat = h.mean(dim=1)
        elif h.dim() == 4:
            h_flat = h.mean(dim=(1, 2))
        else:
            h_flat = h

        # Compute energy of MILP decisions
        u_pos_flat = u_positive.view(B, -1)
        e_milp = ebm(u_pos_flat, h_flat)  # [B]
        milp_energies.extend(e_milp.cpu().tolist())

        # Sample and evaluate K random candidates
        K = config.num_candidates
        u_candidates, _ = sampler.sample(
            h=h,
            n_samples=K,
            n_zones=n_zones,
            n_timesteps=n_timesteps,
            n_features=config.n_features,
        )
        u_binary = binarize_decisions(u_candidates)

        # Compute energy of candidates
        u_neg_flat = u_binary.view(B * K, -1)
        h_exp = h_flat.unsqueeze(1).expand(-1, K, -1).reshape(B * K, -1)
        e_candidates = ebm(u_neg_flat, h_exp).view(B, K)  # [B, K]

        # Store mean energy per scenario
        candidate_energies.extend(e_candidates.mean(dim=1).cpu().tolist())

        # Energy gap per scenario (should be positive if EBM learned correctly)
        gaps = e_candidates.mean(dim=1) - e_milp
        energy_gaps.extend(gaps.cpu().tolist())

# Compute statistics
milp_energies = np.array(milp_energies)
candidate_energies = np.array(candidate_energies)
energy_gaps = np.array(energy_gaps)

print(f"\nüìä Energy Statistics:")
print(f"   MILP solutions (u+):      mean = {milp_energies.mean():.3f}, std = {milp_energies.std():.3f}")
print(f"   Generated candidates (u-): mean = {candidate_energies.mean():.3f}, std = {candidate_energies.std():.3f}")
print(f"   Energy gap (E- - E+):     mean = {energy_gaps.mean():.3f}, std = {energy_gaps.std():.3f}")

# Check if training was successful
pct_correct = (energy_gaps > 0).mean() * 100
print(f"\n‚úì MILP has lower energy in {pct_correct:.1f}% of scenarios")

if energy_gaps.mean() > 0:
    print("‚úÖ EBM TRAINED CORRECTLY: MILP solutions have lower energy on average")
else:
    print("‚ö†Ô∏è WARNING: EBM may need more training - candidates have lower energy than MILP")

# Plot energy distribution
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# 1. Energy histogram
ax1 = axes[0]
ax1.hist(milp_energies, bins=30, alpha=0.7, label="MILP (u+)", color="green")
ax1.hist(candidate_energies, bins=30, alpha=0.7, label="Candidates (u-)", color="red")
ax1.set_xlabel("Energy")
ax1.set_ylabel("Count")
ax1.set_title("Energy Distribution")
ax1.legend()
ax1.axvline(milp_energies.mean(), color="green", linestyle="--", linewidth=2)
ax1.axvline(candidate_energies.mean(), color="red", linestyle="--", linewidth=2)

# 2. Energy gap histogram
ax2 = axes[1]
ax2.hist(energy_gaps, bins=30, alpha=0.7, color="blue")
ax2.axvline(0, color="red", linestyle="--", linewidth=2, label="Zero (boundary)")
ax2.axvline(energy_gaps.mean(), color="green", linestyle="--", linewidth=2, label=f"Mean: {energy_gaps.mean():.2f}")
ax2.set_xlabel("Energy Gap (E- - E+)")
ax2.set_ylabel("Count")
ax2.set_title(f"Energy Gap Distribution\n{pct_correct:.0f}% scenarios with positive gap")
ax2.legend()

# 3. Scatter: MILP energy vs Candidate energy
ax3 = axes[2]
ax3.scatter(milp_energies, candidate_energies, alpha=0.5, s=20)
ax3.plot([milp_energies.min(), milp_energies.max()],
         [milp_energies.min(), milp_energies.max()],
         'r--', linewidth=2, label="y=x (equal energy)")
ax3.set_xlabel("E(MILP)")
ax3.set_ylabel("E(Candidates)")
ax3.set_title("MILP vs Candidate Energy\nPoints above line = EBM correct")
ax3.legend()

plt.tight_layout()
plt.savefig(Path(config.output_dir) / "energy_verification.png", dpi=150, bbox_inches='tight')
plt.show()

print(f"\n‚úì Verification plot saved to: {config.output_dir}/energy_verification.png")
print("=" * 70)