### Tutorial: how to use `pangeo-fish`


**Overview.**

This Jupyter notebook demonstrates how to use `pangeo-fish`.

Specifically, we will fit the geolocation on the data from the study conducted by M. Gonze et al. titled "Combining acoustic telemetry with archival tagging to investigate the spatial dynamics of the understudied pollack *Pollachius pollachius*", accepted for publication in the Journal of Fish Biology.

We will use the biologging tag "A19124", which was attached to pollack fish.

As for the reference Earth Observation (EO) data, we consider the European Union Copernicus Marine Service Information (CMEMS) product "NORTHWESTSHELF_ANALYSIS_FORECAST_PHY_004_013".

_NB: In addition to the Data Storage Tag (DST), the biologging data includes **teledetection by acoustic signals**, as well as the release and recapture/death information of the fish._

Both the reference EO and the biologging data are publicly available, and the computations should be tractable for most standard laptops.

**Workflow.**

Let's first summarize the key steps for running the geolocation:

1. **Define the configuration:** define the required parameters for the analysis.
2. **Compare the reference data with the DST information:** compare the data from the reference model with the biologging data. 
3. **Regrid the comparison to HEALPix:** translate the comparison into a HEALPix grid to avoid spatial distortion.
4. **Construct the temporal emission matrix:** create a temporal emission probability distribution (_pdf_) from the transformed grid.
5. **Construct another emission matrix with the acoustic detections:** calculate a similar model to the previous one, using this time the acoustic teledetections.
6. **Combine and normalize the matrices:** merge and normalize the two _pdfs_.
7. **Estimate (or _fit_) the geolocation model:** determine the parameters of the model based on the normalized emission matrix.
8. **Compute the state probabilities and generate trajectories:** compute the fish's location probability distribution and generate subsequent trajectories.
9. **Visualization:** visualize the evolution of the spatial probabilities over time and export the video.

Throughout this tutorial, you will gain practical experience in setting up and executing a typical workflow using `pangeo-fish` such that you can then apply the tool with your use-case study.

## 1. Initialization and configuration definition

In this step, we prepare the execution of the analysis.
It includes:
- Installing the necessary packages.
- Importing the required libraries.
- Defining the parameters for the next stages of the workflow.
- Configuring the cluster for distributed computing.
    

In [None]:
tag_name = "208992_argos"

In [None]:
import sys

import hvplot.xarray
import xarray as xr
from pint_xarray import unit_registry as ureg

sys.path.append("../")
import pangeo_fish

In [None]:
# tag_name corresponds to the name of the biologging tag name (DST identification number),
# which is also a path for storing all the information for the specific fish tagged with tag_name.

# tag_save = "228035"
# tag_root specifies228036 the root URL for tag data used for this computation.
tag_root = "s3://gfts-ifremer/fra_sodeika/tags/formatted"

# ref_url is the path to the reference model

# Liste des tags de 2021
tags_2021 = [
    "203226_argos",
    "208991_argos",
    "208992_argos",
    "208992_archival",
    "208993_argos",
]

if tag_name in tags_2021:
    ref_model_file = "~/ECAP_FORK/light_pdf/ecap_jmarc_try/okuyama/tags/copernicus_ref_model/copernicus_jpn_daily_feb_mar_2021.zarr/"
else:
    ref_model_file = "~/ECAP_FORK/light_pdf/ecap_jmarc_try/okuyama/tags/copernicus_ref_model/copernicus_jpn_daily_20220105.zarr/"
# scratch_root specifies the root directory for storing output files.
# storage_options specifies options for the filesystem storing output files.

## example for remote storage
# scratch_root = "s3://gfts-ifremer/fra_sodeika/run/capetienne/papermill_copernicus/same_bbox"
scratch_root = "s3://gfts-ifremer/fra_sodeika/run/capetienne/papermill_copernicus/same_bbox/foscat_daily/variable_std"
storage_options = {
    "anon": False,
    "profile": "gfts",
    "client_kwargs": {
        "endpoint_url": "https://s3.gra.perf.cloud.ovh.net",
        "region_name": "gra",
    },
}
## example for using your local file system instead
# scratch_root = "."
# storage_options = None

# Default chunk value for time dimension.  This values depends on the configuration of your dask cluster.
chunk_time = 1

# Either to use a HEALPix grid (["cells"]) or a 2D grid (["x", "y"])
dims = ["cells"]

# bbox, bounding box, defines the latitude and longitude range for the analysis area.
bbox = {"latitude": [16, 25], "longitude": [117, 130]}
bbox = {"latitude": [20, 30], "longitude": [117, 135]}
bbox = {"latitude": [16, 35], "longitude": [120, 145]}

# bbox = {"latitude": [20, 35], "longitude": [125, 140]}

# relative_depth_threshold defines the acceptable fish depth relative to the maximum tag depth.
# It determines whether the fish can be considered to be in a certain location based on depth.
relative_depth_threshold = 0

# optional rotation for the HEALPix grid
rot = {"lat": 0, "lon": 0}
# nside defines the resolution of the healpix grid used for regridding.
nside = 1024
refinement_level = 10
# min_vertices sets the minimum number of vertices for a valid transcription for regridding.
min_vertices = 1

# differences_std sets the standard deviation for scipy.stats.norm.pdf.
# It expresses the estimated certainty of the field of difference.
differences_std = 1.0
# initial_std sets the covariance for initial event.
# It shows the certainty of the initial area.
initial_std = 1e-4
# recapture_std sets the covariance for recapture event.
# It shows the certainty of the final recapture area if it is known.
recapture_std = 1e-4
# earth_radius defines the radius of the Earth used for distance calculations.
earth_radius = ureg.Quantity(6371, "km")
# maximum_speed sets the maximum allowable speed for the tagged fish.
maximum_speed = ureg.Quantity(120, "km / day")
# adjustment_factor adjusts parameters for a more fuzzy search.
# It will factor the allowed maximum displacement of the fish.
adjustment_factor = 5
# truncate sets the truncating factor for computed maximum allowed sigma for convolution process.
truncate = 4

# tolerance describes the tolerance level of the search during the fitting/optimization of the geolocation.
# Smaller values will make the optimization iterate more
tolerance = 1e-3 if dims == ["x", "y"] else 1e-4

# track_modes defines the modes for generating fish's trajectories.
track_modes = ["mean", "mode"]

# additional_track_quantities sets quantities to compute for tracks using moving pandas.
additional_track_quantities = ["speed", "distance"]


# time_step defines the time interval between each frame of the visualization
time_step = 3

In [None]:
# Define target root directories for storing analysis results.
target_root = f"{scratch_root}/{tag_name}"

# Defines default chunk size for optimization.
default_chunk = {"time": chunk_time, "lat": -1, "lon": -1}
default_chunk_dims = {"time": chunk_time}
default_chunk_dims.update({d: -1 for d in dims})

In [None]:
# Set up a local cluster for distributed computing.
from distributed import LocalCluster

cluster = LocalCluster()
client = cluster.get_client()
client

Now that everything is set up, we can start by loading the biologging data (or _tag_)

In [None]:
from pangeo_fish.helpers import load_tag

tag, tag_log, time_slice = load_tag(
    tag_root=tag_root, tag_name=tag_name, storage_options=storage_options
)
tag

You can plot the time series of the DST with the function `plot_tag()`:

In [None]:
from pangeo_fish.helpers import plot_tag

plot = plot_tag(
    tag=tag,
    tag_log=tag_log,
    # you can directly save the plot if you want
    save_html=True,
    storage_options=storage_options,
    target_root=target_root,
)
plot

## 2. Compare the reference data with the DST logs

In this step, we compare the reference model data with Data Storage Tag information.
The process involves reading and cleaning the reference model, aligning time, converting depth units and subtracting the tag data from the model.
We also illustrate how to plot and saving the result.

In [None]:
import xarray as xr
from pangeo_fish.helpers import _open_parquet_model, load_model, prepare_dataset

ref_ds = xr.open_dataset(
    ref_model_file,
    engine="zarr",
    chunks={},
    storage_options=None,
)


model = prepare_dataset(ref_ds)
model

In [None]:
from pangeo_fish.cf import bounds_to_bins
from pangeo_fish.tags import adapt_model_time, reshape_by_bins, to_time_slice

reference_model = (
    model.sel(time=adapt_model_time(time_slice))
    .sel(lat=slice(*bbox["latitude"]), lon=slice(*bbox["longitude"]))
    .pipe(
        lambda ds: ds.sel(
            depth=slice(None, (tag_log["pressure"].max() - ds["XE"].min()).compute())
        )
    )
).chunk({"time": chunk_time, "lat": -1, "lon": -1, "depth": -1})
reference_model

In [None]:
import pandas as pd
from pangeo_fish.cf import bounds_to_bins
from pangeo_fish.tags import adapt_model_time, reshape_by_bins, to_time_slice

start = pd.Timestamp(tag_log["time"].min().item())
end = pd.Timestamp(tag_log["time"].max().item())

extended_time_slice = slice(start - pd.Timedelta("1D"), end)

reference_model = (
    model.sel(time=adapt_model_time(extended_time_slice))
    .sel(lat=slice(*bbox["latitude"]), lon=slice(*bbox["longitude"]))
    .pipe(
        lambda ds: ds.sel(
            depth=slice(None, (tag_log["pressure"].max() - ds["XE"].min()).compute())
        )
    )
).chunk({"time": chunk_time, "lat": -1, "lon": -1, "depth": -1})

reference_model

In [None]:
%%time
# Reshape the tag log, so that it bins to the time step of reference_model
reshaped_tag = reshape_by_bins(
    tag_log,
    dim="time",
    bins=(
        reference_model.cf.add_bounds(["time"], output_dim="bounds")
        .pipe(bounds_to_bins, bounds_dim="bounds")
        .get("time_bins")
    ),
    other_dim="obs",
).chunk({"time": chunk_time})

reshaped_tag

#### VARIANCE

In [None]:
import numpy as np
import xarray as xr


def compute_hist2d(results, nbins_diff=50, nbins_depth=1000, max_depth=None):
    all_diff = np.concatenate([r["diff"] for r in results]) if results else np.array([])
    all_pres = np.concatenate([r["pres"] for r in results]) if results else np.array([])

    mask = np.isfinite(all_diff) & np.isfinite(all_pres)
    all_diff, all_pres = all_diff[mask], all_pres[mask]

    if all_diff.size == 0:
        return None, None, None

    # Depth limit
    if max_depth is not None:
        depth_mask = all_pres <= max_depth
        all_diff, all_pres = all_diff[depth_mask], all_pres[depth_mask]

    dmin, dmax = np.nanpercentile(all_diff, [1, 99])
    dmax_abs = max(abs(dmin), abs(dmax))
    diff_bins = np.linspace(-dmax_abs, dmax_abs, nbins_diff + 1)
    depth_bins = np.linspace(
        0,
        max_depth if max_depth is not None else np.nanpercentile(all_pres, 98),
        nbins_depth + 1,
    )

    counts, xedges, yedges = np.histogram2d(
        all_diff, all_pres, bins=[diff_bins, depth_bins]
    )
    return counts, xedges, yedges


def variance_by_depth(counts, xedges):
    """
    Computes the weighted variance of ΔT for each row (depth)
    from the 2D histogram.
    """
    xcenters = 0.5 * (xedges[:-1] + xedges[1:])

    weights_sum = counts.sum(axis=0)

    weights_sum[weights_sum == 0] = np.nan

    mean = np.nansum(counts * xcenters[:, None], axis=0) / weights_sum

    var = np.nansum(counts * (xcenters[:, None] - mean) ** 2, axis=0) / weights_sum

    var_norm = var / np.nanmax(var)

    return var, var_norm

In [None]:
import pickle
import numpy as np
import matplotlib.pyplot as plt


with open("../comp_argos_model_files/results_copernicus.pkl", "rb") as f:
    results_jamstec = pickle.load(f)


counts_j, xedges, yedges = compute_hist2d(results_copernicus, max_depth=2000)

var_j, var_j_norm = variance_by_depth(counts_j, xedges)

In [None]:
from scipy.interpolate import interp1d
import numpy as np


def variance_interp_function(yedges, var):
    depth_centers = 0.5 * (yedges[:-1] + yedges[1:])
    f_var = interp1d(
        depth_centers, var, bounds_error=False, fill_value=(var[0], var[-1])
    )
    return f_var


def std_interp_function(yedges, var):
    depth_centers = 0.5 * (yedges[:-1] + yedges[1:])

    std = np.sqrt(var)

    f_std = interp1d(
        depth_centers, std, bounds_error=False, fill_value=(std[0], std[-1])
    )

    return f_std


def temperature_interp_function(model_depth, model_temp):

    f_temp = interp1d(
        model_depth,
        model_temp,
        bounds_error=False,
        fill_value=(model_temp[0], model_temp[-1]),
    )
    return f_temp

In [None]:
yedges

In [None]:
f_var = variance_interp_function(yedges, var_j)
f_std = std_interp_function(yedges, var_j)

In [None]:
len(var_j)

In [None]:
import numpy as np
import matplotlib.pyplot as plt

# f_var déjà défini via variance_interp_function(yedges, var_j)
depths = np.linspace(0, 2000, 500)
var_profile = f_var(depths)

plt.figure(figsize=(6, 4))
plt.plot(var_profile, depths, color="tab:blue")
plt.gca().invert_yaxis()
plt.xlabel("Variance interpolée")
plt.ylabel("Profondeur (m)")
plt.title("Profil de variance interpolée selon la profondeur")
plt.grid(True, linestyle="--", alpha=0.4)
plt.tight_layout()
plt.show()

In [None]:
from pangeo_fish.diff import diff_z_var

diff = diff_z_var(
    reference_model,
    reshaped_tag,
    var_depth=(0.5 * (yedges[:-1] + yedges[1:])),
    var_values=var_j,
).assign(
    {
        "H0": reference_model["H0"],
        "XE": reference_model["XE"],
        "ocean_mask": reference_model["H0"].notnull(),
    }
)

In [None]:
%%time
diff = diff.compute()

In [None]:
diff

### end test 

In [None]:
diff["diff"].count(["lat", "lon"]).plot()
diff

In [None]:
diff.compute().to_zarr(
    f"{target_root}/diff.zarr",
    mode="w",
    storage_options=storage_options,
    zarr_version=2,
)

del diff

## 3. HEALPix regridding

In this step, we regrid the data from above to HEALPix coordinates. 

This is a complex process, composed of several steps such as defining the HEALPix grid, creating the target grid and computing interpolation weights

Fortunately though, `pangeo-fish` embarks high-level functions to do the work for us!

In [None]:
from pangeo_fish.helpers import open_diff_dataset, regrid_dataset

# Open the previous dataset (only necessary if you resume the notebook from here)
diff = open_diff_dataset(target_root=target_root, storage_options=storage_options)
diff

In [None]:
reshaped = regrid_dataset(
    ds=diff, refinement_level=10, min_vertices=min_vertices, rot=rot, dims=dims
)[0]
reshaped

Let's plot the same chart as before to check that the HEALPix regridding hasn't changed the data

In [None]:
reshaped["diff"].count(dims).plot()

In [None]:
reshaped["diff"].compute().dggs.decode(
    {"grid_name": "healpix", "level": 10, "indexing_scheme": "nested"}
).dggs.explore(alpha=0.8)

In [None]:
# Saves the result
reshaped.chunk(default_chunk_dims).to_zarr(
    f"{target_root}/diff-regridded.zarr",
    mode="w",
    consolidated=True,
    compute=True,
    storage_options=storage_options,
    zarr_version=2,
)
del reshaped

## 4. Compute the emission probability distribution

In this step, we use the comparison result from the step above to construct the emission probability matrix.

This comparison is essentially he differences between the temperature measured by the tag and the reference sea temperature. 

The emission probability matrix represents the likelihood of observing a specific temperature difference given the model parameters and configurations.

In [None]:
from pangeo_fish.helpers import compute_emission_pdf

In [None]:
# Open the previous dataset (only necessary if you resume the notebook from here)
differences = xr.open_dataset(
    f"{target_root}/diff-regridded.zarr",
    engine="zarr",
    chunks={},
    storage_options=storage_options,
).pipe(lambda ds: ds.merge(ds[["latitude", "longitude"]].compute()))
# ... and compute the emission matrices
emission_pdf = compute_emission_pdf(
    diff_ds=differences,
    events_ds=tag["tagging_events"].ds,
    differences_std=differences_std,
    initial_std=initial_std,
    recapture_std=recapture_std,
    dims=dims,
    chunk_time=chunk_time,
)[0]
emission_pdf

In [None]:
emission_pdf.pdf.compute().dggs.decode(
    {"grid_name": "healpix", "level": 10, "indexing_scheme": "nested"}
).dggs.explore(alpha=0.8)

Whatever the temporal distribution looks like, they must **never** (i.e, at _any time step_) sum to 0.

How could we check that visually? You'd have guessed it by now: similarly as before!

In [None]:
emission_pdf = emission_pdf.chunk(default_chunk_dims).persist()
emission_pdf["pdf"].count(dims).plot()

In [None]:
# Save the dataset
emission_pdf.compute().to_zarr(
    f"{target_root}/emission.zarr",
    mode="w",
    consolidated=True,
    storage_options=storage_options,
    zarr_version=2,
)

## 5. Compute and add bathy pdf

In [None]:
import fsspec
from pangeo_fish.bathy import batch_compute_pdf_bathy
from pangeo_fish.cf import bounds_to_bins
from pangeo_fish.tags import adapt_model_time, reshape_by_bins, to_time_slice

# Ton chemin cible
zarr_path = f"s3://gfts-ifremer/fra_sodeika/run/capetienne/papermill_copernicus/same_bbox/foscat_daily/bathy_pdf_{tag_name}.zarr"

# Ouvre un filesystem compatible avec tes storage_options (ex: S3)
fs = fsspec.filesystem("s3", **storage_options)
bool_bathy = fs.exists(zarr_path)
# Vérifie si le dossier/fichier Zarr existe déjà
if bool_bathy:
    print(f"⚠️ Le fichier {zarr_path} existe déjà — calcul sauté.")
else:
    print(f"✅ Aucun fichier trouvé, lancement du calcul...")


if bool_bathy:
    bathy_pdf = xr.open_dataset(
        f"s3://gfts-ifremer/fra_sodeika/run/capetienne/papermill_copernicus/same_bbox/foscat_daily/bathy_pdf_{tag_name}.zarr",
        engine="zarr",
        chunks={},
        storage_options=storage_options,
    )
else:
    import healpy as hp
    from pangeo_fish.bathy import (
        compute_fish_histogram_bin_size,
        compute_healpix_histogram_region_bin_size,
    )

    # Open the previous dataset (only necessary if you resume the notebook from here)
    full_bathy = xr.open_dataset(
        "s3://gfts-reference-data/gebco_2024_new.zarr",
        engine="zarr",
        chunks={},
        storage_options=storage_options,
    ).rename({"lat": "latitude", "lon": "longitude"})

    subset_bathy = full_bathy.sel(
        {dim: slice(bounds[0], bounds[1]) for dim, bounds in bbox.items()}
    )
    subset_bathy
    import numpy as np

    refinement_level = refinement_level

    ds_histo = compute_healpix_histogram_region_bin_size(
        subset_bathy,
        nside=nside,
        max_depth_m=1800,  # <- profondeur max désirée en mètres
        depth_bin_size=16,  # <- largeur d’un bin en mètres
    )

    hist_ids = ds_histo.cell_ids.values  # cell ids dans ton ds_histo
    pdf_ids = emission_pdf.cell_ids.values  # cell ids dans emission_pdf

    common = np.intersect1d(hist_ids, pdf_ids)
    # isel avec masque
    mask = np.isin(hist_ids, common)
    ds_histo.isel(cells=np.where(mask)[0])

    # Reshape the tag log, so that it bins to the time step of reference_model
    reshaped_tag = reshape_by_bins(
        tag_log,
        dim="time",
        bins=(
            reference_model.cf.add_bounds(["time"], output_dim="bounds")
            .pipe(bounds_to_bins, bounds_dim="bounds")
            .get("time_bins")
        ),
        other_dim="obs",
    ).chunk({"time": chunk_time})

    fish_hist = compute_fish_histogram_bin_size(
        reshaped_tag, depth_max=1800, depth_bin_size=16
    )

    pdf_da_func = batch_compute_pdf_bathy(
        ds_histo,
        reshaped_tag,
        target_root,
        batch_size=5000,
    )
    sum_over_cells = pdf_da_func.sum(dim="cells", skipna=True)  # shape (time,)
    # Normalising
    bathy_pdf = pdf_da_func / sum_over_cells

    bathy_pdf.compute().to_zarr(
        f"s3://gfts-ifremer/fra_sodeika/run/capetienne/papermill_copernicus/same_bbox/foscat_daily/bathy_pdf_{tag_name}.zarr",
        compute=True,
        mode="w",
        consolidated=True,
        zarr_version=2,
        storage_options=storage_options,
    )
    bathy_pdf = xr.open_dataset(
        f"s3://gfts-ifremer/fra_sodeika/run/capetienne/papermill_copernicus/same_bbox/foscat_daily/bathy_pdf_{tag_name}.zarr",
        engine="zarr",
        chunks={},
        storage_options=storage_options,
    )

In [None]:
import numpy as np
import xarray as xr


def normalize_pdf_by_mask(ds, mask_var="mask", pdf_var="pdf", tol=1e-12):
    mask_cells = ds[mask_var].astype(bool)

    mask_time = mask_cells.expand_dims(time=ds["time"]).transpose("time", "cells")

    n_valid_cells = int(mask_cells.sum().compute().item())
    if n_valid_cells == 0:
        raise ValueError(
            "Le masque indique 0 cellules valides (océan). Impossible de normaliser."
        )
    sums_by_time = ds[pdf_var].where(mask_time).fillna(0).sum(dim="cells")
    sums_vals = sums_by_time.compute()
    to_fix = (sums_vals <= tol) | np.isnan(sums_vals)
    idxs = np.where(to_fix.values)[0]
    times = [str(t) for t in ds["time"].isel(time=idxs).values] if idxs.size else []
    print(
        f"{int(to_fix.sum().item())} pas de temps non valides. Indices: {idxs.tolist()}. Times: {times}"
    )
    if idxs.size == 0:
        return ds
    fill_per_time = xr.where(to_fix, 1.0 / float(n_valid_cells), np.nan)
    fill_per_time = xr.DataArray(
        fill_per_time, coords={"time": ds["time"]}, dims=["time"]
    )
    replacement = xr.where(mask_time, fill_per_time, np.nan)
    ds_fixed = ds.copy()
    ds_fixed[pdf_var] = xr.where(to_fix, replacement, ds[pdf_var])
    return ds_fixed

In [None]:
bathy_pdf["mask"] = emission_pdf["mask"]
bathy_pdf_corrected = normalize_pdf_by_mask(bathy_pdf, pdf_var="pdf_bathy")

In [None]:
emission_pdf_corrected = normalize_pdf_by_mask(emission_pdf)

In [None]:
emission_with_bathy = emission_pdf_corrected.merge(
    bathy_pdf_corrected, compat="override"
)
emission_with_bathy

In [None]:
emission_with_bathy["pdf"].count(dims).plot()

In [None]:
from pangeo_fish.helpers import normalize_pdf

combined_diff_bathy = normalize_pdf(
    ds=emission_with_bathy,
    chunks=default_chunk_dims,
    dims=dims,
)[0]
combined_diff_bathy

In [None]:
combined_diff_bathy["pdf"].count(dims).plot()

In [None]:
combined_diff_bathy.compute().to_zarr(
    f"{target_root}/emission_w_bathy_pdf_{tag_name}.zarr",
    compute=True,
    mode="w",
    consolidated=True,
    zarr_version=2,
    storage_options=storage_options,
)
del emission_pdf, combined_diff_bathy

## 7. Estimate the model's parameters

It is now time to determine the parameters of the model based on the normalized emission matrix.

Precisely, is consists of finding the best `sigma`, which corresponds to the standard deviation of the Brownian motion that models the fish's movement between the time steps.  

To do so, in the following we:
1. Define the lower and upper bounds for `sigma`.  
2. Search for the best `sigma` with `optimize_pdf()`.
3. Save the results of the search (i.e., ` sigma`), along with any additional parameters used during the optimization, a human-readable `.json` file.  

In [None]:
from pangeo_fish.helpers import optimize_pdf

In [None]:
# Open the distributions
emission = xr.open_dataset(
    f"{target_root}/emission_w_bathy_pdf_{tag_name}.zarr",
    engine="zarr",
    chunks=default_chunk_dims,
    inline_array=True,
    storage_options=storage_options,
)
emission

In [None]:
emission.pdf.compute().dggs.decode(
    {"grid_name": "healpix", "level": 10, "indexing_scheme": "nested"}
).dggs.explore(alpha=0.8)

In [None]:
emission = emission.dggs.decode(
    {"grid_name": "healpix", "level": 10, "indexing_scheme": "nested"}
)

In [None]:
%%time
# Define the parameter's bounds and search for the best value
params = optimize_pdf(
    ds=emission.compute(),
    earth_radius=earth_radius,
    adjustment_factor=adjustment_factor,
    truncate=truncate,
    maximum_speed=maximum_speed,
    tolerance=tolerance,
    dims=dims,
    # the results can be directly saved
    save_parameters=True,
    storage_options=storage_options,
    target_root=target_root,
)
params

In [None]:
target_root

## 8. State probabilities and Trajectories

In this second to last step, we calculate the spatial probability distribution (based on the `sigma` found earlier) and further compute trajectories.

_NB: the computation precisely relies on `sigma` and the combined emission pdf._

In [None]:
from pangeo_fish.helpers import predict_positions

In [None]:
emission = xr.open_dataset(
    f"{target_root}/emission_w_bathy_pdf_{tag_name}.zarr",
    engine="zarr",
    chunks=default_chunk_dims,
    inline_array=True,
    storage_options=storage_options,
)
emission

In [None]:
%%time
states, trajectories = predict_positions(
    ds=emission.dggs.decode(
        {"grid_name": "healpix", "level": 10, "indexing_scheme": "nested"}
    ),
    target_root=target_root,
    storage_options=storage_options,
    chunks=default_chunk_dims,
    track_modes=track_modes,
    additional_track_quantities=additional_track_quantities,
    save=True,
    tag_name=tag_name,
)

In [None]:
states.states.compute().dggs.decode(
    {"grid_name": "healpix", "level": 10, "indexing_scheme": "nested"}
).dggs.explore(alpha=0.8)

In [None]:
states.compute().to_zarr(
    f"{target_root}/states.zarr",
    compute=True,
    mode="w",
    consolidated=True,
    zarr_version=2,
    storage_options=storage_options,
)