In [None]:
import xarray as xr
import numpy as np
from utils.files import OIFSPreprocessor, NEMOPreprocessor, OASISPreprocessor
import convergence_checker as cc
import control_experiment_1 as ce1
import user_context as context
import proplot as pplt

In [None]:
max_iters = ce1.max_iters
start_date = ce1.start_date
exp_id = "C1SP"
plotting_output_dir = context.plotting_dir / "convergence_criteria"
plotting_output_dir.mkdir(exist_ok=True)

In [None]:
found_output_dirs = list(context.output_dir.glob(f"{exp_id}_*"))
if found_output_dirs == []:
    ce1.run_parallel_schwarz_without_cleanup()
    schwarz_dir_ref = (context.output_dir / exp_id).rename(
        context.output_dir / f"{exp_id}_{max_iters + 1}"
    )
else:
    schwarz_dir_ref = context.output_dir / f"{exp_id}_{max_iters + 1}"

In [None]:
conv_checker = cc.ConvergenceChecker()
locals_final = []
amplitudes_final = []
for iter in range(2, max_iters + 1):
    schwarz_dir_iter = context.output_dir / f"{exp_id}_{iter}"
    local, amplitude = conv_checker.check_convergence(schwarz_dir_iter, schwarz_dir_ref)
    locals_final.append(local)
    amplitudes_final.append(amplitude)
    print(f"Iter {iter - 1}: {local=}, {amplitude=}")

In [None]:
conv_checker = cc.ConvergenceChecker()
locals_subs = []
amplitudes_subs = []
for iter in range(2, max_iters + 1):
    schwarz_dir_iter = context.output_dir / f"{exp_id}_{iter}"
    schwarz_dir_next_iter = context.output_dir / f"{exp_id}_{iter + 1}"
    local, amplitude = conv_checker.check_convergence(
        schwarz_dir_iter, schwarz_dir_next_iter
    )
    locals_subs.append(local)
    amplitudes_subs.append(amplitude)
    print(f"Iter {iter - 1}: {local=}, {amplitude=}")

In [None]:
nemo_preproc = NEMOPreprocessor(start_date)

penultimate_schwarz_dir = context.output_dir / f"{exp_id}_{max_iters - 7}"
final_schwarz_dir = context.output_dir / f"{exp_id}_{max_iters}"
nemo_penultimate_file = next(penultimate_schwarz_dir.glob("*_grid_T.nc"))
nemo_final_file = next(final_schwarz_dir.glob("*_grid_T.nc"))

penultimate_output = xr.open_mfdataset(
    nemo_penultimate_file, preprocess=nemo_preproc.preprocess
)
final_output = xr.open_mfdataset(nemo_final_file, preprocess=nemo_preproc.preprocess)

np.max(np.abs(final_output.sosstsst - penultimate_output.sosstsst)).load()

In [None]:
oifs_preproc = OIFSPreprocessor(start_date)

penultimate_schwarz_dir = context.output_dir / f"{exp_id}_{max_iters - 8}"
oifs_penultimate_file = penultimate_schwarz_dir / "progvar.nc"
oifs_final_file = final_schwarz_dir / "progvar.nc"

penultimate_output = xr.open_mfdataset(
    oifs_penultimate_file, preprocess=oifs_preproc.preprocess
)
final_output = xr.open_mfdataset(oifs_final_file, preprocess=oifs_preproc.preprocess)

np.max(np.abs(final_output.t - penultimate_output.t)).load()

In [None]:
np.spacing(np.float32(280))

In [None]:
def create_conv_criterion_plot(conv_variable: str, conv_variable_name: str):
    preprocessor = OASISPreprocessor()
    coupling_file_reference = next(schwarz_dir_ref.glob(f"{conv_variable}.nc"))
    reference = xr.open_mfdataset(
        coupling_file_reference, preprocess=preprocessor.preprocess
    )
    reference = reference[conv_variable]
    amplitude = reference.max() - reference.min()
    max_local_threshold = 1e-3 * np.abs(reference).max()
    min_local_threshold = 1e-3 * np.abs(reference).min()
    amplitude_threshold = 1e-3 * np.abs(amplitude)
    values_wrt_ref = []
    values_wrt_next = []
    schwarz_dir_iter = context.output_dir / f"{exp_id}_2"
    for iter in range(2, max_iters + 1):
        schwarz_dir_next_iter = context.output_dir / f"{exp_id}_{iter + 1}"

        coupling_file_iterate = next(schwarz_dir_iter.glob(f"{conv_variable}.nc"))
        coupling_file_next_iterate = next(
            schwarz_dir_next_iter.glob(f"{conv_variable}.nc")
        )

        iterate = xr.open_mfdataset(
            coupling_file_iterate, preprocess=preprocessor.preprocess
        )
        next_iterate = xr.open_mfdataset(
            coupling_file_next_iterate, preprocess=preprocessor.preprocess
        )

        iterate = iterate[conv_variable]
        next_iterate = next_iterate[conv_variable]

        max_abs_diff = np.abs(reference - iterate).max()
        values_wrt_ref.append(max_abs_diff.load().data[()])

        max_abs_diff = np.abs(next_iterate - iterate).max()
        values_wrt_next.append(max_abs_diff.load().data[()])

        schwarz_dir_iter = schwarz_dir_next_iter

    fig, ax = pplt.subplots(height="40em", width="70em")
    xvalues = np.arange(1, max_iters)

    ax.area(
        xvalues,
        len(xvalues) * [min_local_threshold],
        len(xvalues) * [max_local_threshold],
        color="green",
        alpha=0.3,
    )
    ax.semilogy(
        xvalues,
        values_wrt_ref,
        label=r"$|| c^{{20}} - c^{{k}} ||_\infty$",
        marker=".",
        ls="none",
        color="blue9",
    )
    ax.semilogy(
        xvalues,
        values_wrt_next,
        label=r"$|| c^{{k+1}} - c^{{k}} ||_\infty$",
        marker="1",
        ls="none",
        color="orange9",
    )
    ax.semilogy(
        xvalues,
        len(xvalues) * [amplitude_threshold],
        label=r"$10^{{-3}} \times A(c^{{20}})$",
        color="k",
        ls="--",
    )
    ax.semilogy(
        xvalues,
        len(xvalues) * [min_local_threshold],
        color="green",
        label="Value Range of $| c^{{20}} |$",
    )
    ax.semilogy(xvalues, len(xvalues) * [max_local_threshold], color="green")

    ax.format(
        yformatter="sci",
        xlabel="Iteration",
        ylabel="Maximum Absolute Difference",
        xminorticks=[0, *xvalues, max_iters],
        title=f"Convergence of {conv_variable_name}",
    )
    ax.legend(loc="ll", ncol=1, framealpha=1)
    fig.savefig(plotting_output_dir / f"convergence_{conv_variable}.pdf")

In [None]:
coupling_vars_with_names = {
    "O_OTaux1": "Zonal Wind Stress",
    "O_OTauy1": "Meridional Wind Stress",
    "O_QsrMix": "Solar Heat Flux",
    "O_QnsMix": "Nonsolar Heat Flux",
    "OTotEvap": "Total Evaporation",
    "OTotRain": "Total Rain",
    # "OTotSnow": "Total Snow",
    "A_SST": "Sea Surface Temperature",
}
for variable, variable_name in coupling_vars_with_names.items():
    create_conv_criterion_plot(variable, variable_name)