# optimization using `optuna`

In [None]:
import fsspec
import xarray as xr

In [None]:
import cmocean
import holoviews as hv
import hvplot.pandas
import hvplot.xarray
import matplotlib

hv.output(widget_location="bottom")
if not hasattr(matplotlib.cm, "_cmap_registry"):
    matplotlib.cm._cmap_registry = matplotlib.cm._colormaps

In [None]:
root = "file:///home/jmagin/work/data/fish-intel"

name = "A18832-f1_e2500-4096"

In [None]:
import logging


def create_default_formatter() -> logging.Formatter:
    """Create a default formatter of log messages.
    This function is not supposed to be directly accessed by library users.
    """
    header = "[%(levelname)1.1s %(asctime)s]"
    message = "%(message)s"
    return logging.Formatter(f"{header} {message}")


def setup_logging():
    logger = logging.getLogger(__name__)

    consoleHandler = logging.StreamHandler()
    formatter = create_default_formatter()
    consoleHandler.setFormatter(formatter)
    consoleHandler.setLevel(logging.DEBUG)
    logger.addHandler(consoleHandler)
    logger.setLevel(logging.INFO)

    return logger


logger = setup_logging()

## the actual optimization

To find the most optimal parameter, we follow the [Parallel hyper-parameter optimization of XGBoost with Optuna and Dask (multiple clusters)](https://github.com/coiled/dask-xgboost-nyctaxi/blob/main/Modeling%203%20-%20Parallel%20HPO%20of%20XGBoost%20with%20Optuna%20and%20Dask%20(multi%20cluster).ipynb) notebook.

This will use `optuna` to find the actual parameter, but have it use multiple threads where each thread gets its own `distributed` cluster.

In [None]:
import threading
import warnings

import dask
import numpy as np
import optuna
from distributed import Client, LocalCluster
from toolz.functoolz import curry

from pangeo_fish.pdf import combine_emission_pdf
from pangeo_fish.hmm.estimator import EagerScoreEstimator

In [None]:
cluster = LocalCluster(n_workers=1)
client = cluster.get_client()
client

In [None]:
def fix_encoding(ds):
    out = ds.copy()

    for var in out.variables.values():
        var.encoding.pop("preferred_chunks")
        var.encoding.pop("chunks")

    return out

In [None]:
clients = {}


def get_client():
    thread_id = threading.get_ident()

    try:
        return clients[thread_id]
    except KeyError:
        pass

    with warnings.catch_warnings():
        warnings.filterwarnings(
            "ignore",
            category=UserWarning,
            module="distributed",
            message=".*Port 8787 is already in use.",
        )

        cluster = LocalCluster(n_workers=1, memory_limit="2GB")
        logger.info(f"opened cluster dashboard at: {cluster.dashboard_link}")
    client = Client(cluster, set_as_default=False)

    clients[thread_id] = client

    return client


def objective(trial, path):
    client = get_client()
    data = (
        xr.open_dataset(
            path, engine="zarr", chunks={"x": -1, "y": -1, "time": 1}, inline_array=True
        )
        .pipe(fix_encoding)
        .pipe(combine_emission_pdf)
    )

    params = {"sigma": trial.suggest_float("sigma", 1e-5, data.attrs["sigma_max"])}

    estimator = EagerScoreEstimator()

    def f(x):
        time.sleep(10)

        return x**2

    with client.as_current():
        return estimator.set_params(**params).score(data).item()

In [None]:
import time
from contextlib import contextmanager


@contextmanager
def isolated_clients():
    global clients

    backup = clients

    try:
        clients = {}
        yield
    finally:
        for thread_id, client in clients.items():
            # make sure we don't cancel anything
            while [_ for _ in client.processing().values() if _]:
                time.sleep(2)
            client.shutdown()
            client.close()

        clusters = backup

optimization configuration

In [None]:
N_TRIALS = 64
N_JOBS = 8

## just the tag log

In [None]:
%%time
path = f"{root}/emission/{name}.zarr"

study = optuna.create_study(study_name="parallel-pangeo-fish-tag_log")

with isolated_clients():
    study.optimize(curry(objective)(path=path), n_trials=N_TRIALS, n_jobs=N_JOBS)

study.best_params

In [None]:
optuna.visualization.plot_optimization_history(study)

In [None]:
%%time
estimator = EagerScoreEstimator(**study.best_params)

data = (
    xr.open_dataset(path, engine="zarr", chunks={"x": -1, "y": -1}, inline_array=True)
    .pipe(fix_encoding)
    .pipe(combine_emission_pdf)
)

state_probabilities = (
    estimator.predict_proba(data)
    .pipe(lambda ds: dask.optimize(ds)[0])
    .to_dataset(name="states")
    .assign_attrs(study.best_params)
)
state_probabilities

In [None]:
%%time
outpath = f"{root}/state/{name}.zarr"
state_probabilities.to_zarr(outpath, mode="w", consolidated=True, compute=True)

## tag log and acoustic detections

In [None]:
%%time
path = f"{root}/emission/{name}-acoustic.zarr"

study = optuna.create_study(study_name="parallel-pangeo-fish-tag_log+acoustic")

with isolated_clients():
    study.optimize(curry(objective)(path=path), n_trials=N_TRIALS, n_jobs=N_JOBS)

study.best_params

In [None]:
optuna.visualization.plot_optimization_history(study)

In [None]:
%%time
estimator = EagerScoreEstimator(**study.best_params)

data = (
    xr.open_dataset(path, engine="zarr", chunks={"x": -1, "y": -1}, inline_array=True)
    .pipe(fix_encoding)
    .pipe(combine_emission_pdf)
)

state_probabilities = (
    estimator.predict_proba(data)
    .pipe(lambda ds: dask.optimize(ds)[0])
    .to_dataset(name="states")
    .assign_attrs(study.best_params)
)
state_probabilities

In [None]:
%%time
outpath = f"{root}/state/{name}-acoustic.zarr"
state_probabilities.to_zarr(outpath, mode="w", consolidated=True, compute=True)

## plot the result

In [None]:
tag_log = xr.open_dataset(f"{root}/state/{name}.zarr", engine="zarr", chunks={})
acoustic = xr.open_dataset(
    f"{root}/state/{name}-acoustic.zarr", engine="zarr", chunks={}
)

In [None]:
acoustic

In [None]:
plot1 = tag_log.states.hvplot.quadmesh(
    x="longitude",
    y="latitude",
    rasterize=True,
    coastline="10m",
    geo=True,
    cmap="cmo.amp",
    title=f"tag log – sigma = {tag_log.attrs['sigma']:.4f}",
).opts(frame_width=500)
plot2 = acoustic.states.hvplot.quadmesh(
    x="longitude",
    y="latitude",
    rasterize=True,
    coastline="10m",
    geo=True,
    cmap="cmo.amp",
    title=f"tag log + acoustic detections – sigma = {acoustic.attrs['sigma']:.4f}",
).opts(frame_width=500)
plot1 + plot2