# Yi-Hsuan Chen's Python module

**import yhc_module as yhc**

**available functions**
1. unit_conversion

In [None]:
import matplotlib.pyplot as plt
from matplotlib.ticker import (MultipleLocator, AutoMinorLocator)
from datetime import date
import numpy as np
import xarray as xr
import io, os, sys, types
xr.set_options(keep_attrs=True)  # keep attributes after xarray operation

#import yhc_module as yhc

In [None]:
#############################################
# Purpose of this code:
#
# 1. convert yhc_module.ipynb to yhc_module.py
# 2. copy yhc_module.py to a a folder, dir_py, so it can be imported
#############################################

convert_ipynb2py = False
#convert_ipynb2py = True

dir_py="/Users/yihsuan/.ipython"  # (Mac) copy yhc_module to this folder 
                                  #       so that other jupyter notebook can import yhc_module

# location of yhc_module in Mac
py_path = "/Users/yihsuan/Downloads/yihsuan/test/tool/python/yhc_module_and_notes"

if (convert_ipynb2py):
    print("convert yhc_module.ipynb to yhc_module.py")
    command="jupyter nbconvert yhc_module.ipynb --to python"
    os.system(command)
    
    command="cp yhc_module.py "+dir_py
    print(command)
    os.system(command)
    
    sys.path.append(py_path)  # add py_path to sys.path
    

## printv

In [None]:
def printv(var,
           text = "", 
           color = "black",
          ):
    """
    ----------------------
    Description:
      print out variables by begining with a line, follow by a text, and end by a empty line
      This helps to view the output more easily.

    Input arguments:
      var : any variable
      text: text will show on screen

    Return:
      print statement on screen

    Example:
      import yhc_module as yhc
      
      yhc.printv(var, "ggg")
      yhc.printv(var, text = "aaa")

    Reference:
      Print with colors: https://predictivehacks.com/?all-tips=print-with-different-colors-in-jupyter-notebook
      Inserting values into strings, https://matthew-brett.github.io/teaching/string_formatting.html

    Date created: 2022/07/06
    ----------------------
    """    

    #------------------
    # set text color
    #   colors: https://pkg.go.dev/github.com/whitedevops/colors
    #------------------
    
    #--- color codes dictionary
    color_codes = {'red':'\033[91m',
                   'green':'\033[92m',
                   'yellow':'\033[93m',
                   'blue':'\033[94m',
                   'pink':'\033[95m',
                   'teal':'\033[96m',
                   'grey':'\033[97m',
                   'black':'\033[30m',
                   'cyan':'\033[36m',
                   'magenta':'\033[35m',
                   #------ short name  -----
                   'r':'\033[91m',
                   'g':'\033[92m',
                   'b':'\033[94m',
                   'y':'\033[93m',
                   'c':'\033[36m',
                   'm':'\033[35m',
                   'b':'\033[30m',
                  }
    
    #--- read color code
    color_code = color_codes[color]
    
    #------------
    # print out
    #------------
    
    text_out = f"{color_code} ------------- \n {text}"  # format text string
    
    #print("--------------")
    print(text_out)
    print(var)
    print("")
        
#----------
# test
#----------

#do_test = True
do_test = False

if (do_test):
    printv(1, "ddd", 'r')
    printv(2, "eee", color = "g")
    printv(3)



## unit_convert

In [None]:
def unit_convert (var_in, units_in = "none", units_out = "none", 
                  var_type = "xarray.DataArray"
                 ):
    """
    ----------------------
    convert unit of a variable
    
    Input arguments:
        var      : (xarray.DataArray): a variable
        units_in  : (string) original units of the var
        units_out: (string) new units of the var 
    
    Return:
        variable values with units_out
    
    Example
        import yhc_module as yhc
        var = yhc.unit_conversion(var, "m", "km")
        
    Date created: 2022/06/29
    -----------------------
    """
    
    func_name = "unit_convert"
    
    #------------- 
    # constants 
    #------------- 

    rho_water = 1000.  # water density [kg/m3]
    latent_heat_evap = 2.5e+6          # latent heat of vaporization for water, J/kg  
    latent_heat_cond = 1./latent_heat_evap  # latent heat of vaporization for water, J/kg
    hr2sec = 1800.    # hour in seconds
    day2sec = 86400.  # day in seconds

    #------------- 
    # conversion dictionary
    #------------- 

    conversion = {'m':1.0, 'mm':0.001, 'cm':0.01, 'km':1000.,
                  'm/s':1.0, 'mm/day':1./(1000.*day2sec), 'kg/m2/s':1./rho_water, 'W/m2':1./latent_heat_evap/rho_water,
                  'kg/kg':1.0, 'g/kg':1.e-3,
                  'fraction':1., "percent":0.01, "%":0.01,
                  'K/s':1., 'K/day':1./day2sec, 'deg_K/s':1., 
                  'kg/kg/s':1., 'kg kg-1/s':1., 'kg kg-1 s-1':1., 'g/kg/hour': 1000./hr2sec, 'g/kg/day':1./(1000.*day2sec),
                  'Pa/s':1., 'hPa/day': 100./day2sec, 
                  'kg/m2':1., 'g/m2':1.e-3,
                  '1/s':1., '1/hour': 1./hr2sec, 
                  'Pa':1.0, 'hPa':100., "mb":100., 
                  'none':1.0,
                  'K':1.0, 'deg_K':1.0,
                 }

    #--- set a default unit convertion, [units_in, units_out]
    units_dict = {'K/s':'K/day', 'deg_K/s':'K/day', 
                  'Pa/s':'hPa/day','Pa s-1':'hPa/day',
                  'deg_K':'K',
                  'kg/kg/s':'g/kg/day',
                 }
    
    #------------- 
    # conversion 
    #------------- 

    #printv(var_in, 'var_in', 'r')
    
    #--- if input var is a Xarray DataArray
    if (var_type == "xarray.DataArray"):

        #--- if units_in and units_out are not given
        if units_in == "none" or units_out == "none":
            
            #--- get units of var_in, and then get the default units_out
            if "units" in var_in.attrs: units_in = var_in.attrs['units']
                
            if units_out == "none" and units_in in units_dict: 
                units_out = units_dict[units_in]
            else:
                error_msg = f"ERROR: [{func_name}] does not contain units [{units_in}]. Please modify {func_name}"
                raise KeyError(error_msg)

        var_out = var_in * conversion[units_in] / conversion[units_out]
        var_out.attrs['units'] = units_out
                
    #--- any other data types
    else:
        var_out = var_in * conversion[units_in] / conversion[units_out]
    
        if hasattr(var, 'units'):
            var.units = units_out
        else:
            setattr(var_out,"units",units_out)

    #------------- 
    # return
    #------------- 
    #print(var_out)
    
    return var_out; 

#--------
# test
#--------

#do_test = True
do_test = False

if (do_test):
    Ps_scm = xr.DataArray([101780.], dims=['time'])
    Ps_scm.attrs['long_name']="Ps"
    printv(Ps_scm, "Pa")

    Ps_scm = unit_convert(Ps_scm, "Pa", "hPa")
    printv(Ps_scm, "hPa")
    
    T_scm = xr.DataArray([0.5], dims=['time'])
    T_scm.attrs['units']="deg_ggg"
    T_scm1 = unit_convert(T_scm, "kg/kg/s", "g/kg/day")
    
    printv(T_scm1, 'T', 'g')



## get_region_latlon

In [None]:
def get_region_latlon(region = "N/A"):
    """
    ----------------------
    Given a region name, return slices of latitude and longitudes.
    
    Input arguments:
        region = (string) name of the region
            region_list = ["Californian_Sc","Peruvian_Sc","Namibian_Sc","DYCOMS"]

    Return:
        lon_slice & lat_slice of the given region

    Example:
      import yhc_module as yhc
      region = "Californiand_Sc"
      lon_slice, lat_slice = yhc.get_region_latlon(region)
      print(lon_slice)
      print(lat_slice)
      
    Date created: 2022/07/01
    ----------------------
    """

    func_name = "get_region_latlon"
    
    #------------------------
    #  read region name
    #------------------------

    region_list = ["Californian_Sc","Peruvian_Sc","Namibian_Sc","DYCOMS"]
    
    if (region == "Californian_Sc"):
        region_name = "CA marine Sc (20-30N, 120-130W)"
        lowerlat = 20.   # 20N 
        upperlat = 30.   # 30N
        lowerlon = 230.  # 230E
        upperlon = 240.  # 240E

    elif (region == "Peruvian_Sc"): 
        region_name = "Peruvian marine Sc (10S-20S, 80W-90W)"
        lowerlat = -20.  # 20S
        upperlat = -10.  # 10S
        lowerlon = 270.  # 90W
        upperlon = 280.  # 80W

    elif (region == "Namibian_Sc"): 
        region_name = "Namibian marine Sc (10S-20S, 0E-10E)"
        lowerlat = -20.  # 20S
        upperlat = -10.  # 10S
        lowerlon =   0.  # 0E
        upperlon =  10.  # 10E
        
    elif (region == "DYCOMS"): 
        #--- reference: Stevens et al. (2007, MWR)??
        region_name = "DYCOMS (30-32.2N, 120-123.8W)"
        lowerlat =  30.    # 30N
        upperlat =  32.2   # 32.2N
        lowerlon =  236.2  # 123.8W
        upperlon =  240.   # 120W

    elif (region == "global"): 
        region_name = "global (0-360E, -90S-90N)"
        lowerlat = -1000.   # N
        upperlat = 1000.    # N
        lowerlon = -1000.   # E
        upperlon = 1000.    # E
    
    #elif (region == ""): 
    #  region_name = "(N/S, E/W)"
    #  lowerlat =      # N
    #  upperlat =     # N
    #  lowerlon =    # W
    #  upperlon =  # W
   
    else:
        error_msg = "function *"+func_name+"*: input region ["+region+"] is not supported. STOP. \n" \
                    + "Available regions: "+ ', '.join(region_list)
        #sys.exit(error_msg)
        raise ValueError(error_msg)
    
    #------------------------
    #  compute lon and lat slices
    #    Python slice function: https://www.w3schools.com/PYTHON/ref_func_slice.asp
    #------------------------    

    lon_slice = slice(lowerlon, upperlon)
    lat_slice = slice(lowerlat, upperlat)

    #setattr(lon_slice,"region_long_name","ddd")
    #lon_slice.regionlong_name = "ddd"
    
    return lon_slice, lat_slice

#-----------
# do_test
#-----------

#do_test=True
do_test=False

if (do_test):
    ddd = "aaDYCOMS"
    lon_slice, lat_slice = get_region_latlon(ddd)


## wgt_avg

In [None]:
def wgt_avg (xa,
             dims = ["lon","lat"],
            ):
    """
    ----------------------
    Calculate weighted mean of a [*, lon, lat] Xarray data
    
    Input arguments:
      xa : data Array
      dim: dimension to average, e.g. ["lon"], ["lon","lat"]  
      
    Return:
      latitude-weighted average of the input data

    Example
      var is a [*, lon, lat] Xarray data
      
      import yhc_module as yhc
      var_ijmean = yhc.wgt_avg(var)
      var_jmean  = yhc.wgt_avg(var, dim="lat")
        
    References
    1. https://docs.xarray.dev/en/stable/examples/area_weighted_temperature.html
    2. https://nordicesmhub.github.io/NEGI-Abisko-2019/training/Example_model_global_arctic_average.html
 
    Date created: 2022/07/01
    ----------------------
    """
    
    #--- compute latitudal weights
    weights = np.cos(np.deg2rad(xa.lat))
    weights.name = "weights"
    #print(weights)
    
    #--- compute weighted mean    
    xa_weighted = xa.weighted(weights)
    xa_weighted_mean = xa_weighted.mean(dims)
    
    return xa_weighted_mean

#-------
# test 
#-------
do_test = False
#do_test = True

if (do_test):
    #gg = np.ones([3,2])
    gg = np.array(([1,1],[1,1],[100,100]))
    print(gg)
    lat = np.array([0.,60.,90.])
    lon = np.array([100.,100.])

    data = xr.DataArray(gg, dims=['lat','lon'], coords=[lat, lon])

    pp = wgt_avg(data, 'lon')
    
    printv(data, 'original')
    printv(pp, 'weighted')
    

## mlevs_to_plevs

In [None]:
def mlevs_to_plevs (ps,
                    model, 
                    plevs, 
                   ):
    """
    ----------------------
    Description: 
        Compute pressure levels from a climate or weather model, given necessary information

    Input arguments:
        ps   : (an xarray DataArray). Surface pressure (Pa)
        model: (a string) model name. Check out "model_list" variable in this function to see which model is supported
        plevs: (a string) return pressure levels. Check out "plevs_list" to see which are supported

        Currently avaialbe:
            model = "AM4_L33_native", GFDL AM4 with 33 levles 
                data.ps MUST be present, surface pressure in Pa
                plevs = ["pfull","phalf"], pressure at full levels or half levels

    Return:
        Input data plus a new dimention that contains pressure levels

    Example:
        import yhc_module as yhc
        data = xr.open_dataset(ncfile)
        data.ps # (time, lat, lon)
        phalf = mlevs_to_plevs(data, plevs="phalf", model = "AM4_L33_native")
        phalf # (time, lat, lon, plev)
     
    Date created: 2022/07/07
    ----------------------
    """

    func_name = "mlevs_to_plevs"
    
    model_list = ["AM4_L33_native"]

#------------------------------------------
# check if the model is supported
#------------------------------------------

    if model not in model_list:
        error_msg = "function *"+func_name+"*: model ["+model+"] is not supported. STOP. \n" \
                    + "Available options: "+ ', '.join(model_list)
        sys.exit(error_msg)
    
#------------------------------------------
# process model and plevs_out
#------------------------------------------

#@@@@@@@@@@@@@@@@@@@@@@@@@
#@@@@@@@@@@@@@@@@@@@@@@@@@
    if (model == "AM4_L33_native"):
        """ 
        In AM4, pressure(k) = pk(k) + bk(k)*ps, where ps is surface pressure (Pa).
        """
    
        plevs_list=["phalf", "pfull"]  # AM4 only supports pressure at half levels (phalf) and at full levels (pfull)
    
        #--- check
        if plevs not in plevs_list:        
            error_msg = "function *"+func_name+"*: input plevs ["+plevs+"] is not supported. STOP. \n" \
                      + "Available options: "+ ', '.join(plevs_list)
            raise ValueError(error_msg)
            
        #--- pk values. Output directly from AM4 files
        pk_list = [100, 400, 818.6021, 1378.886, 2091.795, 2983.641, 4121.79, 5579.222, 
          6907.19, 7735.787, 8197.665, 8377.955, 8331.696, 8094.722, 7690.857, 
          7139.018, 6464.803, 5712.357, 4940.054, 4198.604, 3516.633, 2905.199, 
          2366.737, 1899.195, 1497.781, 1156.253, 867.792, 625.5933, 426.2132, 
          264.7661, 145.0665, 60, 15, 0]
 
        #--- bk values. Output directly from AM4 files
        bk_list = [0, 0, 0, 0, 0, 0, 0, 0, 0.00513, 0.01969, 0.04299, 0.07477, 0.11508, 
          0.16408, 0.22198, 0.28865, 0.36281, 0.44112, 0.51882, 0.59185, 0.6581, 
          0.71694, 0.76843, 0.81293, 0.851, 0.88331, 0.91055, 0.93331, 0.95214, 
          0.9675, 0.97968, 0.98908, 0.99575, 1]      

        #--- make pk and bk as xarray.DataArray
        pk = xr.DataArray(pk_list, dims=['plev'])
        bk = xr.DataArray(bk_list, dims=['plev'])    
    
        #--- compute presure at half levels
        phalf = ps*bk + pk
        phalf.attrs['long_name'] = "Pressure at half levels"
        
        #--- compute pressure at full levels
        pfull = phalf.rolling(plev=2, center=True).mean().dropna("plev")
        pfull.attrs['long_name'] = "Pressure at full levels"
    
        #--- return plevs
        if (plevs == "phalf"):
            plevs_return = phalf
        elif (plevs == "pfull"):
            plevs_return = pfull        
        
        plevs_return.attrs['conversion_method'] = model
        if (hasattr(ps, 'units')): plevs_return.attrs['units'] = ps.attrs['units']
            
        return plevs_return
        
#@@@@@@@@@@@@@@@@@@@@@@@@@
#@@@@@@@@@@@@@@@@@@@@@@@@@
#elif (model == "other_model"):
#    """ 
#    """        
#    return 

#-----------
# do test
#-----------

#do_test = True
do_test = False

if (do_test):
    model = "AM4_L33_native"
    Ps = xr.DataArray([102078.5], dims=['time'])
    Ps.attrs['units'] = "Pa"
    plev = mlevs_to_plevs(Ps, model, "pfull")
    print(plev)


## wrap360

In [None]:
def wrap360(ds, lon='lon'):
    """
    Source code: https://github.com/pydata/xarray/issues/577
    
    wrap longitude coordinates from -180..180 to 0..360

    Parameters
    ----------
    ds : Dataset
        object with longitude coordinates
    lon : string
        name of the longitude ('lon', 'longitude', ...)

    Returns
    -------
    wrapped : Dataset
        Another dataset array wrapped around.
    """

    # wrap -180..179 to 0..359    
    ds.coords[lon] = np.mod(ds[lon], 360)

    # sort the data
    return ds.reindex({ lon : np.sort(ds[lon])})

## get_area_avg

In [None]:
def get_area_avg (var, region, 
                  weighted = True,
                  lat='lat', lon='lon'):    
    """    
    ----------------------
    Description:
      compute area-mean of a xr.DataArray variable. The default is weighted by latitudes.

    Input arguments:
      var     : an xarray DataArray variable, preferably [*, lat, lon]
      region  : a string, region name used by get_region_latlon
      weighted: a logical variable. True: weighted by latitudes, False: normal mean
      lat     : coordinate name of latitude
      lat     : coordinate name of longitude

    Return:
      var_area_avg: an an xarray DataArray variable, area-mean of the given variable

    Example:
      import yhc_module as yhc

      var (lat: 3, lon: 2)
      var_area_mean = yhc.get_area_avg(var)
      var_area_mean = yhc.get_area_avg(var, "DYCOMS")

    Notes:
      1. No need to check if var is pressure or height. Although this can cause troubles in doing area average
         (e.g. on steep terrain, surface height from 0 to 1000m), adding a check is too complicate.

    Date created: 2022/07/23
    ----------------------


    func_name = ""

    #---

    return
    """

    #--- get lat/lon of the region
    lon_slice, lat_slice = get_region_latlon(region)
    
    #--- lat-weighted average
    if (weighted):
        var_area_avg = wgt_avg (var.sel(lat=lat_slice, lon=lon_slice))

    #--- normal average
    else:
        var_area_avg = var.sel(lat=lat_slice, lon=lon_slice).mean([lat,lon])

    #--- set attribute
    var_area_avg.attrs['region'] = region
    
    return var_area_avg;

#-------
# test 
#-------
do_test = False
#do_test = True

if (do_test):
    gg = np.array(([1,1],[2,2],[100,100]))
    lat = np.array([0.,60.,90.])
    lon = np.array([100.,100.])

    var = xr.DataArray(gg, dims=['lat','lon'], coords=[lat, lon])
    var.attrs['units']="K"
    printv(var,'var')
        
    region = "DYCOMS"
    var_ijavg = get_area_avg(var, region)
    printv(var_ijavg, 'var_ijavg')
            

## modify_attrs

In [None]:
def modify_attrs(var, 
                 varname = "N/A", 
                 attrs_add = "N/A",
                 attrs_del = "N/A",
                ): 
    
    """
    ----------------------
    Description:
      Add/Delete/Modify attributes of a Xarray DataArray

    Input arguments:
      var: an Xarray DataArray
      varname: add predefined attributes to var, e.g. ["u","v"]
      attrs_add: attributes that will be added in var, e.g. ["longname=aaa","units=kkk"]
      attrs_del: attributes that will be deleted, e.g. ["att1","att2"]

    Return:
      var with modified attributes

    Example:
      import yhc_module as yhc
      yhc.modify_attrs(var, varname = ["u"], attrs_add=["att1=0101","att2=0202"], attrs_del=["units"])

      Must use [..], otherwise varname is not recognize (e.g. "omega" and the program will see "o")
      
    Date created: 2022-08-17
    ----------------------
    """
    
    func_name = "modify_attrs"
    
    #-----------------------------------------
    # set predefine attributes of variables
    #   the naming follows CF Metadata Convention, 
    #.     https://cfconventions.org/
    #.  CF Standard Name Table, Version 79, 19 March 2022, 
    #.     https://cfconventions.org/Data/cf-standard-names/current/build/cf-standard-name-table.html
    #-----------------------------------------
    
    #--- set dictionary of variables
    varname_dict={
        "u": ["long_name=eastward_wind", "units=m s-1"],
        "v": ["long_name=northward_wind", "units=m s-1"],
        "t": ["long_name=ait_temperature", "units=K"],
        "q": ["long_name=specific_humidity", "units=kg kg-1"],
        "omega": ["long_name=vertical_pressure_velocity", "units=Pa s-1"],  # not consistent with CF convenction
        "ug": ["long_name=geostrophic_eastward_wind", "units=m s-1"], 
        "vg": ["long_name=geostrophic_northward_wind", "units=m s-1"],
        "ps": ["long_name=surface_air_pressure", "units=Pa"], # not consistent with CF convenction
        "ts": ["long_name=surface_temperature", "units=K"],
        "shflx": ["long_name=sensible_heat_flux", "units=W m-2"], # not consistent with CF convenction
        "lhflx": ["long_name=latent_heat_flux", "units=W m-2"], # not consistent with CF convenction
        "divt": ["long_name=Horizontal large scale temp. forcing", "units=K s-1"], # not consistent with CF convenctin
        "divq": ["long_name=Horizontal large scale water vapor forcing", "units=kg kg-1 s-1"], # not consistent with CF convenctin
        "vertdivT": ["long_name=Vertcal large scale temp. forcing", "units=K s-1"], # not consistent with CF convenctin
        "vertdivq": ["long_name=Vertical large scale water vapor forcing", "units=kg kg-1 s-1"], # not consistent with CF convenctin
        "divt3d": ["long_name=3d large scale temp. forcing", "units=K s-1"], # not consistent with CF convenctin
        "divt3d": ["long_name=3d large scale water vapor forcing", "units=kg kg-1 s-1"], # not consistent with CF convenctin    
    }
        #"": ["long_name=", "units="], # not consistent with CF convenctin    
    
    
    if (not varname == "N/A"):
        for var1 in varname:   
            #--- get variable attibutes
            if var1 in varname_dict:
                attrs_var = varname_dict[var1]
            else:
                varname_keys = list(varname_dict.keys())
                error_msg = f"ERROR: [{func_name}] does not support varname [{var1}]. Available varname: {varname_keys}"
                raise KeyError(error_msg)
            
            #print(attrs_var)
            #--- set attributes
            for att1 in attrs_var:
                att1_list = att1.split('=')
                att1_name = att1_list[0]
                att1_value = att1_list[1]

                var.attrs[att1_name] = att1_value
    
    #------------------------
    #  delete attributes
    #------------------------
    for att1 in attrs_del:
        if (hasattr(var, att1)): del var.attrs[att1]
    
    #------------------------
    #  add attributes
    #------------------------
    if (not attrs_add == "N/A"):
        for att1 in attrs_add:
            att1_list = att1.split('=')
            att1_name = att1_list[0]
            att1_value = att1_list[1]

            var.attrs[att1_name] = att1_value
    
    #---------
    # return
    #---------
    return var
    
#-----------
# do_test
#-----------

#do_test=True
do_test=False

if (do_test):
    var = xr.DataArray([1.])
    var.attrs['long_name']="long1"
    var.attrs['units']="units"
    var.attrs['cell']="cell1"
    var.attrs['source']="ss1"
    printv(var,'before')

    #modify_attrs(var, varname = "U", attrs_del=["cell","source","ddd","long_name"], attrs_add=["time=1.2332","average=kkk"])
    var1 = modify_attrs(var, varname = ["omega"], attrs_add=["att1=0101","att2=0202"], attrs_del=["units"])
    printv(var,'after')


## lib

In [None]:
def lib(*keywords, color = True):
    """    
    ----------------------
    My library of Python codes

    Input arguments:
      keywords: string variables. 
      color   : logical variable. Turn on/off colored texts

    Return:
      print out on screen

    Example:
      import yhc_module as yhc
      yhc.lib("plt","")                 
      yhc.lib("plt","pltxy", color = False)


    Date created: 2022/07/23
    ----------------------
    """
    
    color_comment = '\033[36m'  # cyan color
    color_black   = '\033[0m'   # black color
    
    #------------------
    # library of codes
    #------------------
    
    #@@@@@@@@@@@@@@@@@@@@@@@@@@, usual ax in XY plots
    #@@@@@@@@@@@@@@@@@@@@@@@@@@
    #--- ax for XY plots
    ax_def_xy = """
    
    #------------------------ 
    # usual ax in XY plots
    #------------------------  

#==============================
def ax_def_xy (ax, var):

    #--- set grids
    ax.grid(True)
    ax.minorticks_on()
    
    #--- inverse axes
    ax.invert_yaxis()
    
    #--- legend
    legend_default = ["A","B"]
    ax.legend(legend_default)
       ### change legend size, ax.legend(legend_default, prop={'size': legend_size})

    #--- set title
    fontsize_title = 18
    ax.set_title(var.attrs['long_name'], loc='left')
    ax.set_title(var.attrs['units'], loc='right')
      ### set font size, ax.set_title(var.attrs['units'], loc='right', fontsize = fontsize_title)

    #--- set x or y labels
    fontsize_label = 18
    ax.set_xlabel(var.attrs['long_name']+" ("+var.attrs['units']+")")
    ax.set_ylabel("Pressure (hPa)")  
      ### set font size, ax.set_ylabel("A", fontsize = fontsize_label)
    
    #--- set x range
    ax.set_xlim([0,len(var)])
    
    #--- set x tickmark
    #xvalues = np.arange()
    #xlabels = np.arange()
    #ax.set_xticks(xvalues)        # tickmark values
    #ax.set_xticklabels(xlabels)  # tickmark labels

    #--- set tick mark sizes
    fontsize_tm = 12  # set tick mark size
    ax.tick_params(axis='x', labelsize = fontsize_tm)
    ax.tick_params(axis='y', labelsize = fontsize_tm)
#============================== 
    """

    #@@@@@@@@@@@@@@@@@@@@@@@@@@, dictionary
    #@@@@@@@@@@@@@@@@@@@@@@@@@@
    #--- python dictionary
    dict = """

    #-------------------------
    # Python disctionary
    #-------------------------  $$
    
    #--- set a dictionary $$
    dict = {
        "u": ["long_name=eastward_wind", "units=m s-1"],
        "v": ["long_name=northward_wind", "units=m s-1"],
    }
    
    #--- look up the dictionary  $$
    print(dict['u'])
    
    #--- if statement for a dictionary  $$
    if var1 in dict:
        attrs_var = dict['u']
    else:
        dict_keys = list(dict.keys())   # get all keys in the dict
    """
    
    #@@@@@@@@@@@@@@@@@@@@@@@@@@, function
    #@@@@@@@@@@@@@@@@@@@@@@@@@@
    #--- function define statments (fdef)
    
    today = date.today()    
    fdef =  f"""
    
    #-------------------------
    # define a new function
    #-------------------------  $$
    
def (): 
    
    \"""
    ----------------------
    Description:


    Input arguments:


    Return:


    Example:
      import yhc_module as yhc
       = yhc.()

    Date created: {today}
    ----------------------
    \"""


    func_name = ""

    #--- $$

    # error_msg = f"ERROR: function [func_name] does not support []. Available:"
    # raise KeyError(error_msg)
    # raise ValueError(error_msg)

    return
    
    
#-----------
# do_test
#-----------

do_test=True
#do_test=False

if (do_test):
        
    """
    
    #@@@@@@@@@@@@@@@@@@@@@@@@@@, module
    #@@@@@@@@@@@@@@@@@@@@@@@@@@
    #--- module
    module = """
    #-------------------------
    # useful module command
    #-------------------------
    
    #--- import $$
    import yhc_module as yhc

    #--- help  $$
    help(yhc)          # see all functions and file path
    help(yhc.func1)    # see func1 
    dir(yhc)           # retun list of attributes and methods of an objecy
    """
    
    
    #@@@@@@@@@@@@@@@@@@@@@@@@@@, numpy
    #@@@@@@@@@@@@@@@@@@@@@@@@@@
    #--- numpy
    np = """
    #-------------
    # Numpy functions
    #-------------
    
    #--- create arrays  $$
    np.arange(0,3)             # create an array = [0,1,2,3]
    np.linspace(-10., 10., 5)  # create an array of 5 elements ranging from -10 to 10

    """
    
    #@@@@@@@@@@@@@@@@@@@@@@@@@@, matplotlib.pyplot 
    #@@@@@@@@@@@@@@@@@@@@@@@@@@
    #--- matplotlib.pyplot general
    plt = """
    #-------------
    #
    # matplotlib.pyplot general
    #
    #-------------  $$    

    #-------------------------- 
    #  open fig and ax
    #-------------------------- $$

    #      matplotlib.pyplot.subplot, https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.subplot.html  $$
    fig, (ax1, ax2, ax3) = plt.subplots(1,3, figsize=(18, 6))   # 1 row, 3 columns
    fig, ((ax_1, ax_2, ax_3), (ax_4, ax_5, ax_6)) = plt.subplots(nrows=2, ncols=3, figsize=(18, 12))  # 2 rows, 3 columns
    fig.suptitle("title", fontsize = 20, y=0.95)   # add title

    #--- set spacing betweeon subplots
    #      https://www.geeksforgeeks.org/how-to-set-the-spacing-between-subplots-in-matplotlib-in-python/ $$
    fig.tight_layout()
    #fig.tight_layout(pad=5.0)   # larger pad, larger space betweeon plots

    #-------------
    #
    # Set title and labels
    #
    #-------------  $$

    #--- main titles  $$
    fig.suptitle("DYCOMS SCM initial profiles", fontsize=20, y=95)
    
    ax.set_title("main titile")
    ax.set_title(var.attrs['long_name'], loc='left', fontsize = 18, y=0.9)
    ax.set_title(var.attrs['units'], loc='right')

    #--- x and y labels  $$
    ax.set_xlabel("X-Axis title")
    ax.set_ylabel("Y-Axis title")

    #---------------- 
    #  add text 
    #    Use LaTex convention: https://matplotlib.org/stable/tutorials/text/mathtext.html
    #---------------- $$
    ax1.text(xx, yy, 'text', fontsize = 15)
    
    var_dum.attrs['units'] = r"$10^6 s^{-1}$"  # 10^6/s

    #-------------------------- 
    #  save figure
    #-------------------------- $$

    plt.savefig('fig.png', dpi=300)
    
    """

    #@@@@@@@@@@@@@@@@@@@@@@@@@@, matplotlib.pyplot 
    #@@@@@@@@@@@@@@@@@@@@@@@@@@
    #--- matplotlib.pyplot default settings
    pltset = """
#---------------------------
# set plt default settings
#    References: https://www.geeksforgeeks.org/matplotlib-pyplot-rc-in-python/
#---------------------------  $$

#### set default plt parameters & use rc function to update 

#--- set default values for font group  $$
font_default = {
    'size': 15,      # font size 
       }
plt.rc('font', **font_default)

       
#--- set default values for lines group  $$
lines_default = {
    'linewidth':3,
}
plt.rc('lines', **lines_default)

    """
    
    #@@@@@@@@@@@@@@@@@@@@@@@@@@, matplotlib.pyplot 
    #@@@@@@@@@@@@@@@@@@@@@@@@@@
    #--- matplotlib.pyplot for 2-D filled contour plot
    pltcn_2d = """
    #-------------
    #
    # matplotlib.pyplot for 2-D filled contour plots
    #
    #-------------  $$

    #--- open subplots $$
    fig, (ax1, ax2) = plt.subplots(2,1, figsize=(18, 12))   # 2 row, 1 column
    fig.tight_layout(pad=5.0)

    #--- determine contour levels $$
    cn_levels = np.array([...])
    cn_levels = np.linspace(-100., 100., 11)
    
    #--- colormaps $$
    https://matplotlib.org/stable/tutorials/colors/colormaps.html

    cmap="coolwarm"   # temperature tendencies
    cmap="BrBG"       # moisture tendencies
    cmap="YlGnBu"     # omega
    camp="Spectral"   # divergence 

    #--- ax1 $$
    plot_cn1 = ax1.contourf(X, Y, Z, 15, cmap="OrRd")  # automatically 15 levels  reversed cmap "OrRd_r"
    plot_cn1 = ax1.contourf(X, Y, Z, levels = cn_levels, cmap="OrRd") 
    fig.colorbar(plot_cn1, ax=ax1, orientation='vertical', label='Some units')
 
    #--- ax2 $$
    plot_cn2 = ax2.contourf(X, Y, Z, 10, cmap="BrBG")  # automatically 10 levels  reversed cmap "OrRd_r"
    fig.colorbar(plot_cn2, ax=ax2, orientation='vertical', label='Some units')

    #--- set X & Y labels $$
    ax.set_title("title")
    ax.set_xlabel("xx")
    ax.set_ylabel("yy")
    
    """
    
    #@@@@@@@@@@@@@@@@@@@@@@@@@@, matplotlib.pyplot 
    #@@@@@@@@@@@@@@@@@@@@@@@@@@
    #--- matplotlib.pyplot for XY plots
    pltxy = """
    #-------------
    #
    # matplotlib.pyplot for XY plots
    #
    #-------------  $$

    #--- XY line styles  
    #    * Line colors, https://matplotlib.org/stable/gallery/color/named_colors.html
    #        red (r), blue (b), green (g), cyan (c), magenta (m), yellow (y), black (k), white (w)
    #.   * Line dash pattern, https://matplotlib.org/stable/gallery/lines_bars_and_markers/linestyles.html
    #.       {'-', '--', '-.', ':'} {solid, dashed, dashdot, dotted}
    #.   * Markers, https://matplotlib.org/stable/api/markers_api.html
    #        {'.', 'o', '^', 's'} = {point, circle, triangle_up, square}  $$

    style_r1 = 'r-'; style_r2 = 'r--'; style_r3 = 'r:'
    style_b1 = 'b-'; style_b2 = 'b--'; style_b3 = 'b:'
    style_y1 = 'y-'; style_y2 = 'y--'; style_y3 = 'y:'
    style_c1 = 'c-'; style_c2 = 'c--'; style_c2 = 'c:'
    style_g1 = 'g-'; style_g2 = 'g--'; style_g3 = 'g:'
    style_m1 = 'm-'; style_m2 = 'm--'; style_m3 = 'm:'
    style_k1 = 'k-'; style_k2 = 'k--'; style_k3 = 'k:'

    #--- plot  $$
    ax1.plot(xx1, yy1, 'r',
             xx2, yy2, 'b--',
             )

    #--- add legend  $$
    legend_0 = ["",""]
    legend_size = 12
    ax1.legend(legend_0)
       ### change legend size, ax1.legend(legend_0, prop={'size': legend_size})

    #--- set X and Y axis range  $$
    ax1.set_xlim([xmin, xmax])
    ax1.set_ylim([ymin, ymax])

    #--- reverse axes  $$
    ax.invert_xaxis()
    ax.invert_yaxis()

    #-------------
    #
    # Fill between lines
    #.  reference: https://www.geeksforgeeks.org/matplotlib-pyplot-fill_between-in-python/
    #              https://pythonguides.com/matplotlib-fill_between/
    #-------------  $$

    #--- fill between y1 & y2 $$
    ax1.fill_between( 
            xx, y1, y2, color='gray', alpha=0.2,
                )

    #-------------
    #
    # Set grid lines and tick markers
    #
    #-------------  $$

    #--- set grid lines  $$
    #      https://www.pythoncharts.com/matplotlib/customizing-grid-matplotlib/
    #      https://www.tutorialspoint.com/matplotlib/matplotlib_setting_ticks_and_tick_labels.htm
    ax.grid(True)

    ax.grid(which='major', color='gray', linestyle='-.', linewidth=0.7)        # set both X and Y grids
    ax.grid(which='minor', color='gray', linestyle='-.'. linewidth=0.7)

    ax.xaxis.grid(which="minor", color='gray', linestyle='-.', linewidth=0.7)  # set X grids
    ax.xaxis.grid(which="major", color='gray', linestyle='-.', linewidth=0.7)

    #--- minor tickers $$
    ax.minorticks_on()                              # turn on minor tick marks
    ax.xaxis.set_minor_locator(MultipleLocator(1))  # draw minor tickers every 1 units

    #--- setting ticks and tick labels
    #.     https://www.tutorialspoint.com/matplotlib/matplotlib_setting_ticks_and_tick_labels.htm  $$
    ax.set_xticks([0,2,4,6])
    ax.set_xticklabels(['zero','two','four','six'])
    
    #xvalues = np.arange()
    #xlabels = np.arange()
    #ax.set_xticks(xvalues)        # tickmark values
    #ax.set_xticklabels(xlabels)   # tickmark labels

    """

    #@@@@@@@@@@@@@@@@@@@@@@@@@@, python built-in functions
    #@@@@@@@@@@@@@@@@@@@@@@@@@@
    py = """
    #-----------------------------
    # python built-in functions
    #-----------------------------

    #--- stop the code and return error message $$
    
    error_meg = f"ERROR: value [{value}] is not supported"
    raise ValueError(error_meg)
        
    #--- copy a string $$
    text2 = text1[:]   # text2 = text1 also works. To avoid confusion, adding [:] might be better
    """
    
    #@@@@@@@@@@@@@@@@@@@@@@@@@@, xarray
    #@@@@@@@@@@@@@@@@@@@@@@@@@@
    #--- xarray
    xr = """
    #-------------
    # Xarray DataArray
    #-------------
    
      #--- create a DataArray & set attributes $$
      da = xr.DataArray(data, dims=['time', 'lat', 'lon'], coords=[time, lat, lon])
      
      da = xr.DataArray(np.zeros(dim0*dim1).reshape(dim0,dim1), dims=['time','plev'])

      da.attrs['units']="K"
    
      #--- reorder dimensions $$
      # da[time, lat, lon]
      da_jit = da.transpose("lat","lon","time")

    #-*************************************
    # Open netCDF files using Xarray
    #-*************************************  $$
      
      #--- Open a netCDF file  $$
      filename = "./test111.nc"
      ds1 = xr.open_dataset(filename)
      ds1
      
      #--- Open multiple netCDF file  
      #.     open_mfdataset sucks. I spent an hour searching on internet, but I still didn't know 
      #      how to use it. There was no (or useless) examples.  
      #. $$

      datapath = "../data/"
      filename = [file1.nc, file2.nc]    # file has dimension (time, lat, lon, ...)
      filename = [ datapath + ff for ff in filename ]
      
      da = xr.open_mfdataset(filename, concat_dim=["lon"], combine='nested')  # this will return da (lon=2, ...)

      #--- get variable in the DataSet $$
      ds1.temp
      
      varname = "temp"
      var = ds1.get(varname)

    #-************************************
    # Xarray Dataset
    #.   - Save multiple DataArray to a DataSet
    #.   - Merge multiple DataSet
    #.   - Save a DataSet to a new netCDF file
    #-************************************ $$
    
      #--- Convert multiple DataArray into a DataSet  $$
      # var1[x,y], var2[x]
    
      ds = va1.to_dataset(name = "varA")
      ds.attrs["contact"] = "yihsuan"     # set global attribute
    
      ds['varB'] = var2
    
      #--- Merge multiple datasets
      #.     ds1 and ds2 are dataset  $$
      ds_merge = xr.merge([ds1, ds2]) 
    
      #--- Save a DataSet to a new netCDF file  $$
      new_filename = "./test111.nc"
  
      ds.to_netcdf(path=new_filename)
      ds.close()

      ds1 = xr.open_dataset(new_filename)
      ds1
    """

    #@@@@@@@@@@@@@@@@@@@@@@@@@@, string
    #@@@@@@@@@@@@@@@@@@@@@@@@@@
    #--- string
    str = """
    #---------------------
    # string operation
    #---------------------

    #--- Append suffix / prefix to strings in list $$
    filenames_scm = [A, B, C]
    filenames_scm = [datapath+file1 for file1 in filenames_scm]
    """
    
    #------------------
    # print out
    #------------------
    
    keywords_list=['ax_def_xy',
                   'dict',
                   'fdef',
                   'np',
                   'module',
                   'plt',
                   'pltcn_2d',
                   'pltset',
                   'pltxy',
                   'py',
                   'str',
                   'xr'
                  ]
    
    for key1 in keywords:
        if (key1 == "fdef"):
            text = fdef[:]
    
        elif (key1 == "ax_def_xy"):
            text = ax_def_xy[:]

        elif (key1 == "dict"):
            text = dict[:]

        elif (key1 == "np"):
            text = np[:]

        elif (key1 == "module"):
            text = module[:]
            
        elif (key1 == "plt"):
            text = plt[:]        
            
        elif (key1 == "pltcn_2d"):
            text = pltcn_2d[:]        

        elif (key1 == "pltset"):
            text = pltset[:]    
            
        elif (key1 == "pltxy"):
            text = pltxy[:]
        
        elif (key1 == "py"):
            text = py[:]
            
        elif (key1 == "xr"):
            text = xr[:]

        elif (key1 == "str"):
            text = str[:]
            
        else:
            text = ""
            print("ERROR: ["+ key1+"] is not supported yet.")
            print("supported keywords: ["+', '.join(keywords_list)+"]")

            
    #--- print out
    if (color):
        print(text.replace("#-", color_comment+"#-").replace("$$", color_black))
    else:
        print(text)

#lib('pltset')


## lib_func

In [None]:
def lib_func(*keywords, color = False):
    """    
    ----------------------
    My library of Python functions, These functions can be copied to ipynb program.

    Input arguments:
      keywords: string variables. 
      color   : logical variable. Turn on/off colored texts

    Return:
      print out on screen

    Example:
      import yhc_module as yhc
      yhc.lib("f1")                 
      yhc.lib("f1", "f2", color = False)


    Date created: 2023/01/11
    ----------------------
    """
    
    color_comment = '\033[36m'  # cyan color
    color_black   = '\033[0m'   # black color
    
    #@@@@@@@@@@@@@@@@@@@@@@@@@@,
    #@@@@@@@@@@@@@@@@@@@@@@@@@@
    #--- 
    text1 = """
    #---------------------
    # 
    #---------------------

    #--- test text1 $$
    filenames_scm = [A, B, C]
    filenames_scm = [datapath+file1 for file1 in filenames_scm]
    """
    
    #@@@@@@@@@@@@@@@@@@@@@@@@@@, plot_scm_time_series
    #@@@@@@@@@@@@@@@@@@@@@@@@@@
    #--- plot_scm_time_series
    plot_scm_time_series = """
def plot_scm_time_series(datasets, time_coord, 
                         ax, names_styles, varname, 
                         do_unit_convert = False, var_factor = 1, 
                         do_var_stats = True, tt_start = "none", tt_end = "none", 
                         plot_les = "none", 
                        ):
    
    func_name = "plot_scm_time_series"
    
    #------------------------
    # set shared parameters
    #------------------------
    legend_size = 16
    
    #--------------
    # check parts
    #--------------
    
    #--- make sure # of dataset is not greater than # of names_styles
    nda = len(datasets)        # number of dataset
    nname = len(names_styles) # number of names
    
    if (nda > nname):
        error_msg = f"ERROR: function [{func_name}], # of dataset [{nda}] > # of names_styles [{nname}]"
        raise ValueError(error_msg)
      
    #--------------------------------------
    # plot
    #--------------------------------------
    
    #--- get names of all datasets
    da_names = list(names_styles.keys())  
    
    #--- loop for all datesets
    i=0
    for da1 in datasets:

        #--------------------------------------
        # plot the variable in each dataset
        #--------------------------------------
        
        #--- get dataset custom name and color style
        name1 = da_names[i]
        style1 = names_styles[name1]
        
        #--- check whether da1 has varname
        check_var_in_da(da1, varname)
        
        #--- read variable
        var1 = da1.get(varname)[:,0,0]
        if (do_unit_convert): 
            var1 = yhc.unit_convert(var1)   # change units if needed
        else:
            var1 = var1.copy()*var_factor
            if (var_factor == 1):
                unit1 = var1.attrs['units']
            else:
                unit1 = var1.attrs['units']+" / "+str(var_factor)
        
        #--- set time coordinate
        #ntime = len(da1.time)
        #tt = np.arange(0., ntime, 1)
        tt = time_coord
        
        #--- plot the time series
        ax.plot(tt, var1, style1) 

        #--- set labels
        ax.set_title(varname, loc='left')
        ax.set_title(unit1, loc='right')
        ax.set_xlabel("Hours")
        ylabel = var1.attrs['long_name']+" ("+unit1+")"
        ax.set_ylabel(ylabel)
        
        #--- plot legend
        ax.legend(da_names, prop={'size': legend_size})

        #--------------------------------------
        # do some statistics of the variables
        #--------------------------------------
        
        if (do_var_stats):
        
            #--- get time range
            if (tt_start != "none" and tt_end != "none"):
                tt1 = tt_start
                tt2 = tt_end
            else:
                tt1 = 0
                tt2 = len(var1)
                
            #--- print time coordinates
            if (i==0): print(var1.time[tt1:tt2])

            var1_avg = var1[tt1:tt2].mean("time")
            var1_avg_str = '{:.1f}'.format(np.array(var1_avg))

            string= f"Exp [{name1}], avg [{varname}] over time step [{tt1}, {tt2}] = {var1_avg_str}"
            #print("")
            print(string)

        #----------------------------- 
        # loop for another dataset
        #-----------------------------
        #print(i)
        i=i+1

#-----------
# do_test
#-----------
do_test=True
#do_test=False

if (do_test):
    fig, ax1 = plt.subplots(1,1, figsize=(25, 10))
    
    names_styles = {
        "da1": "r--",
        "da2": "b--",
        "da3": "g--",
        "da4": "m--",
    }
    
    varname = "LWP"
    tt_start = 0
    tt_end = 10
    
    das = [da_rf02_am4RAD_ICStv_FStv, da_rf02_am4RAD_ICStv_2Xdiv, da_rf02_scmRAD_ICStv_FAM4n, da_rf02_am4RAD_ICStv_L07div]

    plot_scm_time_series( das, 
                         time_coord=tt, 
                         tt_start = tt_start, tt_end = tt_end, 
                         plot_les = "George_RF02",
                         names_styles = names_styles, varname = varname, ax=ax1, var_factor = 1000)
    """ 
    
    
    #@@@@@@@@@@@@@@@@@@@@@@@@@@, plot_cn_pt
    #@@@@@@@@@@@@@@@@@@@@@@@@@@
    #--- plot_cn_pt
    plot_cn_pt = """
def plot_cn_pt(time_merra2, plev_merra2, var_merra2,
               time_am4, plev_am4, var_am4,
               varname,  
               mlev_merra2 = "Np", 
              ):

    fig, (ax1, ax2) = plt.subplots(2,1, figsize=(18, 12))   # 2 row, 1 column
    fig.tight_layout(pad=6.0)

    if (varname == "tdt_dyn"):
        #cn_levels = np.linspace(-10., 10., 21)
        cn_levels = np.arange(-10., 11., 1.)
        cmap="coolwarm"
        label = "3D dynamical T tendencies (K/day)"
        name="tdt_dyn"
        units="K/day"
        
    else:
        cn_levels = 15
        cmap = "viridis"
        label = "Var (units)"

    #--- ax1, var_merra2
    if (mlev_merra2 == "Nv"): 
        plot_var_merra2 = ax1.contourf(time_merra2.transpose("lev","time"), plev_merra2.transpose("lev","time"), var_merra2.transpose("lev","time"), levels = cn_levels, cmap=cmap)
    else:
        plot_var_merra2 = ax1.contourf(time_merra2, plev_merra2, var_merra2.transpose("lev","time"), levels = cn_levels, cmap=cmap)
    
    fig.colorbar(plot_var_merra2, ax=ax1, orientation='vertical', label=label)
    ax_def_cn(ax1)

    #--- ax2, var_am4
    plot_var_am4 = ax2.contourf(time_am4.transpose("plev","time"), plev_am4.transpose("plev","time"), var_am4.transpose("pfull","time"), levels = cn_levels, cmap=cmap)

    fig.colorbar(plot_var_am4, ax=ax2, orientation='vertical', label=label)
    ax_def_cn(ax2)

    """
    
    #@@@@@@@@@@@@@@@@@@@@@@@@@@, read_am4_pt_ijavg
    #@@@@@@@@@@@@@@@@@@@@@@@@@@
    #--- read_am4_pt_ijavg
    read_am4_pt_ijavg = """

def read_am4_pt_ijavg(choice, varname, region): 
    
    \"""
    ----------------------
    Description:
       Read AM4 (time, pressure/level, lat, lon) data on native levels, and then
       convert it to a two-dimensional pressure-time data averaged over a region (lat, lon).

    Input arguments:
      choice: a string used in read_data. 
      varname: variable name in the MERRA-2 data
      region: a string used in yhc.get_area_avg

    Return:
      var_ijavg  : 2-d pressure-time data
      plevs_ijavg: pressure levels 

    Example:
      See "do_test" section 

    Date created: 2023-01-11
    ----------------------
    \"""

    func_name = "read_am4_pt_ijavg"
    model = "AM4_L33_native"
    
    #--- 
    da_am4 = read_data.read_am4_data(choice)
    var = da_am4.get(varname)
    
    ps = da_am4.ps  # surface pressure
    
    var_ijavg, plevs_ijavg = yhc.get_region_profile(var, region, model = model, ps = ps, plevs = "pfull")

    return var_ijavg, plevs_ijavg

#-----------
# do_test
#-----------

#do_test=True
do_test=False

if (do_test):
    choice = "nudgeAM4_July10_12"
    region = "DYCOMS"
    varname = "tdt_dyn"
    #time_step=slice(0,23)
    
    var_ijavg, plevs_ijavg = read_am4_pt_ijavg(choice, varname, region)
    
    #print(var_ijavg)
    #print(plevs_ijavg)

    
    """
    
    #@@@@@@@@@@@@@@@@@@@@@@@@@@, read_merra2_pt_ijavg
    #@@@@@@@@@@@@@@@@@@@@@@@@@@
    #--- read_merra2_pt_ijavg
    read_merra2_pt_ijavg = """
    
def read_merra2_pt_ijavg(choice, varname, region,
                         mlev = "Np"): 
    
    \"""
    ----------------------
    Description:
       Read MERRA-2 (time, pressure/level, lat, lon) data, and then
       convert it to a two-dimensional pressure-time data averaged over a region (lat, lon).

    Input arguments:
      choice: a string used in read_data. 
      varname: variable name in the MERRA-2 data
      region: a string used in yhc.get_area_avg
      mlev: data vertical levels. "Np" is on pressure levels, "Nv" is on model hybrid levels.

    Return:
      var_ijavg  : 2-d pressure-time data
      plevs_ijavg: pressure levels 

    Example:
      See "do_test" section 

    Date created: 2023-01-04
    ----------------------
    \"""

    func_name = "read_merra2_pt_ijavg"

    #--- 
    da_merra2 = read_data.read_merra2_data(choice)
    da_merra2 = yhc.wrap360(da_merra2)
    var = da_merra2.get(varname)
    
    #--- get domain-avg profile
    var_ijavg = yhc.get_area_avg(var, region)
    
    #--- get pressure levels
    if (mlev == "Np"): 
        plevs_ijavg = da_merra2.lev
        
    elif (mlev == "Nv"):
        PL_merra2 = da_merra2.get("PL")
        PL_ijavg = yhc.get_area_avg(PL_merra2, region)
        plevs_ijavg = PL_ijavg

    elif (mlev == "Nv_none"):
        plevs_ijavg = 0.        
        
    else:
        error_msg = f"ERROR: function [{func_name}] does not support mlev=[{mlev}]."
        raise ValueError(error_msg)
    
    #print(plevs_ijavg)
    #print(da_merra2)
    
    return var_ijavg, plevs_ijavg

#-----------
# do_test
#-----------

#do_test=True
do_test=False

if (do_test):
    choice = "qdt_July10_12"
    region = "DYCOMS"
    varname = "DQVDTDYN"
    
    var_ijavg, plev_merra2 = read_merra2_pt_ijavg(choice, varname, region)
    
    #print(da_merra2)
    #print(var_ijavg)
    print(plev_merra2)
    
    """
    
    #@@@@@@@@@@@@@@@@@@@@@@@@@@,
    #@@@@@@@@@@@@@@@@@@@@@@@@@@
    #---
    
    #------------------
    # print out
    #------------------
    
    keywords_list=[
        'plot_cn_pt',
        'plot_scm_time_series',
        'read_am4_pt_ijavg',
        "read_merra2_pt_ijavg",
        'text1',
                  ]
    
    for key1 in keywords:
        if (key1 == "text1"):
            text = text1[:]

        elif (key1 == "plot_scm_time_series"):
            text = plot_scm_time_series[:]
            
        elif (key1 == "plot_cn_pt"):
            text = plot_cn_pt[:]
            
        elif (key1 == "read_am4_pt_ijavg"):
            text = read_am4_pt_ijavg[:]
            
        elif (key1 == "read_merra2_pt_ijavg"):
            text = read_merra2_pt_ijavg[:]
            
        else:
            text = ""
            print("ERROR: ["+ key1+"] is not supported yet.")
            print("supported keywords: ["+', '.join(keywords_list)+"]")

            
    #--- print out
    if (color):
        print(text.replace("#-", color_comment+"#-").replace("$$", color_black))
    else:
        print(text)
        
#lib_func('plot_scm_time_series')


## get_region_profile

In [None]:
def get_region_profile(var, region, 
                       check_ps_range = True, 
                       model = None, ps = None, plevs = None): 
    
    """
    ----------------------
    Description:
      Get regional-averaged profile. The pressure levels are computed using function *mlevs_to_plevs*.

    Input arguments:
      var: a xarray DataArray
      region: a string, region name used by get_region_latlon
      check_ps_range: Check whether surface pressure exceeds the given range (default is 10hPa)
      model, ps, plevs: input arguments for function *mlevs_to_plevs*

    Return:
      var_region_ijavg & plevs_region_ijavg
      area-averaged variable ans pressure levels

    Example:
      import yhc_module as yhc
      
      var = da.temp
      ps = da.ps
      model = "AM4_L33_native"
      region = "DYCOMS"
    
      var1, plevs1 = yhc.get_region_profile(var, region, model = model, ps = ps, plevs = "pfull")

    Date created: 2022-08-31
    ----------------------
    """

    func_name = "get_region_profile"

    #--- get lat/lon of the region
    lon_slice, lat_slice = get_region_latlon(region)
    
    #--- get regional-average of the var
    var_region = var.sel(lat=lat_slice, lon=lon_slice)
    var_region_ijavg = get_area_avg(var, region)
    
    #--- get pressure levels
    if model is not None:
        ps_region = ps.sel(lat=lat_slice, lon=lon_slice)
        
        #--- check whether surface pressure differ too much in the region
        if (check_ps_range):
            ps_range = 10e+2  # 10 hPa
            
            #with xr.set_options(keep_attrs=False):  # keep attributes after xarray operation
            ps_max = ps_region.max().to_numpy()
            ps_min = ps_region.min().to_numpy()
            ps_diff = abs(ps_max-ps_min)
            
            if (ps_diff > ps_range):
                error_msg = f"""
                ERROR: [{func_name}]: surface pressure exceeds the range 
                [ps_max, ps_min, ps_diff, ps_range] = [{ps_max}, {ps_min}, {ps_diff}, {ps_range}] Pa
                Set check_ps_range=False to skip the check
                """
                raise ValueError(error_msg)
        
        #--- get pressure levels
        plevs_region = mlevs_to_plevs(ps_region, model, plevs)
        #print(plevs_region)
        
        plevs_region_ijavg = get_area_avg(plevs_region, region)
        #print(plevs_region_ijavg)
        
        #printv(var_region,'var','r')
        #printv(plevs_region,'plevs','b')

    return var_region_ijavg, plevs_region_ijavg

#-----------
# do_test
#-----------

#do_test=True
do_test=False

if (do_test):
    file = "../data/data-am4_20010725_8xdaily-temp.nc"
    da = xr.open_dataset(file)
    
    var = da.temp[0,:,:,:]
    ps = da.ps[0,:,:]
    
    model = "AM4_L33_native"
    region = "DYCOMS"
    
    var1, plevs1 = get_region_profile(var, region, model = model, ps = ps, plevs = "pfull")
    #get_region_profile(var, region)

    printv(var1, 'var','r')
    printv(plevs1, 'plev','g')

## read_var_nc

In [None]:
def read_var_nc(filename, varname, 
                do_unit_convert = True,
               ): 
    
    """
    ----------------------
    Description:
      Given filename and varname, use Xarray to read the data and read the variable,
      and then return respective DataSet and DataArray.

    Input arguments:
      filename: (a string) path of a netCDF file, e.g. ../data/test111.nc
      varname: (a string) variable name in the file, e.g. temp
      do_unit_convert: (a logical) whether calling unit_convert or not

    Return:
      xarray DataSet and DataArray

    Example:
      import yhc_module as yhc
      
      filename = ""
      varname =""
      da_scm, var_scm = yhc.read_var_nc (filename, varname)

    Date created: 2022-09-02
    ----------------------
    """

    func_name = "read_var_nc"

    #--- open the netCDF file 
    da = xr.open_dataset(filename)

    #--- read the variable 
    if varname in da.data_vars:
        var_in = da.get(varname) 
    else:
        error_msg = f"ERROR [{func_name}]: variable [{varname}] is not in the file [{filename}]"
        raise ValueError(error_msg)
    
    #--- change units
    if (do_unit_convert):
        var_out = unit_convert(var_in)
    else:
        var_out = var_in.copy()
    
    #--- return
    return da, var_out
    
    
#-----------
# do_test
#-----------

#do_test=True
do_test=False

if (do_test):
    filename = "../data/SCM_am4_xanadu_edmf_mynn.v01_RF01-00cc-am4p0_aerT_clr_am4RAD_sw.1x0m5d_1x1a.atmos_edmf_mynn.nc"
    varname = "qdt_vdif"
    
    da_scm, var_scm = read_var_nc (filename, varname)
    #da_scm, var_scm = read_var_nc (filename, varname, do_unit_convert=False)
    
    printv(da_scm, 'da_scm', 'r')
    printv(da_scm.ucomp, 'u_scm', 'b')
    printv(var_scm, 'var_scm', 'g')
    
    

## Meteorology functions

In [None]:
def es(T_degreeC, Ps_hPa):
    # compute saturation vapor pressure using Bolton 1980.
    #.  Ref: Metpy documentation: https://unidata.github.io/MetPy/v0.3/api/thermo.html
    es = 6.112 * np.exp ( 17.67*T_degreeC / (T_degreeC +243.5) )  # es in hPa
    return es

def qs(T_degreeC, Ps_hPa):
    # compute saturation water mixing ratio (kg/kg)
    # Reference
    #.  https://www.weather.gov/media/epz/wxcalc/mixingRatio.pdf
        
    es0 = es (T_degreeC, Ps_hPa)
    qs = 621.97 * ( es0 / (Ps_hPa-es0) ) / 1000.  # change unit from g/kg to kg/kg
    
    return qs

def des_dT(T_degreeC):
    # compute d(es)/dT
    # unit: hPa/C
    
    des_dT =  6.112 * np.exp ( 17.67*T_degreeC / (T_degreeC +243.5) ) * (17.67*243.5) / (T_degreeC+243.5)**2

    return des_dT

def dqs_dT(T_degreeC, Ps_hPa):
    # compute d(qs)/dT
    # unit: kg/kg/T
    
    es0 = es (T_degreeC, Ps_hPa)
    dqs_dT = 621.97 / 1000. * Ps_hPa*des_dT(T_degreeC) / (Ps_hPa-es0)**2

    return dqs_dT

def rh(T_degreeC, Ps_hPa, q_kg_kg):
    # compute RH (unitless)
    qs0 = qs(T_degreeC, Ps_hPa)
    rh = q_kg_kg / qs0
    
    return rh

## var_stats

In [None]:
def var_stats(*da, dim, 
             opt):
    
    """
    ----------------------
    Description:
      Retrun selected statistic of the input arrays along spcified dimension

    Input arguments:
      da : xarray DataArray
      dim: (string) dimension to operate, e.g. time
      opt: stats option, Available: "avg"

    Return:
      stats values for each da

    Example:
      import yhc_module as yhc
      dim = "time"
      opt = "avg"
      stats = var_stats(da1, da2, dim = dim, opt = opt)

    Date created: 2022-10-03
    ----------------------
    """

    func_name = "var_stats"

    #--- 

    # avaiable stats 
    stats = []
    stats_avail = ["avg","xxx"]
    
    dim = "time"
    
    #--- compute average of the input arrays
    if (opt == "avg"):
        for da1 in da:
            mean1 = da1.mean(dim)
            mean1_format = '{:.2f}'.format(np.array(mean1))    # format output 
            stats.append(mean1_format)

    else:
        error_msg = f"ERROR: func [{func_name}], [opt={opt}] does not support. Available: {stats_avail}"
        raise ValueError(error_msg)
    
    return stats
    
    
#-----------
# do_test
#-----------

#do_test=True
do_test=False

if (do_test):
    gg1 = np.array([1.,2.,3.,4.])
    gg2 = np.array([4.14232,4.232,4.12321])
    
    da1 = xr.DataArray(gg1, dims=['time'])
    da2 = xr.DataArray(gg2, dims=['time'])
    
    #print(da1.mean("time"))
    
    dim = "time2"
    opt = "avg"
    #print(da1.mean(dim))

    stats = var_stats(da1, da2, dim = dim, opt = opt)
    
    print(stats)
    

## read_am4_profile_region

In [None]:
def read_am4_profile_region (filename, varname, time_step,
                             model, region, plevs = "pfull"): 
    
    """
    ----------------------
    Description:
      Read a variable (time, lev, lat, lon) in AM4 output data, 
      and get the profile averaged over a speficied domain.

    Input arguments:
      filename : a string, the full path of AM4 data (in netCDF)
      varname  : a variable with (time, lev, lat, lon) dimensions
      time_step: time step selected
      
      model, region, plevs: defined in yhc. get_region_profile

    Return:
      a Xarray DataArray, the domain-average 1-D variable, and the 1-D pressure levels

    Example:
      import yhc_module as yhc

      filename = datapath_am4+"/"+fname_am4_rf02
      varname = "temp"
      time_step = "2001-07-11 10:30:00"
      region = "DYCOMS"
      model = "AM4_L33_native"
      
      da_am4_1, temp_am4_1, plev_am4_1 = yhc.read_am4_profile_region (model, filename, varname, region, time_step)

    Date created: 2022-11-09
    ----------------------
    """

    func_name = "read_am4_profile_region"

    #--- use Xarray to read variable 
    da1 = xr.open_dataset(filename)
    var = da1.get(varname).sel(time=time_step)
    
    #--- read surface pressure
    ps  = da1.ps.sel(time=time_step)
    
    #--- use yhc.get_region_profile to get the domain-avg profile
    var1, plevs1 = get_region_profile(var, region, model = model, ps = ps, plevs = "pfull")
    
    #--- return Xarray DataArray, 1-D variable, 1-D pressure levels
    return da1, var1[0,:], plevs1[0,:] 
    
    # error_msg = f"ERROR: function [func_name] does not support []. Available:"
    # raise KeyError(error_msg)
    # raise ValueError(error_msg)

    #-----------
# do_test
#-----------

do_test=True
#do_test=False

#if (do_test):

## read_merra2_profile_region

In [None]:
def read_merra2_profile_region (filename, varname, time_step,
                                region): 
    
    """
    ----------------------
    Description:
      Read a variable (time, lev, lat, lon) in MERRA-2 data, 
      and get the profile averaged over a speficied domain.

      Note that lev could be on model hybrid levels (Nv) or on pressure levels (Np).
      Variable "PL" is the mid_level_pressure in *Nv* files.

    Input arguments:
      filename : a string, the full path of AM4 data (in netCDF)
      varname  : a variable with (time, lev, lat, lon) dimensions
      time_step: time step selected
      
      region: defined in yhc.get_region_latlon

    Return:
      a Xarray DataArray, the domain-average 1-D variable

    Example:
      import yhc_module as yhc

      filename = datapath_merra2+"/"+fname_merra2_rf01
      varname = "T"
      time_step = "2001-07-10T10:30:00.000000000"
      region = "DYCOMS"
      
      da_merra2_1, T_merra2_1 = read_merra2_profile_region (filename, varname, region, time_step)

    Date created: 2022-11-09
    ----------------------
    """

    func_name = "read_merra2_profile_region"

    #--- use Xarray to read variable 
    da2 = xr.open_dataset(filename)
    da2 = wrap360(da2)  # flip MERRA-2 longitudes from [-180,180] to [0,360]
    var2 = da2.get(varname).sel(time=time_step)
    
    #--- get domain-avg profile
    var2_ijavg = get_area_avg(var2, region)

    #--- return Xarray DataArray, 1-D variable
    return da2, var2_ijavg 
    
    # error_msg = f"ERROR: function [func_name] does not support []. Available:"
    # raise KeyError(error_msg)
    # raise ValueError(error_msg)

#-----------
# do_test
#-----------

do_test=True
#do_test=False

#if (do_test):

## diagnose_var

In [None]:
def diagnose_var(da, model, varnames): 
    
    """
    ----------------------
    Description:
      Given a xarray dataset, diagnose variables and save to the same dataset

    Input arguments:
      da: a xarray dataset
      model: model name. Currently support: ["AM4"]
      varnames: diagnosed variables 

    Return:
      da

    Example:
      import yhc_module as yhc

      model = "AM4"
      varnames = ["swabs",'swcre']
      fname = "../data/SCM_am4_xanadu_edmf_mynn.v01_RF01-00cc-am4p0_aerT_clr_am4RAD_sw.1x0m5d_1x1a.atmos_edmf_mynn.nc"
      da = xr.open_dataset(fname)
      
      da = yhc.diagnose_var(da, model = model, varnames = varnames)

    Date created: 2022-12-09
    ----------------------
    """

    func_name = "diagnose_var"

    #---
    #@@@@@@@@@@@@@@@@@@@@@@@@@@
    #@@@@@@@@@@@@@@@@@@@@@@@@@@
    if (model == "AM4"):
        
        for var1 in varnames:
            if (var1 == "swabs"):
                swabs = da.swdn_toa - da.swup_toa
                swabs.attrs['long_name'] = "SWABS (TOA net dw SW flux)"
                da['swabs'] = swabs

            elif (var1 == "swcre"):
                swcre = da.swup_toa - da.swup_toa_clr
                swcre.attrs['long_name'] = "TOA Cloud SW radiative effect"
                da['swcre'] = swcre
                
            else:
                error_msg = f"ERROR: function [{func_name}] does not support variable [{var1}]."
                raise ValueError(error_msg)
    else:
        error_msg = f"ERROR: function [{func_name}] does not support model [{model}]."
        raise ValueError(error_msg)
        
    #-----------
    # return
    #-----------
    return da;

    # error_msg = f"ERROR: function [func_name] does not support []. Available:"
    # raise KeyError(error_msg)
    # raise ValueError(error_msg)
    
#-----------
# do_test
#-----------

model = "AM4"
varnames = ["swabs",'swcre']
da1 = 1
da2 = 2
da3 = 3

fname = "../data/SCM_am4_xanadu_edmf_mynn.v01_RF01-00cc-am4p0_aerT_clr_am4RAD_sw.1x0m5d_1x1a.atmos_edmf_mynn.nc"
da = xr.open_dataset(fname)
#print(da)

#do_test=True
do_test=False

if (do_test):
    da = diagnose_var(da, model = model, varnames = varnames)
    print(da.swcre)
    #printv(da.swabs, 'swabs')
    #printv(da.swdn_toa,'swdn_toa','r')