# Save time-averaged fields

In [None]:
import warnings

warnings.filterwarnings("ignore")  # noqa

In [None]:
# Standard library
import multiprocessing.popen_spawn_posix

# Data analysis and viz libraries
import dask
import numpy as np
import xarray as xr
from dask.distributed import Client

# Progress bar
from tqdm.notebook import tqdm

In [None]:
# Local modules
import mypaths
import names
from calc import time_mean, time_std
from load_thai import LOAD_CONF
from model_exocam import calc_alt_exocam, calc_pres_exocam
from model_lmdg import calc_alt_lmdg

Start a local `dask` cluster.

In [None]:
client = Client(processes=True, n_workers=4, threads_per_worker=1)
client

## Choose case

In [None]:
THAI_case = "Hab1"

In [None]:
if THAI_case.endswith("1"):
    import const_ben1_hab1 as const
else:
    import const_ben2_hab2 as const

KW_CONST = dict(
    mw_ratio=const.mw_ratio,
    dry_air_gas_constant=const.rgas,
    condens_gas_constant=const.rvapor,
    gravity=const.gplanet,
)

In [None]:
for model_key in tqdm(LOAD_CONF.keys()):
    model_names = getattr(names, model_key.lower())
    with LOAD_CONF[model_key]["loader"](THAI_case) as ds:
        # if model_key in ["ExoCAM", "LMDG", "ROCKE3D"]:
        #     # Regrid ExoCAM and ROCKE3D data to be compatible with `windspharm`:
        #     # if latitudes are equally-spaced and even-numbered, they should not include poles.
        #     nlat = 50  # new number of latitudes: 50
        #     delta_lat = 180 / nlat
        #     new_lats = np.linspace(90 - 0.5 * delta_lat, -90 + 0.5 * delta_lat, nlat)
        # else:
        #     new_lats = None

        if model_key == "ExoCAM":
            ds[model_names.pres] = calc_pres_exocam(ds)
            ds["z"] = calc_alt_exocam(ds, case=THAI_case, **KW_CONST)
        elif model_key == "LMDG":
            ds["level_height"] = calc_alt_lmdg(ds, case=THAI_case, **KW_CONST)

        ds_mean = {}
        ds_std = {}
        for d in ds.data_vars:
            vrbl = ds[d]
            # if (model_names.y in vrbl.dims) and (new_lats is not None):
            #     vrbl = vrbl.isel(**{model_names.y: slice(1, -1)}).interp(
            #         **{
            #             model_names.y: new_lats,
            #             "kwargs": {"fill_value": "extrapolate"},
            #         },
            #     )
            if model_names.t in vrbl.dims:
                try:
                    vrbl_mean = time_mean(vrbl, model_names.t)
                    vrbl_std = time_std(vrbl, model_names.t)
                except TypeError:
                    # print(model_key, d)
                    pass
            else:
                vrbl_mean = vrbl
                vrbl_std = vrbl
            ds_mean[d] = vrbl_mean
            ds_std[d] = vrbl_std
        xr.Dataset(ds_mean).to_netcdf(
            mypaths.datadir / model_key / f"{THAI_case}_time_mean_{model_key}.nc"
        )
        xr.Dataset(ds_std).to_netcdf(
            mypaths.datadir / model_key / f"{THAI_case}_time_std_{model_key}.nc"
        )

In [None]:
client.close()

In [None]:
# check if the pole points have to be excluded before interpolation
# model_key = "ExoCAM"
# with LOAD_CONF[model_key]["loader"](THAI_case) as ds:
#     model_names = getattr(names, model_key.lower())
#     nlat = 50  # new number of latitudes: 50
#     delta_lat = 180 / nlat
#     new_lats = np.linspace(90 - 0.5 * delta_lat, -90 + 0.5 * delta_lat, nlat)
#     a = ds[model_names.temp][-1, ...].interp(
#         **{
#             model_names.y: new_lats,
#             "kwargs": {"fill_value": "extrapolate"},
#         },
#     )

#     b = ds[model_names.temp][-1, ...].isel(**{model_names.y: slice(1, -1)}).interp(
#         **{
#             model_names.y: new_lats,
#             "kwargs": {"fill_value": "extrapolate"},
#         },
#     )
#     a[15, :, 10].plot()
#     b[15, :, 10].plot()

# Conclusion:
# it seems better to exclude the points at the poles, especially for LMDG and ROCKE3D...