# track decoding

With the model parameter $\sigma$ estimated, we can proceed to estimating tracks.

There's multiple different ways:
- mean track
- mode track
- most probable track

Of all of these, the most probable track is the most meaningful one.

In [None]:
import json

import fsspec
import xarray as xr

from pangeo_fish import tracks
from pangeo_fish.hmm.estimator import EagerScoreEstimator
from pangeo_fish.pdf import combine_emission_pdf

parametrize with [papermill](https://papermill.readthedocs.io/en/latest/)

In [None]:
emission_path: str
parameter_path: str

track_modes: str | list = ["mean", "mode", "viterbi"]
additional_track_quantities: str | list = ["speed", "distance"]

states_path: str | None
tracks_root: str

scheduler_address: str | None = None

In [None]:
emission_path = (
    "https://data-taos.ifremer.fr/fish_mid4096/A19230/copernicus/emission_4096.zarr"
)
parameter_path = (
    "https://data-taos.ifremer.fr/fish_mid4096/A19230/copernicus/sigma_4096.json"
)
tracks_root = "/home/jmagin/work/data/fish-intel/tracks/A19230"
track_modes = ["viterbi", "mean"]

create dask cluster

In [None]:
from distributed import Client, LocalCluster

if scheduler_address is None:
    cluster = LocalCluster()
    client = cluster.get_client()
else:
    client = Client(scheduler_address)
client

open emission probabilities

In [None]:
import operator
from functools import reduce

In [None]:
from pangeo_fish.pdf import combine_emission_pdf

In [None]:
emission = (
    xr.open_dataset(emission_path, engine="zarr", chunks={}, inline_array=True)
    .drop_vars("resolution")
    .pipe(combine_emission_pdf)
    .transpose("time", "y", "x")
    .compute()
)
emission

open state probabilities

read the estimated parameters

In [None]:
with fsspec.open(parameter_path, mode="r") as f:
    parameters = json.load(f)
parameters.pop("tolerance", None)
parameters

create the estimator

In [None]:
estimator = EagerScoreEstimator(**parameters)
estimator

compute the tracks

In [None]:
track_modes = ["viterbi", "viterbi2"]
# track_modes = ["mean"]
for mode in track_modes:
    %time raw_track = estimator.decode(emission, mode=mode, is_states="states" in emission)
    track = tracks.additional_quantities(raw_track, additional_track_quantities)
    track.df.to_parquet(f"{tracks_root}/{mode}.parquet")

In [None]:
import cmocean
import geopandas as gpd
import holoviews as hv
import hvplot.xarray
import movingpandas as mpd
import xarray as xr

In [None]:
track_modes = ["viterbi", "viterbi2", "mean"]

In [None]:
track_paths = [f"{tracks_root}/{mode}.parquet" for mode in track_modes]
all_tracks = {
    mode: mpd.Trajectory(
        gpd.read_parquet(f"{tracks_root}/{mode}.parquet"), traj_id=mode
    )
    for mode in track_modes
}

In [None]:
hv.Layout(
    [
        track.hvplot(c="speed", tiles="CartoLight", title=name, cmap="cmo.speed")
        for name, track in all_tracks.items()
    ]
).cols(2)

In [None]:
import dask
import dask.array as da
import numpy as np

from pangeo_fish.distributions import gaussian_kernel
from pangeo_fish.hmm.decode import kernel_state_metric
from pangeo_fish.pdf import combine_emission_pdf

Mathieu's implementation

In [None]:
from pangeo_fish.distributions import gaussian_kernel

In [None]:
from rich.progress import track

In [None]:
import numba

In [None]:
@numba.njit
def _propagate_timestep(M, kernel, emission, Tprevx, Tprevy):
    row, col = M.shape
    ks = kernel.shape[0]
    Mtemp = np.full((row, col), fill_value=-np.inf)
    Ttempx = np.full((row, col), fill_value=-1, dtype="int16")
    Ttempy = np.full((row, col), fill_value=-1, dtype="int16")

    for x in range(col):
        for y in range(row):
            if M[y, x] == -np.inf:
                continue

            kminlat = max(ks // 2 - y, 0)
            kmaxlat = min((ks - 1) - (y + ks // 2 - (row - 1)), ks - 1)
            kminlong = max(ks // 2 - x, 0)
            kmaxlong = min((ks - 1) - (x + ks // 2 - (col - 1)), ks - 1)

            mminlat = max(y - ks // 2, 0)
            mmaxlat = min(y + ks // 2, row - 1)
            mminlong = max(x - ks // 2, 0)
            mmaxlong = min(x + ks // 2, col - 1)

            B = (
                emission[mminlat : mmaxlat + 1, mminlong : mmaxlong + 1]
                + kernel[kminlat : kmaxlat + 1, kminlong : kmaxlong + 1]
            )

            Msub = B + M[y, x]

            Mupdate = Mtemp[mminlat : mmaxlat + 1, mminlong : mmaxlong + 1]
            Txupdate = Ttempx[mminlat : mmaxlat + 1, mminlong : mmaxlong + 1]
            Tyupdate = Ttempy[mminlat : mmaxlat + 1, mminlong : mmaxlong + 1]

            update = Msub > Mupdate

            Mtemp[mminlat : mmaxlat + 1, mminlong : mmaxlong + 1] = np.where(
                update, Msub, Mupdate
            )
            Ttempx[mminlat : mmaxlat + 1, mminlong : mmaxlong + 1] = np.where(
                update, x, Txupdate
            )
            Ttempy[mminlat : mmaxlat + 1, mminlong : mmaxlong + 1] = np.where(
                update, y, Tyupdate
            )

    return Mtemp, Ttempx, Ttempy

In [None]:
@numba.njit
def _reorder_track(Tprevx, Tprevy, Ttempx, Ttempy, index, M):
    row, col = M.shape

    Tx = np.full_like(Tprevx, fill_value=-1)
    Ty = np.full_like(Tprevy, fill_value=-1)

    for x in range(col):
        for y in range(row):
            if M[y, x] == -np.inf:
                continue

            Tx[y, x, :index] = Tprevx[Ttempy[y, x], Ttempx[y, x], :index]
            Ty[y, x, :index] = Tprevy[Ttempy[y, x], Ttempx[y, x], :index]
            Tx[y, x, index] = x
            Ty[y, x, index] = y

    return Tx, Ty

In [None]:
lik = emission.pdf.fillna(0).pipe(np.log).data
lik[0, :, :] = (
    emission.initial.fillna(0)
    .pipe(lambda arr: (arr > 0.5).astype(float))
    .pipe(np.log)
    .data
)
ocean_mask = emission.mask.compute().data
index = np.argmax(lik[0, ...])
y0, x0 = np.unravel_index(index, lik.shape[1:])
land = np.logical_not(ocean_mask)
sigma = parameters["sigma"]

In [None]:
import hvplot.pandas
import hvplot.xarray

In [None]:
import matplotlib.pyplot as plt

In [None]:
kern = np.log(gaussian_kernel(sigma=np.array([sigma, sigma]), type="continuous"))

In [None]:
M = dask.compute(lik[0, ...])[0]
Tprevx = np.full(M.shape + (lik.shape[0],), fill_value=-1, dtype="int16")
Tprevy = np.full(M.shape + (lik.shape[0],), fill_value=-1, dtype="int16")
Tprevx[y0, x0, 0] = x0
Tprevy[y0, x0, 0] = y0

for index in track(range(1, lik.shape[0]), description="propagating..."):
    lik_ = dask.compute(lik[index, ...])[0]

    Mtemp, Ttempx, Ttempy = _propagate_timestep(M, kern, lik_, Tprevx, Tprevy)
    Mtemp[land] = -np.inf

    Tx, Ty = _reorder_track(Tprevx, Tprevy, Ttempx, Ttempy, index, Mtemp)

    M = Mtemp
    Tprevx = Tx
    Tprevy = Ty
M[land] = -np.inf

In [None]:
plt.imshow(M)
plt.colorbar()

In [None]:
print(np.sum(Tprevx != -1, axis=(0, 1)))

In [None]:
reshaped_x = Tprevx.reshape(-1, lik.shape[0])
reshaped_y = Tprevy.reshape(-1, lik.shape[0])
reshaped_M = M.reshape(-1)
sort_indices = np.argsort(reshaped_M)
sorted_M = reshaped_M[sort_indices]
sorted_M

Todo:
- build a function to ingest lat / lon datasets into `movingpandas`, possibly into `TrajectoryCollection`s
- compare the implementation here with the reference

In [None]:
import shapely

In [None]:
import movingpandas as mpd

In [None]:
def combine_as_points(ds, x, y):
    return xr.apply_ufunc(
        shapely.points, ds[x], ds[y], input_core_dims=[(), ()], output_core_dims=[()]
    )

In [None]:
x = xr.DataArray(reshaped_x[sort_indices, :][-10:, :], dims=["track_id", "time"])
y = xr.DataArray(reshaped_y[sort_indices, :][-10:, :], dims=["track_id", "time"])
selected = emission[["time", "longitude", "latitude"]].isel(x=y, y=x)
combined = selected.pipe(combine_as_points, "longitude", "latitude").drop_vars(
    ["longitude", "latitude", "cell_ids"]
)
df = combined.assign_coords(track_id=lambda ds: ds.track_id).to_dataframe(
    name="geometry"
)
df

In [None]:
import geopandas as gpd

In [None]:
coll = (
    df.reset_index()
    .set_index("time")
    .pipe(gpd.GeoDataFrame, crs="epsg:4326")
    .pipe(mpd.TrajectoryCollection, traj_id_col="track_id")
)

In [None]:
coll

In [None]:
coll.hvplot(tiles="CartoLight")

In [None]:
pos = np.argmax(M)
posy = pos // M.shape[1]
posy, posx = divmod(pos, M.shape[1])

In [None]:
y = Tprevy[posy, posx, :]
x = Tprevx[posy, posx, :]
y, x

In [None]:
import matplotlib.pyplot as plt

In [None]:
plt.imshow(M)

In [None]:
plt.imshow(kern)

In [None]:
plt.imshow(lik[0, ...])

In [None]:
plt.imshow(M)
plt.colorbar()

In [None]:
lik.shape

In [None]:
plt.imshow(lik[0, ...])

visualization

In [None]:
y_ = xr.DataArray(x, dims="time")
x_ = xr.DataArray(y, dims="time")

In [None]:
x_

In [None]:
traj = (
    emission[["time", "longitude", "latitude"]]
    .isel(x=x_, y=y_)
    .to_pandas()
    .pipe(mpd.Trajectory, traj_id="viterbi", x="longitude", y="latitude")
)
traj.hvplot(tiles="CartoLight")

In [None]:
estimator.decode(emission, mode="viterbi")

In [None]:
import cmocean
import geopandas as gpd
import holoviews as hv
import hvplot.xarray
import movingpandas as mpd
import xarray as xr

In [None]:
track_paths = [f"{tracks_root}/{mode}.parquet" for mode in track_modes]
all_tracks = {
    mode: mpd.Trajectory(
        gpd.read_parquet(f"{tracks_root}/{mode}.parquet"), traj_id=mode
    )
    for mode in track_modes
}

In [None]:
hv.Layout(
    [
        track.hvplot(c="speed", tiles="CartoLight", title=name, cmap="cmo.speed")
        for name, track in all_tracks.items()
    ]
).cols(2)