# **Example Usage of Pangeo-Fish Software**


**Overview:**
This Jupyter notebook demonstrates the usage of the Pangeo-Fish software, a tool designed for analyzing biologging data in reference to Earth Observation (EO) data.

The biologging data consist of Data Storage Tag (DST), along with release and recapture time and location of the species in question. Both biologging data and the reference EO data are accessible with https and the access methods are incorporated in this notebook.   

**Purpose:**
By executing this notebook, users will learn how to set up a workflow for utilizing the Pangeo-Fish software. The workflow consists of 9 steps which are described below:

1. **Configure the Notebook:** Prepare the notebook environment for analysis.
2. **Compare Reference Model with DST Information:** Analyze and compare data from the reference model with information from the biologging data of the species in question. 
3. **Regrid the Grid from Reference Model Grid to Healpix Grid:** Transform the grid from the reference model to the Healpix grid for further analysis.
4. **Construct Emission Matrix:** Create an emission matrix based on the transformed grid.
5. **Replace emission for flagged tags:** If the tags are flagged for warm water, then it use the detection file associated and change the flagged timestamps.
6. **Combine and Normalize Emission Matrix:** Merge the emission matrix and normalize it for further processing.
7. **Estimate Model Parameters:** Determine the parameters of the model based on the normalized emission matrix.
8. **Compute State Probabilities and Tracks:** Calculate the probability distribution of the species in question and compute the tracks.
9. **Visualization:** Visualize the results of the analysis for interpretation and insight.

Throughout this notebook, users will gain practical experience in setting up and executing a workflow using Pangeo-Fish, enabling them to apply similar methodologies to their own biologging data analysis tasks.



## 1. **Configure the Notebook:** Prepare the notebook environment for analysis.

In this step, we sets up the notebook environment for analysis. It includes installing necessary packages, importing required libraries, setting up parameters, and configuring the cluster for distributed computing. It also retrieves the tag data needed for analysis.

    

In [None]:
# Import necessary libraries and modules.
import hvplot.xarray
import intake
import pandas as pd
import xarray as xr
from pint_xarray import unit_registry as ureg

from pangeo_fish.io import open_tag

In [None]:
#
# Set up execution parameters for the analysis.
#
# Note: This cell is tagged as parameters, allowing automatic updates when configuring with papermil.

# tag_name corresponds to the name of the biologging tag name (DST identification number),
# which is also a path for storing all the information for the specific fish tagged with tag_name.
# tag_name = "AD_A11849"
# tag_name = "SV_A11957"


tag_list = [
    "NO_A12710",
    "CB_A11036",
    "LT_A11385",
    "SQ_A10684",
    "AD_A11177",
    "PB_A12063",
    "NO_A12742",
    "DK_A10642",
    "CB_A11071",
]
tag_name = tag_list[8]
tag_name = "LT_A11338"

cloud_root = "s3://gfts-ifremer/tags/bargip"

# tag_root specifies the root URL for tag data used for this computation.
tag_root = f"{cloud_root}/cleaned"

# catalog_url specifies the URL for the catalog for reference data used.
catalog_url = "s3://gfts-ifremer/copernicus_catalogs/master.yml"

# scratch_root specifies the root directory for storing output files.
scratch_root = f"{cloud_root}/tracks"


# storage_options specifies options for the filesystem storing output files.
storage_options = {
    "anon": False,
    # 'profile' : "gfts",
    "client_kwargs": {
        "endpoint_url": "https://s3.gra.perf.cloud.ovh.net",
        "region_name": "gra",
    },
}

# if you are using local file system, activate following two lines
folder_name = "../toto"
storage_options = None
scratch_root = f"/home/jovyan/notebooks/papermill/{folder_name}"

# Default chunk value for time dimension.  This values depends on the configuration of your dask cluster.
chunk_time = 24

#
# Parameters for step 2. **Compare Reference Model with DST Information:**
#
# bbox, bounding box, defines the latitude and longitude range for the analysis area.
bbox = {"latitude": [40, 56], "longitude": [-13, 5]}

# relative_depth_threshold defines the acceptable fish depth relative to the maximum tag depth.
# It determines whether the fish can be considered to be in a certain location based on depth.
relative_depth_threshold = 0.8

#
# Parameters for step 3. **Regrid the Grid from Reference Model Grid to Healpix Grid:**
#
# Distance filepath is the path to the coastal distance file.
distance_filepath = "s3://gfts-ifremer/tags/distance2coast.zarr"

# distance_scale_factor scales the squared distance in the exponential decay function.
distance_scale_factor = 0.01

# nside defines the resolution of the healpix grid used for regridding.
nside = 4096  # *2

# rot defines the rotation angles for the healpix grid.
rot = {"lat": 0, "lon": 30}

# min_vertices sets the minimum number of vertices for a valid transcription for regridding.
min_vertices = 1

#
# Parameters for step 4. **Construct Emission Matrix:**
#
# differences_std sets the standard deviation for scipy.stats.norm.pdf.
# It expresses the estimated certainty of the field of difference.
differences_std = 0.75

# recapture_std sets the covariance for recapture event.
# It shows the certainty of the final recapture area if it is known.
recapture_std = 1e-2

# earth_radius defines the radius of the Earth used for distance calculations.
earth_radius = ureg.Quantity(6371, "km")

# maximum_speed sets the maximum allowable speed for the tagged fish.
maximum_speed = ureg.Quantity(60, "km / day")

# adjustment_factor adjusts parameters for a more fuzzy search.
# It will factor the allowed maximum displacement of the fish.
adjustment_factor = 5

# truncate sets the truncating factor for computed maximum allowed sigma for convolution process.
truncate = 4

#
# Parameters for step 5. **Compute Additional Emission Probability Matrix:**
#


# buffer_size sets the size of the powerplant warm plume.
buffer_size = ureg.Quantity(1000, "m")
# powerplant_flag is a boolean that states if the fish has swam in warm plume


#
# Parameters for step 7. **Estimate Model Parameters:**
#
# tolerance sets the tolerance level for optimised parameter serarch computation.
tolerance = 1e-3

#
# Parameters for step 8. **Compute State Probabilities and Tracks:**
#
# track_modes defines the modes for track calculation.
track_modes = ["mean", "mode"]

# additional_track_quantities sets quantities to compute for tracks using moving pandas.
additional_track_quantities = ["speed", "distance"]


#
# Parameters for step 9. **Visualization:**
#
# time_step defines for each time_step value we visualize state and emission matrix.
time_step = 3


# Define target root directories for storing analysis results.
target_root = f"{scratch_root}/{tag_name}"

# Defines default chunk size for optimisation.
default_chunk = {"time": chunk_time, "lat": -1, "lon": -1}
default_chunk_xy = {"time": chunk_time, "x": -1, "y": -1}

In [None]:
# Define target root directories for storing analysis results.
target_root = f"{scratch_root}/{tag_name}"

In [None]:
target_root

In [None]:
tag_root

In [None]:
warm_plume = pd.read_csv(
    "s3://gfts-ifremer/tags/bargip/bar_flag_warm_plume.txt", sep="\t"
)
warm_list = list(warm_plume[warm_plume["warm_plume"] == True]["tag_name"])

if tag_name in warm_list:
    powerplant_flag = True
else:
    powerplant_flag = False

In [None]:
if powerplant_flag:
    detection_file = f"{tag_root}/{tag_name}/detection.csv"
    powerplant_file = f"{cloud_root}/nuclear_plant_loc.csv"

In [None]:
# Set up a local cluster for distributed computing.
from distributed import LocalCluster

cluster = LocalCluster()
client = cluster.get_client()
client

In [None]:
# Open and retrieve the tag data required for the analysis
tag = open_tag(tag_root, tag_name)
tag

## 2. **Compare Reference Model with DST Tag Information:** Analyze and compare data from the reference model with information from the biologging data of the species in question. 

In this step, we compare the reference model data with Data Storage Tag information.
The process involves reading and cleaning the reference model, aligning time, converting depth units, subtracting tag data from the model, and saving the results.

In [None]:
# Import necessary libraries
import intake

from pangeo_fish.cf import bounds_to_bins
from pangeo_fish.diff import diff_z
from pangeo_fish.io import open_copernicus_catalog
from pangeo_fish.tags import adapt_model_time, reshape_by_bins, to_time_slice

# Drop data outside the reference interval
time_slice = to_time_slice(tag["tagging_events/time"])
time = tag["dst"].ds.time
cond = (time <= time_slice.stop) & (time >= time_slice.start)

tag_log = tag["dst"].ds.where(cond, drop=True)

min_ = tag_log.time[0]
max_ = tag_log.time[-1]

time_slice = slice(min_.data, max_.data)

In [None]:
def get_copernicus_zarr(product_id="IBI_MULTIYEAR_PHY_005_002"):
    master_cat = intake.open_catalog(catalog_url)
    if product_id == "IBI_MULTIYEAR_PHY_005_002":

        # Open necessary datasets
        sub_cat = master_cat[product_id]
        thetao = sub_cat["cmems_mod_ibi_phy_my_0.083deg-3D_P1D-m"](
            chunk="time"
        ).to_dask()[["thetao"]]
        zos = (
            sub_cat["cmems_mod_ibi_phy_my_0.083deg-3D_P1D-m"](chunk="time")
            .to_dask()
            .zos
        )
        deptho = sub_cat["cmems_mod_ibi_phy_my_0.083deg-3D_static"].to_dask().deptho

    # Assign latitude and longitude from thetao to deptho to shift in positions
    deptho["latitude"] = thetao["latitude"]
    deptho["longitude"] = thetao["longitude"]

    # Create mask for deptho
    mask = deptho.isnull()

    # Merge datasets and assign relevant variables
    ds = (
        thetao.rename({"thetao": "TEMP"}).assign(
            {
                "XE": zos,
                "H0": deptho,
                "mask": mask,
            }
        )
    ).rename({"latitude": "lat", "longitude": "lon", "elevation": "depth"})

    # Ensure depth is positive
    ds["depth"] = abs(ds["depth"])

    # Rearrange depth coordinates and assign dynamic depth and bathymetry
    ds = (
        ds.isel(depth=slice(None, None, -1))
        .assign(
            {
                "dynamic_depth": lambda ds: (ds["depth"] + ds["XE"]).assign_attrs(
                    {"units": "m", "positive": "down"}
                ),
                "dynamic_bathymetry": lambda ds: (ds["H0"] + ds["XE"]).assign_attrs(
                    {"units": "m", "positive": "down"}
                ),
            }
        )
        .pipe(broadcast_variables, {"lat": "latitude", "lon": "longitude"})
    )
    # print(uris_by_key)
    return ds

In [None]:
# Verify the data
import cmocean
import hvplot.xarray

from pangeo_fish.io import save_html_hvplot

plot = (
    (-tag["dst"].pressure).hvplot(width=1000, height=500, color="blue")
    * (-tag_log).hvplot.scatter(
        x="time", y="pressure", color="red", size=5, width=1000, height=500
    )
    * (
        (tag["dst"].temperature).hvplot(width=1000, height=500, color="blue")
        * (tag_log).hvplot.scatter(
            x="time", y="temperature", color="red", size=5, width=1000, height=500
        )
    )
)
filepath = f"{target_root}/tags.html"

save_html_hvplot(plot, filepath, storage_options)

# plot

In [None]:
from pangeo_fish.io import broadcast_variables

In [None]:
model = get_copernicus_zarr()

In [None]:
# Subset the reference_model by
# - align model time with the time of tag_log, also
# - drop data for depth later that are unlikely due to the observed pressure from tag_log
# - defined latitude and longitude of bbox.
#
reference_model = (
    model.sel(time=adapt_model_time(time_slice))
    .sel(lat=slice(*bbox["latitude"]), lon=slice(*bbox["longitude"]))
    .pipe(
        lambda ds: ds.sel(
            depth=slice(None, (tag_log["pressure"].max() - ds["XE"].min()).compute())
        )
    )
)

In [None]:
%%time
# Reshape the tag log, so that it bins to the time step of reference_model
reshaped_tag = reshape_by_bins(
    tag_log,
    dim="time",
    bins=(
        reference_model.cf.add_bounds(["time"], output_dim="bounds")
        .pipe(bounds_to_bins, bounds_dim="bounds")
        .get("time_bins")
    ),
    bin_dim="bincount",
    other_dim="obs",
).chunk({"time": chunk_time})

In [None]:
# Subtract the time_bined tag_log from the reference_model.
# Here, for each time_bin, each observed value are compared with the correspoindng depth of reference_model using diff_z function.
#

diff = (
    diff_z(
        reference_model.chunk(dict(depth=-1)),
        reshaped_tag,
        depth_threshold=relative_depth_threshold,
    )
    .assign_attrs({"tag_id": tag_name})
    .assign(
        {
            "H0": reference_model["H0"],
            "ocean_mask": reference_model["H0"].notnull(),
        }
    )
)

# Persist the diff data
diff = diff.chunk(default_chunk).persist()
# diff

In [None]:
%%time
# Verify the data
# diff["diff"].count(["lat","lon"]).plot()

In [None]:
target_lat = diff["lat"]
target_lon = diff["lon"]

In [None]:
%%time
# Save snapshot to disk
diff.to_zarr(f"{target_root}/diff.zarr", mode="w", storage_options=storage_options)

# Cleanup
del tag_log, model, reference_model, reshaped_tag, diff

## 3. **Regrid the Grid from Reference Model Grid to Healpix Grid:** Transform the grid from the reference model to the Healpix grid for further analysis.

In this step, we regrid the data from the reference model grid to a Healpix grid. This process involves defining the Healpix grid, creating the target grid, computing interpolation weights, performing the regridding, and saving the regridded data.


In [None]:
# Import necessary libraries
import numpy as np
import s3fs
from xhealpixify import HealpyGridInfo, HealpyRegridder

from pangeo_fish.grid import center_longitude

In [None]:
%%time

# Open the diff data and performs cleaning operations to prepare it for regridding.

ds = (
    xr.open_dataset(
        f"{target_root}/diff.zarr",
        engine="zarr",
        chunks={},
        storage_options=storage_options,
    )
    .pipe(lambda ds: ds.merge(ds[["latitude", "longitude"]].compute()))
    .swap_dims({"lat": "yi", "lon": "xi"})
)
ds

In [None]:
s3 = s3fs.S3FileSystem(
    anon=False,
    client_kwargs={
        "endpoint_url": "https://s3.gra.perf.cloud.ovh.net",
    },
)

In [None]:
coastal_distance = xr.open_zarr(distance_filepath).sel(
    lat=slice(56, 40), lon=slice(-13, 5)
)

In [None]:
coastal_distance = coastal_distance.sortby("lat")

In [None]:
coastal_distance = coastal_distance.interp(
    lat=target_lat, lon=target_lon, method="linear"
)

In [None]:
coastal_distance["dist"] = 1 + np.exp(
    -(coastal_distance.dist * coastal_distance.dist) * distance_scale_factor
)

In [None]:
coastal_distance = coastal_distance.swap_dims({"lat": "yi", "lon": "xi"}).drop_vars(
    ["lat", "lon"]
)

In [None]:
%%time
# Define the target Healpix grid information
grid = HealpyGridInfo(level=int(np.log2(nside)), rot=rot)
target_grid = grid.target_grid(ds).pipe(center_longitude, 0)
target_grid

In [None]:
%%time
# Compute the interpolation weights for regridding the diff data
regridder = HealpyRegridder(
    ds[["longitude", "latitude", "ocean_mask"]],
    target_grid,
    method="bilinear",
    interpolation_kwargs={"mask": "ocean_mask", "min_vertices": min_vertices},
)
regridder

In [None]:
%%time
# Perform the regridding operation using the computed interpolation weights.
regridded = regridder.regrid_ds(ds)
regridded

In [None]:
regridded_coastal = regridder.regrid_ds(coastal_distance)

In [None]:
%%time
# Reshape the regridded data to 2D
reshaped = grid.to_2d(regridded).pipe(center_longitude, 0)
reshaped = reshaped.persist()
reshaped

In [None]:
reshaped_coastal = grid.to_2d(regridded_coastal).pipe(center_longitude, 0)

In [None]:
# This cell verifies the regridded data by plotting the count of non-NaN values.
# reshaped["diff"].count(["x", "y"]).plot()

In [None]:
coastal_chunk = {"x": default_chunk_xy["x"], "y": default_chunk_xy["y"]}

In [None]:
reshaped["diff"].isel(time=0).hvplot.quadmesh(
    title="Carte des différences avant l'ajout de l'incertitude",
    x="longitude",
    y="latitude",
    cmap="cool",
    coastline="10m",
    xlim=bbox["longitude"],
    ylim=bbox["latitude"],
)

In [None]:
reshaped["diff"] = reshaped["diff"] / reshaped_coastal["dist"]

In [None]:
reshaped["diff"].isel(time=0).hvplot.quadmesh(
    title="Carte des différences après l'ajout de l'incertitude",
    x="longitude",
    y="latitude",
    cmap="cool",
    coastline="10m",
    xlim=bbox["longitude"],
    ylim=bbox["latitude"],
)

In [None]:
%%time
# This cell saves the regridded data to Zarr format, then cleans up unnecessary variables to free up memory after the regridding process.
reshaped.chunk(default_chunk_xy).to_zarr(
    f"{target_root}/diff-regridded.zarr",
    mode="w",
    consolidated=True,
    compute=True,
    storage_options=storage_options,
)

reshaped_coastal.chunk(coastal_chunk).to_zarr(
    f"{target_root}/coastal.zarr",
    mode="w",
    consolidated=True,
    compute=True,
    storage_options=storage_options,
)
# Cleanup unnecessary variables to free up memory
del ds, grid, target_grid, regridder, regridded, reshaped, reshaped_coastal

## 4. **Construct Emission Matrix:** Create an emission matrix based on the transformed grid.

In this step, we construct the emission probability matrix based on the differences between the observed tag temperature and the reference sea temperature computed in Workflow 2 and regridded in Workflow 3. The emission probability matrix represents the likelihood of observing a specific temperature difference given the model parameters and configurations.


In [None]:
# Import necessary libraries
from toolz.dicttoolz import valfilter

from pangeo_fish.distributions import create_covariances, normal_at
from pangeo_fish.pdf import normal

In [None]:
%%time
# Open the regridded diff data
differences = xr.open_dataset(
    f"{target_root}/diff-regridded.zarr",
    engine="zarr",
    chunks={},
    storage_options=storage_options,
).pipe(lambda ds: ds.merge(ds[["latitude", "longitude"]].compute()))
differences

In [None]:
%%time
# Compute initial and final position
grid = differences[["latitude", "longitude"]].compute()

initial_position = tag["tagging_events"].ds.sel(event_name="release")
cov = create_covariances(1e-6, coord_names=["latitude", "longitude"])
initial_probability = normal_at(
    grid, pos=initial_position, cov=cov, normalize=True, axes=["latitude", "longitude"]
)

final_position = tag["tagging_events"].ds.sel(event_name="fish_death")
if final_position[["longitude", "latitude"]].to_dataarray().isnull().all():
    final_probability = None
else:
    cov = create_covariances(recapture_std**2, coord_names=["latitude", "longitude"])
    final_probability = normal_at(
        grid,
        pos=final_position,
        cov=cov,
        normalize=True,
        axes=["latitude", "longitude"],
    )

In [None]:
%%time
# compute emission probability matrix

emission_pdf = (
    normal(differences["diff"], mean=0, std=differences_std, dims=["y", "x"])
    .to_dataset(name="pdf")
    .assign(
        valfilter(
            lambda x: x is not None,
            {
                "initial": initial_probability,
                "final": final_probability,
                "mask": differences["ocean_mask"],
            },
        )
    )
    .assign_attrs(differences.attrs)  # | {"max_sigma": max_sigma})
)

emission_pdf = emission_pdf.chunk(default_chunk_xy).persist()
emission_pdf

In [None]:
# Verify the data
# emission_pdf["pdf"].count(["x", "y"]).plot()

In [None]:
# This cell saves the emission data to Zarr format, then cleans up unnecessary variables to free up memory.

emission_pdf.to_zarr(
    f"{target_root}/emission.zarr",
    mode="w",
    consolidated=True,
    storage_options=storage_options,
)

del differences, grid, initial_probability, final_probability, emission_pdf

## 5. **Replace emission for the tags with warm spikes detected**

In [None]:
%%time
# Import necessary libraries and open data and perform initial setup
import hvplot.xarray
import pandas as pd

from pangeo_fish import acoustic, utils
from pangeo_fish.heat import heat_regulation, powerpalnt_emission_map

emission = xr.open_dataset(
    f"{target_root}/emission.zarr",
    engine="zarr",
    chunks={},  # "x": -1, "y": -1},
    storage_options=storage_options,
)
emission

In [None]:
powerplant_flag = False

In [None]:
if powerplant_flag:
    # Loading detections, formatting and reducing observation window
    detections = pd.read_csv(detection_file).set_index("time").to_xarray()
    detections["time"] = detections["time"].astype("datetime64")
    detections = detections.sel(
        time=emission["time"]
    )  # Narrowing the data to the observed days only

    pp_map = (
        pd.read_csv(powerplant_file, sep=";").drop("Country", axis=1).to_xarray()
    )  # Loading powerplant locations data

    # Combining and replacing the emission map at the given timestamps for the days where warm plume are detected
    combined_masks = powerpalnt_emission_map(pp_map, emission, buffer_size, rot)
    emission = heat_regulation(emission, detections, combined_masks)

## 6. **Combine and Normalize Emission Matrix:** Merge the emission matrix and normalize it for further processing.

In this step, we combine the emission probability matrix constructed in Workflow 4 and 5 then normalize it to ensure that the probabilities sum up to one. This step prepares the combined emission matrix for further analysis and interpretation.


In [None]:
# Import necessary libraries
import hvplot.xarray

from pangeo_fish.pdf import combine_emission_pdf

In [None]:
# Open and combine the emission probability matrix

combined = (
    emission.pipe(combine_emission_pdf)
    .chunk(default_chunk_xy)
    .persist()  # convert to comment if the emission matrix does *not* fit in memory
)
combined

In [None]:
# Verify the data and visualize the sum of probabilities
# combined["pdf"].sum(["x", "y"]).hvplot(width=400)

In [None]:
# Save the combined and normalized emission matrix
combined.to_zarr(
    f"{target_root}/combined.zarr",
    mode="w",
    consolidated=True,
    storage_options=storage_options,
)
del combined

## 7. **Estimate Model Parameters:** Determine the parameters of the model based on the normalized emission matrix.

This step first estimates maxixmum allowed value of  model parameter 'sigma' max_sigma.  Then we
create an optimizer with an expected parameter range, fitting the model to the normalized emission matrix.  
The resulting optimized parameters is saved to a json file.  

In [None]:
# Import necessary libraries and modules for data analysis.
import pandas as pd
import xarray as xr

from pangeo_fish.hmm.estimator import EagerScoreEstimator
from pangeo_fish.hmm.optimize import EagerBoundsSearch
from pangeo_fish.utils import temporal_resolution

# Open the data
emission = xr.open_dataset(
    f"{target_root}/combined.zarr",
    engine="zarr",
    chunks={},
    inline_array=True,
    storage_options=storage_options,
)
emission

In [None]:
# Compute maximum displacement for each reference model time step
# and estimate maximum sigma value for limiting the optimisation step

earth_radius_ = xr.DataArray(earth_radius, dims=None)

timedelta = temporal_resolution(emission["time"]).pint.quantify().pint.to("h")
grid_resolution = earth_radius_ * emission["resolution"].pint.quantify()

maximum_speed_ = xr.DataArray(maximum_speed, dims=None).pint.to("km / h")
max_grid_displacement = maximum_speed_ * timedelta * adjustment_factor / grid_resolution
max_sigma = max_grid_displacement.pint.to("dimensionless").pint.magnitude / truncate
emission.attrs["max_sigma"] = max_sigma
max_sigma

In [None]:
# Create and configure estimator and optimizer
emission = (
    emission.compute()
)  # Convert to comment if the emission matrix does *not* fit in memory
estimator = EagerScoreEstimator()
optimizer = EagerBoundsSearch(
    estimator,
    (1e-4, emission.attrs["max_sigma"]),
    optimizer_kwargs={"disp": 3, "xtol": tolerance},
)

In [None]:
%%time
# Fit the model parameter to the data
optimized = optimizer.fit(emission)

In [None]:
# Save the optimized parameters
params = optimized.to_dict()
pd.DataFrame.from_dict(params, orient="index").to_json(
    f"{target_root}/parameters.json", storage_options=storage_options
)

In [None]:
# Cleanup
del optimized, emission

## 8. **Compute State Probabilities and Tracks:** Calculate the probability distribution of the species in question and compute the tracks.

This step involves predicting state probabilities using the optimised parameter sigma computed in the last step together with normalized emission matrix.  

In [None]:
# Import necessary libraries and modules for data analysis.
import hvplot.xarray
import pandas as pd
import xarray as xr

from pangeo_fish.hmm.estimator import EagerScoreEstimator
from pangeo_fish.io import save_trajectories

# Recreate the Estimator
params = pd.read_json(
    f"{target_root}/parameters.json", storage_options=storage_options
).to_dict()[0]
optimized = EagerScoreEstimator(**params)
optimized

In [None]:
%%time
# Load the Data
emission = xr.open_dataset(
    f"{target_root}/combined.zarr",
    engine="zarr",
    chunks=default_chunk_xy,
    inline_array=True,
    storage_options=storage_options,
).compute()

# Predict the State Probabilities

states = optimized.predict_proba(emission)
states = states.to_dataset().chunk(default_chunk_xy).persist()
states

In [None]:
# Verify the data and visualize the sum of probabilities
# states.sum(["x", "y"]).hvplot() +states.count(["x", "y"]).hvplot()

In [None]:
%%time
# Save probability distirbution, state matrix.
states.chunk(default_chunk_xy).to_zarr(
    f"{target_root}/states.zarr",
    mode="w",
    consolidated=True,
    storage_options=storage_options,
)

In [None]:
%%time
# decode tracks

trajectories = optimized.decode(
    emission,
    states.fillna(0),
    mode=track_modes,
    progress=False,
    additional_quantities=additional_track_quantities,
)
trajectories

In [None]:
# Save trajectories.
# Here we can chose format parquet for loading files from 'R'
# or chose to  format 'geoparquet' for further analysis of tracks using
# geopands.

save_trajectories(trajectories, target_root, storage_options, format="parquet")

In [None]:
# Cleanup
del optimized, emission, states, trajectories

In [None]:
cluster.close()
client.close()

## 9. **Visualization:** Visualize the results of the analysis for interpretation and insight.


In this step, we visualize various aspects of the analysis results to gain insights and interpret the model outcomes. We plot the emission matrix, which represents the likelihood of observing a specific temperature difference given the model parameters and configurations. Additionally, we visualize the state probabilities, showing the likelihood of the system being in different states at each time step. We also plot each of the tracks of the tagged fish, displaying their movement patterns over time. Finally, we create a movie that combines the emission matrix and state probabilities to provide a comprehensive visualization of the analysis results.
