In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from pathlib import Path
import json

import optuna
from optuna.visualization import (
    plot_optimization_history,
    plot_param_importances,
    plot_parallel_coordinate,
    plot_slice,
    plot_contour,
)

from circuit_tracer.utils.create_graph_files import load_graph_data
from algorithm1 import run_ea_optimization
from models import EAHyperparameters

## 1. Load Graph

In [None]:
base_name = "war"

graph_dir = Path('graphs')
graph_path = graph_dir / f'{base_name}.pt'

graph = load_graph_data(graph_path)
print(f"Graph loaded: {graph.cfg.n_layers} layers")

## 2. Define Objective Function

In [None]:
def objective(trial: optuna.Trial) -> float:
    """
    Objective function for Optuna optimization.
    
    Returns a score combining quality and complexity of the balanced solution.
    """
    # Sample hyperparameters
    hp = EAHyperparameters(
        population_size=trial.suggest_int("population_size", 20, 100, step=10),
        n_generations=30,  # Fixed for faster trials
        mutation_rate=trial.suggest_float("mutation_rate", 0.05, 0.5),
        crossover_rate=trial.suggest_float("crossover_rate", 0.5, 0.95),
        mutation_sigma=trial.suggest_float("mutation_sigma", 0.05, 0.3),
        
        # Objective weights
        w_completeness=trial.suggest_float("w_completeness", 0.1, 2.0),
        w_replacement=trial.suggest_float("w_replacement", 0.1, 2.0),
        w_complexity_node=trial.suggest_float("w_complexity_node", 1.0, 15.0),
        w_complexity_edge=trial.suggest_float("w_complexity_edge", 0.5, 10.0),
    )
    
    # Run EA optimization
    result = run_ea_optimization(
        graph=graph,
        verbose=False,
        use_batch=True,
        max_batch_per_gpu=8,
        hp=hp,
    )
    
    # Extract balanced solution
    balanced = result['balanced']
    
    quality = balanced['quality']
    complexity = balanced['complexity']
    
    # Composite score: maximize quality, minimize complexity
    score = quality - 0.1 * complexity
    
    # Store metrics for analysis
    trial.set_user_attr("quality", quality)
    trial.set_user_attr("complexity", complexity)
    trial.set_user_attr("completeness", balanced['completeness'])
    trial.set_user_attr("replacement", balanced['replacement'])
    trial.set_user_attr("n_nodes", balanced['n_nodes'])
    trial.set_user_attr("n_edges", balanced['n_edges'])
    trial.set_user_attr("pareto_front_size", result['pareto_front_size'])
    
    return score

## 3. Run Optimization

In [None]:
# Create study
study = optuna.create_study(
    study_name=f"ea_hp_{base_name}",
    direction="maximize",
    sampler=optuna.samplers.TPESampler(seed=42),
    pruner=optuna.pruners.MedianPruner(n_startup_trials=5),
)

# Run optimization
n_trials = 30
study.optimize(objective, n_trials=n_trials, show_progress_bar=True)

## 4. Analyze Results

In [None]:
print("Best trial:")
print(f"  Trial: {study.best_trial.number}")
print(f"  Score: {study.best_value:.4f}")
print(f"\nBest hyperparameters:")
for key, value in study.best_params.items():
    print(f"  {key}: {value:.4f}" if isinstance(value, float) else f"  {key}: {value}")

print(f"\nBalanced solution metrics:")
for key, value in study.best_trial.user_attrs.items():
    print(f"  {key}: {value:.4f}" if isinstance(value, float) else f"  {key}: {value}")

## 5. Visualizations

In [None]:
# Optimization history
plot_optimization_history(study)

In [None]:
# Parameter importances
plot_param_importances(study)

In [None]:
# Parallel coordinate plot
plot_parallel_coordinate(study)

In [None]:
# Slice plot for key parameters
plot_slice(study, params=["w_completeness", "w_replacement", "w_complexity_node"])

In [None]:
# Contour plot for weight parameters
plot_contour(study, params=["w_completeness", "w_replacement"])

## 6. Save Best Hyperparameters

In [None]:
# Save results to JSON
output_dir = Path("optuna_results")
output_dir.mkdir(exist_ok=True)

results_file = output_dir / f"{base_name}_best_hyperparameters.json"
with open(results_file, 'w') as f:
    json.dump({
        "best_trial": study.best_trial.number,
        "best_score": study.best_value,
        "best_params": study.best_params,
        "best_metrics": study.best_trial.user_attrs,
    }, f, indent=2)

print(f"Results saved to {results_file}")

## 7. Test Best Hyperparameters

In [None]:
# Create hyperparameters from best trial
best_hp = EAHyperparameters(
    population_size=study.best_params["population_size"],
    n_generations=50,  # Use more generations for final run
    mutation_rate=study.best_params["mutation_rate"],
    crossover_rate=study.best_params["crossover_rate"],
    mutation_sigma=study.best_params["mutation_sigma"],
    w_completeness=study.best_params["w_completeness"],
    w_replacement=study.best_params["w_replacement"],
    w_complexity_node=study.best_params["w_complexity_node"],
    w_complexity_edge=study.best_params["w_complexity_edge"],
)

# Run with best hyperparameters
print("Running EA with best hyperparameters...")
final_result = run_ea_optimization(
    graph=graph,
    verbose=True,
    use_batch=True,
    max_batch_per_gpu=8,
    hp=best_hp,
)

print("\nFinal Results:")
print(f"Pareto Front Size: {final_result['pareto_front_size']}")
print(f"\nBalanced Solution:")
bal = final_result['balanced']
print(f"  Quality: {bal['quality']:.4f}")
print(f"  Complexity: {bal['complexity']:.4f}")
print(f"  Completeness: {bal['completeness']:.4f}")
print(f"  Replacement: {bal['replacement']:.4f}")
print(f"  Nodes: {bal['n_nodes']}, Edges: {bal['n_edges']}")