# optimization using `optuna`

In [None]:
import fsspec
import xarray as xr

In [None]:
import cmocean
import holoviews as hv
import hvplot.pandas
import hvplot.xarray
import matplotlib

hv.output(widget_location="bottom")
if not hasattr(matplotlib.cm, "_cmap_registry"):
    matplotlib.cm._cmap_registry = matplotlib.cm._colormaps

In [None]:
root = "file:///home/jmagin/work/data/fish-intel"

name = "A18832-f1_e2500-4096"

## the actual optimization

For the optimization, we use the `EagerBoundsOptimizer`:

In [None]:
import dask
from distributed import LocalCluster

In [None]:
cluster = LocalCluster(n_workers=1)
client = cluster.get_client()
client

In [None]:
def fix_encoding(ds):
    out = ds.copy()

    for var in out.variables.values():
        var.encoding.pop("preferred_chunks")
        var.encoding.pop("chunks")

    return out

In [None]:
from pangeo_fish.hmm.estimator import EagerScoreEstimator
from pangeo_fish.hmm.optimize import EagerBoundsSearch
from pangeo_fish.pdf import combine_emission_pdf

## just the tag log

In [None]:
%%time
path = f"{root}/emission/{name}.zarr"

data = (
    xr.open_dataset(path, engine="zarr", chunks={"x": -1, "y": -1}, inline_array=True)
    .pipe(fix_encoding)
    .pipe(combine_emission_pdf)
)
data

In [None]:
estimator = EagerScoreEstimator()
optimizer = EagerBoundsSearch(
    estimator,
    (1e-4, data.attrs["max_sigma"]),
    optimizer_kwargs={"disp": 3, "xtol": 1e-2},
)
optimized = optimizer.fit(data)
optimized

In [None]:
%%time
state_probabilities = (
    optimized.predict_proba(data)
    .pipe(lambda ds: dask.optimize(ds)[0])
    .to_dataset(name="states")
    .assign_attrs(sigma=optimized.sigma)
)
state_probabilities

In [None]:
%%time
outpath = f"{root}/state/{name}-scipy.zarr"
state_probabilities.to_zarr(outpath, mode="w", consolidated=True, compute=True)

## tag log and acoustic detections

In [None]:
%%time
path = f"{root}/emission/{name}-acoustic.zarr"

data = (
    xr.open_dataset(path, engine="zarr", chunks={"x": -1, "y": -1}, inline_array=True)
    .pipe(fix_encoding)
    .pipe(combine_emission_pdf)
)

estimator = EagerScoreEstimator()
optimizer = EagerBoundsSearch(
    estimator,
    (1e-4, data.attrs["sigma_max"]),
    optimizer_kwargs={"disp": 3, "xtol": 1e-2},
)
optimized = optimizer.fit(data)
optimized

In [None]:
%%time
state_probabilities = (
    optimized.predict_proba(data)
    .pipe(lambda ds: dask.optimize(ds)[0])
    .to_dataset(name="states")
    .assign_attrs(sigma=optimized.sigma)
)
state_probabilities

In [None]:
%%time
outpath = f"{root}/state/{name}-acoustic-scipy.zarr"
state_probabilities.to_zarr(outpath, mode="w", consolidated=True, compute=True)

## plot the result

In [None]:
tag_log = xr.open_dataset(f"{root}/state/{name}-scipy.zarr", engine="zarr", chunks={})
acoustic = xr.open_dataset(
    f"{root}/state/{name}-acoustic-scipy.zarr", engine="zarr", chunks={}
)

In [None]:
plot1 = tag_log.states.hvplot.quadmesh(
    x="longitude",
    y="latitude",
    rasterize=True,
    coastline="10m",
    geo=True,
    cmap="cmo.amp",
    title=f"tag log – sigma = {tag_log.attrs['sigma']:.4f}",
)
plot2 = acoustic.states.hvplot.quadmesh(
    x="longitude",
    y="latitude",
    rasterize=True,
    coastline="10m",
    geo=True,
    cmap="cmo.amp",
    title=f"tag log + acoustic detections – sigma = {acoustic.attrs['sigma']:.4f}",
)
plot1 + plot2