# Yi-Hsuan Chen's Python module

**Author:** Yi-Hsuan Chen (yihsuan@umich.edu)

**Convert ipynb to py:**

jupyter nbconvert yhc_module.ipynb --to python

**import:**

import yhc_module as yhc


In [None]:
import matplotlib.pyplot as plt
from matplotlib.ticker import (MultipleLocator, AutoMinorLocator)
from datetime import date
import numpy as np
import xarray as xr
import io, os, sys, types
xr.set_options(keep_attrs=True)  # keep attributes after xarray operation

import yhc_module as yhc

In [None]:
#############################################
# Purpose of this code:
#
# 1. convert yhc_module.ipynb to yhc_module.py
# 2. copy yhc_module.py to a a folder, dir_py, so it can be imported
#############################################

convert_ipynb2py = False
#convert_ipynb2py = True

dir_py="/Users/yihsuan/.ipython"  # (Mac) copy yhc_module to this folder 
                                  #       so that other jupyter notebook can import yhc_module

# location of yhc_module in Mac
py_path = "/Users/yihsuan/Downloads/yihsuan/test/tool/python/yhc_module_and_notes"

if (convert_ipynb2py):
    print("convert yhc_module.ipynb to yhc_module.py")
    command="jupyter nbconvert yhc_module.ipynb --to python"
    os.system(command)
    
    command="cp yhc_module.py "+dir_py
    print(command)
    os.system(command)
    
    sys.path.append(py_path)  # add py_path to sys.path
    

## printv

In [None]:
def printv(var,
           text = "", 
           color = "black",
          ):
    """
    ----------------------
    Description:
      print out variables by begining with a line, follow by a text, and end by a empty line
      This helps to view the output more easily.

    Input arguments:
      var : any variable
      text: text will show on screen

    Return:
      print statement on screen

    Example:
      import yhc_module as yhc
      
      yhc.printv(var, "ggg")
      yhc.printv(var, text = "aaa")

    Reference:
      Print with colors: https://predictivehacks.com/?all-tips=print-with-different-colors-in-jupyter-notebook
      Inserting values into strings, https://matthew-brett.github.io/teaching/string_formatting.html

    Date created: 2022/07/06
    ----------------------
    """    

    #------------------
    # set text color
    #   colors: https://pkg.go.dev/github.com/whitedevops/colors
    #------------------
    
    #--- color codes dictionary
    color_codes = {'red':'\033[91m',
                   'green':'\033[92m',
                   'yellow':'\033[93m',
                   'blue':'\033[94m',
                   'pink':'\033[95m',
                   'teal':'\033[96m',
                   'grey':'\033[97m',
                   'black':'\033[30m',
                   'cyan':'\033[36m',
                   'magenta':'\033[35m',
                   #------ short name  -----
                   'r':'\033[91m',
                   'g':'\033[92m',
                   'b':'\033[94m',
                   'y':'\033[93m',
                   'c':'\033[36m',
                   'm':'\033[35m',
                   'b':'\033[30m',
                  }
    
    #--- read color code
    color_code = color_codes[color]
    
    #------------
    # print out
    #------------
    
    text_out = f"{color_code} ------------- \n {text}"  # format text string
    
    #print("--------------")
    print(text_out)
    print(var)
    print("")
        
#----------
# test
#----------

#do_test = True
do_test = False

if (do_test):
    printv(1, "ddd", 'r')
    printv(2, "eee", color = "g")
    printv(3)



## unit_convert

In [None]:
def unit_convert (var_in, units_in = "none", units_out = "none", 
                  var_type = "xarray.DataArray",
                  do_auto_conv = True, 
                 ):
    """
    ----------------------
    convert unit of a variable
    
    Input arguments:
        var      : (xarray.DataArray): a variable
        units_in  : (string) original units of the var
        units_out: (string) new units of the var 
        var_type: type of var
        do_auto_conv: do automatical unit conversion set in units_dict
    
    Return:
        variable values with units_out
    
    Example
        import yhc_module as yhc
        var = yhc.unit_conversion(var, units_in = "m", units_out = "km")
        
    Date created: 2023/01/15
    -----------------------
    """
    
    func_name = "unit_convert"
    
    #------------- 
    # constants 
    #------------- 

    rho_water = 1000.  # water density [kg/m3]
    latent_heat_evap = 2.5e+6          # latent heat of vaporization for water, J/kg  
    latent_heat_cond = 1./latent_heat_evap  # latent heat of vaporization for water, J/kg
    hr2sec = 1800.    # hour in seconds
    day2sec = 86400.  # day in seconds
    cp_air = 1005 # specific heat of air, units: J/kg/K

    #------------- 
    # conversion dictionary
    #------------- 

    conversion = {'none':1.0,
                  'm':1.0, 'mm':0.001, 'cm':0.01, 'km':1000.,
                  'm/s':1.0, 'mm/day':1./(1000.*day2sec), 'kg/m2/s':1./rho_water, 'W/m2':1./latent_heat_evap/rho_water,
                  'kg/kg':1.0, 'kg kg-1':1.0, 'g/kg':1.e-3,
                  'fraction':1., "percent":0.01, "%":0.01,
                  'K/s':1., 'K/day':1./day2sec, 'deg_K/s':1., 'K s-1':1., 'W/kg':1/cp_air, 
                  'kg/kg/s':1., 'kg kg-1/s':1., 'kg kg-1 s-1':1., 'g/kg/hour': 1000./hr2sec, 'g/kg/day':1./(1000.*day2sec),
                  'Pa/s':1., 'Pa s-1':1, 'pa/sec':1, 'hPa/day': 100./day2sec, 
                  'kg/m2':1., 'kg m-2':1, 'g/m2':1.e-3, 'g m^-2':1.e-3,
                  '1/s':1., '1/hour': 1./hr2sec, '10^-6/s':1.e-6,
                  'Pa':1.0, 'hPa':100., "mb":100., 
                  'none':1.0,
                  'K':1.0, 'deg_K':1.0,
                  'watts/m2':1.0, 'W m^-2':1,
                 }

    #--- set a default unit convertion, [units_in, units_out]
    units_dict = {'none':'none',
                  'K/s':'K/day', 'deg_K/s':'K/day', 'K s-1':'K/day', 'deg_K/sec':'K/day', 'W/kg':'K/day',
                  'Pa/s':'hPa/day','Pa s-1':'hPa/day', 'pa/sec':'hPa/day',
                  'deg_K':'K',
                  'kg/kg':'g/kg', 'kg kg-1':'g/kg',
                  'kg/kg/s':'g/kg/day',
                  'kg m-2':'g m^-2', 'kg/m2':'g m^-2',
                  'watts/m2':'W m^-2',
                 }
    
    #------------- 
    # check long_name. If a string of interest is in the long_name, set units_in and units_out
    #------------- 
    if (do_auto_conv):
        if 'long_name' in var_in.attrs and 'units' in var_in.attrs:
            if 'precip' in var_in.attrs['long_name'] and var_in.attrs['units'] == "m/s": units_in="m/s"; units_out="mm/day"
    
    #print(var_in)
    #print(var_out)
    
    #print(var_in.attrs['long_name'])
    
    #------------- 
    # conversion 
    #------------- 
    
    ##############################################
    # if input var is a Xarray DataArray
    ##############################################

    if (var_type == "xarray.DataArray"):
        
        #--------------------------------------------------------
        # get units_in and units_out if they are not inputed
        #--------------------------------------------------------
        
        #--- get the list of conversion units
        list_units = list(conversion.keys())

        if units_in == "none" or units_out == "none":
            
            #--- read units_in
            if ("units" in var_in.attrs): 
                units_in = var_in.attrs['units']
            else:
                units_in = "none"
                warn_msg = f"WARNING: function [{func_name}], input variable has no [units] attribute. Set units_in to none"
                print(warn_msg)
            
            #--- get units_out
            if (do_auto_conv):
                if (units_in in list(units_dict.keys())):
                    units_out = units_dict[units_in]
                else:
                    units_out = units_in
                    warn_msg = f"WARNING: function [{func_name}], unit_in [{units_in}] is not supported for auto conversion. Set units_out=units_in"
                    print(warn_msg)
            else:
                
                units_out = units_in
                
        #----------------------
        # do unit conversion
        #----------------------
        
        if (units_in in list_units and units_out in list_units):
            var_out = var_in * conversion[units_in] / conversion[units_out]
            var_out.attrs['units'] = units_out  
        else:
            warn_msg = f"WARNING: function [{func_name}], either units_in [{units_in}] or units_out [{units_out}] is not supported. Don't do conversion."
            print(warn_msg)
            var_out = var_in 
        
    ##############################################
    #--- any other data types
    ##############################################
    else:
        var_out = var_in * conversion[units_in] / conversion[units_out]
    
        if hasattr(var, 'units'):
            var.units = units_out
        else:
            setattr(var_out,"units",units_out)

    #------------- 
    # return
    #------------- 
    #print(var_out)
    
    return var_out

#--------
# test
#--------

#do_test = True
do_test = False

if (do_test):
    
    do_auto_conv = False
    #do_auto_conv = True
    
    var_in = xr.DataArray([1.], dims=['time'])
    var_in.attrs['units']="W/kg"
    var_in.attrs['long_ndame']="predcipa"
    
    #var_out = unit_convert(var_in, do_auto_conv=do_auto_conv)
    
    #var_out = unit_convert(var_in, units_in = "1/s", units_out = "10^-5/s")

    var_out = unit_convert(var_in)
    
    #print(var_out.attrs['units'])
    print(var_out)


## get_region_latlon

In [None]:
def get_region_latlon(region = "N/A"):
    """
    ----------------------
    Given a region name, return slices of latitude and longitudes.
    
    Input arguments:
        region = (string) name of the region
            region_list = ["Californian_Sc","Peruvian_Sc","Namibian_Sc","DYCOMS"]

    Return:
        lon_slice & lat_slice of the given region

    Example:
      import yhc_module as yhc
      region = "Californiand_Sc"
      lon_slice, lat_slice = yhc.get_region_latlon(region)
      print(lon_slice)
      print(lat_slice)
      
    Date created: 2022/07/01
    ----------------------
    """

    func_name = "get_region_latlon"
    
    #------------------------
    #  read region name
    #------------------------

    region_list = ["Californian_Sc","Peruvian_Sc","Namibian_Sc","DYCOMS"]
    
    if (region == "Californian_Sc"):
        region_name = "CA marine Sc (20-30N, 120-130W)"
        lowerlat = 20.   # 20N 
        upperlat = 30.   # 30N
        lowerlon = 230.  # 230E
        upperlon = 240.  # 240E

    elif (region == "Peruvian_Sc"): 
        region_name = "Peruvian marine Sc (10S-20S, 80W-90W)"
        lowerlat = -20.  # 20S
        upperlat = -10.  # 10S
        lowerlon = 270.  # 90W
        upperlon = 280.  # 80W

    elif (region == "Namibian_Sc"): 
        region_name = "Namibian marine Sc (10S-20S, 0E-10E)"
        lowerlat = -20.  # 20S
        upperlat = -10.  # 10S
        lowerlon =   0.  # 0E
        upperlon =  10.  # 10E
        
    elif (region == "DYCOMS"): 
        #--- reference: Stevens et al. (2007, MWR)??
        region_name = "DYCOMS (30-32.2N, 120-123.8W)"
        lowerlat =  30.    # 30N
        upperlat =  32.2   # 32.2N
        lowerlon =  236.2  # 123.8W
        upperlon =  240.   # 120W

    elif (region == "global"): 
        region_name = "global (0-360E, -90S-90N)"
        lowerlat = -1000.   # N
        upperlat = 1000.    # N
        lowerlon = -1000.   # E
        upperlon = 1000.    # E
    
    #elif (region == ""): 
    #  region_name = "(N/S, E/W)"
    #  lowerlat =      # N
    #  upperlat =     # N
    #  lowerlon =    # W
    #  upperlon =  # W
   
    else:
        error_msg = "function *"+func_name+"*: input region ["+region+"] is not supported. STOP. \n" \
                    + "Available regions: "+ ', '.join(region_list)
        #sys.exit(error_msg)
        raise ValueError(error_msg)
    
    #------------------------
    #  compute lon and lat slices
    #    Python slice function: https://www.w3schools.com/PYTHON/ref_func_slice.asp
    #------------------------    

    lon_slice = slice(lowerlon, upperlon)
    lat_slice = slice(lowerlat, upperlat)

    #setattr(lon_slice,"region_long_name","ddd")
    #lon_slice.regionlong_name = "ddd"
    
    return lon_slice, lat_slice

#-----------
# do_test
#-----------

#do_test=True
do_test=False

if (do_test):
    ddd = "aaDYCOMS"
    lon_slice, lat_slice = get_region_latlon(ddd)


## wgt_avg

In [None]:
def wgt_avg (xa,
             dims = ["lon","lat"],
            ):
    """
    ----------------------
    Calculate weighted mean of a [*, lon, lat] Xarray data
    
    Input arguments:
      xa : data Array
      dim: dimension to average, e.g. ["lon"], ["lon","lat"]  
      
    Return:
      latitude-weighted average of the input data

    Example
      var is a [*, lon, lat] Xarray data
      
      import yhc_module as yhc
      var_ijmean = yhc.wgt_avg(var)
      var_jmean  = yhc.wgt_avg(var, dim="lat")
        
    References
    1. https://docs.xarray.dev/en/stable/examples/area_weighted_temperature.html
    2. https://nordicesmhub.github.io/NEGI-Abisko-2019/training/Example_model_global_arctic_average.html
 
    Date created: 2022/07/01
    ----------------------
    """
    
    #--- compute latitudal weights
    weights = np.cos(np.deg2rad(xa.lat))
    weights.name = "weights"
    #print(weights)
    
    #--- compute weighted mean    
    xa_weighted = xa.weighted(weights)
    xa_weighted_mean = xa_weighted.mean(dims)
    
    return xa_weighted_mean

#-------
# test 
#-------
do_test = False
#do_test = True

if (do_test):
    #gg = np.ones([3,2])
    gg = np.array(([1,1],[1,1],[100,100]))
    print(gg)
    lat = np.array([0.,60.,90.])
    lon = np.array([100.,100.])

    data = xr.DataArray(gg, dims=['lat','lon'], coords=[lat, lon])

    pp = wgt_avg(data, 'lon')
    
    printv(data, 'original')
    printv(pp, 'weighted')
    

## mlevs_to_plevs

In [None]:
def mlevs_to_plevs (ps,
                    model, 
                    plevs, 
                   ):
    """
    ----------------------
    Description: 
        Compute pressure levels from a climate or weather model, given necessary information

    Input arguments:
        ps   : (an xarray DataArray). Surface pressure (Pa)
        model: (a string) model name. Check out "model_list" variable in this function to see which model is supported
        plevs: (a string) return pressure levels. Check out "plevs_list" to see which are supported

        Currently avaialbe:
            model = "AM4_L33_native", GFDL AM4 with 33 levles 
                data.ps MUST be present, surface pressure in Pa
                plevs = ["pfull","phalf"], pressure at full levels or half levels

    Return:
        Input data plus a new dimention that contains pressure levels

    Example:
        import yhc_module as yhc
        data = xr.open_dataset(ncfile)
        data.ps # (time, lat, lon)
        phalf = mlevs_to_plevs(data, plevs="phalf", model = "AM4_L33_native")
        phalf # (time, lat, lon, plev)
     
    Date created: 2022/07/07
    ----------------------
    """

    func_name = "mlevs_to_plevs"
    
    model_list = ["AM4_L33_native"]

#------------------------------------------
# check if the model is supported
#------------------------------------------

    if model not in model_list:
        error_msg = "function *"+func_name+"*: model ["+model+"] is not supported. STOP. \n" \
                    + "Available options: "+ ', '.join(model_list)
        sys.exit(error_msg)
    
#------------------------------------------
# process model and plevs_out
#------------------------------------------

#@@@@@@@@@@@@@@@@@@@@@@@@@
#@@@@@@@@@@@@@@@@@@@@@@@@@
    if (model == "AM4_L33_native"):
        """ 
        In AM4, pressure(k) = pk(k) + bk(k)*ps, where ps is surface pressure (Pa).
        """
    
        plevs_list=["phalf", "pfull"]  # AM4 only supports pressure at half levels (phalf) and at full levels (pfull)
    
        #--- check
        if plevs not in plevs_list:        
            error_msg = "function *"+func_name+"*: input plevs ["+plevs+"] is not supported. STOP. \n" \
                      + "Available options: "+ ', '.join(plevs_list)
            raise ValueError(error_msg)
            
        #--- pk values. Output directly from AM4 files
        pk_list = [100, 400, 818.6021, 1378.886, 2091.795, 2983.641, 4121.79, 5579.222, 
          6907.19, 7735.787, 8197.665, 8377.955, 8331.696, 8094.722, 7690.857, 
          7139.018, 6464.803, 5712.357, 4940.054, 4198.604, 3516.633, 2905.199, 
          2366.737, 1899.195, 1497.781, 1156.253, 867.792, 625.5933, 426.2132, 
          264.7661, 145.0665, 60, 15, 0]
 
        #--- bk values. Output directly from AM4 files
        bk_list = [0, 0, 0, 0, 0, 0, 0, 0, 0.00513, 0.01969, 0.04299, 0.07477, 0.11508, 
          0.16408, 0.22198, 0.28865, 0.36281, 0.44112, 0.51882, 0.59185, 0.6581, 
          0.71694, 0.76843, 0.81293, 0.851, 0.88331, 0.91055, 0.93331, 0.95214, 
          0.9675, 0.97968, 0.98908, 0.99575, 1]      

        #--- make pk and bk as xarray.DataArray
        pk = xr.DataArray(pk_list, dims=['plev'])
        bk = xr.DataArray(bk_list, dims=['plev'])    
    
        #--- compute presure at half levels
        phalf = ps*bk + pk
        phalf.attrs['long_name'] = "Pressure at half levels"
        
        #--- compute pressure at full levels
        pfull = phalf.rolling(plev=2, center=True).mean().dropna("plev")
        pfull.attrs['long_name'] = "Pressure at full levels"
    
        #--- return plevs
        if (plevs == "phalf"):
            plevs_return = phalf
        elif (plevs == "pfull"):
            plevs_return = pfull        
        
        plevs_return.attrs['conversion_method'] = model
        if (hasattr(ps, 'units')): plevs_return.attrs['units'] = ps.attrs['units']
            
        return plevs_return

#@@@@@@@@@@@@@@@@@@@@@@@@@
#@@@@@@@@@@@@@@@@@@@@@@@@@
    elif (model == "CAM5"):
        """
        jjj
        """        

    
    
#@@@@@@@@@@@@@@@@@@@@@@@@@
#@@@@@@@@@@@@@@@@@@@@@@@@@
#elif (model == "other_model"):
#    """ 
#    """        
#    return 

#-----------
# do test
#-----------

#do_test = True
do_test = False

if (do_test):
    model = "AM4_L33_native"
    Ps = xr.DataArray([102078.5], dims=['time'])
    Ps.attrs['units'] = "Pa"
    plev = mlevs_to_plevs(Ps, model, "pfull")
    print(plev)


## wrap360

In [None]:
def wrap360(ds, lon='lon'):
    """
    Source code: https://github.com/pydata/xarray/issues/577
    
    wrap longitude coordinates from -180..180 to 0..360

    Parameters
    ----------
    ds : Dataset
        object with longitude coordinates
    lon : string
        name of the longitude ('lon', 'longitude', ...)

    Returns
    -------
    wrapped : Dataset
        Another dataset array wrapped around.
    """

    # wrap -180..179 to 0..359    
    ds.coords[lon] = np.mod(ds[lon], 360)

    # sort the data
    return ds.reindex({ lon : np.sort(ds[lon])})

## get_area_avg

In [None]:
def get_area_avg (var, region, 
                  weighted = True,
                  lat='lat', lon='lon'):    
    """    
    ----------------------
    Description:
      compute area-mean of a xr.DataArray variable. The default is weighted by latitudes.

    Input arguments:
      var     : an xarray DataArray variable, preferably [*, lat, lon]
      region  : a string, region name used by get_region_latlon
      weighted: a logical variable. True: weighted by latitudes, False: normal mean
      lat     : coordinate name of latitude
      lat     : coordinate name of longitude

    Return:
      var_area_avg: an an xarray DataArray variable, area-mean of the given variable

    Example:
      import yhc_module as yhc

      var (lat: 3, lon: 2)
      var_area_mean = yhc.get_area_avg(var)
      var_area_mean = yhc.get_area_avg(var, "DYCOMS")

    Notes:
      1. No need to check if var is pressure or height. Although this can cause troubles in doing area average
         (e.g. on steep terrain, surface height from 0 to 1000m), adding a check is too complicate.

    Date created: 2022/07/23
    ----------------------


    func_name = ""

    #---

    return
    """

    #--- get lat/lon of the region
    lon_slice, lat_slice = get_region_latlon(region)
    
    #--- lat-weighted average
    if (weighted):
        var_area_avg = wgt_avg (var.sel(lat=lat_slice, lon=lon_slice))

    #--- normal average
    else:
        var_area_avg = var.sel(lat=lat_slice, lon=lon_slice).mean([lat,lon])

    #--- set attribute
    var_area_avg.attrs['region'] = region
    
    return var_area_avg;

#-------
# test 
#-------
do_test = False
#do_test = True

if (do_test):
    gg = np.array(([1,1],[2,2],[100,100]))
    lat = np.array([0.,60.,90.])
    lon = np.array([100.,100.])

    var = xr.DataArray(gg, dims=['lat','lon'], coords=[lat, lon])
    var.attrs['units']="K"
    printv(var,'var')
        
    region = "DYCOMS"
    var_ijavg = get_area_avg(var, region)
    printv(var_ijavg, 'var_ijavg')
            

## modify_attrs

In [None]:
def modify_attrs(var, 
                 varname = "N/A", 
                 attrs_add = "N/A",
                 attrs_del = "N/A",
                ): 
    
    """
    ----------------------
    Description:
      Add/Delete/Modify attributes of a Xarray DataArray

    Input arguments:
      var: an Xarray DataArray
      varname: add predefined attributes to var, e.g. ["u","v"]
      attrs_add: attributes that will be added in var, e.g. ["longname=aaa","units=kkk"]
      attrs_del: attributes that will be deleted, e.g. ["att1","att2"]

    Return:
      var with modified attributes

    Example:
      import yhc_module as yhc
      yhc.modify_attrs(var, varname = ["u"], attrs_add=["att1=0101","att2=0202"], attrs_del=["units"])

      Must use [..], otherwise varname is not recognize (e.g. "omega" and the program will see "o")
      
    Date created: 2022-08-17
    ----------------------
    """
    
    func_name = "modify_attrs"
    
    #-----------------------------------------
    # set predefine attributes of variables
    #   the naming follows CF Metadata Convention, 
    #.     https://cfconventions.org/
    #.  CF Standard Name Table, Version 79, 19 March 2022, 
    #.     https://cfconventions.org/Data/cf-standard-names/current/build/cf-standard-name-table.html
    #-----------------------------------------
    
    #--- set dictionary of variables
    varname_dict={
        "u": ["long_name=eastward_wind", "units=m s-1"],
        "v": ["long_name=northward_wind", "units=m s-1"],
        "t": ["long_name=ait_temperature", "units=K"],
        "q": ["long_name=specific_humidity", "units=kg kg-1"],
        "omega": ["long_name=vertical_pressure_velocity", "units=Pa s-1"],  # not consistent with CF convenction
        "ug": ["long_name=geostrophic_eastward_wind", "units=m s-1"], 
        "vg": ["long_name=geostrophic_northward_wind", "units=m s-1"],
        "ps": ["long_name=surface_air_pressure", "units=Pa"], # not consistent with CF convenction
        "ts": ["long_name=surface_temperature", "units=K"],
        "shflx": ["long_name=sensible_heat_flux", "units=W m-2"], # not consistent with CF convenction
        "lhflx": ["long_name=latent_heat_flux", "units=W m-2"], # not consistent with CF convenction
        "divt": ["long_name=Horizontal large scale temp. forcing", "units=K s-1"], # not consistent with CF convenctin
        "divq": ["long_name=Horizontal large scale water vapor forcing", "units=kg kg-1 s-1"], # not consistent with CF convenctin
        "vertdivT": ["long_name=Vertcal large scale temp. forcing", "units=K s-1"], # not consistent with CF convenctin
        "vertdivq": ["long_name=Vertical large scale water vapor forcing", "units=kg kg-1 s-1"], # not consistent with CF convenctin
        "divt3d": ["long_name=3d large scale temp. forcing", "units=K s-1"], # not consistent with CF convenctin
        "divt3d": ["long_name=3d large scale water vapor forcing", "units=kg kg-1 s-1"], # not consistent with CF convenctin    
    }
        #"": ["long_name=", "units="], # not consistent with CF convenctin    
    
    
    if (not varname == "N/A"):
        for var1 in varname:   
            #--- get variable attibutes
            if var1 in varname_dict:
                attrs_var = varname_dict[var1]
            else:
                varname_keys = list(varname_dict.keys())
                error_msg = f"ERROR: [{func_name}] does not support varname [{var1}]. Available varname: {varname_keys}"
                raise KeyError(error_msg)
            
            #print(attrs_var)
            #--- set attributes
            for att1 in attrs_var:
                att1_list = att1.split('=')
                att1_name = att1_list[0]
                att1_value = att1_list[1]

                var.attrs[att1_name] = att1_value
    
    #------------------------
    #  delete attributes
    #------------------------
    for att1 in attrs_del:
        if (hasattr(var, att1)): del var.attrs[att1]
    
    #------------------------
    #  add attributes
    #------------------------
    if (not attrs_add == "N/A"):
        for att1 in attrs_add:
            att1_list = att1.split('=')
            att1_name = att1_list[0]
            att1_value = att1_list[1]

            var.attrs[att1_name] = att1_value
    
    #---------
    # return
    #---------
    return var
    
#-----------
# do_test
#-----------

#do_test=True
do_test=False

if (do_test):
    var = xr.DataArray([1.])
    var.attrs['long_name']="long1"
    var.attrs['units']="units"
    var.attrs['cell']="cell1"
    var.attrs['source']="ss1"
    printv(var,'before')

    #modify_attrs(var, varname = "U", attrs_del=["cell","source","ddd","long_name"], attrs_add=["time=1.2332","average=kkk"])
    var1 = modify_attrs(var, varname = ["omega"], attrs_add=["att1=0101","att2=0202"], attrs_del=["units"])
    printv(var,'after')


## lib options

### lib_ax_def_xy

In [None]:
def lib_ax_def_xy():
    
    text_return = """
    
    #------------------------ 
    # usual ax in XY plots
    #------------------------  

#==============================
def ax_def_xy (ax, var):

    #--- set grids
    ax.grid(True)
    ax.minorticks_on()
    ax.grid(False, axis = "x")  # turn off x grids
    
    #--- inverse axes
    ax.invert_yaxis()
    
    #--- legend
    legend_default = ["A","B"]
    ax.legend(legend_default)
       ### change legend size, ax.legend(legend_default, prop={'size': legend_size})

    #--- set title
    fontsize_title = 18
    ax.set_title(var.attrs['long_name'], loc='left')
    ax.set_title(var.attrs['units'], loc='right')
      ### set font size, ax.set_title(var.attrs['units'], loc='right', fontsize = fontsize_title)

    #--- set x or y labels
    fontsize_label = 18
    ax.set_xlabel(var.attrs['long_name']+" ("+var.attrs['units']+")")
    ax.set_ylabel("Pressure (hPa)")  
      ### set font size, ax.set_ylabel("A", fontsize = fontsize_label)
    
    #--- set x range
    ax.set_xlim([0,len(var)])
    
    #--- set x tickmark
    #xvalues = np.arange()
    #xlabels = np.arange()
    #ax.set_xticks(xvalues)        # tickmark values
    #ax.set_xticklabels(xlabels)  # tickmark labels

    #--- set tick mark sizes
    fontsize_tm = 12  # set tick mark size
    ax.tick_params(axis='x', labelsize = fontsize_tm)
    ax.tick_params(axis='y', labelsize = fontsize_tm)
#============================== 
    """
    
    return text_return

### lib_dict

In [None]:
def lib_dict():
    text_return = """

    #-------------------------
    # Python disctionary
    #-------------------------  $$
    
    #--- set a dictionary $$
    dict = {
        "u": ["long_name=eastward_wind", "units=m s-1"],
        "v": ["long_name=northward_wind", "units=m s-1"],
    }
    
    #--- look up the dictionary  $$
    print(dict['u'])
    
    #--- if statement for a dictionary  $$
    if var1 in dict:
        attrs_var = dict['u']
    else:
        dict_keys = list(dict.keys())   # get all keys in the dict
    """
    
    return text_return

### lib_fdef

In [None]:
def lib_fdef():

    today = date.today()    
    text_return =  f"""
    
    #-------------------------
    # define a new function
    #-------------------------  $$
    
def (): 
    
    \"""
    ----------------------
    Description:


    Input arguments:


    Return:


    Example:
      import yhc_module as yhc
       = yhc.()

    Date created: {today}
    ----------------------
    \"""


    func_name = ""

    #--- $$

    # error_msg = f"ERROR: function [func_name] does not support []. Available:"
    # raise KeyError(error_msg)
    # raise ValueError(error_msg)

    return
    
    
#-----------
# do_test
#-----------

do_test=True
#do_test=False

if (do_test):
        
    """
    
    return text_return

### lib_module

In [None]:
def lib_module():

    text_return = """
    #-------------------------
    # useful module command
    #-------------------------
    
    #--- import $$
    import yhc_module as yhc

    #--- help  $$
    help(yhc)          # see all functions and file path
    help(yhc.func1)    # see func1 
    dir(yhc)           # retun list of attributes and methods of an objecy
    """
    
    return text_return

### lib_np

In [None]:
def lib_np():

    text_return = """
    #-------------
    # Numpy functions
    #-------------
    
    #--- create arrays  $$
    np.arange(0,3)             # create an array = [0,1,2,3]
    np.linspace(-10., 10., 5)  # create an array of 5 elements ranging from -10 to 10

    """
    
    return text_return

### lib_plot_panel_string

In [None]:
def lib_plot_panel_string():
    text_return = """
def plot_panel_string(axes, 
                      varname = "none", 
                      xx = 1.5, yy = 825, fontsize = 30,
                      panel_strings = ["(a)", "(b)", "(c)", "(d)", "(e)", "(f)"], 
                     ):
    #----------------------------
    # plot a string on each ax
    #----------------------------
    
    func_name = "plot_panel_string"
    
    #--- modify xx and yy values given different variables
    if (varname == "temp"):
        xx = 282.5
    elif (varname == "sphum"):
        xx = 0.5
    elif (varname == "uv_div"):
        xx = -3
    elif (varname == "tdt_dyn" or varname == "tdt_hadv" or varname == "tdt_vadv"):
        xx = -22.
    elif (varname == "qdt_dyn" or varname == "qdt_hadv" or varname == "qdt_vadv"):
        xx = -28
    else:
        xx = 0
    
    #--- get the number of axes
    nax = len(axes)
        
    #--- loop for each ax and plot string
    for i in range(0,nax):
        ax0 = axes[i]
        string1 = panel_strings[i]
        
        #--- plot the string
        ax0.text(xx, yy, string1, fontsize = fontsize)

#-----------
# do_test
#-----------

do_test=True
#do_test=False

if (do_test):
    fig, (ax1,ax2) = plt.subplots(1,2, figsize=(6, 6))   # 1 row, 1 column
    
    yy_min = -500
    yy_max = 50
    xx_min = 0
    xx_max = 10
    ax1.plot( [xx_min,xx_max],[yy_min,yy_max],'r-')
    ax2.plot( [xx_min,xx_max],[yy_min,yy_max],'b-')

    axes = [ax1, ax2]
    
    varname = "swcre_toa"
    plot_panel_string(axes, varname = varname, panel_strings = ["gg","ff"]) 

"""
    
    return text_return

### lib_plt_basic

In [None]:
def lib_plt_basic():

    text_return = """
    #-------------
    #
    # matplotlib.pyplot general
    #
    #-------------  $$    

    #-------------------------- 
    #  open fig and ax
    #-------------------------- $$

    #      matplotlib.pyplot.subplot, https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.subplot.html  $$
    fig, (ax1, ax2, ax3) = plt.subplots(1,3, figsize=(18, 6))   # 1 row, 3 columns
    fig, ((ax_1, ax_2, ax_3), (ax_4, ax_5, ax_6)) = plt.subplots(nrows=2, ncols=3, figsize=(18, 12))  # 2 rows, 3 columns

    fig, axs = plt.subplots(2, 3, figsize=(18, 6))  # 2 rows, 3 columns. axs[0, 0] refers to the top-left subplot
                                                                         axs[1, 2] refers to the bottom-right subplot
    
    fig.suptitle("title", fontsize = 20, y=0.95)   # add title

    #--- set spacing betweeon subplots
    #      https://www.geeksforgeeks.org/how-to-set-the-spacing-between-subplots-in-matplotlib-in-python/ $$
    fig.tight_layout()
    #fig.tight_layout(pad=5.0)   # larger pad, larger space betweeon plots

    #-------------
    #
    # Set title and labels
    #
    #-------------  $$

    #--- main titles  $$
    fig.suptitle("DYCOMS SCM initial profiles", fontsize=20, y=95)
    
    ax.set_title("main titile")
    ax.set_title(var.attrs['long_name'], loc='left', fontsize = 18, y=0.9)
    ax.set_title(var.attrs['units'], loc='right')

    #--- x and y labels  $$
    ax.set_xlabel("X-Axis title")
    ax.set_ylabel("Y-Axis title")

    #---------------- 
    #  add text 
    #    Use LaTex convention: https://matplotlib.org/stable/tutorials/text/mathtext.html
    #---------------- $$
    ax1.text(xx, yy, 'text', fontsize = 15)
    
    var_dum.attrs['units'] = r"$10^6 s^{-1}$"  # 10^6/s

    #-------------------------- 
    #  save figure
    #-------------------------- $$

    plt.savefig('fig.png', dpi=300)
    
    """
    return text_return

### lib_pltset

In [None]:
def lib_pltset():
    text_return = """
#---------------------------
# set plt default settings
#    References: https://www.geeksforgeeks.org/matplotlib-pyplot-rc-in-python/
#---------------------------  $$

#### set default plt parameters & use rc function to update 

#--- set default values for font group  $$
font_default = {
    'size': 15,      # font size 
       }
plt.rc('font', **font_default)

       
#--- set default values for lines group  $$
lines_default = {
    'linewidth':3,
}
plt.rc('lines', **lines_default)

    """
    return text_return

### lib_pltcn_2d

In [None]:
def lib_pltcn_2d():
    text_return = """
    #-------------
    #
    # matplotlib.pyplot for 2-D filled contour plots
    #
    #-------------  $$

    #--- open subplots $$
    fig, (ax1, ax2) = plt.subplots(2,1, figsize=(18, 12))   # 2 row, 1 column
    fig.tight_layout(pad=5.0)

    #--- determine contour levels $$
    cn_levels = np.array([...])
    cn_levels = np.linspace(-100., 100., 11)
    
    #--- colormaps $$
    https://matplotlib.org/stable/tutorials/colors/colormaps.html

    cmap="coolwarm"   # temperature tendencies
    cmap="BrBG"       # moisture tendencies
    cmap="PiYG"       # omega
    camp="Spectral"   # divergence 

    #--- plot filled contours $$
    cn_attrs = set_dict_cn_attrs(varname)
    plot_cn = ax.contourf(X, Y, var, 20, cmap="OrRd", extend="both")  # automatically 20 levels  reversed cmap "OrRd_r"
    plot_cn = ax.contourf(X, Y, var, levels=cn_attrs['cn_levels'], cmap=cn_attrs['cmap'], extend="both")

    #--- set label bar $$
    lb_label = var.attrs['long_name']+" ("+var.attrs['units']+")"
    fig.colorbar(plot_cn, orientation='vertical', label=lb_label)

*******************************
*******************************
*******************************

def ax_def_cn_2d (ax, var, yy, vert_coord):
            
    #--- x & y label $$
    fontsize_label = 15

    ax.set_xlabel("Time step", fontsize=fontsize_label)  

    ylabel = yy.attrs['long_name']+" ("+yy.attrs['units']+")"
    ax.set_ylabel(ylabel, fontsize=fontsize_label)

    #--- strings $$
    fontsize_string = 12
    ax.set_title(var.attrs['long_name'], loc='left', fontsize=fontsize_string)
    ax.set_title(var.attrs['units'], loc='right', fontsize=fontsize_string)
        
    #--- reverse coordinate if needed $$
    if (vert_coord == "p"):
        ax.invert_yaxis()
    
    """
    return text_return

### lib_pltcn_map

In [None]:
def lib_pltcn_map():
    text_return = """

from cartopy.mpl.ticker import LatitudeFormatter, LongitudeFormatter

########################
########################
########################

def ax_def_cn_map (ax,
                   map_projection,
                   dict_cn_attrs=None):
    \"""    
    ----------------------
    Set attributes in cn_map plot using cartopy and matplotlib 

    Input arguments:
        ax: an Axes class variable
        map_projection: ccrs map projection class variable

    Return:
        update ax

    Example:
        map_projection = ccrs.PlateCarree(central_longitude=0)
        fig, ax = plt.subplots(figsize=(10, 6), subplot_kw={'projection': map_projection})
        ax_def_cn_map(ax, map_projection)
      
    References:
      - Cartopy Tick Labels: https://scitools.org.uk/cartopy/docs/latest/gallery/gridlines_and_labels/tick_labels.html

    Date created: 2023-10-15
    ----------------------
    \"""
    
    #--- add coastline 
    ax.add_feature(cfeature.COASTLINE)

    #--- add grid lines  
    gl = ax.gridlines(draw_labels=True, linestyle='--', color='gray', alpha = 1)  # add lat/lon grid lines
    gl.top_labels = False     # turn off labels on the top and right sides
    gl.right_labels = False
    
    #--- draw specific grid lines. I try to find a method to draw every grid line but not all labels. I couldn't find an easy way.
    #                              Perhaps the easiest one is to specify grid lines and let the program to determine the labels.  
    lon_grid_lines = [-124, -123, -122, -121, -120, -119, -118, -117, -116]
    gl.xlocator = FixedLocator(lon_grid_lines)
    gl.xformatter = LongitudeFormatter(zero_direction_label=True)

    #--- add title 
    fontsize=10
    ax.set_title(dict_cn_attrs['name'], loc='left', fontsize=fontsize)
    ax.set_title(dict_cn_attrs['units'], loc='right', fontsize=fontsize)

########################
########################
########################

def plot_box(ax, region="DYCOMS"):
    
    if (region == "DYCOMS"): 
        region_name = "DYCOMS (30-32.2N, 120-123.8W)"
        lowerlat =  30.    # 30N
        upperlat =  32.2   # 32.2N
        lowerlon =  236.2  # 123.8W
        upperlon =  240.   # 120W
    
    lon_range = upperlon - lowerlon
    lat_range = upperlat - lowerlat

    rect = mpatches.Rectangle((lowerlon, lowerlat), lon_range, lat_range, facecolor='none', edgecolor='cyan', linewidth=2, transform=map_projection)
    ax.add_patch(rect)

########################
########################
########################

from matplotlib.ticker import FixedLocator

def plot_cn_map(ax, map_projection,
                var, varname,
                do_set_cn_attrs = True, 
               ):

    #-------------
    # plot cn_map
    #-------------
    
    if (do_set_cn_attrs):
        
        dict_cn_attrs = set_dict_cn_attrs(varname)  # set contour attributie
        #print(dict_cn_attrs)
        cn_map_region = ax.contourf(var.lon, var.lat, var, extend='both', transform=map_projection, levels=dict_cn_attrs['cn_levels'], cmap=dict_cn_attrs['cmap']) 

    else:
        #--- plot cn_map
        cn_map_region = ax.contourf(var.lon, var.lat, var, transform=map_projection) 

    #--- set cn_map attributes
    ax_def_cn_map(ax, map_projection, dict_cn_attrs=dict_cn_attrs)
    
    #--- plot the colorbar
    cbar = plt.colorbar(cn_map_region, ax=ax, orientation='vertical', shrink=0.45) #, label=dict_cn_attrs['label'])
    cbar.set_label(label=dict_cn_attrs['label'], fontsize=8)
    
    #-------------
    # plot a region
    #-------------    
    plot_box(ax)

#-----------
# do_test
#-----------

do_test=True
#do_test=False
    
if (do_test):
    map_projection = ccrs.PlateCarree(central_longitude=0)
    fig, ax = plt.subplots(figsize=(10, 6), subplot_kw={'projection': map_projection})
    varname = 'swabs_toa_diff'

    varname = 'swabs_toa'
    do_set_cn_attrs=True
    plot_cn_map(ax, map_projection, var_era5, varname, do_set_cn_attrs=do_set_cn_attrs)
    """
    
    return text_return

### lib_pltxy

In [None]:
def lib_pltxy():
    text_return = """
    #-------------
    #
    # matplotlib.pyplot for XY plots
    #
    #-------------  $$

    #--- XY line styles  
    #    * Line colors, https://matplotlib.org/stable/gallery/color/named_colors.html
    #        red (r), blue (b), green (g), cyan (c), magenta (m), yellow (y), black (k), white (w)
    #
    #.   * Line dash pattern, https://matplotlib.org/stable/gallery/lines_bars_and_markers/linestyles.html
    #.       {'-', '--', '-.', ':'} {solid, dashed, dashdot, dotted}
    #
    #.   * Markers, https://matplotlib.org/stable/api/markers_api.html
    #        {'.', 'o', '^', 's'} = {point, circle, triangle_up, square}  $$

    #--- plot  $$
    ax1.plot(xx1, yy1, 'r',
             xx2, yy2, 'b--',
             )
    ax1.plot(xx1, yy1, 'k-^, label="test")

    ax1.plot(xx3, yy3, label='1st data', c='orange', ls='-.', lw=5.)  # set labe, color, linestyle, and linewidth

    #--- add legend 
    #.     https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.legend.html
    #-------------- $$
    legend_0 = ["",""]
    legend_size = 12
    ax1.legend(legend_0)
    ax1.legend(legend_0, prop={'size': legend_size}, loc='upper left')  # change legend size and location

    #--- set X and Y axis range  $$
    ax1.set_xlim([xmin, xmax])
    ax1.set_ylim([ymin, ymax])

    #--- reverse axes  $$
    ax.invert_xaxis()
    ax.invert_yaxis()

    #-------------
    #
    # Fill between lines
    #.  reference: https://www.geeksforgeeks.org/matplotlib-pyplot-fill_between-in-python/
    #              https://pythonguides.com/matplotlib-fill_between/
    #-------------  $$

    #--- fill between y_top & y_bottom at xx range $$
    
    xx = np.arrary([1,5])
    y_top = 10.
    y_bottom = 0.
    
    ax1.fill_between( 
            xx, y_top, y_bottom, color='gray', alpha=0.2,
                )

    #-------------
    #
    # Set grid lines and tick markers
    #
    #-------------  $$

    #--- set grid lines  $$
    #      https://www.pythoncharts.com/matplotlib/customizing-grid-matplotlib/
    #      https://www.tutorialspoint.com/matplotlib/matplotlib_setting_ticks_and_tick_labels.htm
    ax.grid(True)

    ax.grid(which='major', color='gray', linestyle='-.', linewidth=0.7)        # set both X and Y grids
    ax.grid(which='minor', color='gray', linestyle='-.'. linewidth=0.7)

    ax.xaxis.grid(which="minor", color='gray', linestyle='-.', linewidth=0.7)  # set X grids
    ax.xaxis.grid(which="major", color='gray', linestyle='-.', linewidth=0.7)

    #--- minor tickers $$
    ax.minorticks_on()                              # turn on minor tick marks
    ax.xaxis.set_minor_locator(MultipleLocator(1))  # draw minor tickers every 1 units

    #--- setting ticks and tick labels
    #.     https://www.tutorialspoint.com/matplotlib/matplotlib_setting_ticks_and_tick_labels.htm  $$
    ax.set_xticks([0,2,4,6])
    ax.set_xticklabels(['zero','two','four','six'])
    
    #xvalues = np.arange()
    #xlabels = np.arange()
    #ax.set_xticks(xvalues)        # tickmark values
    #ax.set_xticklabels(xlabels)   # tickmark labels

    """
    return text_return

### lib_py_basic

In [None]:
def lib_py_basic():
    text_return = """
    #-----------------------------
    # python built-in functions
    #-----------------------------

    #--- stop the code and return error message $$
    
    error_meg = f"ERROR: value [{value}] is not supported"
    raise ValueError(error_meg)
        
    #--- copy a string $$
    text2 = text1[:]   # text2 = text1 also works. To avoid confusion, adding [:] might be better
    """
    
    return text_return

### lib_xr

In [None]:
def lib_xr():
    text_return = """
    #-------------
    # Xarray DataArray
    #-------------
    
      #--- create a DataArray & set attributes $$
      da = xr.DataArray(data, dims=['time', 'lat', 'lon'], coords=[time, lat, lon])
      
      da = xr.DataArray(np.zeros(dim0*dim1).reshape(dim0,dim1), dims=['time','plev'])

      da.attrs['units']="K"
    
      #--- reorder dimensions $$
      # da[time, lat, lon]
      da_jit = da.transpose("lat","lon","time")

    #-*************************************
    # Open netCDF files using Xarray
    #-*************************************  $$
      
      #--- Open a netCDF file  $$
      filename = "./test111.nc"
      ds1 = xr.open_dataset(filename)
      ds1
      
      #--- Open multiple netCDF file  
      #.     open_mfdataset sucks. I spent an hour searching on internet, but I still didn't know 
      #      how to use it. There was no (or useless) examples.  
      #. $$

      datapath = "../data/"
      filename = [file1.nc, file2.nc]    # file has dimension (time, lat, lon, ...)
      filename = [ datapath + ff for ff in filename ]
      
      da = xr.open_mfdataset(filename, concat_dim=["lon"], combine='nested')  # this will return da (lon=2, ...)

      #--- get variable in the DataSet $$
      ds1.temp
      
      varname = "temp"
      var = ds1.get(varname)

    #-************************************
    # Xarray Dataset
    #.   - Save multiple DataArray to a DataSet
    #.   - Merge multiple DataSet
    #.   - Save a DataSet to a new netCDF file
    #-************************************ $$
    
      #--- Convert multiple DataArray into a DataSet  $$
      # var1[x,y], var2[x]
    
      ds = va1.to_dataset(name = "varA")
      ds.attrs["contact"] = "yihsuan"     # set global attribute
    
      ds['varB'] = var2
    
      #--- Merge multiple datasets
      #.     ds1 and ds2 are dataset  $$
      ds_merge = xr.merge([ds1, ds2]) 
    
      #--- Save a DataSet to a new netCDF file  $$
      new_filename = "./test111.nc"
  
      ds.to_netcdf(path=new_filename)
      ds.close()

      ds1 = xr.open_dataset(new_filename)
      ds1
    """
    return text_return

### lib_read_data

In [None]:
def lib_read_data():
    text_return = """

######################
######################
######################

## Set machine name

machine_name = "Mac_studio"

if (machine_name == "Mac_studio"):
    datapath = "/Users/yi-hsuanchen/Downloads/yihsuan/research/projects/Sc_diag/data/"

elif (machine_name == "WD"):
    datapath = "/Volumes/My_WD_Passport/manuscript/Sc_diag/data"

else:
    error_meg = f"ERROR: machine_name [{machine_name}] is not supported"
    datapath = ""
    raise ValueError(error_meg)

######################
######################
######################

def read_data_SS (choice, datapath=datapath+""): 

    func_name = "read_data_SS"
    
    if (choice == "A"):        
        fnames = ["A"]
        
    elif (choice == "B"):        
        fnames = ["B"]

    else:
        error_msg = f"ERROR: function [{func_name}] does not support [{choice}]."
        raise ValueError(error_msg)        
    
    #--- read files
    fnames = [datapath+fname1 for fname1 in fnames]
                
    #da_return = xr.open_mfdataset(fnames, decode_cf=False)  # ERA5 variables are in short format
    da_return = xr.open_mfdataset(fnames) 

    return da_return

#-----------
# do_test
#-----------

#do_test=True
do_test=False

if (do_test):
    choice = "A"

    da1 = read_data_SS(choice)    
    
    """
    
    return text_return


### lib_read_data_merra2

In [None]:
def lib_read_data_merra2():

    text_return = """
    
######################
######################
######################

def read_data(choice, datapath00="/Users/yi-hsuanchen/Downloads/yihsuan/research/projects/Sc_diag/data/"):

    \"""
    read a dataset
    \"""
    func_name = "read_data"
    
    #--- datapath
    if (choice == "MERRA2_slv"):
        datapath = datapath00+"data.MERRA-2/tavg1_2d_slv_Nx.200107/"
        fnames = [datapath+"MERRA2_300.tavg1_2d_slv_Nx.20010710.SUB.nc",
                  datapath+"MERRA2_300.tavg1_2d_slv_Nx.20010711.SUB.nc",
                 ]
        da = xr.open_mfdataset(fnames)
        da = yhc.wrap360(da)            # MERRA-2 has [-180,180' longitudes, change it to [0-360].

    else:
        error_msg = f"ERROR: function [{func_name}] does not support choice=[{choice}]."
        raise KeyError(error_msg)
    
    #--- return
    return da


######################
######################
######################
def read_merra2_var(da_merra2, varname):
    \"""
    read a variable in MERRA-2 dataset
    \"""
    var_merra2 = da_merra2.get(varname)
    return var_merra2

#-----------
# do_test
#-----------

#do_test=True
do_test=False
    
if (do_test):
    choice = "MERRA2_slv"
    da_merra2 = read_data(choice)
    varname = "TS"
    var_merra2 = read_merra2_var(da_merra2, varname)
    
#var_merra2
"""
    return text_return


### lib_read_data_era5

In [None]:
def lib_read_data_era5():

    text_return = """

import cftime

######################
######################
######################

def read_dataset(choice, datapath00="/Users/yi-hsuanchen/Downloads/yihsuan/research/projects/Sc_diag/data/"):

    \"""
    read a data
    \"""
    func_name = "read_data"
    
    #--- datapath
    if (choice == "ERA5_single_level"):
        datapath = datapath00+"data.ERA5/"
        fnames = [datapath+"ERA5-2001July-single_level.nc",
                 ]
        da = xr.open_mfdataset(fnames)
        da = xr.open_mfdataset(fnames, decode_cf=False)  # In ERA5 data is stored in short format, set decode_cf=False otherwise the xarray will read wrong values
        da = yhc.wrap360(da, lon='longitude')            # # MERRA-2 has [-180,180' longitudes, change it to [0-360].

    else:
        error_msg = f"ERROR: function [{func_name}] does not support choice=[{choice}]."
        raise KeyError(error_msg)
    
    #--- return
    return da

######################
######################
######################

def read_era5_var(da_era5, varname,
                  var_format="short"):

    \"""
    read a variable in ERA5 dataset
    \"""

    #--- read a variable
    var_era5 = da_era5.get(varname)
    
    #--- convert short type to float type if needed
    if (var_format == "short"):
    
        #--- if scale_factor and add_offset are not read in, set them manually
        if (varname == "aaa"):
            var_era5.attrs['scale_factor']=1
            var_era5.attrs['add_offset']=1

        #--- convert short type to float type
        var_era5 = (var_era5*var_era5.scale_factor + var_era5.add_offset)

    #--- reorganize ERA5 coordinates
    var_era5 = var_era5.rename({'longitude':'lon', 'latitude':'lat'})  # rename lat/lon coordinate names
    var_era5 = var_era5.sortby('lat', ascending=True)                  # change latitude to ascending order, e.g. [28N-34N]
        
    #--- modify time coordinate from "hours since" to readable format "'%Y-%m-%dT%H:%M:%S"
    time_in_hours = var_era5["time"].values
    reference_time = cftime.DatetimeNoLeap(1900, 1, 1, 0, 0, 0)  # Adjust based on your reference time
    date_coordinates = cftime.num2date(time_in_hours, var_era5["time"].attrs["units"], calendar="standard")
    date_strings = [date.strftime('%Y-%m-%dT%H:%M:%S') for date in date_coordinates]
    var_era5["time"] = xr.DataArray(date_strings, dims="time")
    
    #---return
    return var_era5

#-----------
# do_test
#-----------

do_test=True
#do_test=False

if (do_test):
    
    choice = "ERA5_single_level"
    da_era5 = read_dataset(choice)
    
    varname = "u10"
    var_era5 = read_era5_var(da_era5, varname)

#var_era5
"""
    return text_return

### lib_select_var

In [None]:
def lib_select_var():

    text_return = """
    
def select_var(var,
               region="NE_CA",
               time_slice="mean", 
               data_source=None, 
               RF=None, 
              ):
    \"""
    Select a subsample of variable in given (time, lat, lon) ranges
    var: a Xarray DataArray
    region: a region given by lat/lon ranges
    time_slice: time_slice
    \"""
    
    #--- set time slice
    if (data_source == "MERRA2"):
        if (RF == "RF01"): time_slice="2001-07-10T10:30:00.000000000"
    
    elif (data_source == "ERA5"):
        if (RF == "RF01"): time_slice=slice("2001-07-10T09:00:00.000000","2001-07-10T12:00:00.000000")

    #--- set region
    if (region == "NE_CA"):
        lowerlon=235; upperlon=245; lowerlat=28; upperlat=35
    else:
        lowerlon=-1000; upperlon=1000; lowerlat=-1000; upperlat=1000
    
    #--- select and return 
    lon_slice = slice(lowerlon, upperlon)
    lat_slice = slice(lowerlat, upperlat)
    
    #--- select time
    if (time_slice == "mean"):
        var_region = var.sel(lat=lat_slice, lon=lon_slice).mean("time")
    else:
        var_region = var.sel(lat=lat_slice, lon=lon_slice, time=time_slice)

    #--- return variable
    return var_region

#-----------
# do_test
#-----------

#do_test=True
do_test=False
    
if (do_test):
    
    da = read_data()
    varname = "TS"
    
    var = da.get(varname)
    var_2d = select_var(var, varname) #, region=region)

#var_2d
"""
    return text_return

### lib_str

In [None]:
def lib_str():
    text_return = """
    #---------------------
    # string operation
    #---------------------

    #--- Append suffix / prefix to strings in list $$
    filenames_scm = [A, B, C]
    filenames_scm = [datapath+file1 for file1 in filenames_scm]
    
    #-------------------
    # function string
    #------------------- $$
    
    text1 = f"ERROR: time coord [{time_coord}] is not supported"  # 'f' means function, {} use the variable value
    
    text2 = r"$W \ m^{-2}$"  # 'r' means to use LaTeX format

    #-------------------
    # set title properties
    #------------------- $$
    
    ax.set_title('Custom Title', fontweight='bold', fontsize=16, loc='left')
    
    title_properties = {'fontweight': 'bold', 'fontsize': 16, 'loc': 'left'}
    ax.set_title('Custom Title', **title_properties)

    """
    return text_return


### lib_set_cn_attrs_dict

In [None]:
def lib_set_cn_attrs_dict():
    text_return = """
from matplotlib.colors import BoundaryNorm

def set_dict_cn_attrs (varname):
    \"""    
    ----------------------
    Set contour attributes

    Input arguments:
        var: an Xarray.DataArray variable
        varname: variable name

    Return:
        1. return  a dictionary variable, dict_cn_attrs = {'cn_levels','cmap',label,name,units'}
 
    Example:
        gg = set_cn_attrs(var_tmp, 'tdt_dyn')
        print(gg['units'])

    References:
        colormaps:  https://matplotlib.org/stable/tutorials/colors/colormaps.html

    Date created: 2023-10-15
    ----------------------
    \"""
    
    #--- set cn levels
    tdt_dyn_cnlevels = np.arange(-25., 27.5, 2.5)
    qdt_dyn_cnlevels = np.arange(-10., 11, 1.)
    
    #---------------------
    # set cn attributes
    #---------------------
    if (varname == "skdfksjdkf"):
        cn_levels = tdt_dyn_cnlevels
        cmap="c1"
        label = "l1"
        name="n1"
        units="u1"
    
    elif (varname == "tdt_dyn"):
        cn_levels = tdt_dyn_cnlevels
        cmap="coolwarm"
        label = "3D dynamical T tendencies (K/day)"
        name="tdt_dyn"
        units="K/day"
        
    elif (varname == "qdt_dyn"):
        cn_levels = qdt_dyn_cnlevels
        cmap="BrBG"
        label = "3D dynamical Q tendencies (g/kg/day)"
        name="qdt_dyn"
        units="g/kg/day"
        
    elif (varname == "omega"):
        cn_levels = np.arange(-120., 130., 10.)
        cmap="PiYG"
        label = "Omega (hPa/day)"
        name="omega"
        units="hPa/day"

    elif (varname == "uv_div"):
        cn_levels = np.arange(-14., 14., 1.)
        cmap="Spectral"
        label=r"Divergence ($10^6 s^{-1}$)"
        name="Divergence"
        units=r"$10^6 s^{-1}$"

    elif (varname == "swabs_toa"):
        cn_levels = 15
        cmap="plasma"
        name = "TOA net downward SW flux"
        units = r"$W m^{-2}$"
        label = name+" ("+units+")"
        
    elif (varname == "swabs_toa_diff"):
        #cn_levels = np.array([-50, -40, -30, -20, -15, -10, -5, 0, 
        #                         5, 10, 15, 20, 30, 40, 50])
        #cn_levels = np.linspace(-10, 30, 41)
        
        cn_levels = np.arange(-120,140,10)
        
        cmap="Spectral_r"
        name = "TOA net downward SW flux diff"
        units = r"$W m^{-2}$"
        label = name+" ("+units+")"
        norm = BoundaryNorm(cn_levels, ncolors=len(cn_levels))

    else:
        cn_levels = np.array([15])
        cmap = "viridis"
        name = "Var"
        units = "units"
        label = name+" ("+units+")"

    #---------------------------------
    # return a dictionary variable
    #---------------------------------

    dict_cn_attrs = {
        'varname':varname,
        'cn_levels':cn_levels,
        'cmap': cmap,
        'label': label,
        'name': name,
        'units': units,
                  }
    
    return dict_cn_attrs

#-----------
# do_test
#-----------

#do_test=True
do_test=False

if (do_test):
    var_tmp = xr.DataArray(1)
    var_tmp.attrs['standard_name']="ggg"
    var_tmp.attrs['units']="KK"

    dict_cn_attrs = set_dict_cn_attrs('swabs_toa_diff')
    print(dict_cn_attrs)
    print(dict_cn_attrs['cmap'])
    """
    
    return text_return

## lib

In [None]:
def lib(*keywords, color = 0):
    """    
    ----------------------
    My library of Python codes

    Input arguments:
      keywords: string variables. 
      color   : logical variable. Turn on/off colored texts

    Return:
      print out on screen

    Example:
      import yhc_module as yhc
      yhc.lib("plt","")                 
      yhc.lib("plt","pltxy", color = False)


    Date created: 2022/07/23
    ----------------------
    """
    
    color_comment = '\033[36m'  # cyan color
    color_black   = '\033[0m'   # black color
    
    #------------------
    # keyword list
    #------------------
    
    #--- plot-related notes
    keywords_plt = ['plt_basic','pltset',
                    'pltcn_2d','pltcn_map',
                    'pltxy','ax_def_xy']
    
    #--- python built-in notes
    keywords_py = ['py_basic','str','dict','module']
    
    #--- packages
    keywords_pk = ['np','xr']

    #--- basic template
    keywords_tmp = ['fdef','read_data','read_data_merra2','read_data_era5', 'select_var',
                    'plot_panel_string','set_cn_attrs_dict']

    #--- function template
    keywords_func = ['dddd']
    
    #--- combined all lists
    keywords_class = ['plt','py','pk','tmp','func']
    keywords_list = keywords_class + keywords_plt + keywords_py + keywords_pk + keywords_tmp + keywords_func
     
    #------------------
    # print out
    #------------------
    
    for key1 in keywords:
        
        color=0  # set default color
        
        if (key1 == "sdfkjskdfh") : text = "test"
        
        #--- lists
        elif (key1 == "plt")     : text = keywords_plt; short_notes = f"[{key1}] related notes"          ;color=1 
        elif (key1 == "py")      : text = keywords_py ; short_notes = f"[{key1} (python)] related notes" ;color=1 
        elif (key1 == "pk")      : text = keywords_pk ; short_notes = f"[{key1} (package)] related notes";color=1 
        elif (key1 == "tmp")     : text = keywords_tmp; short_notes = f"[{key1} (template)] templates"   ;color=1 
        elif (key1 == "func")    : text = keywords_func; short_notes = f"[{key1} (function)] templates"  ;color=1 
        
        #--- notes
        elif (key1 == "ax_def_xy"): text = lib_ax_def_xy()
        elif (key1 == "dict")     : text = lib_dict()
        elif (key1 == "module")   : text = lib_module()
        elif (key1 == "np")       : text = lib_np()
        elif (key1 == "plt_basic"): text = lib_plt_basic()        
        elif (key1 == "pltcn_2d") : text = lib_pltcn_2d() 
        elif (key1 == "pltcn_map"): text = lib_pltcn_map()
        elif (key1 == "str")      : text = lib_str()
        elif (key1 == "pltset")   : text = lib_pltset()
        elif (key1 == "pltxy")    : text = lib_pltxy()
        elif (key1 == "py_basic") : text = lib_py_basic()
        elif (key1 == "xr")       : text = lib_xr()
        
        #--- templates
        elif (key1 == "fdef")               : text = lib_fdef()
        elif (key1 == "select_var")         : text = lib_select_var()
        elif (key1 == "read_data")          : text = lib_read_data()     
        elif (key1 == "read_data_merra2")   : text = lib_read_data_merra2()        
        elif (key1 == "read_data_era5")     : text = lib_read_data_era5()        
        elif (key1 == "plot_panel_string")  : text = lib_plot_panel_string()
        elif (key1 == "set_cn_attrs_dict")  : text = lib_set_cn_attrs_dict()

        #elif (key1 == "")       : text = lib_()
        
        else:
            text = ""
            print("ERROR: ["+ key1+"] is not supported yet.")
            print("  view groups: ["+', '.join(keywords_class)+"]")
            
            #--- print all keywords in the same line
            #str1 = ', '.join(sorted(keywords_list))
            #print(f"  all supported keywords: [{str1}]")

            #--- print every five keywords in the same line
            keywords_list_sorted = sorted(keywords_list)
            print("  all supported keywords:")
            print("")
            for i in range(0, len(keywords_list_sorted), 5):
                pair = keywords_list_sorted[i:i+5]  # Get the next two variables
                print("    "+",  ".join(pair))  # Join and print them on the same line
                        
        #--- print out
        if (color == 0):  # notes
            print(text.replace("#-", color_comment+"#-").replace("$$", color_black))
            
        elif (color == 1):  # list
            print(short_notes)
            #print("available list")
            print(text)
        else:
            print(text)

#lib('select_var')


## lib_func options (out-of-date)

### add_lon_cyclic

In [None]:
def lib_func_add_lon_cyclic():
    text_return = """

#--- add a cyclic point in lon, otherwise, there will be a white line at the edge
def add_lon_cyclic(var): 
    var_return, lon_cyclic = cutil.add_cyclic_point(var, coord=var.lon)  

    return var_return, lon_cyclic

#-----------
# do_test
#-----------

#do_test=True
do_test=False

if (do_test):
    sw_all_CERES_mean_plot, lon_CERES = add_lon_cyclic(sw_all_CERES_mean)
    """
    return text_return

### plot_scm_time_series

In [None]:
def lib_func_plot_scm_time_series():
    text_return = """def plot_scm_time_series(datasets, time_coord, 
                         ax, names_styles, varname, 
                         do_unit_convert = False, var_factor = 1, 
                         do_var_stats = True, tt_start = "none", tt_end = "none", 
                         plot_les = "none", 
                        ):
    
    func_name = "plot_scm_time_series"
    
    #------------------------
    # set shared parameters
    #------------------------
    legend_size = 16
    
    #--------------
    # check parts
    #--------------
    
    #--- make sure # of dataset is not greater than # of names_styles
    nda = len(datasets)        # number of dataset
    nname = len(names_styles) # number of names
    
    if (nda > nname):
        error_msg = f"ERROR: function [{func_name}], # of dataset [{nda}] > # of names_styles [{nname}]"
        raise ValueError(error_msg)
      
    #--------------------------------------
    # plot
    #--------------------------------------
    
    #--- get names of all datasets
    da_names = list(names_styles.keys())  
    
    #--- loop for all datesets
    i=0
    for da1 in datasets:

        #--------------------------------------
        # plot the variable in each dataset
        #--------------------------------------
        
        #--- get dataset custom name and color style
        name1 = da_names[i]
        style1 = names_styles[name1]
        
        #--- check whether da1 has varname
        check_var_in_da(da1, varname)
        
        #--- read variable
        var1 = da1.get(varname)[:,0,0]
        if (do_unit_convert): 
            var1 = yhc.unit_convert(var1)   # change units if needed
        else:
            var1 = var1.copy()*var_factor
            if (var_factor == 1):
                unit1 = var1.attrs['units']
            else:
                unit1 = var1.attrs['units']+" / "+str(var_factor)
        
        #--- set time coordinate
        #ntime = len(da1.time)
        #tt = np.arange(0., ntime, 1)
        tt = time_coord
        
        #--- plot the time series
        ax.plot(tt, var1, style1) 

        #--- set labels
        ax.set_title(varname, loc='left')
        ax.set_title(unit1, loc='right')
        ax.set_xlabel("Hours")
        ylabel = var1.attrs['long_name']+" ("+unit1+")"
        ax.set_ylabel(ylabel)
        
        #--- plot legend
        ax.legend(da_names, prop={'size': legend_size})

        #--------------------------------------
        # do some statistics of the variables
        #--------------------------------------
        
        if (do_var_stats):
        
            #--- get time range
            if (tt_start != "none" and tt_end != "none"):
                tt1 = tt_start
                tt2 = tt_end
            else:
                tt1 = 0
                tt2 = len(var1)
                
            #--- print time coordinates
            if (i==0): print(var1.time[tt1:tt2])

            var1_avg = var1[tt1:tt2].mean("time")
            var1_avg_str = '{:.1f}'.format(np.array(var1_avg))

            string= f"Exp [{name1}], avg [{varname}] over time step [{tt1}, {tt2}] = {var1_avg_str}"
            #print("")
            print(string)

        #----------------------------- 
        # loop for another dataset
        #-----------------------------
        #print(i)
        i=i+1

#-----------
# do_test
#-----------
do_test=True
#do_test=False

if (do_test):
    fig, ax1 = plt.subplots(1,1, figsize=(25, 10))
    
    names_styles = {
        "da1": "r--",
        "da2": "b--",
        "da3": "g--",
        "da4": "m--",
    }
    
    varname = "LWP"
    tt_start = 0
    tt_end = 10
    
    das = [da_rf02_am4RAD_ICStv_FStv, da_rf02_am4RAD_ICStv_2Xdiv, da_rf02_scmRAD_ICStv_FAM4n, da_rf02_am4RAD_ICStv_L07div]

    plot_scm_time_series( das, 
                         time_coord=tt, 
                         tt_start = tt_start, tt_end = tt_end, 
                         plot_les = "George_RF02",
                         names_styles = names_styles, varname = varname, ax=ax1, var_factor = 1000)
    """ 
    
    return text_return

### plot_cn_pt

In [None]:
def lib_func_plot_cn_pt():
    text_return = """
def plot_cn_pt(time_merra2, plev_merra2, var_merra2,
               time_am4, plev_am4, var_am4,
               varname,  
               mlev_merra2 = "Np", 
              ):

    fig, (ax1, ax2) = plt.subplots(2,1, figsize=(18, 12))   # 2 row, 1 column
    fig.tight_layout(pad=6.0)

    if (varname == "tdt_dyn"):
        #cn_levels = np.linspace(-10., 10., 21)
        cn_levels = np.arange(-10., 11., 1.)
        cmap="coolwarm"
        label = "3D dynamical T tendencies (K/day)"
        name="tdt_dyn"
        units="K/day"
        
    else:
        cn_levels = 15
        cmap = "viridis"
        label = "Var (units)"

    #--- ax1, var_merra2
    if (mlev_merra2 == "Nv"): 
        plot_var_merra2 = ax1.contourf(time_merra2.transpose("lev","time"), plev_merra2.transpose("lev","time"), var_merra2.transpose("lev","time"), levels = cn_levels, cmap=cmap)
    else:
        plot_var_merra2 = ax1.contourf(time_merra2, plev_merra2, var_merra2.transpose("lev","time"), levels = cn_levels, cmap=cmap)
    
    fig.colorbar(plot_var_merra2, ax=ax1, orientation='vertical', label=label)
    ax_def_cn(ax1)

    #--- ax2, var_am4
    plot_var_am4 = ax2.contourf(time_am4.transpose("plev","time"), plev_am4.transpose("plev","time"), var_am4.transpose("pfull","time"), levels = cn_levels, cmap=cmap)

    fig.colorbar(plot_var_am4, ax=ax2, orientation='vertical', label=label)
    ax_def_cn(ax2)

    """
    
    return text_return

### read_am4_pt_ijavg

In [None]:
def lib_func_read_am4_pt_ijavg():
    text_return = """
    
    def read_am4_pt_ijavg(choice, varname, region): 
    
    \"""
    ----------------------
    Description:
       Read AM4 (time, pressure/level, lat, lon) data on native levels, and then
       convert it to a two-dimensional pressure-time data averaged over a region (lat, lon).

    Input arguments:
      choice: a string used in read_data. 
      varname: variable name in the MERRA-2 data
      region: a string used in yhc.get_area_avg

    Return:
      var_ijavg  : 2-d pressure-time data
      plevs_ijavg: pressure levels 

    Example:
      See "do_test" section 

    Date created: 2023-01-11
    ----------------------
    \"""

    func_name = "read_am4_pt_ijavg"
    model = "AM4_L33_native"
    
    #--- 
    da_am4 = read_data.read_am4_data(choice)
    var = da_am4.get(varname)
    
    ps = da_am4.ps  # surface pressure
    
    var_ijavg, plevs_ijavg = yhc.get_region_profile(var, region, model = model, ps = ps, plevs = "pfull")

    return var_ijavg, plevs_ijavg

#-----------
# do_test
#-----------

#do_test=True
do_test=False

if (do_test):
    choice = "nudgeAM4_July10_12"
    region = "DYCOMS"
    varname = "tdt_dyn"
    #time_step=slice(0,23)
    
    var_ijavg, plevs_ijavg = read_am4_pt_ijavg(choice, varname, region)
    
    #print(var_ijavg)
    #print(plevs_ijavg)

    
    """
    return text_return

### read_merra2_pt_ijavg

In [None]:
def lib_func_read_merra2_pt_ijavg():
    text_return = """
    
def read_merra2_pt_ijavg(choice, varname, region,
                         mlev = "Np"): 
    
    \"""
    ----------------------
    Description:
       Read MERRA-2 (time, pressure/level, lat, lon) data, and then
       convert it to a two-dimensional pressure-time data averaged over a region (lat, lon).

    Input arguments:
      choice: a string used in read_data. 
      varname: variable name in the MERRA-2 data
      region: a string used in yhc.get_area_avg
      mlev: data vertical levels. "Np" is on pressure levels, "Nv" is on model hybrid levels.

    Return:
      var_ijavg  : 2-d pressure-time data
      plevs_ijavg: pressure levels 

    Example:
      See "do_test" section 

    Date created: 2023-01-04
    ----------------------
    \"""

    func_name = "read_merra2_pt_ijavg"

    #--- 
    da_merra2 = read_data.read_merra2_data(choice)
    da_merra2 = yhc.wrap360(da_merra2)
    var = da_merra2.get(varname)
    
    #--- get domain-avg profile
    var_ijavg = yhc.get_area_avg(var, region)
    
    #--- get pressure levels
    if (mlev == "Np"): 
        plevs_ijavg = da_merra2.lev
        
    elif (mlev == "Nv"):
        PL_merra2 = da_merra2.get("PL")
        PL_ijavg = yhc.get_area_avg(PL_merra2, region)
        plevs_ijavg = PL_ijavg

    elif (mlev == "Nv_none"):
        plevs_ijavg = 0.        
        
    else:
        error_msg = f"ERROR: function [{func_name}] does not support mlev=[{mlev}]."
        raise ValueError(error_msg)
    
    #print(plevs_ijavg)
    #print(da_merra2)
    
    return var_ijavg, plevs_ijavg

#-----------
# do_test
#-----------

#do_test=True
do_test=False

if (do_test):
    choice = "qdt_July10_12"
    region = "DYCOMS"
    varname = "DQVDTDYN"
    
    var_ijavg, plev_merra2 = read_merra2_pt_ijavg(choice, varname, region)
    
    #print(da_merra2)
    #print(var_ijavg)
    print(plev_merra2)
    
    """
    
    return text_return

### set_cn_attrs (combined with lib_set_cn_attrs)

In [None]:
def lib_func_set_cn_attrs():
    text_return = """
    
def set_cn_attrs (var, 
                  varname = "none"):
    
    #--- set cn levels
    tdt_dyn_cnlevels = np.arange(-25., 27.5, 2.5)
    qdt_dyn_cnlevels = np.arange(-10., 11, 1.)
    
    #---------------------
    # set cn attributes
    #---------------------
    if (varname == "tdt_dyn"):
        cn_levels = tdt_dyn_cnlevels
        cmap="coolwarm"
        label = "3D dynamical T tendencies (K/day)"
        name="tdt_dyn"
        units="K/day"
        
    elif (varname == "qdt_dyn"):
        cn_levels = qdt_dyn_cnlevels
        cmap="BrBG"
        label = "3D dynamical Q tendencies (g/kg/day)"
        name="qdt_dyn"
        units="g/kg/day"
        
    elif (varname == "omega"):
        cn_levels = np.arange(-120., 130., 10.)
        cmap="PiYG"
        label = "Omega (hPa/day)"
        name="omega"
        units="hPa/day"

    elif (varname == "uv_div"):
        cn_levels = np.arange(-14., 14., 1.)
        cmap="Spectral"
        label=r"Divergence ($10^6 s^{-1}$)"
        name="Divergence"
        units=r"$10^6 s^{-1}$"

    elif (varname == "TOA_SW_diff"):
        cn_levels = np.array([-50, -40, -30, -20, -15, -10, -5, 0, 
                                 5, 10, 15, 20, 30, 40, 50])
        cmap="seismic_r"
        name = "TOA outgoing SW flux diff"
        units = var.units
        label = name+" ("+units+")"

    else:
        cn_levels = 15
        cmap = "viridis"
        label = "Var (units)"

    #---------------------------------
    # return a dictionary variable
    #---------------------------------

    return_dict = {
        'cn_levels':cn_levels,
        'cmap': cmap,
        'label': label,
        'name': name,
        'units': units,
                  }
    
    return return_dict

#-----------
# do_test
#-----------

#do_test=True
do_test=False

if (do_test):
    var_tmp = xr.DataArray(1)
    var_tmp.attrs['standard_name']="ggg"
    var_tmp.attrs['units']="KK"

    gg = set_cn_attrs(var_tmp, 'tdt_dyn')
    
    print(gg['units'])
    
    """
    
    return text_return

## lib_func (out-of-date)

In [None]:
def lib_func(*keywords, color = False):
    """    
    ----------------------
    My library of Python functions, These functions can be copied to ipynb program.

    Input arguments:
      keywords: string variables. 
      color   : logical variable. Turn on/off colored texts

    Return:
      print out on screen

    Example:
      import yhc_module as yhc
      yhc.lib("f1")                 
      yhc.lib("f1", "f2", color = False)


    Date created: 2023/01/11
    ----------------------
    """
    
    color_comment = '\033[36m'  # cyan color
    color_black   = '\033[0m'   # black color
    
    #------------------
    # keyword list
    #------------------

    #--- plot-related functions
    keywords_plt = ['ax_def_xy','plt_basic','pltcn_2d','pltset','pltxy']
    
    #--- basic template
    keywords_template = ['fdef']
    
    #--- legacy functions
    keywords_list_legacy=[
        'add_lon_cyclic',
        'plot_cn_pt',
        'plot_scm_time_series',
        'read_am4_pt_ijavg',
        "read_merra2_pt_ijavg",
        'set_cn_attrs',
        'text1',
                  ]    

    #--- combined all lists
    keywords_class = ['plt','template']
    keywords_list = keywords_class + keywords_plt + keywords_template + keywords_list_legacy

    #------------------
    # print out
    #------------------
    
    for key1 in keywords:
        
        color=0
        
        if (key1 == "sdfkjskdfh") : text = "test"
        
        #--- lists
        elif (key1 == "plt")  : text = keywords_plt; short_notes = f"[{key1}] related notes"          ;color=1 
        elif (key1 == "template")  : text = keywords_plt; short_notes = f"[{key1}] templates"         ;color=1 

        #--- legacy functions
        elif (key1 == "add_lon_cyclic"): text = lib_func_add_lon_cyclic()
        elif (key1 == "plot_cn_pt"): text = lib_func_plot_cn_pt()
        elif (key1 == "plot_scm_time_series"): text = lib_func_plot_scm_time_series()
        elif (key1 == "read_am4_pt_ijavg"): text = lib_func_read_am4_pt_ijavg()
        elif (key1 == "read_merra2_pt_ijavg"): text = lib_func_read_merra2_pt_ijavg()
        elif (key1 == "set_cn_attrs"): text = lib_func_set_cn_attrs()
        else:
            text = ""
            print("ERROR: ["+ key1+"] is not supported yet.")
            print("supported keywords: ["+', '.join(keywords_list)+"]")

        #--- print out
        if (color == 0):  # notes
            print(text.replace("#-", color_comment+"#-").replace("$$", color_black))

        elif (color == 1):  # list
            print(short_notes)
            #print("available list")
            print(text)
        else:
            print(text)
    
            
#lib_func('read_am4_pt_ijavg')
#lib('pltcn_2d')


## get_region_profile

In [None]:
def get_region_profile(var, region, 
                       check_ps_range = True, 
                       model = None, ps = None, plevs = None): 
    
    """
    ----------------------
    Description:
      Get regional-averaged profile. The pressure levels are computed using function *mlevs_to_plevs*.

    Input arguments:
      var: a xarray DataArray
      region: a string, region name used by get_region_latlon
      check_ps_range: Check whether surface pressure exceeds the given range (default is 10hPa)
      model, ps, plevs: input arguments for function *mlevs_to_plevs*

    Return:
      var_region_ijavg & plevs_region_ijavg
      area-averaged variable ans pressure levels

    Example:
      import yhc_module as yhc
      
      var = da.temp
      ps = da.ps
      model = "AM4_L33_native"
      region = "DYCOMS"
    
      var1, plevs1 = yhc.get_region_profile(var, region, model = model, ps = ps, plevs = "pfull")

    Date created: 2022-08-31
    ----------------------
    """

    func_name = "get_region_profile"

    #--- get lat/lon of the region
    lon_slice, lat_slice = get_region_latlon(region)
    
    #--- get regional-average of the var
    var_region = var.sel(lat=lat_slice, lon=lon_slice)
    var_region_ijavg = get_area_avg(var, region)
    
    #--- get pressure levels
    if model is not None:
        ps_region = ps.sel(lat=lat_slice, lon=lon_slice)
        
        #--- check whether surface pressure differ too much in the region
        if (check_ps_range):
            ps_range = 10e+2  # 10 hPa
            
            #with xr.set_options(keep_attrs=False):  # keep attributes after xarray operation
            ps_max = ps_region.max().to_numpy()
            ps_min = ps_region.min().to_numpy()
            ps_diff = abs(ps_max-ps_min)
            
            if (ps_diff > ps_range):
                error_msg = f"""
                ERROR: [{func_name}]: surface pressure exceeds the range 
                [ps_max, ps_min, ps_diff, ps_range] = [{ps_max}, {ps_min}, {ps_diff}, {ps_range}] Pa
                Set check_ps_range=False to skip the check
                """
                raise ValueError(error_msg)
        
        #--- get pressure levels
        plevs_region = mlevs_to_plevs(ps_region, model, plevs)
        #print(plevs_region)
        
        plevs_region_ijavg = get_area_avg(plevs_region, region)
        #print(plevs_region_ijavg)
        
        #printv(var_region,'var','r')
        #printv(plevs_region,'plevs','b')

    return var_region_ijavg, plevs_region_ijavg

#-----------
# do_test
#-----------

#do_test=True
do_test=False

if (do_test):
    file = "../data/data-am4_20010725_8xdaily-temp.nc"
    da = xr.open_dataset(file)
    
    var = da.temp[0,:,:,:]
    ps = da.ps[0,:,:]
    
    model = "AM4_L33_native"
    region = "DYCOMS"
    
    var1, plevs1 = get_region_profile(var, region, model = model, ps = ps, plevs = "pfull")
    #get_region_profile(var, region)

    printv(var1, 'var','r')
    printv(plevs1, 'plev','g')

## read_var_nc

In [None]:
def read_var_nc(filename, varname, 
                do_unit_convert = True,
               ): 
    
    """
    ----------------------
    Description:
      Given filename and varname, use Xarray to read the data and read the variable,
      and then return respective DataSet and DataArray.

    Input arguments:
      filename: (a string) path of a netCDF file, e.g. ../data/test111.nc
      varname: (a string) variable name in the file, e.g. temp
      do_unit_convert: (a logical) whether calling unit_convert or not

    Return:
      xarray DataSet and DataArray

    Example:
      import yhc_module as yhc
      
      filename = ""
      varname =""
      da_scm, var_scm = yhc.read_var_nc (filename, varname)

    Date created: 2022-09-02
    ----------------------
    """

    func_name = "read_var_nc"

    #--- open the netCDF file 
    da = xr.open_dataset(filename)

    #--- read the variable 
    if varname in da.data_vars:
        var_in = da.get(varname) 
    else:
        error_msg = f"ERROR [{func_name}]: variable [{varname}] is not in the file [{filename}]"
        raise ValueError(error_msg)
    
    #--- change units
    if (do_unit_convert):
        var_out = unit_convert(var_in)
    else:
        var_out = var_in.copy()
    
    #--- return
    return da, var_out
    
    
#-----------
# do_test
#-----------

#do_test=True
do_test=False

if (do_test):
    filename = "../data/SCM_am4_xanadu_edmf_mynn.v01_RF01-00cc-am4p0_aerT_clr_am4RAD_sw.1x0m5d_1x1a.atmos_edmf_mynn.nc"
    varname = "qdt_vdif"
    
    da_scm, var_scm = read_var_nc (filename, varname)
    #da_scm, var_scm = read_var_nc (filename, varname, do_unit_convert=False)
    
    printv(da_scm, 'da_scm', 'r')
    printv(da_scm.ucomp, 'u_scm', 'b')
    printv(var_scm, 'var_scm', 'g')
    
    

## Meteorology functions

In [None]:
def es(T_degreeC, Ps_hPa):
    # compute saturation vapor pressure using Bolton 1980.
    #.  Ref: Metpy documentation: https://unidata.github.io/MetPy/v0.3/api/thermo.html
    es = 6.112 * np.exp ( 17.67*T_degreeC / (T_degreeC +243.5) )  # es in hPa
    return es

def qs(T_degreeC, Ps_hPa):
    # compute saturation water mixing ratio (kg/kg)
    # Reference
    #.  https://www.weather.gov/media/epz/wxcalc/mixingRatio.pdf
        
    es0 = es (T_degreeC, Ps_hPa)
    qs = 621.97 * ( es0 / (Ps_hPa-es0) ) / 1000.  # change unit from g/kg to kg/kg
    
    return qs

def des_dT(T_degreeC):
    # compute d(es)/dT
    # unit: hPa/C
    
    des_dT =  6.112 * np.exp ( 17.67*T_degreeC / (T_degreeC +243.5) ) * (17.67*243.5) / (T_degreeC+243.5)**2

    return des_dT

def dqs_dT(T_degreeC, Ps_hPa):
    # compute d(qs)/dT
    # unit: kg/kg/T
    
    es0 = es (T_degreeC, Ps_hPa)
    dqs_dT = 621.97 / 1000. * Ps_hPa*des_dT(T_degreeC) / (Ps_hPa-es0)**2

    return dqs_dT

def rh(T_degreeC, Ps_hPa, q_kg_kg):
    # compute RH (unitless)
    qs0 = qs(T_degreeC, Ps_hPa)
    rh = q_kg_kg / qs0
    
    return rh

## var_stats

In [None]:
def var_stats(*da, dim, 
             opt):
    
    """
    ----------------------
    Description:
      Retrun selected statistic of the input arrays along spcified dimension

    Input arguments:
      da : xarray DataArray
      dim: (string) dimension to operate, e.g. time
      opt: stats option, Available: "avg"

    Return:
      stats values for each da

    Example:
      import yhc_module as yhc
      dim = "time"
      opt = "avg"
      stats = var_stats(da1, da2, dim = dim, opt = opt)

    Date created: 2022-10-03
    ----------------------
    """

    func_name = "var_stats"

    #--- 

    # avaiable stats 
    stats = []
    stats_avail = ["avg","xxx"]
    
    dim = "time"
    
    #--- compute average of the input arrays
    if (opt == "avg"):
        for da1 in da:
            mean1 = da1.mean(dim)
            mean1_format = '{:.2f}'.format(np.array(mean1))    # format output 
            stats.append(mean1_format)

    else:
        error_msg = f"ERROR: func [{func_name}], [opt={opt}] does not support. Available: {stats_avail}"
        raise ValueError(error_msg)
    
    return stats
    
    
#-----------
# do_test
#-----------

#do_test=True
do_test=False

if (do_test):
    gg1 = np.array([1.,2.,3.,4.])
    gg2 = np.array([4.14232,4.232,4.12321])
    
    da1 = xr.DataArray(gg1, dims=['time'])
    da2 = xr.DataArray(gg2, dims=['time'])
    
    #print(da1.mean("time"))
    
    dim = "time2"
    opt = "avg"
    #print(da1.mean(dim))

    stats = var_stats(da1, da2, dim = dim, opt = opt)
    
    print(stats)
    

## read_am4_profile_region

In [None]:
def read_am4_profile_region (filename, varname, time_step,
                             model, region, plevs = "pfull"): 
    
    """
    ----------------------
    Description:
      Read a variable (time, lev, lat, lon) in AM4 output data, 
      and get the profile averaged over a speficied domain.

    Input arguments:
      filename : a string, the full path of AM4 data (in netCDF)
      varname  : a variable with (time, lev, lat, lon) dimensions
      time_step: time step selected
      
      model, region, plevs: defined in yhc. get_region_profile

    Return:
      a Xarray DataArray, the domain-average 1-D variable, and the 1-D pressure levels

    Example:
      import yhc_module as yhc

      filename = datapath_am4+"/"+fname_am4_rf02
      varname = "temp"
      time_step = "2001-07-11 10:30:00"
      region = "DYCOMS"
      model = "AM4_L33_native"
      
      da_am4_1, temp_am4_1, plev_am4_1 = yhc.read_am4_profile_region (model, filename, varname, region, time_step)

    Date created: 2022-11-09
    ----------------------
    """

    func_name = "read_am4_profile_region"

    #--- use Xarray to read variable 
    da1 = xr.open_dataset(filename)
    var = da1.get(varname).sel(time=time_step)
    
    #--- read surface pressure
    ps  = da1.ps.sel(time=time_step)
    
    #--- use yhc.get_region_profile to get the domain-avg profile
    var1, plevs1 = get_region_profile(var, region, model = model, ps = ps, plevs = "pfull")
    
    #--- return Xarray DataArray, 1-D variable, 1-D pressure levels
    return da1, var1[0,:], plevs1[0,:] 
    
    # error_msg = f"ERROR: function [func_name] does not support []. Available:"
    # raise KeyError(error_msg)
    # raise ValueError(error_msg)

    #-----------
# do_test
#-----------

do_test=True
#do_test=False

#if (do_test):

## read_merra2_profile_region

In [None]:
def read_merra2_profile_region (filename, varname, time_step,
                                region): 
    
    """
    ----------------------
    Description:
      Read a variable (time, lev, lat, lon) in MERRA-2 data, 
      and get the profile averaged over a speficied domain.

      Note that lev could be on model hybrid levels (Nv) or on pressure levels (Np).
      Variable "PL" is the mid_level_pressure in *Nv* files.

    Input arguments:
      filename : a string, the full path of AM4 data (in netCDF)
      varname  : a variable with (time, lev, lat, lon) dimensions
      time_step: time step selected
      
      region: defined in yhc.get_region_latlon

    Return:
      a Xarray DataArray, the domain-average 1-D variable

    Example:
      import yhc_module as yhc

      filename = datapath_merra2+"/"+fname_merra2_rf01
      varname = "T"
      time_step = "2001-07-10T10:30:00.000000000"
      region = "DYCOMS"
      
      da_merra2_1, T_merra2_1 = read_merra2_profile_region (filename, varname, region, time_step)

    Date created: 2022-11-09
    ----------------------
    """

    func_name = "read_merra2_profile_region"

    #--- use Xarray to read variable 
    da2 = xr.open_dataset(filename)
    da2 = wrap360(da2)  # flip MERRA-2 longitudes from [-180,180] to [0,360]
    var2 = da2.get(varname).sel(time=time_step)
    
    #--- get domain-avg profile
    var2_ijavg = get_area_avg(var2, region)

    #--- return Xarray DataArray, 1-D variable
    return da2, var2_ijavg 
    
    # error_msg = f"ERROR: function [func_name] does not support []. Available:"
    # raise KeyError(error_msg)
    # raise ValueError(error_msg)

#-----------
# do_test
#-----------

do_test=True
#do_test=False

#if (do_test):

## diagnose_var

In [None]:
def diagnose_var(da, varnames,
                 model = "TaiESM"): 
    
    """
    ----------------------
    Description:
      Given a xarray dataset, diagnose variables and save to the same dataset

    Input arguments:
      da: a xarray dataset
      model: model name. Currently support: ["AM4", "TaiESM", "CESM"]
      varnames: diagnosed variables 

    Return:
      da

    Example:
      import yhc_module as yhc

      ### Example 1 ###
      model = "AM4"
      varnames = ["swabs",'swcre']
      fname = "../data/SCM_am4_xanadu_edmf_mynn.v01_RF01-00cc-am4p0_aerT_clr_am4RAD_sw.1x0m5d_1x1a.atmos_edmf_mynn.nc"
      da = xr.open_dataset(fname)
      da = yhc.diagnose_var(da, model = model, varnames = varnames)

      ### Example 2 ###
      fname = "../data_test/scam_tw606_taiphys.camrun.cam.h0.2006-01-17-10800.nc"
      da = xr.open_dataset(fname)
      varnames=["DQV_SUM_PHYS",'DT_SUM_PHYS',"DQL_SUM_PHYS","DQI_SUM_PHYS"]

      da = yhc.diagnose_var(da, varnames = varnames)    
      print(da.DQL_SUM_PHYS)

    Lastest modified date: 2024-03-20
    ----------------------
    """

    #---
    func_name = "diagnose_var"

    #--- physical constants
    cp_air = 1005 # specific heat of air, units: J/kg/K
    
    #--- dictionary 
    dict_cesm = {
        ###
        #('','varnames') : ,
        #('','long_name'): ,
        #('','units')    : ,
        ########
        ## DQV
        ########
        ('DQV_deep','varnames') : ["ZMDQ", "EVAPQZM"],
        ('DQV_deep','long_name'): 'Q tendency - deep convection (ZMDQ+EVAPQZM)',
        ('DQV_deep','units')    : 'kg/kg/s',
        ###
        ('DQV_SUM_PHYS','varnames') : ["ZMDQ", "EVAPQZM", "CMFDQ", "MACPDQ", "MPDQ", "VD01"],
        ('DQV_SUM_PHYS','long_name'): 'Q tendency - sum of all physics',
        ('DQV_SUM_PHYS','units')    : 'kg/kg/s',
        ###
        ('DQV_DYN_PHYS','varnames') : ["DQVCORE","PTEQ"],
        ('DQV_DYN_PHYS','long_name'): 'Q tendency - sum of dynamics and physics',
        ('DQV_DYN_PHYS','units')    : 'kg/kg/s',
        
        ########
        ## DQL
        ########
        ('DQL_SUM_PHYS','varnames') : ["ZMDLIQ", "CMFDLIQ", "DPDLFLIQ", "SHDLFLIQ", "MACPDLIQ", "MPDLIQ", "VDCLDLIQ"],
        ('DQL_SUM_PHYS','long_name'): 'Cloud liq tendency - sum of all physics',
        ('DQL_SUM_PHYS','units')    : 'kg/kg/s',
        ###
        ('DQL_DYN_PHYS','varnames') : ["DQLCORE","PTECLDLIQ"],
        ('DQL_DYN_PHYS','long_name'): 'Cloud liq tendency - sum of dynamics and physics',
        ('DQL_DYN_PHYS','units')    : 'kg/kg/s',

        ########
        ## DQI
        ########
        ('DQI_SUM_PHYS','varnames') : ["ZMDICE", "CMFDICE", "DPDLFICE", "SHDLFICE", "MACPDICE", "MPDICE", "VDCLDICE"],
        ('DQI_SUM_PHYS','long_name'): 'Cloud ice tendency - sum of all physics',
        ('DQI_SUM_PHYS','units')    : 'kg/kg/s',
        ###
        ('DQI_DYN_PHYS','varnames') : ["DQICORE","PTECLDICE"],
        ('DQI_DYN_PHYS','long_name'): 'Cloud ice tendency - sum of dynamics and physics',
        ('DQI_DYN_PHYS','units')    : 'kg/kg/s',
        
        ########
        ## DT
        ########
        ('DT_SUM_PHYS','varnames') : ["ZMDT", "ZMMTT", "EVAPTZM", "CMFDT", "DPDLFT", "SHDLFT", "MACPDT", "MPDT","QRL","QRS", "DTV", "TTGWORO"],
        ('DT_SUM_PHYS','long_name'): 'T tendency - sum of all physics',
        ('DT_SUM_PHYS','units')    : 'K/s',
        ###
        ('DT_DYN_PHYS','varnames') : ["DTCORE","PTTEND"],
        ('DT_DYN_PHYS','long_name'): 'T tendency - sum of dynamics and physics',
        ('DT_DYN_PHYS','units')    : 'K/s',
        ###
        ('DT_deep','varnames') : ["ZMDT", "ZMMTT", "EVAPTZM"],
        ('DT_deep','long_name'): 'T tendency - deep convection (ZMDT+ZMMTT+EVAPTZM)',
        ('DT_deep','units')    : 'K/s',

    }

    # Iterate over the keys of the dictionary
    variables_with_varnames = []
    
    for key in dict_cesm:
        # Check if the key contains 'varnames'
        if 'varnames' in key:
            # Extract the variable name
            variable_name = key[0]
            # Append the variable name to the list
            variables_with_varnames.append(variable_name)
    
    # Save the list into dict_cesm under a new key
    dict_cesm['supported_vars'] = variables_with_varnames
    #print(dict_cesm['supported_vars'] )
    
    #---
    #@@@@@@@@@@@@@@@@@@@@@@@@@@
    #@@@@@@@@@@@@@@@@@@@@@@@@@@
    if (model == "TaiESM" or model == "CESM"):
        
        for var1 in varnames:

            #--- If var1 is already in the dataset, do nothing
            do_diag = False if var1 in da else True

            if (do_diag):
                #--- check whether var1 is supported. If not, stop the program.
                supported_vars = dict_cesm['supported_vars']
                if var1 not in supported_vars:
                    raise ValueError(f"'ERROR: function [{func_name}] does not support '{var1}'. Supported variables: {supported_vars}")
    
                #--- check whether the dataset contains the necessary variables
                varnames_processed = dict_cesm.get((var1, 'varnames'))
                for var_tmp in varnames_processed:
                    if var_tmp not in da:
                        raise ValueError(f"ERROR: function [{func_name}] Compute '{var1}', but the required variable '{var_tmp}' does not exist in the dataset.")
    
                #--- diagnose variables
                if (var1 == "DT_SUM_PHYS"):
    
                    #--- change the units of MACPDT and MPDT
                    #da['MACPDT_KperS'] = da['MACPDT'] / cp_air
                    #da['MACPDT_KperS'].attrs['units'] = "K/s"
                    #da['MPDT_KperS'] = da['MPDT'] / cp_air
                    #da['MPDT_KperS'].attrs['units'] = "K/s"
                    MACPDT_KperS = unit_convert(da['MACPDT'], units_in="W/kg", units_out="K/s")
                    MPDT_KperS   = unit_convert(da['MPDT'], units_in="W/kg", units_out="K/s")
                    
                    varnames_sum = ["ZMDT", "ZMMTT", "EVAPTZM", "CMFDT", "DPDLFT", "SHDLFT","QRL","QRS", "DTV", "TTGWORO"]
                    var_diag = sum(da[var1] for var1 in varnames_sum)                
                    var_diag = var_diag + MACPDT_KperS + MPDT_KperS
                
                else:  # sum all variables in varnames_processed
                    var_diag = sum(da[var1] for var1 in varnames_processed)
    
                #--- return variables
                var_diag.attrs['long_name']=dict_cesm.get((var1, 'long_name'))
                var_diag.attrs['units']=dict_cesm.get((var1, 'units'))
                da[var1]=var_diag
        
    #@@@@@@@@@@@@@@@@@@@@@@@@@@
    #@@@@@@@@@@@@@@@@@@@@@@@@@@    
    elif (model == "AM4"):
        
        for var1 in varnames:
            if (var1 == "swabs"):
                swabs = da.swdn_toa - da.swup_toa
                swabs.attrs['long_name'] = "SWABS (TOA net dw SW flux)"
                da['swabs'] = swabs

            elif (var1 == "swcre"):
                swcre = da.swup_toa - da.swup_toa_clr
                swcre.attrs['long_name'] = "TOA Cloud SW radiative effect"
                da['swcre'] = swcre
                
            else:
                error_msg = f"ERROR: function [{func_name}] does not support variable [{var1}]."
                raise ValueError(error_msg)
    else:
        error_msg = f"ERROR: function [{func_name}] does not support model [{model}]."
        raise ValueError(error_msg)
        
    #-----------
    # return
    #-----------
    return da;

    # error_msg = f"ERROR: function [func_name] does not support []. Available:"
    # raise KeyError(error_msg)
    # raise ValueError(error_msg)
    
#-----------
# do_test
#-----------

do_test="TaiESMxxx"
#do_test=False

if (do_test == "TaiESM"):
    #fname = "../data_test/scam_tw606_taiphys.camrun.cam.h0.2006-01-17-10800.nc"
    fname = "/lfs/home/yihsuanc/scripts/python/data_test/scam_tw606_taiphys.camrun.cam.h0.2006-01-17-10800.nc"
    da = xr.open_dataset(fname)
    #varnames=["DQV_SUM_PHYS",'DT_SUM_PHYS',"DQL_SUM_PHYS","DQI_SUM_PHYS", "ZMDQ", ""]
    varnames = ["DQV_DYN_PHYS"]
    
    da = diagnose_var(da, varnames = varnames)    
    print(da)    
#print(da.DQL_SUM_PHYS)

elif (do_test == "AM4"):
    
    model = "AM4"
    varnames = ["swabs",'swcre']
    fname = "../data/SCM_am4_xanadu_edmf_mynn.v01_RF01-00cc-am4p0_aerT_clr_am4RAD_sw.1x0m5d_1x1a.atmos_edmf_mynn.nc"
    da = xr.open_dataset(fname)
    da = diagnose_var(da, model = model, varnames = varnames)
    print(da.swcre)
    

## check_var_in_da

In [None]:
def check_var_in_da (da, varname): 
    
    """
    ----------------------
    Description:
      Check whether the given variable exists in the given data array

    Input arguments:
      da: a xarray data array
      varname: a string, variable name

    Return:
      N/A

    Example:
      import yhc_module as yhc
      
      varname = "LWP1"
      check_var_in_da (da1, varname)

    Date created: 2023-01-08
    ----------------------
    """

    func_name = "check_var_in_da"

    #--- get all variable names in the dataset
    allvars = list(da.keys())  

    #--- check whether the da has varname
    if (varname not in allvars):
        error_msg = f"ERROR: function [{func_name}], the input dataset does not have variable [{varname}]"
        raise ValueError(error_msg)

    return
    
#-----------
# do_test
#-----------

#do_test=True
do_test=False

if (do_test):
    varname = "LWP1"
    check_var_in_da (da_rf02_am4RAD_ICStv_FStv, varname)

## read_ds_nfiles_nvars

In [None]:
########################
########################
def check_list(input_variable, error_msg):
    # Check if the length of the input variable is not equal to 1
    if not isinstance(input_variable, list):
        raise ValueError(error_msg)
        
########################
########################
def read_ds_nfiles_nvars(file_paths, variables): 
    
    """
    ----------------------
    Description:
      Given file paths and variable names, 
      return a Xarray DataSet (files, [variable dimensions, e.g. time,lat,lon])
      
    Input arguments:
      file_paths: (a list) file paths
      variables : variable names in the file paths

    Return:
      a Xarray DataSet (files, [variable dimensions, e.g. time,lat,lon])

    Example:
      import yhc_module as yhc

      variables = ['var1','var2','var3']
      file_paths = ['file1.nc', 'file2.nc]
      
      merged_ds = yhc.read_ds_nfiles_nvars (file_paths, variables)

      merged_ds['var1']                     # get var1 in all files (an DataArray)
      merged_ds['var1'].sel[files='file1']  # get var1 in file1 (an DataArray)

    Date created: 2024-02-19
    ----------------------
    """

    func_name = "read_ds_nfiles_nvars"

    #--- 
    check_list (file_paths, error_msg=f"ERROR: function [{func_name}], file_paths must be a list!")
    
    # Initialize an empty list to store datasets
    datasets = []

    # Loop through each file
    for file_path in file_paths:
        # Open each dataset and select the specified variables
        ds = xr.open_dataset(file_path)[variables]
    
        # Add a new dimension 'files' to the dataset
        ds = ds.expand_dims(files=[file_path])
    
        # Append the dataset to the list
        datasets.append(ds)

        ds.close()

    # Concatenate all datasets along the new dimension 'files'
    merged_ds = xr.concat(datasets, dim='files')
    
    # error_msg = f"ERROR: function [func_name] does not support []. Available:"
    # raise KeyError(error_msg)
    # raise ValueError(error_msg)

    return merged_ds
    
#-----------
# do_test
#-----------

#do_test=True
do_test=False

if (do_test):
    variables = ['cldarea_total_daynight_mon', 'solar_mon', 'toa_lw_clr_c_mon']
    #variables = 'solar_mon'
    file_paths = ["../data_test/CERES_EBAF-TOA_Ed4.2-test.nc", "../data_test/CERES_EBAF-TOA_Ed4.2-test.nc"]  # Add all your file names

    merged_ds = read_ds_nfiles_nvars(file_paths, variables)

    merged_ds['cldarea_total_daynight_mon'].sel(files="../data_test/CERES_EBAF-TOA_Ed4.2-test.nc")

## (Not tested) get_nc_vars_files

In [None]:
########################
########################
def check_dim(input_variable, error_msg):
    # Check if the length of the input variable is not equal to 1
    if len(input_variable) != 1:
        raise ValueError(error_msg)

########################
########################
def get_nc_vars_files(files, varnames, opt): 
    
    """
    ----------------------
    Description:


    Input arguments:


    Return:


    Example:
      import yhc_module as yhc
       = yhc.()

    Date created: 2024-02-19
    ----------------------
    """

    func_name = "get_nc_vars_files"

    #-------------------------
    #--- Read a variable from multiple files, and combine into a single xarray, var (files, ...)
    #-------------------------
    if (opt == "1var_nfile"):

        #--- check whether vars has length=1
        check_dim (varnames, error_msg=f"opt=[{opt}], varnames must be a list have length 1")

        #--- read 
        dataarrays = []
        
        for file in files:
            ds = xr.open_dataset(file)
            da_var = ds[varnames]  # Assuming 'T' is the variable name
            dataarrays.append(da_var)
            ds.close()

        # Concatenate along a new dimension 'N'
        var_all_ds = xr.concat(dataarrays, dim='files')
        var_all = var_all_ds.to_array(dim='variables').squeeze()
        var_all.attrs['files']=files
        var_all.attrs['varname']=varnames

    #-------------------------
    #--- Read multiple variables from a, and combine into a single xarray, var (vars, ...)
    #-------------------------
    elif (opt == "nvar_1file"):

        #--- check whether files has length=1
        check_dim (files, error_msg=f"opt=[{opt}], files must be a list and have length 1")

        ds = xr.open_dataset(files[0])

        # List to store individual DataArrays
        dataarrays = []

        for varname in varnames:
            var1 = ds[varname]
            dataarrays.append(var1)

            var_all = xr.concat(dataarrays, dim='vars')
            var_all.attrs['varnames']=varnames
            var_all.attrs['files']=files

    else:
        opt_avail = ['1var_nfile','nvar_1file']
        error_msg = f"ERROR: function [{func_name}] does not support [{opt}]. Available: {opt_avail}"
        raise ValueError(error_msg)

    #--- return variables
    return var_all
    
#-----------
# do_test
#-----------

do_test=False
do_test1=False

if (do_test):
    files = ["../data_test/CERES_EBAF-TOA_Ed4.2-test.nc",
             "../data_test/CERES_EBAF-TOA_Ed4.2-test.nc",
            ]
    varname1 = ['cldarea_total_daynight_mon']
    
    opt = "1var_nfile"
    var_1v_nf = get_nc_vars_files(files, varname1, opt)

    #----------
    opt = "nvar_1file"
    file1 = ["../data_test/CERES_EBAF-TOA_Ed4.2-test.nc"]
    varnames = ['cldarea_total_daynight_mon', 'solar_mon', 'toa_lw_clr_c_mon']
    var_nv_1f = get_nc_vars_files(file1, varnames, opt)

#var_nv_1f
#var_1v_nf


#tt=1 ; jj=123 ; ii=67
#print(var_1v_nf[0, tt, jj, ii].values)
#print(var_1v_nf[1, tt, jj, ii].values)


## print_1d_arrays

In [None]:
def print_1d_arrays(*arrays, 
                   coord = None, do_print=False):
    """
    ----------------------
    Description:
      Print multiple 1-D arrays in a tabular format.

      This code is written by ChatGPT.

    Input arguments:
      *arrays : multiply 1-D arrays, e.g. (data1, data2, ...)
      coord   : a string, coordinate name
      do_print: a logical variable. True: print out the codes to generate *arrays
      All of them MUST have the same coordinate.

    Return:
      print out array values on screen

    Example:
      import yhc_module as yhc
      yhc.print_1d_arrays(data1, data2, ..., coord="lev")

        Index	time	Array1	Array2
        1	11   	10	a
        2	21   	20	b
        3	31   	30	c

    Date created: 2024-04-17
    ----------------------
    """

    func_name = "print_1d_arrays"

    #--- print out
    if (do_print):
        text = """
#--- use this code to get variables
ds_name="ds"
varnames = ['a','b','c']
dim_name = "[0,kk,0,0]"
text_all = ', '.join([f"{ds_name}.{var}{dim_name}" for var in varnames])
print(text_all)
       """
        print(text)
        
        return
    
    #--- Check if all arrays have the same coordinate name
    ### if coord is not given, read the first coordinate
    if (coord is None):
        coord_name = list(arrays[0].coords.keys())[0]

        for arr in arrays[1:]:
            arr_coord_name = list(arr.coords.keys())[0]
            if arr_coord_name != coord_name:
                raise ValueError("All arrays must have the same coordinate name.")

    ### if coord is given
    else:
        coord_name = coord

    #--- Print header with column names
    header = f"Index\t{coord_name}\t" + "\t".join(arr.name if arr.name else f"Data{i+1}" for i, arr in enumerate(arrays))
    print(header)

    #--- Print data with index and coordinate values
    if (coord_name in arrays[0].dims):
        coordinates = arrays[0].coords[coord_name].values
    else:
       raise ValueError(f"ERROR: func [{func_name}], coordinate [{coord_name}] is not in the input arrays. set coord as an input argument")
    
    for i, coordinate_value in enumerate(coordinates):
        coordinate_str = str(coordinate_value) if coordinate_value is not None else ""
        row_data = [f"{i+1}\t{coordinate_str:<5}"] + [str(arr[i].item()) for arr in arrays]
        print("\t".join(row_data))

#-----------
# do_test
#-----------

do_test=False
#do_test=True

if (do_test):
    # Example usage
    time = [11, 21, 31, 41, 51]
    data1 = [10, 20, 30, 40, 50]
    data2 = ['a', 'b', 'c', 'd', 'e']

    array1 = xr.DataArray(data1, dims=('time',), coords={'time': time}, name="Array1")
    array2 = xr.DataArray(data2, dims=('time',), coords={'time': time}, name="Array2")

    print_1d_arrays(array1, array2)
    #print_1d_arrays(array1, array2, coord='vdd', do_print=True)

    print(array1.values)
    print(array2.values)
    

## plot_ds_var_profiles

In [None]:
#####################
#####################
#####################
def set_profile_varnames (opt):
    """
    ----------------------
    Description:
       Return a dictionary variable that can be used in the function, plot_ds_var_profiles

    Input arguments:
       opt (str): option. Currently supported: [DT_budget, DQV_budget, DQL_budget, DQI_budget]

    Return:
       varnames (a dict)

    Example:
      import yhc_module as yhc
      varnames = set_profile_varnames('DT_budget')

    Date created: 2024-03-20
    ----------------------
    """

    func_name = "set_profile_varnames"

    #--- colors for each process
    color_tot = 'black' ; color_dyn = 'darkgray' ; color_PT = 'red' ; color_sum = color_PT
    color_ZM = 'blue'   ; color_CMF = 'cyan'     ; color_DPDLF = 'royalblue' ; color_SHDLF = 'skyblue'
    color_MACP = 'green' ; color_MP = 'limegreen' ; color_QRL = 'orange' ; color_QRS = 'pink' ; 
    color_vdiff = 'purple' ; color_gwd = 'yellow'
    
    #@@@@@@@@@@@@@@@@@
    if (opt == "DT_budget"):
        varnames = {
            'TTEND_TOT': {'color':color_tot, 'label': 'Total=dyn+phys [TTEND_TOT]'},
            'DTCORE': {'color':color_dyn, 'label': 'Dynamics [DTCORE]'},
            'PTTEND': {'color':color_PT, 'label': 'All phys [PTTEND]'},
            'DT_deep': {'color': color_ZM, 'label':'Deep convection [ZMDT+ZMMTT+EVAPTZM]'},
            'CMFDT': {'color': color_CMF, 'label': 'Shallow convection [CMFDT]'},
            'DPDLFT': {'color':color_DPDLF, 'label': 'Deep conv detrain [DPDLFT]'},
            'SHDLFT': {'color':color_SHDLF, 'label':'Shallow conv detrain [SHDLFT]'},
            'MACPDT': {'color':color_MACP, 'label':'Macrophysics [MACPDT]'},
            'MPDT': {'color':color_MP, 'label':'Microphysics [MPCT]'},
            'QRL': {'color':color_QRL, 'label': 'LW radiation [QRL]'},
            'QRS': {'color':color_QRS, 'label':'SW radiation [QRS]'},
            'DTV': {'color':color_vdiff, 'label': 'PBL & turbulence [DTV]'},
            'TTGWORO': {'color':color_gwd, 'label':'grav wave drag [TTGWORO]'},
            'DT_SUM_PHYS': {'color':color_sum, 'label': 'Sum of all phys tend', 'linestyle':'--'},
        }

    #@@@@@@@@@@@@@@@@@
    elif (opt == "DQV_budget"):        
        varnames = {
            'PTEQ': {'color':color_PT, 'label': 'All phys [PTEQ]'},
            'DQV_deep': {'color': color_ZM, 'label':'Deep convection [ZMDQ+EVAPQZM]'},
            'CMFDQ': {'color': color_CMF, 'label': 'Shallow convection [CMFDQ]'},
            'MACPDQ': {'color':color_MACP, 'label':'Macrophysics [MACPDQ]'},
            'MPDQ': {'color':color_MP, 'label':'Microphysics [MPDQ]'},
            'VD01': {'color':color_vdiff, 'label': 'PBL & turbulence [VD01]'},
            'DQV_SUM_PHYS': {'color':color_sum, 'label': 'Sum of all phys tend', 'linestyle':'--'},
        }

    #@@@@@@@@@@@@@@@@@
    elif (opt == "DQL_budget"):           
        varnames = {
            'PTECLDLIQ': {'color':color_PT, 'label': 'All phys [PTECLDLIQ]'},
            'ZMDLIQ': {'color': color_ZM, 'label':'Deep convection [ZMDLIQ]'},
            'CMFDLIQ': {'color': color_CMF, 'label': 'Shallow convection [CMFDLIQ]'},
            'DPDLFLIQ': {'color':color_DPDLF, 'label': 'Deep conv detrain [DPDLFLIQ]'},
            'SHDLFLIQ': {'color':color_SHDLF, 'label':'Shallow conv detrain [SHDLFLIQ]'},
            'MACPDLIQ': {'color':color_MACP, 'label':'Macrophysics [MACPDLIQ]'},
            'MPDLIQ': {'color':color_MP, 'label':'Microphysics [MPDLIQ]'},
            'VDCLDLIQ': {'color':color_vdiff, 'label': 'PBL & turbulence [VDCLDLIQ]'},
            'DQL_SUM_PHYS': {'color':color_sum, 'label': 'Sum of all phys tend', 'linestyle':'--'},
        }

    #@@@@@@@@@@@@@@@@@
    elif (opt == "DQI_budget"):   
        varnames = {
            'PTECLDICE': {'color':color_PT, 'label': 'All phys [PTECLDICE]'},
            'ZMDICE': {'color': color_ZM, 'label':'Deep convection [ZMDICE]'},
            'CMFDICE': {'color': color_CMF, 'label': 'Shallow convection [CMFDICE]'},
            'DPDLFICE': {'color':color_DPDLF, 'label': 'Deep conv detrain [DPDLFICE]'},
            'SHDLFICE': {'color':color_SHDLF, 'label':'Shallow conv detrain [SHDLFICE]'},
            'MACPDICE': {'color':color_MACP, 'label':'Macrophysics [MACPDICE]'},
            'MPDICE': {'color':color_MP, 'label':'Microphysics [MPDICE]'},
            'VDCLDICE': {'color':color_vdiff, 'label': 'PBL & turbulence [VDCLDICE]'},
            'DQI_SUM_PHYS': {'color':color_sum, 'label': 'Sum of all phys tend', 'linestyle':'--'},
        }

    #@@@@@@@@@@@@@@@@@
    else:
        error_msg = f"ERROR: function [{func_name}] does not support [{opt}]."
        raise ValueError(error_msg)


    return varnames

#####################
#####################
#####################
def plot_ds_var_profiles(ds, varnames,
                      tt=5, jj=0, ii=0, 
                      title = "TaiESM1 SCM TWPICE", xlabel = "VAR", yvar = 'lev', ylabel = "Nominal Pressure (hPa)",
                      do_units=True):
    """
    ----------------------

    Plot variables from an xarray dataset with yvar on the Y axis.

    Parameters:
    - ds (xarray.Dataset): The dataset containing the variables.
    - varnames (dict): A dictionary mapping variable names to styles, including colors, linestyles, and labels.
    - tt, jj, ii: indexes for time, lat, lon, respectively
    
    Returns:
    - A plot

    Example:
      import yhc_module as yhc
      
      ds_scm = xr.open_dataset(file_scm)
      varnames = set_profile_varnames('DT_budget')
      tt=8
      plot_ds_var_profiles (ds_scm, varnames, tt=tt, xlabel="Temperature tendency")

      jj=0
      ii=0
      yhc.print_1d_arrays(
          ds_scm.ZMDT[tt,:,jj,ii],
          ds_scm.CMFDT[tt,:,jj,ii],
                         )
    Date created: 2024-03-21
    ----------------------
    """

    # Plot each variable
    for var_name, style in varnames.items():

        #--- get variable
        ds = yhc.diagnose_var(ds, [var_name])  # pass var_name as a list, otherwise it will read the first character
        var_data = ds[var_name].isel(time=tt, lat=jj, lon=ii)
        if (do_units): var_data = yhc.unit_convert(var_data)
        units = var_data.attrs['units']
        
        #--- lev
        lev = ds[yvar]

        # If 'label' is not provided in styles, set label as var_name
        label = style.get('label', var_name)
        marker_default = 'o'
        linestyle_default = '-'
        
        # Plot the variable with style
        plt.plot(var_data, lev, label=label, color=style.get('color', 'blue'), linestyle=style.get('linestyle', linestyle_default), marker=marker_default, markersize=4)

    # Add labels and legend

    #--- title and x & y labels
    plt.xlabel(f"{xlabel} ({units})")
    plt.ylabel(ylabel)
      #lat_formatted = "{:.2f}".format(ds['lat'][jj].values)
      #lon_formatted = "{:.2f}".format(ds['lon'][ii].values)
      #plt.title(f"{title}, time step = {tt}, lat={lat_formatted}, lon={lon_formatted}")
    plt.title(f"{title}, time step = {tt}", y=1.1)
    plt.title(f"{xlabel}", loc='left', y=1.001, fontsize = 10)
    plt.title(f"{units}", loc='right', y=1.001, fontsize = 10)

    #--- legend
    plt.legend(loc='upper left', bbox_to_anchor=(1.05, 1))

    #--- grids and axises
    plt.grid(True, color='gray', linestyle='--')
    plt.gca().invert_yaxis()
    plt.minorticks_on()

    plt.show()

#----------
# do test
#----------

#do_test="test1_scm"
do_test="xxx"

if (do_test == "test1_scm"):

    datapath = "/lfs/home/yihsuanc/test/c1-scam-tendencies/"
    filename_scm = "camrun.cam.h0.2006-01-17-10800.nc"
    file_scm = datapath+"/"+filename_scm
    ds_scm = xr.open_dataset(file_scm)

    varnames = set_profile_varnames('DT_budget')
    tt=8
    plot_ds_var_profiles (ds_scm, varnames, tt=tt, xlabel="Temperature tendency")

    jj=0
    ii=0
    yhc.print_1d_arrays(
        ds_scm.ZMDT[tt,:,jj,ii],
        ds_scm.CMFDT[tt,:,jj,ii],
                       )