This script is used to plot a 2-D spatial plot of given variables.

In [1]:
# import functions
# OS interaction and time
import os
import sys
import cftime
import datetime
import time
import glob
import dask
import dask.bag as db
import calendar
import importlib

# math and data
import numpy as np
import netCDF4 as nc
import xarray as xr
import scipy as sp
import scipy.linalg
from scipy.signal import detrend
import pandas as pd
import pickle as pickle
from sklearn import linear_model
import matplotlib.patches as mpatches
from shapely.geometry.polygon import LinearRing
import statsmodels.stats.multitest as multitest

# plotting
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib import ticker
import matplotlib.colors as mcolors
from matplotlib.gridspec import GridSpec
import matplotlib.image as mpimg
from matplotlib.colors import TwoSlopeNorm

from matplotlib.ticker import FormatStrFormatter
from mpl_toolkits.axes_grid1.axes_divider import HBoxDivider
import mpl_toolkits.axes_grid1.axes_size as Size
from mpl_toolkits.axes_grid1 import make_axes_locatable

import cartopy.crs as ccrs
import cartopy.feature as cfeature
from cartopy.util import add_cyclic_point

# random
from IPython.display import display
from IPython.display import HTML
import IPython.core.display as di # Example: di.display_html('<h3>%s:</h3>' % str, raw=True)

In [2]:
my_era5_path = '/glade/u/home/zcleveland/scratch/ERA5/'  # path to subset data
misc_data_path = '/glade/u/home/zcleveland/scratch/misc_data/'  # path to misc data
plot_out_path = '/glade/u/home/zcleveland/NAM_soil-moisture/ERA5_analysis/plots/'  # path to generated plots
scripts_main_path = '/glade/u/home/zcleveland/NAM_soil-moisture/scripts_main/'  # path to my dicts, lists, and functions

In [3]:
# import variable lists and dictionaries
if scripts_main_path not in sys.path:
    sys.path.insert(0, scripts_main_path)  # path to file containing these lists/dicts
if 'get_var_data' in sys.modules:
    importlib.reload(sys.modules['get_var_data'])
if 'my_functions' in sys.modules:
    importlib.reload(sys.modules['my_functions'])
if 'my_dictionaries' in sys.modules:
    importlib.reload(sys.modules['my_dictionaries'])
if 'order_years' in sys.modules:
    importlib.reload(sys.modules['order_years'])

# import common functions that I've created
from get_var_data import get_var_data, get_var_files, open_var_data, subset_var_data, time_to_year_month_avg
from my_functions import month_num_to_name, ensure_var_list
from order_years import *  # order_years(var, months, **kwargs)

# import lists and dictionaries
import my_dictionaries
# my lists
sfc_instan_list = my_dictionaries.sfc_instan_list  # instantaneous surface variables
sfc_accumu_list = my_dictionaries.sfc_accumu_list  # accumulated surface variables
pl_var_list = my_dictionaries.pl_var_list  # pressure level variables
invar_var_list = my_dictionaries.invar_var_list  # invariant variables
NAM_var_list = my_dictionaries.NAM_var_list  # NAM-based variables
region_avg_list = my_dictionaries.region_avg_list  # region IDs for regional averages
flux_var_list = my_dictionaries.flux_var_list  # flux variables that need to be flipped (e.g., sensible heat so that it's positive up instead of down
misc_var_list = my_dictionaries.misc_var_list  # misc variables
# my dictionaries
var_dict = my_dictionaries.var_dict  # variables and their names
var_units = my_dictionaries.var_units  # variable units
region_avg_dict = my_dictionaries.region_avg_dict  # region IDs and names
region_avg_coords = my_dictionaries.region_avg_coords  # coordinates for regions
region_colors_dict = my_dictionaries.region_colors_dict  # colors to plot for each region

In [4]:
# define a function to make the plot title
def make_plot_title(var, var_month_names, var_region, **kwargs):
    # get kwargs
    var_level = kwargs.get('var_level', '')

    var_parenth = [str(var_level), str(var_month_names), str(var_region)]

    title = f"{var_dict[var]}\n{' '.join([i for i in var_parenth if i != ''])}"

    return title

In [5]:
# define a function to create the output filenames
def get_out_fn(var, var_month_names, var_region, **kwargs):

    # get optional arguments for naming
    var_level = kwargs.get('var_level', '')

    # create core of output file name
    fn_list = [str(var), str(var_level), str(var_month_names), str(var_region)]
    fn_core = '_'.join([i for i in fn_list if i != ''])

    out_fn = f'spatial_{fn_core}.png'

    return fn_core, out_fn

In [6]:
# define a function to create the output filepaths
def get_out_fp(out_fn, var_region, **kwargs):

    # check region for where to save .nc and .png files
    if (var_region == 'global'):
        out_fp = os.path.join(plot_out_path, 'spatial/global', out_fn)
    elif (var_region == 'dsw'):
        out_fp = os.path.join(plot_out_path, 'spatial/dsw', out_fn)
    elif (var_region in region_avg_list):
        out_fp = os.path.join(plot_out_path, 'spatial/regions', out_fn)
    else:
        out_fp = None

    return out_fp

In [7]:
# define the main function to plot 2-D variables
def main_spatial(var, var_months=[i for i in range(1,13)], var_region='dsw', **kwargs):
    # get kwargs
    var_years = kwargs.get('var_years', [i for i in range(1980,2020)])
    save_png = kwargs.get('save_png', False)
    overwrite_flag = kwargs.get('overwrite_flag', False)

    # get var month names
    var_month_names = month_num_to_name(var, var_months, **kwargs)

    # get output file name and path
    fn_core, out_fn = get_out_fn(var, var_month_names, var_region, **kwargs)
    out_fp = get_out_fp(out_fn, var_region, **kwargs)

    # check existence of file
    png_file = None
    if os.path.exists(out_fp):
        print(f'File already exists: {out_fn}')
        if not overwrite_flag:
            if save_png:
                print('overwrite_flag is False. Cannot save. Set to True to overwrite . . .')
                return
            else:
                print('overwrite_flag is False. Showing existing plot . . .')
                png_file = mpimg.imread(out_fp)
        else:
            print('overwrite_flag is True. Overwriting . . .')

    if not png_file:
        # get var data
        var_data = get_var_data(var, var_region, var_months, dim_means=[], level=kwargs.get('var_level', None), **kwargs)

        # convert data to year, month with monthly averages
        var_data = time_to_year_month_avg(var_data)

        # subset data by year and month and take mean
        var_data_sub = var_data.sel(year=var_years, month=var_months)

        # calculate mean and std deviation
        var_data_mean = var_data_sub.mean(['year', 'month'], skipna=True)
        var_data_std = var_data_sub.std(['year', 'month'], skipna=True)

        # plot the data
        fig, ax = plot_spatial(var, var_month_names, var_region, var_data, var_data_mean, var_data_std, **kwargs)

        # save if specified
        if save_png:
            plt.savefig(out_fp, bbox_inches='tight', dpi=300)

    plt.show()
    plt.close()

In [8]:
# define a function to show the already created .png file
def show_plot(png_file, **kwargs):
    fig, ax = plt.subplots(figsize=(12,10))
    plt.imshow(png_file)
    plt.axis('off')
    # plt.title(make_plot_title(var1, var1_month_names, var1_region, var2, var2_month_names, var2_region, **kwargs))
    plt.tight_layout()

    return fig, ax

In [9]:
# define a function to plot spatial plots
def plot_spatial(var, var_month_names, var_region, var_data, var_data_mean, var_data_std, **kwargs):

    # set figure characteristics
    projection = ccrs.PlateCarree()
    fig, ax = plt.subplots(figsize=(12, 10), subplot_kw=dict(projection=projection))

    # set min and max for plotting
    vmin = np.nanmin(var_data)
    vmax = np.nanmax(var_data)
    cf_levels = np.linspace(vmin, vmax, 50)
    norm = plt.Normalize(vmin, vmax)

    cs_min = np.nanmin(var_data_std)
    cs_max = np.nanmax(var_data_std)
    cs_levels = np.linspace(cs_min, cs_max, 10)

    # optional cmap kwarg. Default is turbo
    cmap = kwargs.get('cmap', 'turbo')

    # shaded contour plot of data
    cf = plt.contourf(var_data_mean.longitude, var_data_mean.latitude, var_data_mean,
                      levels=cf_levels, cmap=cmap, norm=norm, extend='both')

    # lined contour plot of standard deviation
    cs = plt.contour(var_data_std.longitude, var_data_std.latitude, var_data_std,
                     levels=cs_levels, linewidths=0.5, linestyles='--', colors='black')

    # add coastlines, state borders, and other features
    ax.coastlines(linewidth=0.5)
    ax.add_feature(cfeature.BORDERS, linestyle=':', linewidth=0.5)
    ax.add_feature(cfeature.STATES, linewidth=0.5)

    # add color bar and contour labels
    plt.colorbar(cf, ax=ax, label=var_units[var], pad=0.02)
    plt.clabel(cs, inline=True, fontsize=8, fmt='%1.1f')
    fig.suptitle(make_plot_title(var, var_month_names, var_region, **kwargs))
    plt.tight_layout()

    return fig, ax

In [None]:
def main_spatial_multi_month(var, var_years = [i for i in range(1980,2020)], var_months=[[i] for i in range(1,13)], var_region='dsw', **kwargs):
    # get var_data
    var_data = get_var_data(var, region=var_region, **kwargs)

    # convert to year, month with monly averages
    var_data = time_to_year_month_avg(var_data).sel(year=var_years, month=var_months)

    # compute mean and std deviation for each month (along year dimension)
    var_data_mean = var_data.mean('year', skipna=True)
    var_data_std = var_data.std('year', skipna=True)

    

In [None]:
# cell to plot data
var = 'z_height'
var_region = 'dsw'
var_months_list = [[i] for i in range(1,13)]

var_kwargs = {
    'var_level': 500,
}

main_kwargs = {
    'cmap': 'turbo',
    'save_png': False,
    'overwrite_flag': False,
}

for var_months in var_months_list:
    kwargs = main_kwargs.copy()
    if var in pl_var_list:
        kwargs.update({'var_level': var_kwargs['var_level']})
    main_spatial(var, var_months=var_months, var_region=var_region, **kwargs)

In [None]:
# cell to plot data
var = 'z_height'
var_region = 'dsw'
var_months = [[i] for i in range(1,13)]

var_kwargs = {
    'var_level': 500,
}

main_kwargs = {
    'cmap': 'turbo',
    'save_png': False,
    'overwrite_flag': False,
}

for var_months in var_months_list:
    kwargs = main_kwargs.copy()
    if var in pl_var_list:
        kwargs.update({'var_level': var_kwargs['var_level']})
    main_spatial_monthly(var, var_months=var_months, var_region=var_region, **kwargs)