# **This notebook is meant to compute the difference between the temperature mesured by the fish and the temperature from mars model**
___
### **Summary:**
> **I:** Opening the data from mars model and data from fish tags.   
> **II:** Conversion of from sigma level to depth.   
> **III:** Resampling the tag? data to match the time of the model and the time of observation.   
> **IV:** Definition and application of the ufunc with xr.applyufunc().  
> **V:** Running of the computation with dask.   
> **VI:** Saving the diff dataset to a netcdf file

In [None]:
import os

import dask
import intake
import numpy as np
import pandas as pd
import xarray as xr

___
## **I: Opening the data from mars model and data from fish tags.**

In [None]:
import os

tag_id = os.environ.get("tag_id", "SV_A11981")
model = os.environ.get("model", "f1_e2500")
basepath = os.environ.get("basepath", "../data_local/")
basepath = os.environ.get("basepath", "/home/datawork-lops-iaocea/work/fish/marc/")
year = os.environ.get("year", "2022")

catalogue = "https://data-taos.ifremer.fr/kerchunk/ref-marc.yaml"
catalogue = "/home/datawork-taos-s/intranet/kerchunk/ref-marc.yaml"


outzarr = True
if outzarr:
    output_filename = basepath + "diff/" + tag_id + "-" + model + ".zarr"
else:
    output_filename = basepath + "diff/" + tag_id + "-" + model + ".nc"

In [None]:
tag_id, model, basepath, year, catalogue, output_filename

In [None]:
tag_url = basepath + "tag_nc/" + tag_id + ".nc"
fish = xr.open_dataset(tag_url, engine="h5netcdf")
fish  # .compute()

In [None]:
if model == "rang0":
    catalogue = "/home/datawork-lops-iaocea/catalog/intake/agrif_archive.yaml"
    cat = intake.open_catalog(catalogue)["agrif_archive"]
    region = "rejeu_agrif_2016"
elif model == "f1_e2500":
    cat = intake.open_catalog(catalogue)["marc"]
    region = model
else:
    cat = intake.open_catalog(catalogue)["marc"]
    region = "f1_e2500_agrif/MARC_F1-MARS3D-" + str.upper(model)
ds = (cat(region=region, year=year).to_dask())[
    ["H0", "level", "XE", "theta", "b", "hc", "TEMP"]
]  # .chunk(
# chunks={"ni": -1, "nj": -1, "time": 1, "level": -1})
ds

## Set up dask enviroment


In [None]:
import dask_hpcconfig

# cluster = dask_hpcconfig.cluster("datarmor-local")
# from distributed import Client
from dask.distributed import Client, LocalCluster


if model == "rang0" or model == "f1_e2500":
    overrides = {"cluster.cores": 7}
    cluster = dask_hpcconfig.cluster("datarmor", **overrides)
    cluster.scale(49)
else:
    cluster = dask_hpcconfig.cluster("datarmor-local")

client = Client(cluster)
client

In [None]:
### Selecting the right time span to align the datas in time.


if "times" in fish.coords:
    fish_time_span = slice(
        fish.times.data[0], fish.times.data[1]
    )  # Reducing the data to the right time span

    fish = fish.sel(time=fish_time_span)
    model_time_span = slice(
        fish.times.data[0] - np.timedelta64(30, "m"),
        fish.times.data[1] + np.timedelta64(30, "m"),
    )
else:
    model_time_span = slice(
        fish.time[0] - np.timedelta64(30, "m"),
        fish.time[-1] + np.timedelta64(30, "m"),
    )

In [None]:
### Chunking the data

ds = (
    ds.sel(time=model_time_span)
    .chunk(chunks={"ni": -1, "nj": -1, "time": 1, "level": -1})
    .unify_chunks()
)  # .persist()

### The datas are opened but in order to make the coordinates match, we need to convert the coordinate "level" to depth with a particular formula.

___

## **II: Conversion from sigma level to depth.**   


In [None]:
def compute_depth(marc_data):
    ####TODO: Find why the dims are not in the same order for TEMP and z-XE, find which order is the best for optimize in memory access
    ####TODO:  transpose is slowing down the compute, verify that rechunking should be 'before' or 'after' the transpose to make the computation faster

    s = marc_data.level
    eta = marc_data.XE
    depth = marc_data.H0
    a = marc_data.theta
    b = marc_data.b
    depth_c = marc_data.hc

    C = (1.0 - b) * np.sinh(a * s) / np.sinh(a) + b * (
        np.tanh(a * (s + 0.5)) - np.tanh(0.5 * a)
    ) / (2.0 * np.tanh(0.5 * a))

    marc_data["C"] = C

    marc_data["z"] = (eta * (1.0 + s) + depth_c * s + (depth - depth_c) * C).astype(
        "float32"
    )

    marc_data["depth"] = (
        marc_data.z - marc_data.XE
    )  # .transpose("time", "level","nj", "ni")
    marc_data["bottom"] = marc_data.XE + marc_data.H0
    marc_data["TEMP"] = marc_data["TEMP"]  # .transpose("time", "nj", "ni", "level")
    return marc_data[["TEMP", "depth", "bottom", "H0"]]

In [None]:
### Applying the computation,
data_model = compute_depth(ds)

___
## **III: Resampling the data to match the time of the model and the time of observation.** 

### Now that the datas from model and fish are loaded and that the sigma level has been converted to a depth, we need to operate on fish data to sort them well and create the right dataset (temp(time,obs), depth(time,obs))

In [None]:
%%time

### Saving the data from the model.

model_time = data_model.time.data

### creating the bins, time groups of an hour from x:30:00 to x+1:30:00

time_bins = np.append(
    model_time - np.timedelta64(30, "m"), model_time[-1] + np.timedelta64(30, "m")
)
######################################################################

### Using of groupby_bins to get the indexes inside each time bins

time_groups = list(fish.groupby_bins(group="time", bins=time_bins).groups.values())

######################################################################
### Reducing fish data to water temperature and pressure

fish = fish[["water_temperature", "pressure"]]

######################################################################
### Creating arrays of values for temperature and depth per time group

fish_temp = [fish.water_temperature.isel(time=t).data for t in time_groups]

fish_depth = [fish.pressure.isel(time=t).data for t in time_groups]

### filling the edges of depth and temp with nans

if 40 - len(fish_temp[0]) != 0:
    print("filling left edge with nans")
    nan_list1 = [np.full(shape=40 - len(fish_temp[0]), fill_value=np.nan)]
if 40 - len(fish_temp[-1]):
    print("filling right edge with nans")
    nan_list2 = [np.full(shape=40 - len(fish_temp[-1]), fill_value=np.nan)]

fish_temp[0] = np.append(nan_list1, fish_temp[0])
fish_depth[0] = np.append(nan_list1, fish_depth[0])

fish_temp[-1] = np.append(fish_temp[-1], nan_list2)
fish_depth[-1] = np.append(fish_depth[-1], nan_list2)

In [None]:
### Changing the dtype to nanosecond otherwise we cant save the data
time = model_time.astype("datetime64[ns]")

In [None]:
### Creating the dataset of data for each time group

data_fish = xr.Dataset(
    data_vars=dict(
        temp=(["time", "obs"], fish_temp), depth=(["time", "obs"], fish_depth)
    ),
    coords=dict(time=time, obs=np.arange(0, 40)),
)

In [None]:
### Changigng the data model values to nanoseconds too

data_model = data_model.assign_coords(time=time)

___
## **IV: Definition and application of the ufunc with xr.applyufunc().**

In [None]:
def marc_pdf_z(model_temp, model_depth, bottom, fish_temp, fish_depth):
    diff_temp = []

    if bottom - fish_depth.max() * 0.90 < 0:
        return np.nan

    for f_i, f_depth in enumerate(
        fish_depth
    ):  # Looping over the depth to find the datas
        if not np.isnan(f_depth):
            diff_depth = np.absolute(np.absolute(model_depth) - f_depth)

            idx = diff_depth.argmin()

            diff_temp.append(np.absolute(fish_temp[f_i] - model_temp[idx]))

    return np.mean(diff_temp)

In [None]:
data_fish = data_fish.chunk(chunks={"time": 1}).unify_chunks().persist()

___
## **V: Running of the computation with dask.**

In [None]:
diff = xr.apply_ufunc(
    marc_pdf_z,
    data_model.TEMP,  # .chunk(dict(level=-1)),
    data_model.depth,  # .chunk(dict(level=-1)),
    data_model.bottom,
    data_fish.temp,
    data_fish.depth,
    input_core_dims=[["level"], ["level"], [], ["obs"], ["obs"]],
    exclude_dims=set(("level", "obs")),
    vectorize=True,
    dask="parallelized",
    output_dtypes=[data_model.TEMP.dtype],
)

___
## **VI: Saving the diff dataset to a netcdf file**

In [None]:
%%time
diff = (
    diff.to_dataset(name="diff_")
    .assign_attrs({"tag": "SV_A11981"})
    .assign({"H0": data_model.H0})
    .unify_chunks()
    .persist()
)

In [None]:
def optimize_dataset(ds):
    import dask

    for varname, da in ds.data_vars.items():
        # print(varname)
        da = da.data
        (da,) = dask.optimize(da)
        ds[varname].data = da
    return ds


diff = optimize_dataset(diff)
diff

In [None]:
%%time
diff.to_zarr(output_filename, mode="w")