# optimization using `optuna`

imports

In [None]:
import json

import fsspec
import optuna
import xarray as xr

from toolz.functoolz import curry

In [None]:
from pangeo_fish.hmm.estimator import EagerScoreEstimator
from pangeo_fish.pdf import combine_emission_pdf
from pangeo_fish.hmm.optimize.optuna import get_client, isolated_clients

parametrize with [papermill](https://papermill.readthedocs.io/en/latest/)

In [None]:
input_path: str
output_path: str

tolerance: float = 1e-2
n_trials: int = 64
n_jobs: int = 8

## open the data

In [None]:
data = xr.open_dataset(
    input_path, engine="zarr", chunks={"x": -1, "y": -1}, inline_array=True
).pipe(combine_emission_pdf)
data

## select the estimator

In [None]:
estimator = EagerScoreEstimator()

## prepare the optimization

To find the most optimal parameter, we follow the [Parallel hyper-parameter optimization of XGBoost with Optuna and Dask (multiple clusters)](https://github.com/coiled/dask-xgboost-nyctaxi/blob/main/Modeling%203%20-%20Parallel%20HPO%20of%20XGBoost%20with%20Optuna%20and%20Dask%20(multi%20cluster).ipynb) notebook.

This will use `optuna` to find the actual parameter, but have it use multiple threads where each thread gets its own `distributed` cluster.

In [None]:
def objective(trial, estimator, data):
    client = get_client()

    params = {"sigma": trial.suggest_float("sigma", 1e-5, data.attrs["max_sigma"])}

    with client.as_current():
        return estimator.set_params(**params).score(data).item()

## execute the optimization

In [None]:
%%time
study = optuna.create_study(study_name="parallel-pangeo-fish")

with isolated_clients():
    study.optimize(
        curry(objective)(estimator=estimator, data=data),
        n_trials=n_trials,
        n_jobs=n_jobs,
    )

study.best_params

In [None]:
optuna.visualization.plot_optimization_history(study)

In [None]:
optimized = estimator.set_params(**study.best_params)
optimized

## store the optimized parameters to disk

In [None]:
params = optimized.to_dict()
with fsspec.open(output_path, mode="w") as f:
    json.dump(params, f)