# ParameterPredictor Hyperparameter Search

Searching for optimal LSTM predictor architecture for direct parameter estimation.

In [None]:
import torch
import numpy as np
from JHPY import generate_sine_data, ParameterPredictor, predictor_hyperparameter_search, load_predictor

In [None]:
data = generate_sine_data(num_simulations=5000, num_points=1000)
train_loader = data['Train_Loader']
val_loader = data['Val_Loader']
test_loader = data['Test_Loader']

In [None]:
param_grid = {
    'lstm_hidden_size': [128, 256],
    'lstm_num_layers': [1, 2],
    'fc_layer_sizes': [[64, 32], [128, 64, 32]],
    'activation': ['silu', 'relu'],
    'dropout': [0.0, 0.1],
    'learning_rate': [0.01, 0.005]
}

best_config, results = predictor_hyperparameter_search(
    param_grid,
    train_loader,
    val_loader,
    n_epochs=15,
    n_trials=8,
    model_path='best_predictor_model.pt'
)

In [None]:
best_config, results = predictor_hyperparameter_search(
    param_grid,
    train_loader,
    val_loader,
    n_epochs=15,
    n_trials=8
)

In [None]:
print("\nTop 5 Configurations:")
for i, result in enumerate(results[:5]):
    print(f"\nRank {i+1}:")
    print(f"  Best Val Loss: {result['best_val_loss']:.6f}")
    print(f"  Best Val MAE: {result['best_val_mae']:.6f}")
    print(f"  Best Val RMSE: {result['best_val_rmse']:.6f}")
    print(f"  Best Val R²: {result['best_val_r2']:.4f}")
    print(f"  Config:")
    for key, value in result['config'].items():
        print(f"    {key}: {value}")

In [None]:
model, checkpoint = load_predictor('best_predictor_model.pt')

In [None]:
test_data_batch = next(iter(test_loader))
X_test, y_test = test_data_batch

with torch.no_grad():
    predictions = model(X_test)

from JHPY import calculate_metrics
metrics = calculate_metrics(predictions.numpy(), y_test.numpy())

print(f"Test Metrics:")
print(f"  MAE: {metrics['mae']:.6f}")
print(f"  RMSE: {metrics['rmse']:.6f}")
print(f"  R²: {metrics['r2']:.4f}")

In [None]:
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 3, figsize=(15, 4))

param_names = ['Amplitude', 'Frequency', 'Phase']
for i in range(3):
    axes[i].scatter(y_test[:, i].numpy(), predictions[:, i].numpy(), alpha=0.5)
    axes[i].plot([y_test[:, i].min(), y_test[:, i].max()], 
                 [y_test[:, i].min(), y_test[:, i].max()], 'r--', lw=2)
    axes[i].set_xlabel('True')
    axes[i].set_ylabel('Predicted')
    axes[i].set_title(f'{param_names[i]} Predictions')
    axes[i].grid(alpha=0.3)

plt.tight_layout()
plt.show()