In [2]:
%load_ext autoreload
%autoreload 2
%config IPCompleter.greedy=True

: 

In [3]:
import os
import xarray as xr
from sithom.time import timeit
import matplotlib.pyplot as plt
import xesmf as xe
from tcpips.constants import CMIP6_PATH
from sithom.plot import plot_defaults
# from tcpips.conver
from tcpips.constants import CONVERSION_NAMES
from tcpips.constants import (
    DATA_PATH,
    CMIP6_PATH,
    RAW_PATH,
    REGRIDDED_PATH,
    BIAS_CORRECTED_PATH,
    FIGURE_PATH,
)

In [5]:
MODELS = ["CESM2"]

In [None]:
def regrid_ocean(in_path, out_path):
    print(in_path), print(out_path)
    ds = xr.open_dataset(in_path)
    print(ds)


for model in MODELS:
    model_path = os.path.join(RAW_PATH, "ssp585", "ocean", model)
    print(model_path)
    for file in os.listdir(model_path):
        print(file)
        regrid_ocean(
            os.path.join(model_path, file),
            os.path.join(REGRIDDED_PATH, "ssp585", "ocean", model, file),
        )

In [None]:
from tcpips.regrid import regrid_cmip6_part; regrid_cmip6_part("ssp585", "atmos", "CESM2", "r4i1p1f1")

In [7]:
@timeit
def regrid_1d(xesmf: bool = False) -> None:
    """
    Regrid 1d data.

    Args:
        xesmf (bool, optional): Defaults to False.
    """

    def open_1d(path: str):
        ds = xr.open_dataset(path)
        # plt.imshow(ds.lat.values)
        #
        # plt.imshow(ds.lon.values)
        #
        # print("ds", name, ds)
        ds = ds.drop_vars(
            [
                x
                for x in [
                    "lon",
                    "lat",
                    "lat_verticies",
                    "lon_verticies",
                    "lon_bounds",
                    "time_bounds",
                    "lat_bounds",
                    "CMIP6_PATH",
                    "member_id",
                ]
                if x in ds
            ]
        ).isel(y=slice(1, -1))
        if xesmf:
            ds = ds.assign_coords({"lon": ds["x"], "lat": ds["y"]})
            return ds.drop_vars(["x", "y"])
        else:
            return ds.rename({"x": "lon", "y": "lat"})

    ocean_ds = open_1d(os.path.join(CMIP6_PATH, "ocean.nc"))
    print("ocean_ds", ocean_ds)
    ocean_ds.isel(time=0).tos.plot(x="lon", y="lat")

    new_coords = atmos_ds[["lon", "lat"]]
    print("new_coords", new_coords)
    plt.plot(new_coords.lat.values, label="new")
    plt.plot(ocean_ds.lat.values, label="ocean")
    plt.legend()
    plt.title("lat")

    plt.plot(new_coords.lon.values, label="new")
    plt.plot(ocean_ds.lon.values, label="ocean")
    plt.legend()
    plt.title("lon")

    if xesmf:
        regridder = xe.Regridder(ocean_ds, new_coords, "nearest_s2d", periodic=True)
        print(regridder)
        ocean_out = regridder(
            ocean_ds,  # .drop_vars(["x", "y"]).set_coords(["lon", "lat"]),
            keep_attrs=True,
            skipna=True,
        )
    else:
        ocean_out = ocean_ds.interp(
            {"lon": new_coords.lon.values, "lat": new_coords.lat.values},
            method="nearest",
        )
    print("ocean_out", ocean_out)
    ocean_out.to_netcdf(os.path.join(CMIP6_PATH, "regrid1d_ocean_regridded.nc"))
    ocean_out.tos.isel(time=0).plot(x="lon", y="lat")
    plt.savefig(os.path.join(FIGURE_PATH, "ocean_regridded_regrid1d.png"))

In [2]:
from sithom.time import timeit
from sithom.plot import plot_defaults
from sithom.misc import in_notebook
from tcpips.constants import CONVERSION_NAMES, RAW_PATH, CMIP6_PATH, REGRIDDED_PATH
import xarray as xr
import xesmf as xe
import os


@timeit
def regrid_any(
    output_res: float = 1.0,
    time_chunk: int = 10,
    exp: str = "ssp585",
    typ: str = "ocean",
    model: str = "CESM2",
    member: str = "r4i1p1f1",
) -> None:
    """
    Regrid 2d data to 1 degree resolution.

    Args:
        output_res (float, optional): Resolution of the output grid. Defaults to 1.0.
        time_chunk (int, optional): Chunk size for time. Defaults to 10.
    """
    plot_defaults()

    def open_ds(path: str) -> xr.Dataset:
        """
        Open dataset.

        Args:
            path (str): path to the dataset.

        Returns:
            xr.Dataset: xarray dataset.
        """
        nonlocal time_chunk
        # open netcdf4 file using dask backend
        ds = xr.open_dataset(path, chunks={"time": time_chunk})
        ds = ds.drop_vars(
            [
                x
                for x in [
                    "x",
                    "y",
                    "dcpp_init_year",
                    "member_id",
                ]
                if x in ds
            ]
        )
        return ds

    in_ds = open_ds(os.path.join(RAW_PATH, exp, typ, model, member) + ".nc").isel(
        time=slice(0, 10)
    )
    # atmos_ds = open_ds(os.path.join(RAW_PATH, "ssp585", "atmos", "CESM2", 'r4i1p1f1.nc'))

    new_coords = xe.util.grid_global(
        output_res, output_res
    )  # make regular lat/lon grid

    def regrid_and_save(input_ds: xr.Dataset, output_name: str) -> xr.Dataset:
        """
        Regrid and save the input dataset to the output.

        Args:
            input_ds (xr.Dataset): dataset to regrid.
            output_name (str): of the output file.

        Returns:
            xr.Dataset: regridded dataset.
        """
        regridder = xe.Regridder(
            input_ds, new_coords, "bilinear", periodic=True, ignore_degenerate=True
        )
        print(regridder)
        out_ds = regridder(
            input_ds,
            keep_attrs=True,
            skipna=True,
            # ignore_degenerate=True,
        )
        delayed_obj = out_ds.to_netcdf(
            os.path.join(CMIP6_PATH, output_name),
            format="NETCDF4",
            engine="h5netcdf",  # should be better at parallel writing/dask
            chunks={"time": time_chunk},
            encoding={
                var: {"dtyp": "float32", "zlib": True, "complevel": 6}
                for var in CONVERSION_NAMES.keys()
                if var in out_ds
            },
            compute=False,
        )
        with ProgressBar():
            results = delayed_obj.compute()
        return out_ds  # return for later plotting.

    folder = os.path.join(REGRIDDED_PATH, exp, typ, model)
    os.makedirs(folder, exist_ok=True)
    out_ds = regrid_and_save(in_ds, os.path.join(folder, member) + ".nc")
    print("out_ds", out_ds)
    if typ == "ocean" and in_notebook():
        out_ds.tos.isel(time=0).plot(x="lon", y="lat")
        plt.show()
        out_ds.tos.isel(time=0).plot()
        plt.show()
    elif typ == "atmos" and in_notebook():
        out_ds.tas.isel(time=0, p=0).plot(x="lon", y="lat")
        plt.show()

In [None]:
os.listdir(os.path.join(RAW_PATH, "ssp585", "atmos", "CESM2"))

In [None]:
regrid_any(output_res=0.5, time_chunk=1)

In [1]:
from tcpips.regrid import regrid_any

In [2]:
from tcpips.regrid import define_tasks

In [None]:
define_tasks()

In [None]:
%load_ext autoreload
%autoreload 2
%config IPCompleter.greedy=True

In [None]:
from tcpips.regrid import run_tasks

run_tasks()

In [None]:
from tcpips.regrid import run_tasks

run_tasks(force_regrid=True, output_res=0.5)