In [None]:
import pathlib
import itertools
from collections import defaultdict

import numpy as np
import xarray as xr
import pyzome as pzm

import matplotlib as mpl
from matplotlib import pyplot as plt
from matplotlib import dates as mdates
from matplotlib import ticker as mticker

#%matplotlib inline # uncomment if you want plots in the notebook to appear

In [None]:
# from matplotlib import font_manager
# font_manager.findSystemFonts()

# Setup
## Define general globals
These will be used throughout the rest of the notebook, where needed

In [None]:
# Where are we reading data from?
DATA_DIR = pathlib.Path("/gws/nopw/j04/snapsi/processed")

# Where are we saving plots to?
PLOT_DIR = pathlib.Path("./plots/atlas")

# Overwrite old plot files? If false, this notebook will
# skip any files that have already been made. If true, 
# it will overwrite (remake) them.
CLOBBER = True

# This will be used to guarantee the order of panels in plots. 
# If you prefer a different order, change the numeric values 
# accordingly (lower val = higher panel in the plot). 
PANEL_ORDER_PRECEDENCE = {
    "free": 0,
    "nudged": 1,
    "control": 2,
    "nudged-full": 3,
    "control-full": 4,
}

# For labeling the modeling centres with their respective models
CENTRE_TRANSLATOR = {
    "BCC-CSM2-HR": "", # model name includes the centre name
    "CanESM5": "ECCC",
    "CESM2-CAM6": "NCAR",
    "CNRM-CM61": "Meteo-France",
    "GEM-NEMO": "ECCC",
    "GLOBO": "CNR-ISAC",
    "GRIMs": "SNU",
    "GloSea6": "UKMO",
    "GloSea6-GC32": "KMA",
    "IFS": "ECMWF",
    "NAVGEM": "NRL",
    "SPEAR": "NOAA-GFDL"
}

# for diagnostics that should only be plotted
# for the relevant winter hemisphere
HEMI_TRANSLATOR = {
    "s20180125": 1,
    "s20180208": 1,
    "s20181213": 1,
    "s20190108": 1,
    "s20190829": -1,
    "s20191001": -1,
}

## Define plot globals

In [None]:
XMAJOR_LOCATOR = mdates.DayLocator([1,8,15,22])
XMINOR_LOCATOR = mdates.DayLocator()
XTICK_FORMATTER = mdates.DateFormatter("%y%m%d")

# use consistent colors for the different experiments
# free is ~blue, nudged is ~orange, control is ~green,
# nudged-full is ~light-orange, and control-full is 
# ~light-green (based on tab20 colors)
EXPERIMENT_COLORS = {
    "free": "#004580",
    "nudged": "#D04A00",
    "nudged-full": "#D18847",
    "control": "#057004",
    "control-full": "#68B258",
}

RC_PARAMS = {
    "font.family": "sans-serif",
    "font.sans-serif": "Ubuntu", # not many fonts available on the jasmin notebook system!
    "font.size": 12,
    "xtick.major.size": 7,
    "xtick.minor.size": 4,
    "xtick.minor.visible": True,
    "ytick.major.size": 7,
    "ytick.minor.size": 4,
    "ytick.minor.visible": True,
}
for key,val in RC_PARAMS.items():
    mpl.rcParams[key] = val

## Define the plotting function

In [None]:
def make_plume_plots(
    das: list[xr.DataArray], 
    obs: xr.DataArray, 
    titles: list[str], 
    suptitle: str, 
    ylabel: str="", 
    ylim: tuple[float,float]=None,
    colors: list[str]=None,
):
    """ Function for making quick & dirty plume plots of the SNAPSI data.

    Parameters
    ----------
    das : list of `xarray.DataArray`s
        The data to plot in each panel. Each DataArray goes to one panel.
        It's expected that every DataArray has the same time axis.
    obs : `xarray.DataArray`
        The observational data to overplot on top, in each panel.
        It's expected that this data has exactly the same times
        as all the arrays in das
    titles : list of str
        The titles for each panel. Should have as many titles as
        the `DataArray`s in das
    suptitle : str
        The overall suptitle of the plot
    ylabel : str, optional
        The label for the y-axis. Defaults to an empty string
    ylim : tuple of two floats
        The ylim that should be applied to each panel. Defaults
        to None, which mean the limits are auto-determined, but 
        consistent across the panels.
    colors: list of strings
        The hex-string of the colors to plot for each `DataArray` in das.
        If provided, should have as many colors as the `DataArray`s in das.

    Returns
    -------
    fig : 
        The matplotlib figure instance with the data plotted
    """
    num_panels = len(das)
    if len(titles) != num_panels:
        raise ValueError("Number of titles should equal number of provided DataArrays")

    if not isinstance(colors, list) and len(colors) != num_panels:
        raise ValueError("Number of colors should equal number of provided DataArrays")
    elif colors is None:
        colors = ["red"]*num_panels

    # Set up the figure 
    fig_width = 8
    fig_height = 3.5*num_panels
    fig_size = (fig_width, fig_height)
    fig, axs = plt.subplots(nrows=num_panels, ncols=1, figsize=fig_size)
    if isinstance(axs, mpl.axes.Axes): # handle case where we only get one axis
        axs = [axs]

    # loop over the DataArrays
    lo_ylims, hi_ylims = [], []
    for i,da in enumerate(das):
        c = colors[i]
        da.plot.line(ax=axs[i], hue="member_id", alpha=0.25, linewidth=0.5, add_legend=False, color=c)
        da.mean("member_id").plot.line(ax=axs[i], linewidth=2.0, label="model", color=c)
        obs.plot.line(ax=axs[i], color="black", linewidth=2.0, label="ERA5")
        axs[i].minorticks_on()

        # keep track of the y-limits in case we need them later
        # to set consistent ylims
        ylo,yhi = axs[i].get_ylim()
        lo_ylims.append(ylo)
        hi_ylims.append(yhi)

    # are we setting ylims ourselves?
    if ylim is None:
        ylim = (min(lo_ylims), max(hi_ylims))

    # apply axis styling. we do this separate from the loop above
    # in case we need to use information from the data plotted in
    # all of the panels (such as the ylims)
    for i,ax in enumerate(axs):
        # x-axis
        xlabel = "Date [YYMMDD]" if i == num_panels-1 else ""
        ax.set_xlim(obs.time[0], obs.time[-1])
        ax.xaxis.set_major_locator(XMAJOR_LOCATOR)
        ax.xaxis.set_major_formatter(XTICK_FORMATTER)
        ax.xaxis.set_minor_locator(XMINOR_LOCATOR)
        ax.set_xlabel(xlabel, fontsize=14)
        ax.tick_params(axis="x", rotation=0)
        xticklabels = ax.get_xticklabels()
        for xtl in xticklabels:
            xtl.set_ha("center")

        # y-axis
        ax.set_ylabel(ylabel, fontsize=14)
        ax.set_ylim(ylim)
        ax.yaxis.set_ticks_position("both")
        if ylim[0] < 0 < ylim[1]:
            axs[i].axhline(0, color="black", linewidth=0.5, linestyle="-")

        # add a legend on first panel
        if (i == 0):
            ax.legend()
        ax.set_title(titles[i], fontsize=16)

    # Place the suptitle nicely, with respect to top axis
    dy = 0.8925 # offset from top ax, to top of suptitle, in inches
    plt.subplots_adjust(bottom=0.05, hspace=0.22)
    l, b, w, h = axs[0].get_position().bounds
    plt.suptitle(suptitle, fontsize=20, fontweight="semibold", y=b+h+(dy/fig_height))

    return fig

## Define the "Atlas" functions
The idea of these functions is that they will take a zonal mean dataset read in by xarray,
and return DataArrays for plume plots. You can define more, but the idea is that they should
pull out a DataArray from a Dataset, and transform it so that it has only (at most)
dimensions for `(member_id, time, level)`. If `level` is kept, the transformed data needs
to be limited to a single level before sending it to the above plotting function, which 
expects only `(member_id, time)`

In [None]:
def vt4575(ds: xr.Dataset, hemi: int, k: int):
    """ 45-75 lat avg of eddy heat flux """
    if k == 0:
        vt = ds['vT']
    elif 0 < k < 4:
        vt = ds['vT_k'].sel(wavenum_lon = k)
    else:
        raise ValueError("can only take wavenums from 1-3")

    if hemi == 1:
        lats = (45, 75)
    elif hemi == -1:
        lats = (-75, -45)
    else:
        raise ValueError("hemi can only be 1 (NH) or -1 (SH)")

    return pzm.meridional_mean(vt, *lats).resample(time="1D").mean("time")


def vt4575tot(ds: xr.Dataset, hemi: int):
    return vt4575(ds, hemi, 0)
    

def t6090(ds: xr.Dataset, hemi: int):
    """ 60-90 lat avg of temperature """
    if hemi == 1:
        lats = (60, ds.lat.values.max())
    elif hemi == -1:
        lats = (ds.lat.values.min(), -60)
    else:
        raise ValueError("hemi can only be 1 (NH) or -1 (SH)")

    return pzm.meridional_mean(ds["T"], *lats).resample(time="1D").mean("time")


def u60(ds: xr.Dataset, hemi: int):
    """ zonal mean u at 60 degrees """
    if hemi == 1:
        lat = 60
    elif hemi == -1:
        lat = -60
    else:
        raise ValueError("hemi can only be 1 (NH) or -1 (SH)")
        
    return ds["u"].interp(lat=lat).resample(time="1D").mean("time")


def zamp60(ds: xr.Dataset, hemi: int, k: int):
    """ amplitude of geohgt waves at 60 degrees """
    if hemi == 1:
        lat = 60
    elif hemi == -1:
        lat = -60
    else:
        raise ValueError("hemi can only be 1 (NH) or -1 (SH)")

    if not 0 < k < 4:
        raise ValueError("can only take wavenums from 1-3")
    
    z_ks = (ds["Z_k_real"] + 1j*ds["Z_k_imag"]).sel(wavenum_lon=k)
    
    return (2*np.absolute(z_ks.interp(lat=lat))/ds.nlons).resample(time="1D").mean("time")


def uqbo(ds: xr.Dataset, *args):
    """ zonal winds averaged from -5 to 5 for QBO """
    return pzm.meridional_mean(ds["u"], -5, 5).resample(time="1D").mean("time")


def tqbo(ds: xr.Dataset, *args):
    """ temperatures averaged from -5 to 5 for QBO """
    return pzm.meridional_mean(ds["T"], -5, 5).resample(time="1D").mean("time")

## Define the job parameters
Top level of nested dictionary is the "batch" - the dictionary underneath it specifies
which of the above functions should be called for each of the specified 
pressure levels / zonal wavenumbers, as well as strings that go into the eventual plot.

In [None]:
regular_jobs = {
    "U60": {
        "callback": u60,
        "levels": (10, 100),
        "suptitle": "{lev} hPa, 60°{hemi} Zonal Mean U\n{centre} {model} {init}",
        "ylabel": "Zonal Wind [m/s]",
    },
    "T6090": {
        "callback": t6090,
        "levels": (10, 100),
        "suptitle": "{lev} hPa, 60-90°{hemi} Polar Cap T\n{centre} {model} {init}",
        "ylabel": "Temperature [K]",
    },
    "vT4575": {
        "callback": vt4575tot,
        "levels": (50, 100, 300),
        "suptitle": "{lev} hPa, 45-75°{hemi} v'T'\n{centre} {model} {init}",
        "ylabel": "Eddy Heat Flux [K m/s]",
    },
    "UQBO": {
        "callback": uqbo,
        "levels": (10, 30, 50),
        "suptitle": "{lev} hPa, 5°S-5°N QBO U\n{centre} {model} {init}",
        "ylabel": "Zonal Wind [m/s]"
    },
    "TQBO": {
        "callback": tqbo,
        "levels": (50, 70, 100),
        "suptitle": "{lev} hPa, 5°S-5°N QBO T\n{centre} {model} {init}",
        "ylabel": "Temperature [K]"
    },
}

wavenum_jobs = {
    "Z60-amp-k": {
        "callback": zamp60,
        "levels": (10, 100, 300),
        "wavenums": (1, 2, 3),
        "suptitle": "{lev} hPa, 60°{hemi} Wave-{k} Amplitude\n{centre} {model} {init}",
        "ylabel": "Amplitude [m]"
    },
    "vT4575-k": {
        "callback": vt4575,
        "levels": (50, 100, 300),
        "wavenums": (1, 2, 3),
        "suptitle": "{lev} hPa, 45-75°{hemi} Wave-{k} v'T'\n{centre} {model} {init}",
        "ylabel": "Eddy Heat Flux [K m/s]"
    },
}

# Begin the work
## Get zonal mean files for obs/models

In [None]:
# Get list of zonal mean dataset files
zmd_files = sorted(list(DATA_DIR.glob("**/*zonalmeans.nc")))

# We don't want the reanalysis files in the mix of the experimental data,
# so separate these with list-comprehension filters
era5_files = [fi for fi in zmd_files if "ERA5" in str(fi)]
zmd_files = [fi for fi in zmd_files if "ERA5" not in str(fi)]

In [None]:
# To make "atlas" figures, we want to group the zonal mean files by
# the combination of the model and init date
keygen = lambda fi: (fi.parts[6], fi.parts[8])
grouped = {key: list(group) for key,group in itertools.groupby(sorted(zmd_files, key=keygen), keygen)}

In [None]:
# If you want to see what grouped now contains, try uncommenting/running

# for key,group in grouped.items():
#    print(f"{key} -> {group}")

# Essentially, we get (model, init) -> list of files having the same model & init.
# The key is the tuple, and the group is a generator (or it would be if we didn't 
# put the results into a list; if you don't know what these mean in the context of python, 
# don't worry about it). 

## Read in the reanalysis data
This will be used across all of the atlas plots

In [None]:
rean_data = xr.open_mfdataset(era5_files)

## Make the plots
Below we loop over the different `(model, init)` combinations, then the regular job batch. 

In [None]:
for (model, init),group in grouped.items():
    centre = CENTRE_TRANSLATOR[model]
    hemi = HEMI_TRANSLATOR[init]
    hs = "N" if hemi == 1 else "S"

    # sort the group of files in panel order 
    files_to_read = sorted(list(group), key = lambda fi: PANEL_ORDER_PRECEDENCE[fi.parts[7]]) 
    experiments = [fi.parts[7] for fi in files_to_read]
    colors = [EXPERIMENT_COLORS[exp] for exp in experiments]

    # read the data, and obtain the reanalysis data for same time period
    ds_list = [xr.open_dataset(fi) for fi in files_to_read]
    if (model == "CESM2-CAM6"): # handle the CESM2 noleap dates
        ds_list = [ds.assign_coords({"time": ds.time.convert_calendar("gregorian").time}) for ds in ds_list]
    obs_ds = rean_data.sel(time=ds_list[0].time, method="nearest") # method = nearest needed for SPEAR, which has times offset from typical 6 hourly?

    for batch,job_info in regular_jobs.items():
        get_product = job_info["callback"]
        print(f"Now working on {model} {init} {batch}")
        for lev in job_info["levels"]:
            # Set up output dirs/filenames
            outdir = PLOT_DIR / f"{init}/{batch}/{lev:03d}mb/"
            plot_file = f"{centre}_{model}_{init}_{batch}_{lev:03d}mb.png"
            outfile = outdir / plot_file
            if outfile.exists() and CLOBBER is False:
                print(f"  {outfile} exists; skipping!")
                continue
            outdir.mkdir(exist_ok=True, parents=True)

            # Fill in the suptitle
            suptitle = job_info["suptitle"].format(lev=lev, hemi=hs, centre=centre, model=model, init=init)

            # Subset data using sel and the callback func
            da_list = [get_product(ds.sel(plev=lev*100), hemi) for ds in ds_list]
            obs_da = get_product(obs_ds.sel(plev=lev*100), hemi)

            # Make the plot
            fig = make_plume_plots(da_list, obs_da, experiments, suptitle, ylabel=job_info["ylabel"], colors=colors)
            fig.savefig(outfile, bbox_inches="tight")
            plt.close(fig)

Now loop over the wavenumber job batch, which requires an additional loop to enumerate the combos of pressure level and wavenumber

In [None]:
for (model, init),group in grouped.items():
    centre = CENTRE_TRANSLATOR[model]
    hemi = HEMI_TRANSLATOR[init]
    hs = "N" if hemi == 1 else "S"

    # sort the group of files in panel order 
    files_to_read = sorted(list(group), key = lambda fi: PANEL_ORDER_PRECEDENCE[fi.parts[7]]) 
    experiments = [fi.parts[7] for fi in files_to_read]
    colors = [EXPERIMENT_COLORS[exp] for exp in experiments]

    # read the data, and obtain the reanalysis data for same time period
    ds_list = [xr.open_dataset(fi) for fi in files_to_read]
    if (model == "CESM2-CAM6"): # handle the CESM2 noleap dates
        ds_list = [ds.assign_coords({"time": ds.time.convert_calendar("gregorian").time}) for ds in ds_list]
    obs_ds = rean_data.sel(time=ds_list[0].time, method="nearest") # method = nearest needed for SPEAR, which has times offset from typical 6 hourly?

    for batch,job_info in wavenum_jobs.items():
        get_product = job_info["callback"]
        print(f"Now working on {model} {init} {batch}")
        for lev in job_info["levels"]:
            for k in job_info["wavenums"]:
                # Set up output dirs/filenames
                outdir = PLOT_DIR / f"{init}/{batch}/{lev:03d}mb/k{k}"
                plot_file = f"{centre}_{model}_{init}_{batch}{k}_{lev:03d}mb.png"
                outfile = outdir / plot_file
                if outfile.exists() and CLOBBER is False:
                    print(f"  {outfile} exists; skipping!")
                    continue
                outdir.mkdir(exist_ok=True, parents=True)
                
                # Fill in the suptitle
                suptitle = job_info["suptitle"].format(lev=lev, hemi=hs, k=k, centre=centre, model=model, init=init)
    
                # Subset data using sel and the callback func
                da_list = [get_product(ds.sel(plev=lev*100), hemi, k) for ds in ds_list]
                obs_da = get_product(obs_ds.sel(plev=lev*100), hemi, k)
    
                # Make the plot
                fig = make_plume_plots(da_list, obs_da, experiments, suptitle, ylabel=job_info["ylabel"], colors=colors)
                fig.savefig(outfile, bbox_inches="tight")
                plt.close(fig)