In [1]:
import os
from tqdm import tqdm
import xarray as xa
import pandas as pd
import numpy as np
import cartopy.crs as ccrs
import matplotlib.pyplot as plt
import matplotlib.colors as mpcrs
from matplotlib import ticker
from pylibs.custom_colormaps import white_gist_earth
from pylibs.plot_utils import set_size, setupax_2dmap
from pylibs.utils import setup_cmap, get_dates
from joblib import Parallel, delayed

In [2]:
sdate = 2024110100
edate = 2024113018
date_interval = 6
dates = get_dates(sdate, edate, date_interval)
bkg = 'MERRA-2'
plot_ch = 4
plotvar = 'aerosolOpticalDepth'
pltcyc = False
pltstats = True
all_omb_stats = False
fsave = 1
quality=600

In [3]:
# control PDF, boxplot
binwidth = 0.1
halfbin = binwidth/2
binmax = 2.
bins = np.arange(0., binmax + halfbin, binwidth)
bin_lb = [round(l, 2) for l in np.arange(0. + halfbin, binmax, binwidth)]
bin_lb_str = [f'{l:.2f}' for l in bin_lb]
zero_c_bins = np.arange(-binmax+halfbin, binmax + halfbin, binwidth)
zero_c_bin_lb = np.where(abs(np.arange(-binmax+binwidth, binmax, binwidth)) < 1e-10,
                         0, np.arange(-binmax+binwidth, binmax, binwidth))
zero_c_bin_lb = [round(l, 2) for l in zero_c_bin_lb]
zero_c_bin_lb_str = [f'{l:.2f}' for l in zero_c_bin_lb]

# Control hist2d
hist2d_in_log=1
hist2d_xmax=5.0
hist2d_xybins=30
if hist2d_in_log:
    h2d_axis = np.linspace(np.log(0.01), np.log(hist2d_xmax), hist2d_xybins+1)
    h2d_axis = np.exp(h2d_axis)
else:
    h2d_axis = np.linspace(0.01, hist2d_xmax, hist2d_xybins+1)

In [16]:
hofx_path = '/glade/work/swei/Git/JEDI-METplus/output/aodobs_merra2/hofx/f00'
obs_name_list = [
    # 'modis_terra_aod', 'modis_aqua_aod',
    # 'pace_aod',
    # 'viirs_aod_dt_npp', 'viirs_aod_dt_n20',
    # 'viirs_aod_db_npp', 'viirs_aod_db_n20',
]
vardict = {
    'obs': 'Obs',
    'omb': 'OmB',
    'hfx': 'HofX',
}
aeronet_aod_wvl = [340., 380., 440., 500., 675, 870., 1020., 1640.]

In [5]:
savedir = '/glade/work/swei/projects/mmm.pace_aod/plots'
cycs_savedir = f'{savedir}/cycles'
stat_savedir = f'{savedir}/stats'
if not os.path.exists(stat_savedir):
    os.makedirs(stat_savedir)
for varType in vardict.keys():
    for obsname in obs_name_list:
        cycdir = f'{cycs_savedir}/{varType}/{obsname}'
        if not os.path.exists(cycdir):
            os.makedirs(cycdir)

In [11]:
def plt_hist2d(dataframe, x, y, axis, save, savename, **kwargs):
    x_data = dataframe[x]
    y_data = dataframe[y]
    xlbstr = kwargs.get('xlb', x)
    ylbstr = kwargs.get('ylb', y)
    hist2d, x_edge, y_edge = np.histogram2d(x_data,
                                            y_data,
                                            bins=axis)

    cnlvs = np.linspace(0, hist2d.max(), 256)
    clrnorm = mpcrs.BoundaryNorm(cnlvs, len(cnlvs), extend='max')

    fig, ax = plt.subplots()
    set_size(5, 5, b=0.1, l=0.1, r=0.95, t=0.95)
    cn = ax.contourf(axis[:-1], axis[:-1], hist2d.swapaxes(0,1),
                     levels=cnlvs, norm=clrnorm, cmap=white_gist_earth,
                     extend='max')
    plt.plot(
        [0.0, hist2d_xmax],
        [0.0, hist2d_xmax],
        color='gray',
        linewidth=2,
        linestyle='--'
    )
    plt.xlim(0.01, hist2d_xmax)
    plt.ylim(0.01, hist2d_xmax)

    if hist2d_in_log:
        ax.set_xscale('log')
        ax.set_yscale('log')
    ax.set_aspect('equal')

    plt.grid(alpha=0.5)
    plt.xlabel(xlbstr, fontsize=11)
    plt.ylabel(ylbstr, fontsize=11)
    plt.xticks(fontsize=10)
    plt.yticks(fontsize=10)

    correlation_matrix = np.corrcoef(x_data, y_data)
    correlation_xy = correlation_matrix[0, 1]
    r_squared = correlation_xy ** 2
    bias = np.mean(y_data) - np.mean(x_data)
    rbias = bias/np.mean(x_data)
    ssize = len(x_data)

    stats_dict = {
        'Counts': str("%.0f" % ssize),
        'Absolute Bias': str("%.3f" % bias),
        'Relative Bias': str("%.3f" % rbias),
        'R': str("%.3f" % correlation_xy),
        'R\u00b2': str("%.3f" % r_squared),
    }
    x_pos = 0.012
    y_pos = 1.02
    for key in stats_dict.keys():
        stat_str = '%s= %s' %(key, stats_dict[key])
        y_pos = y_pos - 0.05
        ax.annotate(stat_str, (x_pos, y_pos), ha='left', va='center', 
                    fontsize=12, xycoords='axes fraction')

    cb = plt.colorbar(cn, orientation='horizontal', fraction=0.03, aspect=30, 
                      pad=0.12, extend='max', ticks=cnlvs[::50])
    cb.ax.minorticks_off()
    cb.ax.ticklabel_format(axis='x', style='sci', scilimits=(0, 0),
                           useMathText=True)

    if save:
        plt.savefig(savename, dpi=quality)
    plt.close(fig)
    return

In [25]:
# for obsname in obs_name_list:
def process_obsname(obsname):
    print(f'Processing {obsname}')
    for i, cdate in tqdm(enumerate(dates)):
        cdate_str = cdate.strftime('%Y%m%d%H')
        aeronetfile = f'{hofx_path}/{obsname}/hofx.aeronet_aod.{cdate_str}.nc4'
        hofxfile = f'{hofx_path}/{obsname}/hofx.{obsname}.{cdate_str}.nc4'
        dims_ds = xa.open_dataset(hofxfile)
        channel = dims_ds.Channel.values

        meta_ds = xa.open_dataset(hofxfile, group='MetaData')
        lats = meta_ds.latitude.data
        lons = meta_ds.longitude.data
        if 'aeronet' in obsname:
            wvl = round(aeronet_aod_wvl[plot_ch - 1])
        else:
            wvl = round(meta_ds.sensorCentralWavelength.sel(Channel=plot_ch)*1e3)
        if 'aeronet' not in obsname:
            lsfs = meta_ds.surfaceQualifier.data

        obsv_ds = xa.open_dataset(hofxfile, group='ObsValue').assign_coords(Channel=channel.astype(np.int32))
        hofx_ds = xa.open_dataset(hofxfile, group='hofx').assign_coords(Channel=channel.astype(np.int32))

        obsv_data = obsv_ds[plotvar].sel(Channel=plot_ch).data
        hofx_data = hofx_ds[plotvar].sel(Channel=plot_ch).data
        ombs_data = obsv_data - hofx_data

        data_dict = {
            'omb': (['locs'], ombs_data),
            'obs': (['locs'], obsv_data),
            'hfx': (['locs'], hofx_data),
            'lat': (['locs'], lats),
            'lon': (['locs'], lons),
        }
        if 'aeronet' not in obsname:
            data_dict['lsf'] = (['locs'], lsfs)

        coord_dict = {'locs': range(ombs_data.size)}
        tmpds = xa.Dataset(data_dict, coords=coord_dict)
        if pltcyc:
            for var in ['obs', 'omb', 'hfx']:
                titlestr = f'{vardict[var]} of {obsname} at {cdate_str}'
                sc2dplt = f'{cycs_savedir}/{var}/{obsname}/{var}2d.{obsname}.{cdate_str}.png'
                plt_glb_scatter(tmpds, var, titlestr, fsave, sc2dplt)

        if i == 0:
            pltds = tmpds
        else:
            pltds = xa.concat((pltds, tmpds), dim='locs')

    msking_dict = {
        'all': pltds.obs.notnull().data,
    }
    if 'aeronet' not in obsname:
        msking_dict['water'] = (pltds.lsf.data == 0),
        msking_dict['land'] = (pltds.lsf.data != 0),
    
    # stratified based on obs aod
    if pltstats:
        for key, mask in tqdm(msking_dict.items()):
            df = pltds.sel(locs=mask).to_dataframe()
            df['obs_bin'] = pd.cut(df['obs'], bins=bins, labels=bin_lb, right=True)
            
            hist2dplt = f'{stat_savedir}/{key}.hist2d.{obsname}.{wvl}nm.{sdate}_{edate}.png'
            plt_hist2d(df, 'obs', 'hfx', h2d_axis, fsave, hist2dplt,
                       xlb=f'{obsname} {plotvar} {wvl}nm',
                       ylb=f'{plotvar} {wvl}nm based on {bkg}')
    out_dict = {}
    if all_omb_stats:
        # Calculate histogram and return the dataset for merge
        for key, mask in tqdm(msking_dict.items()):
            # column_name = f'{obsname}_{key}'
            columns = pd.MultiIndex.from_tuples([(key, obsname)])
            out_dict[(key, obsname)], bin_edges = np.histogram(
                pltds['omb'].sel(locs=mask).data,
                density=True,
                bins=zero_c_bins,
            )
            
    print(f'{obsname} finished')
    return out_dict

In [26]:
ombhists = Parallel(n_jobs=-1)(delayed(process_obsname)(obs) for obs in obs_name_list)
print('Process Finished')

Processing aeronet_aod


120it [00:02, 46.27it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

Process Finished


100%|██████████| 1/1 [00:01<00:00,  1.05s/it]


aeronet_aod finished
aeronet_aod finished


In [16]:
if all_omb_stats:
    flat_dict = {k: v for d in ombhists for k, v in d.items()}
    df = pd.DataFrame(flat_dict)
    df.index = zero_c_bin_lb

In [53]:
for key in ['all', 'water', 'land']:
    ombhistplt = f'{stat_savedir}/allombhist.{key}.{sdate}_{edate}.png'
    plt_all_omb_hist(key.capitalize(), df[key], fsave, ombhistplt)

In [47]:
stat_savedir

'/glade/work/swei/projects/mmm.pace_aod/plots/stats'

In [49]:
a = 'all'

In [51]:
a.capitalize()

'All'