In [None]:
import os

import numpy as np
import xarray as xr

# from scipy import interpolate

In [None]:
%%time
tag_id = os.environ.get("tag_id", "SV_A11981")
model = os.environ.get("model", "merged")
healpy = os.environ.get("healpy", "4")

basepath = os.environ.get("basepath", "../data_local/")
basepath = os.environ.get("basepath", "/home/datawork-lops-iaocea/work/fish/marc/")

input_filename = basepath + "diff/" + tag_id + "-" + model + ".zarr"

nside = 4096 * 2 * 2
if healpy == "1":
    nside = 4096
param_filename = "../data_local/healpixbase_nside_" + str(nside) + "_" + model + ".nc"

output_filename = (
    basepath + "diff_healpix/" + tag_id + "-" + model + "-" + str(nside) + ".zarr"
)


tag_id, model, healpy, basepath, input_filename, nside, output_filename, param_filename

In [None]:
ds = xr.open_dataset(input_filename, engine="zarr", chunks={})
ds

In [None]:
%%time
param = xr.open_dataset(param_filename)
nside2 = param.nside2.item()
p2 = param.p2.load()  # .data
w = param.w.load()  # .data

## parallelised using dask to compute each time step 

(later, i need to update to optimise the size of cluster with memory.
should change the memory requirements depending on the sizer of nside) 

In [None]:
import dask_hpcconfig

# cluster = dask_hpcconfig.cluster("datarmor-local")
from distributed import Client

overrides = {"cluster.processes": 7}
cluster = dask_hpcconfig.cluster("datarmor", **overrides)
cluster.scale(56)

# cluster = dask_hpcconfig.cluster("datarmor-local")
client = Client(cluster)

client

In [None]:
def regrid(data, pix, weights, nside):
    if nside == -1:
        nside = nside
    b = np.zeros([nside * nside])
    # b[:]=np.nan
    bh = np.zeros([nside * nside])
    # bh[:]=np.nan
    for iii in range(4):
        b = b + np.bincount(
            pix[iii, :],
            weights=weights[iii, :] * (data).flatten(),
            minlength=nside * nside,
        )
        bh = bh + np.bincount(
            pix[iii, :], weights=weights[iii, :], minlength=nside * nside
        )
    b[bh > 0] /= bh[bh > 0]
    b[bh == 0] = np.nan
    del bh
    res = b.reshape(nside, nside)
    del b
    return res

In [None]:
# for not using dask, here you do
# ds=ds.compute()
#
data = xr.apply_ufunc(
    regrid,
    ds.diff_,
    p2,
    w,
    nside2,
    input_core_dims=[["nj", "ni"], ["a", "b"], ["a", "b"], []],
    output_core_dims=[["x", "y"]],
    exclude_dims=set(("nj", "ni")),
    vectorize=True,
    dask="parallelized",
    dask_gufunc_kwargs={
        "output_sizes": {"x": nside2, "y": nside2},
    },
    output_dtypes=[ds.diff_.dtype],
)

In [None]:
ds_healpy = data.to_dataset(name="diff_")
ds_healpy = ds_healpy.assign({"H0": param.H0})
ds_healpy = (
    ds_healpy.assign_coords({"longitude": param.longitude})
    .assign_coords({"latitude": param.latitude})
    .assign_coords({"x": param.x})
    .assign_coords({"y": param.y})
    .assign_attrs({"tag_id": tag_id})
    .assign_attrs({"grid_size": param.attrs["grid_size"]})
    .chunk({"time": 1, "x": -1, "y": -1})
)

ds_healpy = ds_healpy.dropna(dim="x", how="all", subset=["H0"]).dropna(
    dim="y", how="all", subset=["H0"]
)


ds_healpy

In [None]:
%%time
ds_healpy.to_zarr(output_filename, mode="w")