# state probabilities

In [None]:
import json

import fsspec
import xarray as xr

from pangeo_fish.hmm.estimator import EagerScoreEstimator
from pangeo_fish.pdf import combine_emission_pdf

parametrize with [papermill](https://papermill.readthedocs.io/en/latest/)

In [None]:

# scheduler_address: str | None = None
tag_name: str = "A18832"
working_path: str = "/home/datawork-taos-s/public/fish/"
# working_path: str | "/Users/todaka/python/git/pangeo-fish/data_local/fish-intel/"
ref_model_name: str = "copernicus"

nside: int = 4096  # healpix resolution

# we can set the parameter acoustic to use acoustic information or not
# acoustic: str = ""
acoustic: str = "/acoustic"

# cluster_size: int = 1
# This step is important to chose cluster_size as one.
cluster_size: int = 1

In [None]:

domainname=!domainname

if domainname == ["nisdatarmor"]:
    # Datarmor
    tag_base_path = "/home/datawork-lops-iaocea/data/fish-intel/"
    catalog = "/home/datawork-taos-s/intranet/kerchunk/ref-copernicus.yaml"
# This step, we use a local cluster.  
    cluster_name="datarmor"
    #cluster_name="datarmor-local"

else:
    # local PC
    tag_base_path: str = "/Users/todaka/python/git/pangeo-fish/data_local/fish-intel/"
    catalog = "https://data-taos.ifremer.fr/kerchunk/ref-copernicus.yaml"
    cluster_name="local"

tag_url = tag_base_path + "tag/nc/" + tag_name + ".nc"
tag_db_path = tag_base_path + "acoustic/FishIntel_tagging_France.csv"
detections_path = tag_base_path + "/acoustic/detections_recaptured_fishintel.csv"

input_path = working_path + tag_name + "/" + ref_model_name + acoustic + "/emission_"+ str(nside) +".zarr"
parameter_path = working_path + tag_name + "/" + ref_model_name + acoustic + "/sigma_"+ str(nside) +".json"
output_path = working_path + tag_name + "/" + ref_model_name + acoustic + "/state_"+ str(nside) +".json"

### Set up Dask



In [None]:
import dask_hpcconfig
from distributed import Client
if domainname == ["nisdatarmor"]:
    overrides = {}
    # overrides = { "cluster.cores": 28 , "cluster.processes": 6 }    
    cluster = dask_hpcconfig.cluster(cluster_name, **overrides)
    cluster.scale(cluster_size)
else:
    cluster = dask_hpcconfig.cluster("local")

client = Client(cluster)
client

open the emission probabilities

In [None]:
emission = (
    xr.open_dataset(input_path, engine="zarr", chunks={}, inline_array=True)
    .pipe(combine_emission_pdf)
    .drop_vars("resolution")
)
emission

read the parameters

In [None]:
with fsspec.open(parameter_path, mode="r") as f:
    parameters = json.load(f)
    parameters.pop('tolerance')
parameters

create the estimator

In [None]:
estimator = EagerScoreEstimator(**parameters)
estimator

compute the state probabilities

In [None]:
%%time
states = estimator.predict_proba(emission)
states

write to disk

In [None]:
%%time
states.to_zarr(output_path, mode="w", consolidated=True)

## visualization

In [None]:
states_ = xr.open_zarr(output_path)
states_

In [None]:
states_["states"].isel(time=5).plot(x="longitude", y="latitude")