# Example: regrid a dataset to a healpix grid

Since the default geographic rectilinear grid has non-uniform grid sizes and distances and is thus not suitable for this application, we transform the dataset to one with equal distances between the cell centers. The Healpix cells satisfy these conditions, and can be used for spatial convolutions when reshaped to a 2D array according to the "nested" cell numbering.

As with any interpolation, the linear interpolation supported by the `healpy` library is split into the computation of weights and the application of those weights to the data.

In [None]:
import warnings

import fsspec  # noqa: F401
import hvplot.xarray  # noqa: F401
import intake
import xarray as xr  # noqa: F401
import zarr  # noqa: F401

warnings.filterwarnings("ignore")
from xhealpixify.grid import create_grid
from xhealpixify.regridder import HealpyRegridder

## Define the resolution for healpix

In [None]:
# Notebook specification
nside = 4096  # healpix resolution
# for 0.5km example use
# nside = nside*2
rot = {"lat": 0, "lon": 30}

## Read the data

Below, you can try out 1.5km (copernicus marine services) 2.5km (marc, ifremer) and 0.5km (marc, ifremer) data using intake/kerchunk hosted on Ifremer's HPC center (possible to access same data both from HPC and cloud) 

In [None]:
# Example using copernicus (NORTHWESTSHELF_ANALYSIS_FORECAST_PHY_004_01)
catalog = "https://data-taos.ifremer.fr/kerchunk/ref-copernicus.yaml"

cat = intake.open_catalog(catalog)
# ds = cat.data(type="TEM").to_dask().rename({"thetao": "TEMP"})[["TEMP"]]
# ds = cat.data(type="SSH").to_dask().zos.to_dataset(name='XE')
ds = (
    cat.data_tmp(type="mdt")
    .to_dask()
    .deptho.rename({"latitude": "lat", "longitude": "lon"})
    .to_dataset(name="H0")
)
ds
reference_model_ = ds
(broadcasted,) = xr.broadcast(
    reference_model_[["lat", "lon"]]
    .reset_index(["lat", "lon"])
    .rename_vars({"lat": "latitude", "lon": "longitude"})
    .reset_coords()
)
ds = reference_model_.merge(broadcasted).set_coords(["latitude", "longitude"]).compute()
ds

In [None]:
# Example using MARC multi resolution. (2.5km / 0.5km)

catalog = "https://data-taos.ifremer.fr/kerchunk/ref-marc.yaml"
cat = intake.open_catalog(catalog)["marc"]
# 2.5km without zoom(=agrif)
catalog_parameters: dict = {"region": "f1_e2500", "year": "2017"}
# 0.5km
catalog_parameters: dict = {
    "region": "f1_e2500_agrif/MARC_F1-MARS3D-FINIS",
    "year": "2017",
}
# 2.5km
catalog_parameters: dict = {
    "region": "f1_e2500_agrif/MARC_F1-MARS3D-MANGAE2500-AGRIF",
    "year": "2017",
}

catalog_kwargs = {
    "chunks": {"ni": -1, "nj": -1, "time": 1},
    "inline_array": True,
}
ds = (
    cat(**catalog_kwargs, **catalog_parameters)
    .to_dask()[
        [
            "H0",
            "XE",
        ]
    ]
    .assign_coords(time=lambda ds: ds.time.astype("datetime64[ns]"))
)
ds = ds[["H0"]]
ds

In [None]:
## Use interpolate before projection for coastal area....
max_gap = 2
limit = 1
method = "nearest"
ds = ds.interpolate_na(
    dim="lon", method=method, limit=limit, max_gap=max_gap
).interpolate_na(
    dim="lat", method=method, limit=limit, max_gap=max_gap
)  # , fill_value="extrapolate")
ds

## Define the target grid

In [None]:
grid = create_grid(nside=nside, rot=rot)
grid

## Compute the weights

In [None]:
regridder = HealpyRegridder(ds, grid)

## Apply the weights

In [None]:
regridded = regridder.regrid_ds(ds).compute()
regridded

In [None]:
regridded.H0.plot(x="longitude", y="latitude")

In [None]:
ds.H0.plot(x="longitude", y="latitude")

## Plotting
We first select the area for the plot. Here, around ouessant island.

In [None]:
subset: dict = {
    "lat_min": 48.35,
    "lat_max": 48.55,
    "lon_min": -5.25,
    "lon_max": -4.95,
    "depth_min": -100,
    "depth_max": 0,
}
cmap = "ocean"

Next, we plot side-by-side the original grid and the new grid.

_If too slow change coastline 10m to 110m!_

In [None]:
(-ds.H0).hvplot.quadmesh(
    x="longitude",
    y="latitude",
    geo=True,
    coastline="10m",
    xlim=(subset["lon_min"], subset["lon_max"]),
    ylim=(subset["lat_min"], subset["lat_max"]),
    clim=((subset["depth_min"], subset["depth_max"])),
    cmap=cmap,
    title="original grid",
    rasterize=True,
) + (-regridded.H0).hvplot.quadmesh(
    x="longitude",
    y="latitude",
    geo=True,
    coastline="10m",
    xlim=(subset["lon_min"], subset["lon_max"]),
    ylim=(subset["lat_min"], subset["lat_max"]),
    clim=((subset["depth_min"], subset["depth_max"])),
    cmap=cmap,
    title="healpix projected grid",
    rasterize=True,
)

## Compute and save to disk

In [None]:
regridded.to_zarr("./test.zarr", mode="w", consolidated=True, compute=True)

## Result Checking

In [None]:
regridded_ = xr.open_dataset("./test.zarr", engine="zarr", chunks={})
regridded_

In [None]:
regridded_["diff"].isel(time=0).plot(x="longitude", y="latitude")