# optimization using `optuna`

imports

In [None]:
import json
import threading
import time
import warnings
from contextlib import contextmanager

import dask
import fsspec
import numpy as np
import optuna
import xarray as xr
from distributed import Client, LocalCluster
from toolz.functoolz import curry

In [None]:
from pangeo_fish.hmm.estimator import EagerScoreEstimator
from pangeo_fish.pdf import combine_emission_pdf

parametrize with [papermill](https://papermill.readthedocs.io/en/latest/)

In [None]:
input_path: str
output_path: str

tolerance: float = 1e-2
n_trials: int = 64
n_jobs: int = 8

scheduler_address: str | None = None

## open the data

In [None]:
data = xr.open_dataset(
    input_path, engine="zarr", chunks={"x": -1, "y": -1}, inline_array=True
).pipe(combine_emission_pdf)
data

## verify the data

In [None]:
import hvplot.xarray

In [None]:
data.pdf.count(["x", "y"]).hvplot(title="count of valid values")

## select the estimator

In [None]:
estimator = EagerScoreEstimator()

## prepare the optimization

In [None]:
import logging


def create_default_formatter() -> logging.Formatter:
    """Create a default formatter of log messages.
    This function is not supposed to be directly accessed by library users.
    """
    header = "[%(levelname)1.1s %(asctime)s]"
    message = "%(message)s"
    return logging.Formatter(f"{header} {message}")


def setup_logging():
    logger = logging.getLogger(__name__)

    consoleHandler = logging.StreamHandler()
    formatter = create_default_formatter()
    consoleHandler.setFormatter(formatter)
    consoleHandler.setLevel(logging.DEBUG)
    logger.addHandler(consoleHandler)
    logger.setLevel(logging.INFO)

    return logger


logger = setup_logging()

To find the most optimal parameter, we follow the [Parallel hyper-parameter optimization of XGBoost with Optuna and Dask (multiple clusters)](https://github.com/coiled/dask-xgboost-nyctaxi/blob/main/Modeling%203%20-%20Parallel%20HPO%20of%20XGBoost%20with%20Optuna%20and%20Dask%20(multi%20cluster).ipynb) notebook.

This will use `optuna` to find the actual parameter, but have it use multiple threads where each thread gets its own `distributed` cluster.

In [None]:
clients = {}


def get_client():
    thread_id = threading.get_ident()

    try:
        return clients[thread_id]
    except KeyError:
        pass

    with warnings.catch_warnings():
        warnings.filterwarnings(
            "ignore",
            category=UserWarning,
            module="distributed",
            message=".*Port 8787 is already in use.",
        )

        cluster = LocalCluster(n_workers=1, memory_limit="2GB")
        logger.info(f"opened cluster dashboard at: {cluster.dashboard_link}")
    client = Client(cluster, set_as_default=False)

    clients[thread_id] = client

    return client


def objective(trial, estimator, data):
    client = get_client()

    params = {"sigma": trial.suggest_float("sigma", 1e-5, data.attrs["sigma_max"])}

    with client.as_current():
        return estimator.set_params(**params).score(data).item()

In [None]:
@contextmanager
def isolated_clients():
    global clients

    backup = clients

    try:
        clients = {}
        yield
    finally:
        for thread_id, client in clients.items():
            # make sure we don't cancel anything
            while [_ for _ in client.processing().values() if _]:
                time.sleep(2)
            client.shutdown()
            client.close()

        clusters = backup

## execute the optimization

In [None]:
%%time
path = f"{root}/emission/{name}.zarr"

study = optuna.create_study(study_name="parallel-pangeo-fish-tag_log")

with isolated_clients():
    study.optimize(
        curry(objective)(estimator=estimator, data=data),
        n_trials=n_trials,
        n_jobs=n_jobs,
    )

study.best_params

In [None]:
optuna.visualization.plot_optimization_history(study)

In [None]:
optimized = estimator.set_params(**study.best_params)
optimized

## store the optimized parameters to disk

In [None]:
params = optimized.to_dict()
with fsspec.open(output_path, mode="w") as f:
    json.dump(params, f)