# **This notebook is meant to upscale sptial resolution for the agrif datas with the child regions**

In [None]:
import xarray as xr
import numpy as np
import os

In [None]:
def common_area(ds1, ds2):
    return {
        "lon": np.intersect1d(ds1.lon.data, ds2.lon.data),
        "lat": np.intersect1d(ds1.lat.data, ds2.lat.data),
    }


def replace_diff(rang0, rang1):
    rang1_ = remove_nan(rang1, "H0")
    coords = common_area(rang0, rang1_)
    # print('coords',coords)
    rang0["H0"].loc[coords].data = rang1["H0"].loc[coords].data
    rang0["diff_"].loc[coords].data = rang1["diff_"].loc[coords].data
    return rang0


def mean_areas(ds1, ds2):
    coords = common_area(ds1, ds2)
    return (
        ds1.sel(lon=coords["lon"], lat=coords["lat"])
        + ds2.sel(lon=coords["lon"], lat=coords["lat"])
    ) / 2


def remove_nan(ds, var):
    ds = ds[var].dropna(dim="lat", how="all").dropna(dim="lon", how="all")
    return ds


def model_path(model, engine="zarr"):
    filename = basepath + "diff/" + tag_id + "-" + model
    if engine == "netcdf4":
        filename = filename + ".nc"
    else:
        filename = filename + ".zarr"
    # print(filename)
    return filename


def load_model(model, engine="zarr"):
    # print(model,engine)
    return (
        xr.open_dataset(model_path(model, engine), engine=engine, chunks={})
        .assign_coords(
            lat=lambda ds: ds.latitude.isel(ni=0, drop=True)
            .compute()
            .astype("float32"),
            lon=lambda ds: ds.longitude.isel(nj=0, drop=True)
            .compute()
            .astype("float32"),
        )
        .swap_dims({x: "lon", y: "lat"})
    )  # .pipe(to_float32)

In [None]:
import dask_hpcconfig

# cluster = dask_hpcconfig.cluster("datarmor-local")
from distributed import Client

cluster = dask_hpcconfig.cluster("datarmor-local")
client = Client(cluster)
client

___
## **Opening all the files**

In [None]:
tag_id = os.environ.get("tag_id", "SV_A11930")
method = os.environ.get("method", "linear")
basepath = os.environ.get("basepath", "/home/datawork-lops-iaocea/work/fish/marc/")

x = "ni"
y = "nj"
engine = "zarr"

# engine='netcdf4'
# basepath=os.environ.get('basepath','../data_local/')


output_filename = model_path("merged")

rang1s = ["pdc", "seine", "armor", "finis", "loire", "gironde", "adour"]
rang1 = {model: load_model(model, engine=engine) for model in rang1s}
output_filename

___
### **Upscaling the rang0 data**

In [None]:
model = "rang0"
rang0 = xr.open_dataset(model_path(model, engine), engine=engine, chunks={})
rang0

In [None]:
rang1_ni = np.arange(min(rang0.ni.data), max(rang0.ni.data), 0.2)
rang1_nj = np.arange(min(rang0.nj.data), max(rang0.nj.data), 0.2)

rang0_upscale = rang0.interp(nj=rang1_nj, ni=rang1_ni, method=method)
rang0_upscale["latitude"] = rang0.latitude.interp(
    nj=rang1_nj, ni=rang1_ni, method="linear"
)
rang0_upscale["longitude"] = rang0.longitude.interp(
    nj=rang1_nj, ni=rang1_ni, method="linear"
)
rang0_upscale = rang0_upscale.assign_coords(
    lat=lambda ds: ds.latitude.isel(ni=0, drop=True).compute().astype("float32"),
    lon=lambda ds: ds.longitude.isel(nj=0, drop=True).compute().astype("float32"),
).swap_dims({x: "lon", y: "lat"})
rang0_upscale

# replacing

In [None]:
for replace in rang1s:
    print(replace)
    rang0_upscale = replace_diff(rang0_upscale, rang1[replace])

In [None]:
for i, name in enumerate(rang1s[0:6]):
    print(rang1s[i], rang1s[i + 1])
    mean = mean_areas(rang1[rang1s[i]], rang1[rang1s[i + 1]])
    rang0_upscale = replace_diff(rang0_upscale, mean)

___
### **Removing the lat and lon dimensions introduce at the begining**

In [None]:
rang0_upscale = (
    rang0_upscale.swap_dims({"lon": "ni", "lat": "nj"})
    .drop(("lat", "lon"))
    .chunk({"time": 1})
)
rang0_upscale

In [None]:
%%time
rang0_upscale.to_zarr(output_filename, mode="w")