# Program - Plot cn_2d overlaid with wind vectors

**Content**
- Read data
- Plot cn_2d
- Overlay cn_2d with wind vectors

**Reference program:**

**Author:**
Yi-Hsuan chen

**Date:**
December 2023

In [1]:
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import cartopy.util as cutil
import cartopy.mpl.ticker as cticker
import matplotlib.pyplot as plt
from matplotlib.ticker import (MultipleLocator, AutoMinorLocator)
from matplotlib.ticker import MaxNLocator
import matplotlib.patches as mpatches
import numpy as np
import xarray as xr
import io, os, sys, types

import yhc_module as yhc

#--- silence dask large chunk and silence the warning.
import dask
dask.config.set(**{'array.slicing.split_large_chunks': False})

xr.set_options(keep_attrs=True)  # keep attributes after xarray operation

<xarray.core.options.set_options at 0x7fd749d44b20>

## Read data 

### read data

In [3]:
yhc.lib('read_data_era5')


    
######################
######################
######################

def read_dataset(choice, datapath00="/Users/yi-hsuanchen/Downloads/yihsuan/research/projects/Sc_diag/data/"):

    """
    read a data
    """
    func_name = "read_data"
    
    [36m#--- datapath
    if (choice == "ERA5_single_level"):
        datapath = datapath00+"data.ERA5/"
        fnames = [datapath+"ERA5-2001July-single_level.nc",
                 ]
        da = xr.open_mfdataset(fnames)
        da = xr.open_mfdataset(fnames, decode_cf=False)  # In ERA5 data is stored in short format, set decode_cf=False otherwise the xarray will read wrong values
        da = yhc.wrap360(da, lon='longitude')            # # MERRA-2 has [-180,180' longitudes, change it to [0-360].

    else:
        error_msg = f"ERROR: function [{func_name}] does not support choice=[{choice}]."
        raise KeyError(error_msg)
    
    [36m#--- return
    return da

######################
######################
######################

In [11]:
######################
######################
######################

import cftime

def read_dataset(choice="0001", datapath="../data_test/"):

    """
    read a data
    """
    func_name = "read_data"
    
    #--- datapath
    if (choice == "0001"):
        fnames = [datapath+"ERA5-DYCOMS_state.nc",
                 ]
        da = xr.open_mfdataset(fnames)
        da = xr.open_mfdataset(fnames, decode_cf=False)  # In ERA5 data is stored in short format, set decode_cf=False otherwise the xarray will read wrong values
        da = yhc.wrap360(da, lon='longitude')            # # MERRA-2 has [-180,180' longitudes, change it to [0-360].

    else:
        error_msg = f"ERROR: function [{func_name}] does not support choice=[{choice}]."
        raise KeyError(error_msg)
    
    #--- return
    return da

######################
######################
######################

def read_era5_var(da_era5, varname,
                  var_format="short"):

    """
    read a variable in ERA5 dataset
    """

    #--- read a variable
    var_era5 = da_era5.get(varname)
    
    #--- convert short type to float type if needed
    if (var_format == "short"):
    
        #--- if scale_factor and add_offset are not read in, set them manually
        if (varname == "d"):
            var_era5.attrs['scale_factor']=3.16711921762329e-09
            var_era5.attrs['add_offset']=2.75280295481647e-06

        #--- convert short type to float type
        var_era5 = (var_era5*var_era5.scale_factor + var_era5.add_offset)

    #--- reorganize ERA5 coordinates
    var_era5 = var_era5.rename({'longitude':'lon', 'latitude':'lat'})  # rename lat/lon coordinate names
    var_era5 = var_era5.sortby('lat', ascending=True)                  # change latitude to ascending order, e.g. [28N-34N]
        
    #--- modify time coordinate from "hours since" to readable format "'%Y-%m-%dT%H:%M:%S"
    time_in_hours = var_era5["time"].values
    reference_time = cftime.DatetimeNoLeap(1900, 1, 1, 0, 0, 0)  # Adjust based on your reference time
    date_coordinates = cftime.num2date(time_in_hours, var_era5["time"].attrs["units"], calendar="standard")
    date_strings = [date.strftime('%Y-%m-%dT%H:%M:%S') for date in date_coordinates]
    var_era5["time"] = xr.DataArray(date_strings, dims="time")
    
    #---return
    return var_era5

#-----------
# do_test
#-----------

#do_test=True
do_test=False

if (do_test):
    
    da_era5 = read_dataset()
        
    varname = "d"
    var_era5 = read_era5_var(da_era5, varname)

#var_era5

### select the variable for plotting

In [15]:
def select_var(var,
               region="NE_CA",
               time_slice="mean", 
               data_source=None, 
               RF=None, 
              ):
    """
    Select a subsample of variable in given (time, lat, lon) ranges
    var: a Xarray DataArray
    region: a region given by lat/lon ranges
    time_slice: time_slice
    """
    
    #--- set time slice
    if (data_source == "MERRA2"):
        if (RF == "RF01"): time_slice="2001-07-10T10:30:00.000000000"
    
    elif (data_source == "ERA5"):
        if (RF == "RF01"): time_slice=slice("2001-07-10T09:00:00.000000","2001-07-10T12:00:00.000000")

    #--- set region
    if (region == "NE_CA"):
        lowerlon=235; upperlon=245; lowerlat=28; upperlat=35
    else:
        lowerlon=-1000; upperlon=1000; lowerlat=-1000; upperlat=1000
    
    #--- select and return 
    lon_slice = slice(lowerlon, upperlon)
    lat_slice = slice(lowerlat, upperlat)
    
    #--- select time
    if (time_slice == "mean"):
        var_region = var.sel(lat=lat_slice, lon=lon_slice).mean(["time","lon"])
    else:
        var_region = var.sel(lat=lat_slice, lon=lon_slice, time=time_slice)

    #--- return variable
    return var_region

#-----------
# do_test
#-----------

do_test=True
#do_test=False
    
if (do_test):
    
    da_era5 = read_dataset()
        
    varname = "d"
    var_era5 = read_era5_var(da_era5, varname)
    
    var_2d = select_var(var_era5, varname) #, region=region)

#var_2d

Unnamed: 0,Array,Chunk
Bytes,7.23 kiB,7.23 kiB
Shape,"(37, 25)","(37, 25)"
Count,9 Tasks,1 Chunks
Type,float64,numpy.ndarray
"Array Chunk Bytes 7.23 kiB 7.23 kiB Shape (37, 25) (37, 25) Count 9 Tasks 1 Chunks Type float64 numpy.ndarray",25  37,

Unnamed: 0,Array,Chunk
Bytes,7.23 kiB,7.23 kiB
Shape,"(37, 25)","(37, 25)"
Count,9 Tasks,1 Chunks
Type,float64,numpy.ndarray


### read_data_var_2d

In [19]:
def read_data_var_2d(varname):
    da_era5 = read_dataset()
    
    var_era5 = read_era5_var(da_era5, varname)
    
    var_2d = select_var(var_era5, varname) #, region=region)
    
    return var_2d

#-----------
# do_test
#-----------

#do_test=True
do_test=False
    
if (do_test):
    var_2d = read_data_var_2d(varname='d')
    

## Plot functions

### ax_def_cn_map

In [None]:
from cartopy.mpl.ticker import LatitudeFormatter, LongitudeFormatter

def ax_def_cn_map (ax,
                   map_projection,
                   dict_cn_attrs=None):
    """    
    ----------------------
    Set attributes in cn_map plot using cartopy and matplotlib 

    Input arguments:
        ax: an Axes class variable
        map_projection: ccrs map projection class variable

    Return:
        update ax

    Example:
        map_projection = ccrs.PlateCarree(central_longitude=0)
        fig, ax = plt.subplots(figsize=(10, 6), subplot_kw={'projection': map_projection})
        ax_def_cn_map(ax, map_projection)
      
    References:
      - Cartopy Tick Labels: https://scitools.org.uk/cartopy/docs/latest/gallery/gridlines_and_labels/tick_labels.html

    Date created: 2023-10-15
    ----------------------
    """
    
    #--- add coastline
    ax.add_feature(cfeature.COASTLINE)

    #--- add grid lines
    gl = ax.gridlines(draw_labels=True, linestyle='--', color='gray', alpha = 0.5)  # add lat/lon grid lines
    gl.top_labels = False     # turn off labels on the top and right sides
    gl.right_labels = False
    
    #--- draw specific grid lines. I try to find a method to draw every grid line but not all labels. I couldn't find an easy way.
    #                              Perhaps the easiest one is to specify grid lines and let the program to determine the labels.
    lon_grid_lines = [-124, -123, -122, -121, -120, -119, -118, -117, -116]
    gl.xlocator = FixedLocator(lon_grid_lines)
    gl.xformatter = LongitudeFormatter(zero_direction_label=True)

    #--- add title 
    fontsize=12
    ax.set_title(dict_cn_attrs['name'], loc='left', fontsize=fontsize)
    ax.set_title(dict_cn_attrs['units'], loc='right', fontsize=fontsize)
    
    #--- set lat/lon range
    lon_min=-125 ; lon_max=-115  # even lon is in degrees_east (0-360), lon_min/lon_max still in degrees_west
                                 # i.s. lon_min=-125 rather than 235.
    lat_min=28,  ; lat_max=35.   # 
    #ax.set_xlim(-125,-115)
    #ax.set_ylim(28,35)    
    

### set_dict_cn_attrs

In [None]:
from matplotlib.colors import BoundaryNorm

def set_dict_cn_attrs (varname):
    """    
    ----------------------
    Set contour attributes

    Input arguments:
        var: an Xarray.DataArray variable
        varname: variable name

    Return:
        1. return  a dictionary variable, dict_cn_attrs = {'cn_levels','cmap',label,name,units'}
 
 
    Example:
        gg = set_cn_attrs(var_tmp, 'tdt_dyn')
        print(gg['units'])

    References:
        colormaps:  https://matplotlib.org/stable/tutorials/colors/colormaps.html

    Date created: 2023-10-15
    ----------------------
    """
    
    #--- set cn levels
    tdt_dyn_cnlevels = np.arange(-25., 27.5, 2.5)
    qdt_dyn_cnlevels = np.arange(-10., 11, 1.)
    
    #---------------------
    # set cn attributes
    #---------------------
    if (varname == "skdfksjdkf"):
        cn_levels = tdt_dyn_cnlevels
        cmap="c1"
        label = "l1"
        name="n1"
        units="u1"
 
    elif (varname == "swabs_toa"):
        cn_levels = 15
        cmap="plasma"
        name = "TOA net downward SW flux"
        units = r"$W m^{-2}$"
        label = name+" ("+units+")"

    else:
        cn_levels = np.array([15])
        cmap = "viridis"
        name = "Var"
        units = "units"
        label = name+" ("+units+")"

    #---------------------------------
    # return a dictionary variable
    #---------------------------------

    dict_cn_attrs = {
        'varname':varname,
        'cn_levels':cn_levels,
        'cmap': cmap,
        'label': label,
        'name': name,
        'units': units,
                  }
    
    return dict_cn_attrs

#-----------
# do_test
#-----------

#do_test=True
do_test=False

if (do_test):
    var_tmp = xr.DataArray(1)
    var_tmp.attrs['standard_name']="ggg"
    var_tmp.attrs['units']="KK"

    dict_cn_attrs = set_dict_cn_attrs('swabs_toa_diff')
    print(dict_cn_attrs)
    print(dict_cn_attrs['cmap'])

### plot_box

In [None]:
def plot_box(ax, region="DYCOMS"):
    
    """
    Draw a box on the plot
    """
    func_name = "plot_box"
    
    if (region == "DYCOMS"): 
        region_name = "DYCOMS (30-32.2N, 120-123.8W)"
        lowerlat =  30.    # 30N
        upperlat =  32.2   # 32.2N
        lowerlon =  236.2  # 123.8W
        upperlon =  240.   # 120W
    else:
        error_msg = f"ERROR: function [{func_name}] does not support region=[{region}]"
        raise ValueError(error_msg)        
    
    lon_range = upperlon - lowerlon
    lat_range = upperlat - lowerlat

    rect = mpatches.Rectangle((lowerlon, lowerlat), lon_range, lat_range, facecolor='none', edgecolor='cyan', linewidth=2, transform=map_projection)
    ax.add_patch(rect)
        

### plot_cn_map

In [None]:
from matplotlib.ticker import FixedLocator

def plot_cn_map(ax, map_projection,
                var, varname,
                do_set_cn_attrs = True, 
                lb_orientation='vertical', lb_shrink=0.9, lb_fontsize=12, 
                title="Title",
               ):

    """
    Make a contour over map plot
    
    Input arguments:
        ax             : an Axes class variable
        map_projection : ccrs map projection class variable
        var            : a 2D variable (lat, lon). lat & lon MUST be coordinate variables
        varname        : variable name used in the function set_dict_cn_attrs
        do_set_cn_attrs: call function dict_cn_attrs or not
        lb_orientation : label orientation
        lb_shrink      : control label bar size
        lb_fontsize    : set label bar label font size

    Return:
        update ax
        
    Example:
        map_projection = ccrs.PlateCarree(central_longitude=0)
        fig, ax = plt.subplots(figsize=(10, 6), subplot_kw={'projection': map_projection})

        varname = "TS"
        var_2d = select_var(da, varname)
        do_set_cn_attrs = False
        plot_cn_map(ax, map_projection, var_2d, varname, do_set_cn_attrs=do_set_cn_attrs)
    """
    
    #-------------
    # plot cn_map
    #-------------
    
    #--- set contour properties
    if (do_set_cn_attrs):
        dict_cn_attrs = set_dict_cn_attrs(varname)  # set contour attributie
        cn_map_region = ax.contourf(var.lon, var.lat, var, extend='both', transform=map_projection, levels=dict_cn_attrs['cn_levels'], cmap=dict_cn_attrs['cmap']) 

    else:
        name = var.attrs['long_name']
        units = var.attrs['units']
        label = name+" ("+units+")"
        
        dict_cn_attrs = {
            'varname':varname,
            'cmap': "viridis",
            'label': label,
            'name': name,
            'units': units,
        }
        
        cn_map_region = ax.contourf(var.lon, var.lat, var, 20, transform=map_projection, extend='both') 
        
    #--- set cn_map attributes
    ax_def_cn_map(ax, map_projection, dict_cn_attrs=dict_cn_attrs)
    
    #--- plot the colorbar
    cbar = plt.colorbar(cn_map_region, ax=ax, orientation=lb_orientation, shrink=lb_shrink) #, label=dict_cn_attrs['label'])
    #cbar.ax.tick_params(axis='both', which='major', labelsize=5)
    cbar.set_label(label=dict_cn_attrs['label'], fontsize=lb_fontsize)
    
    #-------------
    # plot DYCOMS region
    #-------------    
    plot_box(ax)

    #------------
    # set title
    #------------
    ax.set_title(title, y=1.1)
    
#-----------
# do_test
#-----------

do_test=True
#do_test=False
    
if (do_test):
    map_projection = ccrs.PlateCarree(central_longitude=0)
    fig, ax = plt.subplots(figsize=(10, 6), subplot_kw={'projection': map_projection})

    da = read_data()
    varname = "TS"

    var_2d = select_var(da.get(varname))
    do_set_cn_attrs = False
    
    #print(var_2d)
    
    plot_cn_map(ax, map_projection, var_2d, varname, do_set_cn_attrs=do_set_cn_attrs)

### plot_vc_map

In [None]:
def plot_vc_map(ax, U_2d, V_2d, 
               lon_stride = 2, lat_stride = 2,
               scale=100, vc_color='red', width=0.005,
               key_lon=-117, key_lat=35.2, key_length=5,   # even lon is in degrees_east, the key_lon is degree_west. So strange.
               ):

    """
    Overlay wind vectors on a map plot
    
    Input arguments:
        ax             : an Axes class variable
        U_2d           : zonal wind (lat, lon). lat & lon MUST be coordinate variables
        V_2d           : meridional wind (lat, lon). lat & lon MUST be coordinate variables
        lon_stride     : plot wind vectors every lon_stride, i.e. U_2d(:, ::lon_stride)
        lat_stride     : plot wind vectors every lat_stride, i.e. U_2D(::lat_stride, :)
        scale          : adjust the length of the vectors
        vc_color       : wind vector color
        width          : wind vector width
        key_lon        : longitude position to put the reference vector length
                         Note that even lon is in degrees_east, the key_lon should be set in degree_west, e.g. lon=240E but key_lon must be -120W. So strange.
        key_lat        : latitude position to put the reference vector length
        key_length     : magnitude of reference vector

    Return:
        update ax
        
    Example:
        map_projection = ccrs.PlateCarree(central_longitude=0)
        fig, ax = plt.subplots(figsize=(10, 6), subplot_kw={'projection': map_projection})

        varname = "TS"
        var_2d = select_var(da, varname)
        do_set_cn_attrs = False
        plot_cn_map(ax, map_projection, var_2d, varname, do_set_cn_attrs=do_set_cn_attrs)

        U_2d = select_var(da, "U10M")
        V_2d = select_var(da, "V10M")
    
        plot_vc_map (ax, U_2d, V_2d)    
    """
    
    #-------------
    # plot vc_map
    #-------------

    #--- vc_map
    quiver = ax.quiver(U_2d.lon[::lon_stride], U_2d.lat[::lat_stride], U_2d[::lat_stride, ::lon_stride], V_2d[::lat_stride, ::lon_stride], 
                       scale_units='width', scale=scale, color=vc_color, width=width)

    #--- vc ref
    key_title = f"{key_length} m/s"
    plot_vc_ref = ax.quiverkey(quiver, key_lon, key_lat, key_length, key_title, coordinates='data', color=vc_color)

#-----------
# do_test
#-----------

do_test=True
#do_test=False
    
if (do_test):
    map_projection = ccrs.PlateCarree(central_longitude=0)
    fig, ax = plt.subplots(figsize=(10, 6), subplot_kw={'projection': map_projection})

    da = read_data()
    
    var_2d = select_var(da.get("TS"))
    U_2d = select_var(da.get("U10M"))
    V_2d = select_var(da.get("V10M"))
    
    do_set_cn_attrs = False
        
    plot_cn_map(ax, map_projection, var_2d, varname, do_set_cn_attrs=do_set_cn_attrs)
    
    plot_vc_map (ax, U_2d, V_2d)    

### plot_all

In [None]:
#--- merge all relevant codes into a single function
def plot_all(varname, title="title"):
    
    map_projection = ccrs.PlateCarree(central_longitude=0)
    fig, ax = plt.subplots(figsize=(10, 6), subplot_kw={'projection': map_projection})

    da = read_data()
    
    var_2d = select_var(da.get(varname))
    U_2d = select_var(da.get("U10M"))
    V_2d = select_var(da.get("V10M"))
    
    do_set_cn_attrs = False
        
    plot_cn_map(ax, map_projection, var_2d, varname, do_set_cn_attrs=do_set_cn_attrs, title=title)
    
    plot_vc_map (ax, U_2d, V_2d)  

#-----------
# do_test
#-----------

#do_test=True
do_test=False
    
if (do_test):
    plot_all(varname="TS", title="TS")

## Plot

#### TS & U10M

In [None]:
plot_all(varname="TS", title="TS")
plot_all(varname="U10M", title="U10M")