# Time series of key diagnostics of the THAI simulations


In [None]:
import warnings

warnings.filterwarnings("ignore")  # noqa

In [None]:
# Standard library
import multiprocessing.popen_spawn_posix

# Data analysis and viz libraries
import dask
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr
from dask.distributed import Client

# Progress bar
from tqdm.notebook import tqdm

In [None]:
import aeolus.plot as aplt

In [None]:
# Local modules
import mypaths
import names
from calc import (
    dayside_mean,
    get_time_rel_days,
    global_mean,
    meridional_mean,
    nightside_mean,
    spatial_mean,
    terminator_mean,
    time_mean,
    zonal_mean,
)
from commons import MODELS
from load_thai import LOAD_CONF
from model_exocam import calc_alt_exocam
from model_lmdg import calc_alt_lmdg, calc_virtual_temp_lmdg
from plot_func import KW_AUX_TTL, KW_MAIN_TTL, KW_SBPLT_LABEL, figsave

In [None]:
plt.style.use("paper.mplstyle")

Start a local `dask` cluster.

In [None]:
# client = Client(processes=True, n_workers=4, threads_per_worker=1)
# client

In [None]:
# SIMULATION LENGTH [Earth days]
ndays = 610

## Choose case

In [None]:
THAI_case = "Hab1"

In [None]:
if THAI_case.endswith("1"):
    import const_ben1_hab1 as const
else:
    import const_ben2_hab2 as const

# KW_CONST = dict(
#     mw_ratio=const.mw_ratio,
#     dry_air_gas_constant=const.rgas,
#     condens_gas_constant=const.rvapor,
#     gravity=const.gplanet,
# )

In [None]:
def extract_sfc_temp(ds, model_key, case, const):
    """Extract surface temperature from a THAI dataset."""
    model_names = getattr(names, model_key.lower())
    out = ds[model_names.t_sfc]
    if model_key == "ROCKE3D":
        out += const.t_melt  # convert from degC to K
    return out


def extract_open_ocean_frac(ds, model_key, case, const):
    """Extract open ocean fraction from a THAI dataset."""
    model_names = getattr(names, model_key.lower())
    if model_key == "ExoCAM":
        out = 1 - ds[model_names.ocean_frac]
    elif model_key == "LMDG":
        out = (1000 - ds[model_names.ocean_frac]) / 1000
    elif model_key == "ROCKE3D":
        out = ds[model_names.ocean_frac]
    elif model_key == "UM":
        out = ds[model_names.ocean_frac]
    return out


def extract_cwp(ds, model_key, case, const):
    """Extract cloud water path from a THAI dataset."""
    model_names = getattr(names, model_key.lower())
    if model_key in ["ExoCAM", "ROCKE3D"]:
        # input in [g m-2]
        out = (ds[model_names.lwp] + ds[model_names.iwp]) / 1000
    elif model_key == "LMDG":
        out = ds[model_names.cwp]
    elif model_key == "UM":
        out = ds[model_names.lwp] + ds[model_names.iwp]
    return out


def max_cloud_mmr_altitude(ds, model_key, case, const):
    """Calculate the altitude of the maximum in cloud MMR from a THAI dataset."""
    model_names = getattr(names, model_key.lower())
    _kw_const = dict(
        mw_ratio=const.mw_ratio,
        dry_air_gas_constant=const.rgas,
        condens_gas_constant=const.rvapor,
        gravity=const.gplanet,
    )
    if model_key == "ExoCAM":
        alt = calc_alt_exocam(ds, case=case, **_kw_const)
        cld_mmr = ds[model_names.cld_ice_mf] + ds[model_names.cld_liq_mf]
        out = alt.isel({model_names.lev: cld_mmr.argmax(dim=model_names.lev).compute()})
    elif model_key == "LMDG":
        alt = calc_alt_lmdg(ds, case=case, **_kw_const)
        cld_mmr = ds[model_names.cld_ice_mf]
        out = alt.isel({model_names.z: cld_mmr.argmax(dim=model_names.z).compute()})
    elif model_key == "ROCKE3D":
        alt = ds[model_names.z]
        cld_mmr = ds[model_names.cld_ice_mf] + ds[model_names.cld_liq_mf]
        out = alt.isel({model_names.lev: cld_mmr.argmax(dim=model_names.lev).compute()})
    elif model_key == "UM":
        cld_mmr = ds[model_names.cld_ice_mf] + ds[model_names.cld_liq_mf]
        out = cld_mmr.idxmax(dim=model_names.z)
    out /= 1000  # convert to km
    return out

In [None]:
AGGR_DICT = dict(g=global_mean, d=dayside_mean, n=nightside_mean, t=terminator_mean)

AGGR_TITLES = {
    "g": "Global",
    "d": "Day side",
    "n": "Night side",
    "t": "Terminators",
}

In [None]:
DATA_DICT = {}
for model_key in tqdm(LOAD_CONF.keys()):
    model_names = getattr(names, model_key.lower())
    DATA_DICT[model_key] = {}
    with LOAD_CONF[model_key]["loader"](THAI_case) as ds:
        DATA_DICT[model_key]["time"] = get_time_rel_days(ds[model_names.t])
        for var_key, var_dict in tqdm(VAR_PLOT.items(), leave=False):
            vrbl = var_dict["func"](ds, model_key, THAI_case, const).isel(
                **{model_names.y: slice(1, -1)}
            )
            DATA_DICT[model_key][var_key] = {}
            for aggr_key, aggr_func in tqdm(AGGR_DICT.items(), leave=False):
                DATA_DICT[model_key][var_key][aggr_key] = (
                    aggr_func(vrbl, model_names).sortby(model_names.t).compute()
                )

In [None]:
VAR_PLOT = {
    "t_sfc": {
        "func": extract_sfc_temp,
        "title": "Surface temperature [$K$]",
        "lim": dict(g=[230, 255], d=[258, 272], n=[200, 240]),
    },
    "ocean_frac": {
        "func": extract_open_ocean_frac,
        "title": "Ice-frea ocean area fraction",
        "lim": dict(g=[0.18, 0.25]),
    },
    "cwp": {
        "func": extract_cwp,
        "title": "Cloud water path [$kg$ $m^{-2}$]",
        "lim": dict(g=[0.02, 0.18], d=[0.04, 0.30], n=[0, 0.12], t=[0, 0.18]),
    },
    "max_cld_mmr_alr": {
        "func": max_cloud_mmr_altitude,
        "title": "Maximum cloud MMR altitude [$km$]",
        "lim": dict(g=[0, 12], d=[0, 12], n=[0, 12], t=[0, 12]),
    },
}

In [None]:
ncols = len(AGGR_DICT)
nrows = 3  # len(VAR_PLOT)

fig, axs = plt.subplots(
    ncols=ncols, nrows=nrows, figsize=(7 * ncols, 4 * nrows), constrained_layout=False
)
iletters = aplt.subplot_label_generator()
# Surface temperature
ax_slice = axs[0, :-1]
var_key = "t_sfc"
var_dict = VAR_PLOT[var_key]
for ax, aggr_key in zip(ax_slice, AGGR_DICT.keys()):
    for model_key in LOAD_CONF.keys():
        model_dict = MODELS[model_key.lower()]
        time = np.arange(DATA_DICT[model_key]["time"].shape[0]) / 4
        data = DATA_DICT[model_key][var_key][aggr_key]
        ax.plot(time, data, color=model_dict["color"], label=model_dict["title"])
    ax.set_ylim(var_dict["lim"][aggr_key])
    if ax.is_first_col():
        ax.set_ylabel(var_dict["title"])
    ax.set_title(AGGR_TITLES[aggr_key], **KW_AUX_TTL)

# Ocean fraction (global only)
var_key = "ocean_frac"
var_dict = VAR_PLOT[var_key]
aggr_key = "g"
ax = axs[0, -1]
for model_key in LOAD_CONF.keys():
    model_dict = MODELS[model_key.lower()]
    time = np.arange(DATA_DICT[model_key]["time"].shape[0]) / 4
    data = DATA_DICT[model_key][var_key][aggr_key]
    ax.plot(time, data, color=model_dict["color"], label=model_dict["title"])
ax.set_ylim(var_dict["lim"][aggr_key])
ax.set_title(var_dict["title"], **KW_MAIN_TTL)
ax.set_title(AGGR_TITLES[aggr_key], **KW_AUX_TTL)

# The rest: CWP and cloud height
for ax_slice, var_key in zip(axs[1:, :], ["cwp", "max_cld_mmr_alr"]):
    var_dict = VAR_PLOT[var_key]
    for ax, aggr_key in zip(ax_slice, AGGR_DICT.keys()):
        for model_key in LOAD_CONF.keys():
            model_dict = MODELS[model_key.lower()]
            time = np.arange(DATA_DICT[model_key]["time"].shape[0]) / 4
            data = DATA_DICT[model_key][var_key][aggr_key]
            ax.plot(time, data, color=model_dict["color"], label=model_dict["title"])
        ax.set_ylim(var_dict["lim"][aggr_key])
        if ax.is_first_col():
            ax.set_ylabel(var_dict["title"])
        ax.set_title(AGGR_TITLES[aggr_key], **KW_AUX_TTL)

# Common labels
for ax in axs.flat:
    ax.set_title(f"({next(iletters)})", **KW_SBPLT_LABEL)
    ax.set_xlim(0, ndays)
    ax.set_xticks(np.linspace(0, ndays, 11))
    ax.set_xticklabels(np.round(np.linspace(0, ndays, 11) / const.period).astype(int))
    if ax.is_last_row():
        ax.set_xlabel("Orbits")

axs.flatten()[0].legend(
    title=THAI_case, ncol=len(MODELS), loc="lower left", bbox_to_anchor=(-0.2, 1.1)
)
fig.align_ylabels(axs[:, 0])
# plt.subplots_adjust(wspace=0.3)