## 11. Plot Training Curves (Graph EBM)

In [None]:
import matplotlib.pyplot as plt

fig, axes = plt.subplots(2, 2, figsize=(15, 10))

# Loss
axes[0, 0].plot(history['train_loss'], label='Train', linewidth=2)
if len(history.get('val_loss', [])) > 0:
    val_epochs = [i * CONFIG['validate_every'] for i in range(len(history['val_loss']))]
    axes[0, 0].plot(val_epochs, history['val_loss'], label='Val', marker='o', linewidth=2)
axes[0, 0].set_title('Training Loss', fontsize=14, fontweight='bold')
axes[0, 0].set_xlabel('Epoch', fontsize=12)
axes[0, 0].set_ylabel('Loss', fontsize=12)
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

# Energy gap (E_pos - E_neg)
axes[0, 1].plot(history['train_gap'], label='Train', linewidth=2, color='green')
if len(history.get('val_gap', [])) > 0:
    val_epochs = [i * CONFIG['validate_every'] for i in range(len(history['val_gap']))]
    axes[0, 1].plot(val_epochs, history['val_gap'], label='Val', marker='o', linewidth=2, color='darkgreen')
axes[0, 1].set_title('Energy Gap (E_pos - E_neg)', fontsize=14, fontweight='bold')
axes[0, 1].set_xlabel('Epoch', fontsize=12)
axes[0, 1].set_ylabel('Gap', fontsize=12)
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

# Positive energy
axes[1, 0].plot(history['train_e_pos'], linewidth=2, color='blue')
axes[1, 0].set_title('Positive Sample Energy', fontsize=14, fontweight='bold')
axes[1, 0].set_xlabel('Epoch', fontsize=12)
axes[1, 0].set_ylabel('E(u_pos | graph)', fontsize=12)
axes[1, 0].grid(True, alpha=0.3)

# Negative energy
axes[1, 1].plot(history['train_e_neg'], linewidth=2, color='red')
axes[1, 1].set_title('Negative Sample Energy', fontsize=14, fontweight='bold')
axes[1, 1].set_xlabel('Epoch', fontsize=12)
axes[1, 1].set_ylabel('E(u_neg | graph)', fontsize=12)
axes[1, 1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(os.path.join(CONFIG['output_dir'], 'graph_ebm_training_curves.png'), dpi=150, bbox_inches='tight')
plt.show()

print(f"\n✓ Training curves saved to {os.path.join(CONFIG['output_dir'], 'graph_ebm_training_curves.png')}")

## 12. Save Training History (Graph EBM)

In [None]:
import json

# Save history as JSON
history_path = os.path.join(CONFIG['output_dir'], 'graph_ebm_training_history.json')

with open(history_path, 'w') as f:
    json.dump(history, f, indent=2)

print(f"✓ Training history saved to {history_path}")

# Save config
config_path = os.path.join(CONFIG['output_dir'], 'graph_ebm_config.json')
with open(config_path, 'w') as f:
    # Convert non-serializable types
    config_save = {k: str(v) if not isinstance(v, (int, float, str, bool, list, dict, type(None))) else v 
                   for k, v in CONFIG.items()}
    json.dump(config_save, f, indent=2)

print(f"✓ Config saved to {config_path}")

# Print summary statistics
print("\n" + "=" * 60)
print("TRAINING STATISTICS")
print("=" * 60)
print(f"Total epochs: {len(history['train_loss'])}")
print(f"Best validation gap: {max(history.get('val_gap', [0])):.4f}")
print(f"Final training loss: {history['train_loss'][-1]:.4f}")
print(f"Final training gap: {history['train_gap'][-1]:.4f}")
print("=" * 60)

## 13. Evaluate Best Model (Graph EBM)

In [None]:
print("Evaluating best Graph EBM model...\n")

# Load best checkpoint
best_path = os.path.join(CONFIG['output_dir'], 'graph_ebm_best.pt')
if os.path.exists(best_path):
    graph_model.load_state_dict(torch.load(best_path))
    print(f"✓ Loaded best model from {best_path}")
else:
    print(f"⚠️  Best model not found, using current model")

# Detailed validation
graph_model.eval()

all_energy_pos = []
all_energy_neg = []
all_gaps = []
all_graph_sizes = []

print("Running evaluation on validation set...")
with torch.no_grad():
    for batch in tqdm(val_loader, desc="Evaluating"):
        batch = batch.to(CONFIG['device'])
        
        # Positive energy (ground truth)
        E_pos = graph_model(batch)
        
        # Sample negatives
        u_neg = graph_sampler.sample(batch)
        batch_neg = batch.clone()
        batch_neg.x = u_neg
        E_neg = graph_model(batch_neg)
        
        # Compute metrics
        energy_gap = (E_pos.mean() - E_neg.mean()).item()
        
        # Store results
        all_energy_pos.append(E_pos.mean().item())
        all_energy_neg.append(E_neg.mean().item())
        all_gaps.append(energy_gap)
        
        # Track graph sizes for analysis
        batch_size = batch.h.shape[0]
        avg_nodes = batch.x.shape[0] / batch_size
        all_graph_sizes.append(avg_nodes)

# Summary statistics
print("\n" + "=" * 60)
print("EVALUATION RESULTS (GRAPH EBM)")
print("=" * 60)
print(f"Energy (positive):    {np.mean(all_energy_pos):.4f} ± {np.std(all_energy_pos):.4f}")
print(f"Energy (negative):    {np.mean(all_energy_neg):.4f} ± {np.std(all_energy_neg):.4f}")
print(f"Energy gap:           {np.mean(all_gaps):.4f} ± {np.std(all_gaps):.4f}")
print(f"\nGraph Statistics:")
print(f"  Avg nodes per graph: {np.mean(all_graph_sizes):.1f}")
print(f"  Total batches:       {len(all_gaps)}")
print("=" * 60)

# Save evaluation results
eval_results = {
    'energy_pos_mean': float(np.mean(all_energy_pos)),
    'energy_pos_std': float(np.std(all_energy_pos)),
    'energy_neg_mean': float(np.mean(all_energy_neg)),
    'energy_neg_std': float(np.std(all_energy_neg)),
    'energy_gap_mean': float(np.mean(all_gaps)),
    'energy_gap_std': float(np.std(all_gaps)),
    'avg_graph_size': float(np.mean(all_graph_sizes)),
    'num_batches': len(all_gaps),
}

eval_path = os.path.join(CONFIG['output_dir'], 'graph_ebm_evaluation_results.json')
with open(eval_path, 'w') as f:
    json.dump(eval_results, f, indent=2)

print(f"\n✓ Evaluation results saved to {eval_path}")

## 14. Sample and Analyze Configurations (Graph EBM)

In [None]:
print("Sampling configurations from Graph EBM...\n")

# Get a test graph from validation set
test_batch = next(iter(val_loader)).to(CONFIG['device'])
print(f"Test batch size: {test_batch.h.shape[0]} graphs")
print(f"Total nodes: {test_batch.x.shape[0]}")
print(f"Nodes per graph (avg): {test_batch.x.shape[0] / test_batch.h.shape[0]:.1f}")

# Sample multiple configurations for the same graph
num_samples = 10
print(f"\nGenerating {num_samples} samples...")

sampled_energies = []
sampled_configs = []

with torch.no_grad():
    # Get ground truth energy
    E_true = graph_model(test_batch)
    print(f"Ground truth energy: {E_true.mean().item():.4f}")
    
    # Sample multiple configurations
    for i in range(num_samples):
        u_sample = graph_sampler.sample(test_batch)
        batch_sample = test_batch.clone()
        batch_sample.x = u_sample
        E_sample = graph_model(batch_sample)
        
        sampled_energies.append(E_sample.mean().item())
        sampled_configs.append(u_sample.cpu().numpy())

# Analyze samples
print(f"\n{num_samples} samples generated:")
print(f"Energy range: [{min(sampled_energies):.4f}, {max(sampled_energies):.4f}]")
print(f"Energy mean: {np.mean(sampled_energies):.4f}")
print(f"Energy std: {np.std(sampled_energies):.4f}")

# Check diversity (average Hamming distance between samples)
if num_samples > 1:
    hamming_distances = []
    for i in range(num_samples):
        for j in range(i+1, num_samples):
            dist = np.mean(sampled_configs[i] != sampled_configs[j])
            hamming_distances.append(dist)
    
    avg_diversity = np.mean(hamming_distances)
    print(f"\nSample diversity (avg Hamming): {avg_diversity:.4f}")

# Visualize first sample
sample_u = sampled_configs[0].flatten()
plt.figure(figsize=(15, 4))

# Plot first 500 variables or all if less
n_plot = min(500, len(sample_u))
plt.subplot(1, 2, 1)
plt.plot(sample_u[:n_plot], 'o-', markersize=2, linewidth=0.5)
plt.title(f'Sample Binary Configuration (first {n_plot} variables)', fontsize=12, fontweight='bold')
plt.xlabel('Variable Index')
plt.ylabel('Value')
plt.ylim([-0.1, 1.1])
plt.grid(True, alpha=0.3)

# Plot histogram of values
plt.subplot(1, 2, 2)
unique, counts = np.unique(sample_u, return_counts=True)
plt.bar(unique, counts)
plt.title('Distribution of Binary Values', fontsize=12, fontweight='bold')
plt.xlabel('Value')
plt.ylabel('Count')
plt.xticks([0, 1])
plt.grid(True, alpha=0.3, axis='y')

# Add sparsity info
sparsity = np.mean(sample_u == 0)
plt.text(0.5, max(counts)*0.9, f'Sparsity: {sparsity:.1%}', 
         ha='center', fontsize=10, bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))

plt.tight_layout()
plt.savefig(os.path.join(CONFIG['output_dir'], 'graph_ebm_sample_configuration.png'), dpi=150, bbox_inches='tight')
plt.show()

print(f"\n✓ Sample visualization saved")
print(f"\nSample statistics:")
print(f"  Total variables: {len(sample_u)}")
print(f"  Sparsity (% zeros): {sparsity:.1%}")
print(f"  Active variables: {np.sum(sample_u == 1)}")

## 15. Summary and Next Steps (Graph EBM)

In [None]:
# Count model parameters
num_params = sum(p.numel() for p in graph_model.parameters())
num_trainable = sum(p.numel() for p in graph_model.parameters() if p.requires_grad)

# Get dataset info
try:
    dataset_size = len(train_loader.dataset) + len(val_loader.dataset)
except:
    dataset_size = "Unknown"

print("\n" + "=" * 60)
print("GRAPH EBM TRAINING SUMMARY")
print("=" * 60)

print(f"\nModel Architecture:")
print(f"  Model Type: Graph Energy Model (Deep Sets)")
print(f"  Parameters: {num_params:,}")
print(f"  Trainable: {num_trainable:,}")
print(f"  Hidden Dims: {CONFIG['hidden_dims']}")
print(f"  Activation: {CONFIG['activation']}")
print(f"  Dropout: {CONFIG['dropout']}")

print(f"\nTraining Configuration:")
print(f"  Epochs: {len(history['train_loss'])}")
print(f"  Batch size: {CONFIG['batch_size']}")
print(f"  Learning rate: {CONFIG['learning_rate']}")
print(f"  Dataset size: {dataset_size}")
print(f"  Device: {CONFIG['device']}")

print(f"\nTraining Results:")
print(f"  Best validation gap: {max(history.get('val_gap', [0])):.4f}")
print(f"  Final training loss: {history['train_loss'][-1]:.4f}")
print(f"  Final training gap: {history['train_gap'][-1]:.4f}")
print(f"  Final E_pos: {history['train_e_pos'][-1]:.4f}")
print(f"  Final E_neg: {history['train_e_neg'][-1]:.4f}")

if 'energy_gap_mean' in eval_results:
    print(f"\nEvaluation Metrics:")
    print(f"  Energy gap (mean): {eval_results['energy_gap_mean']:.4f} ± {eval_results['energy_gap_std']:.4f}")
    print(f"  Avg graph size: {eval_results['avg_graph_size']:.1f} nodes")

print(f"\nSaved Artifacts:")
print(f"  Best model: {os.path.join(CONFIG['output_dir'], 'graph_ebm_best.pt')}")
print(f"  Final model: {os.path.join(CONFIG['output_dir'], 'graph_ebm_final.pt')}")
print(f"  Training history: {os.path.join(CONFIG['output_dir'], 'graph_ebm_training_history.json')}")
print(f"  Evaluation results: {os.path.join(CONFIG['output_dir'], 'graph_ebm_evaluation_results.json')}")
print(f"  Training curves: {os.path.join(CONFIG['output_dir'], 'graph_ebm_training_curves.png')}")
print(f"  Sample visualization: {os.path.join(CONFIG['output_dir'], 'graph_ebm_sample_configuration.png')}")

print(f"\nNext Steps:")
print(f"  1. Analyze training curves and energy gaps")
print(f"  2. Test on different graph sizes and scenarios")
print(f"  3. Compare with flat EBM baseline")
print(f"  4. Tune sampling strategy (exact vs approximate)")
print(f"  5. Use sampled configurations for MILP warmstart")
print(f"  6. Evaluate on unseen scenarios")

print("\n" + "=" * 60)
print("✓ GRAPH EBM TRAINING COMPLETE!")
print("=" * 60)

# Optional: Show model summary
print(f"\nModel Summary:")
print(graph_model)