In [None]:
import os

import cf_xarray
import numpy as np
import pint_xarray
import xarray as xr

In [None]:
import dask_hpcconfig

from distributed import Client

# overrides = {"cluster.processes": 7}
# cluster = dask_hpcconfig.cluster("datarmor", **overrides)
cluster = dask_hpcconfig.cluster("datarmor-local")


# cluster = dask_hpcconfig.cluster("datarmor")
# cluster.scale(12)

# cluster = dask_hpcconfig.cluster("datarmor-local")
client = Client(cluster)
client

## Load raw 'difference' data 

In [None]:
%%time

tag_id = os.environ.get("tag_id", "SV_A11981")
model = os.environ.get("model", "merged")
healpy = os.environ.get("healpy", "healpy")

basepath = os.environ.get("basepath", "../data_local/")
basepath = os.environ.get("basepath", "/home/datawork-lops-iaocea/work/fish/marc/")

engine = "zarr"

nside = 4096 * 2 * 2
if healpy == "1":
    nside = 4096

if healpy == "none":
    print("non")
    input_filename = basepath + "diff/" + tag_id + "-" + model + ".zarr"
    output_filename = basepath + "emission/" + tag_id + "-" + model + ".zarr"
    coord = ["ni", "nj"]
else:
    print("heal")
    input_filename = (
        basepath + "diff_healpix/" + tag_id + "-" + model + "-" + str(nside) + ".zarr"
    )
    output_filename = (
        basepath
        + "emission_healpix/"
        + tag_id
        + "-"
        + model
        + "-"
        + str(nside)
        + ".zarr"
    )
    coord = ["x", "y"]

tag_id, model, healpy, basepath, input_filename, nside, output_filename

In [None]:
# input_filename='/home/datawork-lops-iaocea/work/fish/marc/emission_healpix/A18832-f1_e2500-4096.zarr'
# input_filename='/home/datawork-lops-iaocea/work/fish/marc/emission_healpix/SV_A11930-merged.zarr'
# input_filename='/home/datawork-lops-iaocea/work/fish/marc/emission_healpix/SV_A11981-merged.zarr'


raw_pdf = xr.open_dataset(
    input_filename, engine=engine, chunks={coord[0]: -1, coord[1]: -1, "time": 1}
)
raw_pdf

## load tag data

In [None]:
tag_url = basepath + "tag_nc/" + tag_id + ".nc"
tag = xr.open_dataset(tag_url, engine="h5netcdf")  # .compute()

# Comuting emission_probability from Temperature difference map



In [None]:
def convert_pdf(samples):
    import numpy as np
    import scipy

    mean = np.zeros_like(samples)
    fill_value = 0.75
    std = np.full_like(samples, fill_value**2)
    # std= # change to std*std
    return scipy.stats.norm.pdf(samples, mean, std)


raw_pdf["normal_pdf"] = xr.apply_ufunc(
    convert_pdf,
    raw_pdf.diff_,  # .isel(time=slice(0,2))
    input_core_dims=[[coord[0], coord[1]]],
    output_core_dims=[[coord[0], coord[1]]],
    vectorize=True,
    dask="parallelized",
    #    dask_gufunc_kwargs={
    #        "output_sizes": {"x": nside2 * 4, "y": nside2 * 4},
    #    },
    output_dtypes=[raw_pdf.diff_.dtype],
)

## Compute Mask

In [None]:
ocean_mask = xr.where(np.isnan(raw_pdf.H0), False, True)

## Compute grid


In [None]:
grid = raw_pdf.cf[["latitude", "longitude"]].compute()
grid = grid.reset_index(coord)

## Compute initial and final probability

In [None]:
from pangeo_fish import distributions

In [None]:
# TODO apply np.nan to ground later to save time with dask worker only np.nan
#
initial_pos = tag[["longitude", "latitude", "times"]].sel(events="release")
cov = xr.DataArray(
    [[1e-6, 0], [0, 1e-6]], dims=["i", "j"], coords={"i": ["X", "Y"], "j": ["X", "Y"]}
)
initial_probability = distributions.normal_at(
    grid, pos=initial_pos, cov=cov, normalize=True, axes=["latitude", "longitude"]
)
# initial_probability

In [None]:
final_pos = tag[["longitude", "latitude", "times"]].sel(events="recapture")
cov = xr.DataArray(
    [[1e-4, 0], [0, 3e-4]],
    dims=["i", "j"],
    coords={"i": ["latitude", "longitude"], "j": ["latitude", "longitude"]},
)
final_probability = distributions.normal_at(
    grid, pos=final_pos, cov=cov, normalize=True, axes=["latitude", "longitude"]
)
# final_probability

## Combine data and save

In [None]:
raw_pdf = raw_pdf.assign(
    {
        "initial": initial_probability.normal_pdf.chunk({coord[0]: -1, coord[1]: -1}),
        "final": final_probability.normal_pdf.chunk({coord[0]: -1, coord[1]: -1}),
        "mask": ocean_mask,
    }
).assign_coords(grid)
raw_pdf = raw_pdf.drop_vars(["diff_", "H0"])

In [None]:
from pangeo_fish import pdf

emission_pdf = pdf.combine_emission_pdf(raw_pdf)
emission_pdf = emission_pdf.unify_chunks()

### below comput sigmamax

In [None]:
ureg = pint_xarray.unit_registry

timedelta = emission_pdf.time.isel(time=1) - emission_pdf.time.isel(time=0)
timestep = (
    timedelta.astype("m8[ns]")
    .assign_attrs({"units": "ns"})
    .astype("float")
    .pint.quantify()
    .pint.to("day")
    .item()
)

if raw_pdf.attrs.get("grid_size") == None:
    grid_spacing = (
        raw_pdf.latitude.diff(dim=coord[1]).max().compute().pint.quantify().item()
    )  # marc,
    grid_spacing = ureg.Quantity(40008.6 / 360 * grid_spacing, "km")
else:
    grid_spacing = raw_pdf.attrs.get("grid_size")
    grid_spacing = ureg.Quantity(grid_spacing, "km")
print(grid_spacing)
max_speed = ureg.Quantity(60, "km / day")
adjustment = 10  # arbitrary factor to avoid restricting the fish too much, the max_speed is a daily average
max_distance = max_speed * timestep * adjustment  # .to("km")
max_pixels = ((max_distance) / grid_spacing).to("dimensionless").m

With the distance in pixels above, we can define the maximum $\sigma$: $\sigma_{\mathrm{max}} = \frac{d_{\mathrm{max}}}{t}$ with $t$ the truncation factor for gaussian kernels.

In [None]:
truncate = 4.0
sigma_max = max_pixels / truncate
sigma_max

In [None]:
%%time
emission_pdf = emission_pdf.assign_attrs({"sigma_max": sigma_max})
#

In [None]:
%%time
emission_pdf = emission_pdf.persist()

In [None]:
%%time
emission_pdf.to_zarr(output_filename, mode="w")