# state probabilities

In [None]:
import json

import fsspec
import xarray as xr

from pangeo_fish.hmm.estimator import EagerScoreEstimator
from pangeo_fish.pdf import combine_emission_pdf

parametrize with [papermill](https://papermill.readthedocs.io/en/latest/)

In [None]:
emission_path: str
parameter_path: str

states_path: str

scheduler_address: str | None = None

dask cluster

In [None]:
from distributed import Client, LocalCluster

if scheduler_address is None:
    cluster = LocalCluster()
    client = cluster.get_client()
else:
    client = Client(scheduler_address)
client

open the emission probabilities

In [None]:
emission = (
    xr.open_dataset(emission_path, engine="zarr", chunks={}, inline_array=True)
    .pipe(combine_emission_pdf)
    .drop_vars("resolution")
)
emission

read the parameters

In [None]:
with fsspec.open(parameter_path, mode="r") as f:
    parameters = json.load(f)
parameters.pop("tolerance", None)
parameters

create the estimator

In [None]:
estimator = EagerScoreEstimator(**parameters)
estimator

compute the state probabilities

In [None]:
%%time
states = estimator.predict_proba(emission)
states

write to disk

In [None]:
%%time
states.to_zarr(states_path, mode="w", consolidated=True)

## visualization

In [None]:
import hvplot.xarray
import cmocean

In [None]:
states_path = "/home/jmagin/work/data/fish-intel/states/A18832-f1_e2500-hp4096.zarr"
states_ = xr.open_dataset(states_path, engine="zarr", chunks={})
states_

In [None]:
states_.states.count(dim=["x", "y"]).plot()