In [None]:
import mass_flux_off as experiment_runner
from pathlib import Path
import numpy as np
import proplot as pplt
import xarray as xr
import pandas as pd
from utils.files import OIFSPreprocessor, NEMOPreprocessor

In [None]:
oifs_preprocessor = OIFSPreprocessor(
    experiment_runner.start_date, np.timedelta64(-7, "h")
)
nemo_preprocessor = NEMOPreprocessor(
    experiment_runner.start_date, np.timedelta64(-7, "h")
)

In [None]:
plotting_output_directory = Path("plots/mass_flux_off")
plotting_output_directory.mkdir(exist_ok=True)

In [None]:
naive_exp_ids = ["MFN0", "MFN1", "MFN2"]
schwarz_exp_id = "MFS0"
max_schwarz_iters = experiment_runner.max_iters
setup = "PAPA"

In [None]:
oifs_progvars_naive = [
    xr.open_mfdataset(
        f"{setup}/{exp_id}/progvar.nc", preprocess=oifs_preprocessor.preprocess
    )
    for exp_id in naive_exp_ids
]

oifs_progvar_cvg_swz = xr.open_mfdataset(
    f"{setup}/{schwarz_exp_id}_{max_schwarz_iters}/progvar.nc",
    preprocess=oifs_preprocessor.preprocess,
)

# Vertical Temperature Profile at Two Points in Time

In [None]:
oifs_progvars = [*oifs_progvars_naive, oifs_progvar_cvg_swz]

labels = ["parallel", "atmosphere-first", "ocean-first", "converged SWR"]
colors = ["m", "c", "y", "k"]
linestyles = ["--", ":", "-.", "-"]

fig, axs = pplt.subplots(nrows=1, ncols=4, width="70em", height="30em")
# axs.format(suptitle="Vertical Atmospheric Temperature Profiles in Control Experiment")
axs.format(abc="a)")
timestamps = [
    pd.Timestamp("2014-07-02 12:00"),
    pd.Timestamp("2014-07-03 00:00"),
    pd.Timestamp("2014-07-03 12:00"),
    pd.Timestamp("2014-07-04 12:00"),
]

ax = axs[0]
ims = []
for i in range(len(oifs_progvars)):
    label = labels[i]
    color = colors[i]
    linestyle = linestyles[i]
    oifs_progvar = oifs_progvars[i]
    oifs_progvar = oifs_progvar.assign_coords(
        air_pressure=("nlev", oifs_progvar.pressure_f[0].data / 100)
    )
    oifs_progvar = oifs_progvar.swap_dims({"nlev": "air_pressure"})
    timestamp = timestamps[0]
    t_for_plotting = oifs_progvar.t.sel(time=timestamp)[45:] - 273.15
    im = ax.plot(
        t_for_plotting,
        t_for_plotting.air_pressure,
        label=label,
        color=color,
        ls=linestyle,
    )
    ims.append(im)
    ax.format(title=timestamp)

ax = axs[1]
for i in range(len(oifs_progvars)):
    label = labels[i]
    color = colors[i]
    linestyle = linestyles[i]
    oifs_progvar = oifs_progvars[i]
    oifs_progvar = oifs_progvar.assign_coords(
        air_pressure=("nlev", oifs_progvar.pressure_f[0].data / 100)
    )
    oifs_progvar = oifs_progvar.swap_dims({"nlev": "air_pressure"})
    timestamp = timestamps[1]
    t_for_plotting = oifs_progvar.t.sel(time=timestamp)[45:] - 273.15
    ax.plot(
        t_for_plotting,
        t_for_plotting.air_pressure,
        label=label,
        color=color,
        ls=linestyle,
    )
    ax.format(title=timestamp)

ax = axs[2]
for i in range(len(oifs_progvars)):
    label = labels[i]
    color = colors[i]
    linestyle = linestyles[i]
    oifs_progvar = oifs_progvars[i]
    oifs_progvar = oifs_progvar.assign_coords(
        air_pressure=("nlev", oifs_progvar.pressure_f[0].data / 100)
    )
    oifs_progvar = oifs_progvar.swap_dims({"nlev": "air_pressure"})
    timestamp = timestamps[2]
    t_for_plotting = oifs_progvar.t.sel(time=timestamp)[45:] - 273.15
    ax.plot(
        t_for_plotting,
        t_for_plotting.air_pressure,
        label=label,
        color=color,
        ls=linestyle,
    )
    ax.format(title=timestamp)

ax = axs[3]
for i in range(len(oifs_progvars)):
    label = labels[i]
    color = colors[i]
    linestyle = linestyles[i]
    oifs_progvar = oifs_progvars[i]
    oifs_progvar = oifs_progvar.assign_coords(
        air_pressure=("nlev", oifs_progvar.pressure_f[0].data / 100)
    )
    oifs_progvar = oifs_progvar.swap_dims({"nlev": "air_pressure"})
    timestamp = timestamps[3]
    t_for_plotting = oifs_progvar.t.sel(time=timestamp)[45:] - 273.15
    ax.plot(
        t_for_plotting,
        t_for_plotting.air_pressure,
        label=label,
        color=color,
        ls=linestyle,
    )
    ax.format(title=timestamp)


for ax in axs:
    ax.format(
        xlabel="Temperature [°C]",
        ylabel="Air Pressure [hPa]",
        yreverse=True,
        xlim=[-2, 12],
    )

fig.legend(ims, frame=False, ncols=4, loc="b")
fig.savefig(plotting_output_directory / "ce_air_temperature_stratification.pdf")