# DINGOModel Hyperparameter Search

Searching for optimal hyperparameters for the complete DINGO-style neural posterior estimation model.

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from JHPY import generate_sine_data, DINGOModel, npe_hyperparameter_search, load_npe, infer_NPE

In [None]:
data = generate_sine_data(num_simulations=5000, num_points=1000)
train_loader = data['Train_Loader']
val_loader = data['Val_Loader']
test_loader = data['Test_Loader']

In [None]:
param_grid = {
    'data_dim': [1000],
    'param_dim': [3],
    'context_dim': [64, 128],
    'num_flow_layers': [6, 8],
    'hidden_dim': [128, 256],
    'learning_rate': [0.001, 0.0005]
}

In [None]:
best_config, results = npe_hyperparameter_search(
    param_grid,
    train_loader,
    val_loader,
    model_class=DINGOModel,
    n_epochs=15,
    n_trials=None
)

In [None]:
print("\nTop 5 Configurations:")
for i, result in enumerate(results[:5]):
    print(f"\nRank {i+1}:")
    print(f"  Log Prob: {result['best_val_log_prob']:.4f}")
    print(f"  Config:")
    for key, value in result['config'].items():
        print(f"    {key}: {value}")

In [None]:
model, checkpoint = load_npe('best_npe_model.pt', model_class=DINGOModel)

plt.figure(figsize=(10, 5))
plt.plot(checkpoint['train_log_probs'], linewidth=2, label='Train Log Probability')
plt.plot(checkpoint['val_log_probs'], linewidth=2, label='Validation Log Probability')
plt.legend(fontsize=12)
plt.title('Best Model Training Progress', fontsize=14)
plt.xlabel('Epoch')
plt.ylabel('Log Probability')
plt.grid(True, alpha=0.3)
plt.show()

In [None]:
from matplotlib.patches import Rectangle

test_frequencies = [1.0, 2.5, 4.0]
test_phases = [0.5, -0.3, 1.2]
test_amplitudes = [1.0, 0.8, 2.4]

fig, axes = plt.subplots(7, len(test_frequencies), figsize=(8*len(test_frequencies), 40))

print("Testing best model on new observations:\n")

for idx, true_freq in enumerate(test_frequencies):
    true_phase = test_phases[idx]
    true_amp = test_amplitudes[idx]
    print(f"\nTest {idx+1}: True Frequency = {true_freq}, True Phase = {true_phase:.2f}, True Amplitude = {true_amp:.2f}")
    
    from JHPY import simulate_sine_wave
    observed_data = simulate_sine_wave(true_freq, phase=true_phase, amplitude=true_amp)
    
    posterior_samples, stats = infer_NPE(model, observed_data, num_samples=5000)
    
    # Handle 2D samples [num_samples, param_dim]
    if len(posterior_samples.shape) == 1:
        # 1D case (shouldn't happen with 3 params but handle it)
        amp_samples = posterior_samples
        freq_samples = posterior_samples
        phase_samples = posterior_samples
    else:
        # Multi-dimensional case
        amp_samples = posterior_samples[:, 0]
        freq_samples = posterior_samples[:, 1]
        phase_samples = posterior_samples[:, 2]
    
    freq_stats = {
        'mean': np.mean(freq_samples),
        'median': np.median(freq_samples),
        'std': np.std(freq_samples),
        'q05': np.percentile(freq_samples, 5),
        'q95': np.percentile(freq_samples, 95),
    }
    
    phase_stats = {
        'mean': np.mean(phase_samples),
        'median': np.median(phase_samples),
        'std': np.std(phase_samples),
        'q05': np.percentile(phase_samples, 5),
        'q95': np.percentile(phase_samples, 95),
    }
    
    amp_stats = {
        'mean': np.mean(amp_samples),
        'median': np.median(amp_samples),
        'std': np.std(amp_samples),
        'q05': np.percentile(amp_samples, 5),
        'q95': np.percentile(amp_samples, 95),
    }
    
    parameters = ['amplitude', 'frequency', 'phase']
    ps = ['A', 'f', 'φ']
    samples_list = [amp_samples, freq_samples, phase_samples]
    stats_list = [amp_stats, freq_stats, phase_stats]
    true_list = [true_amp, true_freq, true_phase]

    print(f"  Frequency posterior: mean={freq_stats['mean']:.3f} ± {freq_stats['std']:.3f}")
    print(f"  Phase posterior:     mean={phase_stats['mean']:.3f} ± {phase_stats['std']:.3f}")
    print(f"  Amplitude posterior: mean={amp_stats['mean']:.3f} ± {amp_stats['std']:.3f}")
    
    t = np.linspace(0, 6*np.pi, 1000)

    axes[0, idx].plot(t, observed_data, 'b-', alpha=0.7, linewidth=1.5, label='Observed')
    axes[0, idx].plot(t, true_amp * np.sin(2*np.pi*true_freq * t + true_phase), 'r--', 
                      label=f'True (f={true_freq}, φ={true_phase:.2f}, A={true_amp:.2f})', linewidth=2)
    axes[0, idx].set_title(f'Test {idx+1}: f={true_freq}, φ={true_phase:.2f}, A={true_amp:.2f}', fontsize=12, fontweight='bold')
    axes[0, idx].set_xlabel('Time')
    axes[0, idx].set_ylabel('Value')
    axes[0, idx].legend(fontsize=8)
    axes[0, idx].grid(True, alpha=0.3)

    for idx2, _ in enumerate(parameters):
        axes[idx2 + 1, idx].hist(samples_list[idx2], bins=60, density=True, 
                      alpha=0.6, edgecolor='black', label='Posterior')
        axes[idx2 + 1, idx].axvline(true_list[idx2], color='red', linestyle='--', 
                             linewidth=2.5, label=f'True: {true_list[idx2]:.2f}', zorder=10)
        axes[idx2 + 1, idx].axvline(stats_list[idx2]['mean'], color='green', linestyle='-', 
                             linewidth=2.5, label=f"Mean: {stats_list[idx2]['mean']:.2f}", zorder=10)
        axes[idx2 + 1, idx].axvspan(stats_list[idx2]['q05'], stats_list[idx2]['q95'], alpha=0.2, color='gray', label='90% CI')
        axes[idx2 + 1, idx].set_title(f'p({parameters[idx2]} | data)', fontsize=12, fontweight='bold')
        axes[idx2 + 1, idx].set_xlabel(parameters[idx2].capitalize())
        axes[idx2 + 1, idx].set_ylabel('Density')
        axes[idx2 + 1, idx].legend(loc='upper right', fontsize=8)
        axes[idx2 + 1, idx].grid(True, alpha=0.3)

    for idx2, _ in enumerate(parameters):
        h = axes[len(parameters) + idx2 + 1, idx].hist2d(samples_list[idx2], samples_list[(idx2+1) % 3], bins=60, cmap='plasma', density=True)
        plt.colorbar(h[3], ax=axes[len(parameters) + idx2 + 1, idx], label='Probability Density')
        axes[len(parameters) + idx2 + 1, idx].scatter(true_list[idx2], true_list[(idx2+1) % 3], color='cyan', s=200, marker='x',  
                            edgecolors='white', linewidth=2, label='True values', zorder=10)
        axes[len(parameters) + idx2 + 1, idx].scatter(stats_list[idx2]['mean'], stats_list[(idx2+1) % 3]['mean'], color='lime', s=100, marker='o', 
                            linewidth=3, label='Posterior mean', zorder=10)

        rect = Rectangle((stats_list[idx2]['mean'] - stats_list[idx2]['std'], stats_list[(idx2+1) % 3]['mean'] - stats_list[(idx2+1) % 3]['std']), 
                        width=2*stats_list[idx2]['std'], height=2*stats_list[(idx2+1) % 3]['std'],
                        facecolor='yellow', edgecolor='yellow', linewidth=2, 
                        alpha=0.3, label='1σ region', zorder=5)
        
        axes[len(parameters) + idx2 + 1, idx].add_patch(rect)
        axes[len(parameters) + idx2 + 1, idx].set_xlabel(parameters[idx2], fontsize=12, fontweight='bold')
        axes[len(parameters) + idx2 + 1, idx].set_ylabel(parameters[(idx2+1) % 3], fontsize=12, fontweight='bold')
        axes[len(parameters) + idx2 + 1, idx].set_title(f'Joint Posterior p({ps[idx2]}, {ps[(idx2+1) % 3]} | data)\nTrue: {ps[idx2]}={true_list[idx2]}, {ps[(idx2+1) % 3]}={true_list[(idx2+1) % 3]}', fontsize=12, fontweight='bold')
        axes[len(parameters) + idx2 + 1, idx].legend(loc='upper right', fontsize=8)
        axes[len(parameters) + idx2 + 1, idx].grid(True, alpha=0.3)

    
plt.tight_layout()
plt.show()