# Save time series of spatially collapsed diagnostics

In [None]:
import warnings

warnings.filterwarnings("ignore")  # noqa

In [None]:
# Data analysis and viz libraries
import dask
import numpy as np
import xarray as xr
from dask.distributed import Client

# Progress bar
from tqdm.notebook import tqdm

In [None]:
# Local modules
import mypaths
import names
from calc import (
    altitude_of_cloud_mmr_maximum,
    cloud_path_total,
    dayside_mean,
    get_time_rel_days,
    global_mean,
    meridional_mean,
    nightside_mean,
    open_ocean_frac,
    sfc_temp,
    spatial_mean,
    terminator_mean,
)
from commons import MODELS
from load_thai import LOAD_CONF
from model_exocam import calc_alt_exocam, calc_pres_exocam
from model_lmdg import calc_alt_lmdg

Start a local `dask` cluster.

In [None]:
client = Client(processes=True, n_workers=4, threads_per_worker=1)
client

## Choose case

In [None]:
THAI_case = "Hab2"

In [None]:
if THAI_case.endswith("1"):
    import const_ben1_hab1 as const
else:
    import const_ben2_hab2 as const

KW_CONST = dict(
    mw_ratio=const.mw_ratio,
    dry_air_gas_constant=const.rgas,
    condens_gas_constant=const.rvapor,
    gravity=const.gplanet,
)

In [None]:
AGGR_DICT = dict(g=global_mean, d=dayside_mean, n=nightside_mean, t=terminator_mean)

In [None]:
DIAGS = {
    "t_sfc": {
        "func": sfc_temp,
    },
    "ocean_frac": {
        "func": open_ocean_frac,
    },
    "cwp": {
        "func": cloud_path_total,
    },
    "alt_cld_mmr_max": {
        "func": altitude_of_cloud_mmr_maximum,
    },
}

In [None]:
for model_key in tqdm(MODELS.keys()):
    model_names = getattr(names, model_key.lower())
    with LOAD_CONF[model_key]["loader"](THAI_case) as ds:
        if model_key == "ExoCAM":
            ds[model_names.pres] = calc_pres_exocam(ds)
            ds["z"] = calc_alt_exocam(ds, case=THAI_case, **KW_CONST)
        elif model_key == "LMDG":
            ds["level_height"] = calc_alt_lmdg(ds, case=THAI_case, **KW_CONST)

        ds_out = {}
        for diag_key, diag_dict in tqdm(DIAGS.items(), leave=False):
            if diag_key == "t_sfc":
                args = (ds, model_key, const)
            else:
                args = (ds, model_key)
            _arr = diag_dict["func"](*args)
            for aggr_key, aggr_func in tqdm(AGGR_DICT.items(), leave=False):
                ds_out[f"{diag_key}_{aggr_key}"] = (
                    aggr_func(_arr, model_names).sortby(model_names.t).compute()
                )

        xr.Dataset(ds_out).to_netcdf(
            mypaths.datadir / model_key / f"{THAI_case}_time_series_{model_key}.nc"
        )

In [None]:
client.close()