# emission probability from acoustic ranges

parametrize with [papermill](https://papermill.readthedocs.io/en/latest/)

In [None]:
scheduler_address: str | None = None

tag_db_path: str
detections_path: str

receiver_buffer: float = 1000.0  # in [m]

emission_path: str
output_path: str

In [None]:
import pathlib

# root directory, only needed for the defaults in the next cell
root = pathlib.Path.home() / "work/data/fish-intel"

tag_db_path: str = (
    f"{root}/acoustic/FishIntel_tagging_France.csv"  # path to the tag database
)
detections_path: str = (
    f"{root}/acoustic/detections_recaptured_fishintel.csv"  # path to the detections
)
emission_path: str = (
    f"{root}/emission/A18832-f1_e2500-hp4096.zarr"  # path to the grid file
)

output_path: str = f"{root}/emission/A18832-f1_e2500-hp4096-acoustic.zarr"  # path to write the new emission matrix

create dask cluster

In [None]:
from distributed import Client, LocalCluster

if scheduler_address is None:
    cluster = LocalCluster()
    client = cluster.get_client()
else:
    client = Client(scheduler_address)
client

imports

In [None]:
import io
import pathlib

import cf_xarray
import dask
import flox.xarray
import numpy as np
import pandas as pd
import xarray as xr

from pangeo_fish import utils
from pangeo_fish.acoustic import (
    count_detections,
    extract_receivers,
    search_acoustic_tag_id,
)
from pangeo_fish.healpy import (
    astronomic_to_cartesian,
    astronomic_to_cell_ids,
    buffer_points,
    geographic_to_astronomic,
)

In [None]:
from pangeo_fish.distributions import normal_at

In [None]:
import cmocean
import holoviews as hv
import hvplot.pandas
import hvplot.xarray
import matplotlib

if not hasattr(matplotlib.cm, "_cmap_registry"):
    matplotlib.cm._cmap_registry = matplotlib.cm._colormaps

hv.output(widget_location="top")

tag database

In [None]:
tag_database = pd.read_csv(tag_db_path, sep=";")
tag_database.head(2)

detections

In [None]:
# work around the weird quoting
with open(detections_path, mode="r") as f:
    lines = (line.replace('"', "") for line in f)
    data = "\n".join(lines)
content = io.StringIO(data)

detection_database = (
    pd.read_csv(content, parse_dates=[1])
    .rename(columns={"date_time": "time"})
    .set_index("time")
)
detection_database.head(2)

base grid

In [None]:
ds = xr.open_dataset(emission_path, engine="zarr", chunks={"x": -1, "y": -1})
ds

extract receiver locations

In [None]:
receivers = extract_receivers(detection_database).to_xarray()
receivers

In [None]:
acoustic_tag_id = search_acoustic_tag_id(tag_database, ds.attrs["tag_id"])
acoustic_tag_id

In [None]:
detections = (
    detection_database[["receiver_id", "acoustic_tag_id"]]
    .reset_index()
    .set_index("acoustic_tag_id")
    .loc[acoustic_tag_id]
    .set_index("time")
    .to_xarray()
)
detections

count detections

In [None]:
time_intervals = (
    ds[["time"]]
    .cf.add_bounds(keys="time")["time_bounds"]
    .pipe(cf_xarray.bounds_to_vertices, bounds_dim="bounds")
    .pipe(pd.IntervalIndex.from_breaks)
)
time_intervals

In [None]:
weights = (
    count_detections(detections, by=time_intervals)
    .swap_dims({"time_bins": "time"})
    .assign_coords(time=ds.time)
    .pipe(lambda ds: ds.merge(receivers.sel(receiver_id=ds["receiver_id"])))
    .pipe(utils.normalize, dim="receiver_id")
    .fillna(0)
    .rename_vars({"count": "weights"})["weights"]
)
weights

#### `query_disc`

In [None]:
rot = {k.removeprefix("rot_"): v for k, v in ds.attrs.items() if k.startswith("rot_")}
phi, theta = geographic_to_astronomic(
    lon=receivers.deploy_longitude, lat=receivers.deploy_latitude, rot=rot
)
cartesian_positions = astronomic_to_cartesian(theta=theta, phi=phi, dim="receiver_id")
cartesian_positions

In [None]:
phi, theta = geographic_to_astronomic(lat=ds.latitude, lon=ds.longitude, rot=rot)
cell_ids = astronomic_to_cell_ids(nside=ds.attrs["nside"], phi=phi, theta=theta)
cell_ids

In [None]:
ds.cell_ids.compute()

In [None]:
masks = buffer_points(
    cell_ids,
    cartesian_positions,
    nside=ds.attrs["nside"],
    buffer_size=receiver_buffer,
    factor=2**16,
    intersect=True,
)
masks

In [None]:
combined_mask = masks.sum(dim="receiver_id").astype(bool)
combined_mask

In [None]:
grid = ds.cf[["latitude", "longitude"]]
grid

#### apply weights

In [None]:
reindexed = weights.reindex(time=ds.time, fill_value=0).chunk({"time": 1})
reindexed

In [None]:
fill_values = reindexed.sum(dim="receiver_id").pipe(lambda ds: 1 - ds)
fill_values

In [None]:
acoustic_pdfs = (
    (reindexed * masks.astype(float))
    .sum(dim="receiver_id")
    .where(combined_mask, fill_values)
    .chunk()
)
acoustic_pdfs

In [None]:
combined = ds.assign(acoustic=acoustic_pdfs)
combined

write to disk

In [None]:
combined.drop_vars(["time_bins"]).to_zarr(output_path, mode="w", consolidated=True)

plotting