In [None]:
import sys
print(sys.executable)

In [None]:
import xarray as xr
import dask
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from eofs.xarray import Eof
import time

In [None]:
import sys
import warnings

if not sys.warnoptions:
    warnings.simplefilter("ignore")

# Read data

In [None]:
WWLLN_dataset = xr.open_mfdataset('/home/disk/eos12/wycheng/data/US/WWLLN/WWLLN_*_F_cg_1deg3hr_US.nc',
                                  parallel=True,
                                  chunks={'Time':'auto', 'lat':'auto', 'lon':'auto'}).sel(lat=slice(32,42), lon=slice(-120,-110))
WWLLN_dataset['F'] = (1/((111.19492664455873)**2)) * (365.25*8) * WWLLN_dataset['F']
display(WWLLN_dataset)

In [None]:
TRMM_dataset = xr.open_mfdataset('/home/disk/eos12/wycheng/data/US/TRMM/TRMM_*_pcp_cg_1deg3hr_US.nc',
                                  parallel=True,
                                  chunks={'Time':'auto', 'lat':'auto', 'lon':'auto'}).sel(lat=slice(32,42), lon=slice(-120,-110))
TRMM_dataset['pcp'] = TRMM_dataset['pcp'].where(TRMM_dataset['pcp']>0,0)
display(TRMM_dataset)

In [None]:
ERA5_cape_dataset = xr.open_dataset('/home/disk/eos12/wycheng/data/US/ERA5/dataset/ERA5_cape_dataset.nc').sel(lat=slice(32,42), lon=slice(-120,-110))
display(ERA5_cape_dataset)

In [None]:
dataset_raw = xr.merge([WWLLN_dataset, TRMM_dataset, ERA5_cape_dataset]).sel(Time=slice("2010-01-01", "2019-12-31"))
dataset_raw['CP'] = dataset_raw['cape'] * dataset_raw['pcp']
dataset_raw = dataset_raw.persist()
display(dataset_raw)

In [None]:
#dataset.to_netcdf(path='/home/disk/eos12/wycheng/data/US/dataset/dataset_test.nc', mode='w')
#dataset = xr.open_dataset('/home/disk/eos12/wycheng/data/US/dataset/dataset_test.nc')

### Normalize data

- remove annual cycle (remove monthly average)
- normalize the input/output data

In [None]:
def normalize(x, m, s):
    return (x-m)/s

In [None]:
def unnormalize(x, m, s):
    return (x*s)+m

In [None]:
dataset_norm = xr.apply_ufunc(normalize,
    dataset_raw.groupby('Time.month'),
    dataset_raw.groupby('Time.month').mean('Time'),
    dataset_raw.groupby('Time.month').std('Time'),
    dask='allowed'
                        )

In [None]:
display(dataset_norm)

# Set country borders

In [None]:
import regionmask
import geopandas as gpd

In [None]:
PATH_TO_SHAPEFILE = '/home/disk/eos10/wycheng/LightningMachineLearning/data/WorldCountriesBoundaries/99bfd9e7-bb42-4728-87b5-07f8c8ac631c2020328-1-1vef4ev.lu5nk.shp'
countries = gpd.read_file(PATH_TO_SHAPEFILE)

In [None]:
indexes = np.arange(250).tolist()
countries_mask_poly = regionmask.Regions(name = 'COUNTRY', numbers = indexes, names = countries.CNTRY_NAME[indexes], abbrevs = countries.CNTRY_NAME[indexes], outlines = list(countries.geometry.values[i] for i in range(0,countries.shape[0])))
mask = countries_mask_poly.mask(dataset_raw['F'].isel(Time = 0), lat_name='lat', lon_name='lon')
mask = mask.where( (mask==232) & (mask.lat<49.35) & (mask.lat>24.74)  & (mask.lon>-124.78) & (mask.lon<-66.95) )

In [None]:
dataset_norm = dataset_norm.where( ~np.isnan(mask) )

# ML Setup

In [None]:
feature_name   = ['pcp','cape']
output_name    = ['F']

In [None]:
idx_train = np.arange(0,23376)
idx_test = np.arange(23376,29216)
#idx_train, idx_val, _, _ = train_test_split(idx_train, idx_train, test_size=0.25)

In [None]:
train_dataset_X = dataset_norm[feature_name].isel(Time=idx_train)
test_dataset_X  = dataset_norm[feature_name].isel(Time=idx_test)
train_dataset_y = dataset_norm[output_name].isel(Time=idx_train)
test_dataset_y  = dataset_norm[output_name].isel(Time=idx_test)

In [None]:
X_train = train_dataset_X.to_dataframe().dropna(axis=0)
X_test  = test_dataset_X.to_dataframe().dropna(axis=0)
y_train = train_dataset_y.to_dataframe().dropna(axis=0)
y_test  = test_dataset_y.to_dataframe().dropna(axis=0)

In [None]:
y_predict_truth_norm = y_test[output_name].to_xarray()

In [None]:
y_predict_truth_unnorm = xr.apply_ufunc(unnormalize,
    y_predict_truth_norm.groupby('Time.month'),
    dataset_raw['F'].groupby('Time.month').mean('Time'),
    dataset_raw['F'].groupby('Time.month').std('Time'),
    dask='allowed'
                                   )

## R14

In [None]:
class R14:
    
    def fit(CAPE,pcp,y):
        coef = np.sum(y) / np.sum(CAPE*pcp)
        return coef
    
    def predict(CAPE,pcp,coef):
        y_predict = coef*CAPE*pcp
        return y_predict

In [None]:
r14coef = R14.fit(
    dataset_raw['cape'].where( ~np.isnan(mask) ).isel(Time=idx_train), 
    dataset_raw['pcp'].where( ~np.isnan(mask) ).isel(Time=idx_train), 
    dataset_raw['F']
)

In [None]:
y_predict_r14_unnorm = R14.predict(
    dataset_raw['cape'].where( ~np.isnan(mask) ).isel(Time=idx_test), 
    dataset_raw['pcp'].where( ~np.isnan(mask) ).isel(Time=idx_test), 
    r14coef
)

In [None]:
y_predict_r14_norm = xr.apply_ufunc(normalize,
    y_predict_r14_unnorm.groupby('Time.month'),
    dataset_raw['F'].groupby('Time.month').mean('Time'),
    dataset_raw['F'].groupby('Time.month').std('Time'),
    dask='allowed'
                                   )

## Random Forest Regressor

In [None]:
from sklearn.ensemble import RandomForestRegressor

In [None]:
rfreg = RandomForestRegressor(n_estimators=10, 
                              max_depth=4, 
                              random_state=0)

In [None]:
rfreg.fit(X_train[feature_name], y_train[output_name].values.ravel())

In [None]:
y_predict_rfreg_norm = rfreg.predict(X_test[feature_name])
y_test['y_predict_rfreg_norm'] = y_predict_rfreg_norm
y_predict_rfreg_norm = y_test['y_predict_rfreg_norm'].to_xarray()
y_predict_rfreg_unnorm = xr.apply_ufunc(unnormalize,
    y_predict_rfreg_norm.groupby('Time.month'),
    dataset_raw['F'].groupby('Time.month').mean('Time'),
    dataset_raw['F'].groupby('Time.month').std('Time'),
    dask='allowed'
                                   )

## Neural Network Classifier

In [None]:
from sklearn.neural_network import MLPRegressor

In [None]:
mlpreg = MLPRegressor(random_state=0, max_iter=800)

In [None]:
mlpreg.fit(X_train[feature_name], y_train[output_name].values.ravel())

In [None]:
y_predict_mlpreg_norm = mlpreg.predict(X_test[feature_name])
y_test['y_predict_mlpreg_norm'] = y_predict_mlpreg_norm
y_predict_mlpreg_norm = y_test['y_predict_mlpreg_norm'].to_xarray()
y_predict_mlpreg_unnorm = xr.apply_ufunc(unnormalize,
    y_predict_mlpreg_norm.groupby('Time.month'),
    dataset_raw['F'].groupby('Time.month').mean('Time'),
    dataset_raw['F'].groupby('Time.month').std('Time'),
    dask='allowed'
                                   )

## Evaluating Model Performance

In [None]:
import numpy.ma as ma

In [None]:
#models  = ['r14','rfreg','mlpreg']
#model_names = ['R14','RF','NN']
#colors  = ['b','orange','g']
#markers = ['.','s','p']

In [None]:
def plot_4panels(da1, da2, da3, da4, names, vmax):
    
    plt.rcParams.update({'font.size': 20})
    
    pcorr = np.zeros((3))

    for imodel in np.arange(3):

        exec('pcorr[imodel] = ma.corrcoef('+\
             'ma.masked_invalid(np.array(da1).ravel()),'+\
             'ma.masked_invalid(np.array(da'+str(imodel+2)+')).ravel(),'+\
             ')[0,1]')
        
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2,figsize=(16,16))
    
    for imodel in np.arange(1,5):
        exec( 'base'+str(imodel)+' = countries.plot(ax=ax'+str(imodel)+', color=(1,1,1,0.0), edgecolor=\'black\',zorder=2)' )
        exec( 'pcm'+str(imodel)+' = da'+str(imodel)+'.plot(ax=base'+str(imodel)+',cmap=\'rainbow\', zorder=1, vmin=0, vmax='+str(vmax)+', add_colorbar=False)' )
        if (imodel==1):
            exec( 'ax'+str(imodel)+'.set_title(names['+str(imodel-1)+'])' )
        else:
            exec('ax'+str(imodel)+'.set_title(names['+str(imodel-1)+']+\'(pcorr=\'+str(np.round(pcorr['+str(imodel-2)+'],2))+\')\' )')
        exec( 'ax'+str(imodel)+'.set_xlabel(\'lon\')' )
        exec( 'ax'+str(imodel)+'.set_ylabel(\'lat\')' )
    
    exec( 'pcm1.set_clim([0,vmax])')
    fig.subplots_adjust(right=0.8)
    cbar_ax = fig.add_axes([0.85, 0.15, 0.03, 0.7])
    exec( 'fig.colorbar(pcm1, cax=cbar_ax, extend=\'max\')' )
    fig.text(0.82,0.86,'[# yr$^{-1}$ km$^{-2}$]')
    
    plt.rcParams.update({'font.size': 10})
    
    return

In [None]:
plot_4panels(y_predict_truth_unnorm['F'].mean('Time'), 
             y_predict_r14_unnorm.mean('Time'),
             y_predict_rfreg_unnorm.mean('Time'),
             y_predict_mlpreg_unnorm.mean('Time'),
             ['OBS','R14','RF','NN'],
             vmax=1)

In [None]:
x = y_predict_truth_unnorm['F'].mean(dim={'lat','lon'}).resample(Time='1D').mean().Time

In [None]:
obs_unnorm    = y_predict_truth_unnorm['F'].mean(dim={'lat','lon'}).resample(Time='1D').mean().values
r14_unnorm    = y_predict_r14_unnorm.mean(dim={'lat','lon'}).resample(Time='1D').mean().values
rfreg_unnorm  = y_predict_rfreg_unnorm.mean(dim={'lat','lon'}).resample(Time='1D').mean().values
mlpreg_unnorm = y_predict_mlpreg_unnorm.mean(dim={'lat','lon'}).resample(Time='1D').mean().values

In [None]:
r14_cor  = np.round(np.corrcoef(obs_unnorm, r14_unnorm)[0,1], 2)
r14_rmse = np.round(np.sqrt(np.mean((obs_unnorm - r14_unnorm) ** 2)), 2)

rfreg_cor  = np.round(np.corrcoef(obs_unnorm, rfreg_unnorm)[0,1], 2)
rfreg_rmse = np.round(np.sqrt(np.mean((obs_unnorm - rfreg_unnorm) ** 2)), 2)

mlpreg_cor  = np.round(np.corrcoef(obs_unnorm, mlpreg_unnorm)[0,1], 2)
mlpreg_rmse = np.round(np.sqrt(np.mean((obs_unnorm - mlpreg_unnorm) ** 2)), 2)

In [None]:
fig = plt.figure(figsize=[32,8])

# obs
line1 = plt.plot(x, obs_unnorm, 'black')
plt.setp(line1,linewidth=2.0, marker='o', markersize=6.0)

# r14
line2 = plt.plot(x, r14_unnorm, 'blue')
plt.setp(line2,linewidth=2.0, marker='o', markersize=6.0)

# rf
line3 = plt.plot(x, rfreg_unnorm, 'orangered')
plt.setp(line3,linewidth=2.0, marker='o', markersize=6.0)

# nn
line4 = plt.plot(x, mlpreg_unnorm, 'green')
plt.setp(line4,linewidth=2.0, marker='o', markersize=6.0)

# y-scale
#plt.yscale('log')

# set x-, y-axis
#plt.xticks(np.arange(0,39,5), np.arange(1980,2019,5))
#plt.tick_params(labelsize=7,direction='in',length=3,width=0.4,color='black')
plt.xticks(fontsize=24)
plt.yticks(fontsize=24)

# zero line
plt.axhline(0,color='black',linewidth=1.0)

# x-, y-label
plt.xlabel('Time', fontsize=24)
plt.ylabel('Lightning stroke density', fontsize=24)
plt.title('', y=0.99, fontsize=32)

# legend
plt.legend(
    ['OBS', 
     'R14 (Cor='+str(r14_cor)+', RMSE='+str(r14_rmse)+')', 
     'RF (Cor='+str(rfreg_cor)+', RMSE='+str(rfreg_rmse)+')',
     'NN (Cor='+str(mlpreg_cor)+', RMSE='+str(mlpreg_rmse)+')'], 
    loc='upper left', 
    prop={'size':24}, 
    ncol=2
    )

# set plot area
plt.subplots_adjust(bottom=0.1, top=0.93, left=0.1, right=0.96)

# save
# plt.savefig(main_dir+'fig/result_cnn.png', dpi=300)
# plt.close()


# Examine the performance for dry thunderstorms

In [None]:
pcp_thrs = 0.01
Xdt_test = X_test.where(X_test['pcp']<pcp_thrs).dropna()

In [None]:
ydt_predict_truth_unnorm = dataset_raw[output_name].isel(Time=idx_test).where( (dataset_raw['pcp']<pcp_thrs) & (~np.isnan(mask)))
ydt_predict_truth_norm = xr.apply_ufunc(normalize,
    ydt_predict_truth_unnorm.groupby('Time.month'),
    dataset_raw['F'].groupby('Time.month').mean('Time'),
    dataset_raw['F'].groupby('Time.month').std('Time'),
    dask='allowed'
                                   )

In [None]:
ydt_predict_r14_unnorm = y_predict_r14_unnorm.where(dataset_raw['pcp'].isel(Time=idx_test)<pcp_thrs)
ydt_predict_r14_norm   = y_predict_r14_norm.where(dataset_raw['pcp'].isel(Time=idx_test)<pcp_thrs)

In [None]:
ydt_predict_rfreg_unnorm = y_predict_rfreg_unnorm.where(dataset_raw['pcp'].isel(Time=idx_test)<pcp_thrs)
ydt_predict_rfreg_norm   = y_predict_rfreg_norm.where(dataset_raw['pcp'].isel(Time=idx_test)<pcp_thrs)

In [None]:
ydt_predict_mlpreg_unnorm = y_predict_mlpreg_unnorm.where(dataset_raw['pcp'].isel(Time=idx_test)<pcp_thrs)
ydt_predict_mlpreg_norm   = y_predict_mlpreg_norm.where(dataset_raw['pcp'].isel(Time=idx_test)<pcp_thrs)

In [None]:
def plot_4panels(da1, da2, da3, da4, names, vmax):
    
    plt.rcParams.update({'font.size': 20})
    
    pcorr = np.zeros((3))

    for imodel in np.arange(3):

        exec('pcorr[imodel] = ma.corrcoef('+\
             'ma.masked_invalid(np.array(da1).ravel()),'+\
             'ma.masked_invalid(np.array(da'+str(imodel+2)+')).ravel(),'+\
             ')[0,1]')
        
    plt.rcParams.update({'font.size': 20})
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2,figsize=(16,16))

    for imodel in np.arange(1,5):
        exec( 'base'+str(imodel)+' = countries.plot(ax=ax'+str(imodel)+', color=(1,1,1,0.0), edgecolor=\'black\',zorder=2)' )
        if (imodel==2):
            exec( 'pcm2 = da2.plot(ax=base2,cmap=\'rainbow\', zorder=1, vmin=0, vmax='+str(vmax/200)+', add_colorbar=True, extend=\'max\', cbar_kwargs={\'orientation\': \'horizontal\'})' )
        else:
            exec( 'pcm'+str(imodel)+' = da'+str(imodel)+'.plot(ax=base'+str(imodel)+',cmap=\'rainbow\', zorder=1, vmin=0, vmax='+str(vmax)+', add_colorbar=False)' )
        if (imodel==1):
            exec( 'ax'+str(imodel)+'.set_title(names['+str(imodel-1)+'])' )
        else:
            exec('ax'+str(imodel)+'.set_title(names['+str(imodel-1)+']+\'(pcorr=\'+str(np.round(pcorr['+str(imodel-2)+'],2))+\')\' )')
        exec( 'ax'+str(imodel)+'.set_xlabel(\'lon\')' )
        exec( 'ax'+str(imodel)+'.set_ylabel(\'lat\')' )

    exec( 'pcm1.set_clim([0,vmax])')
    fig.subplots_adjust(right=0.8)
    cbar_ax = fig.add_axes([0.85, 0.15, 0.03, 0.7])
    exec( 'fig.colorbar(pcm1, cax=cbar_ax, extend=\'max\')' )
    fig.text(0.82,0.86,'[# yr$^{-1}$ km$^{-2}$]')
    
    plt.rcParams.update({'font.size': 10})

In [None]:
plot_4panels(ydt_predict_truth_unnorm['F'].mean('Time'), 
             ydt_predict_r14_unnorm.mean('Time'),
             ydt_predict_rfreg_unnorm.mean('Time'),
             ydt_predict_mlpreg_unnorm.mean('Time'),
             ['OBS','R14','RF','NN'],
             vmax=1)

In [None]:
ss = 1
plot_4panels(ydt_predict_truth_unnorm['F'].groupby('Time.season').mean('Time').isel(season=ss), 
             ydt_predict_r14_unnorm.groupby('Time.season').mean('Time').isel(season=ss),
             ydt_predict_rfreg_unnorm.groupby('Time.season').mean('Time').isel(season=ss),
             ydt_predict_mlpreg_unnorm.groupby('Time.season').mean('Time').isel(season=ss),
             names=['OBS','R14','RF','NN'],
             vmax=1)

In [None]:
x = ydt_predict_truth_unnorm['F'].mean(dim={'lat','lon'}).resample(Time='1D').mean().Time

In [None]:
obs_dt_unnorm    = ydt_predict_truth_unnorm['F'].mean(dim={'lat','lon'}).resample(Time='1D').mean().values
r14_dt_unnorm    = ydt_predict_r14_unnorm.mean(dim={'lat','lon'}).resample(Time='1D').mean().values
rfreg_dt_unnorm  = ydt_predict_rfreg_unnorm.mean(dim={'lat','lon'}).resample(Time='1D').mean().values
mlpreg_dt_unnorm = ydt_predict_mlpreg_unnorm.mean(dim={'lat','lon'}).resample(Time='1D').mean().values

In [None]:
r14_dt_cor  = np.round(np.corrcoef(obs_dt_unnorm, r14_dt_unnorm)[0,1], 2)
r14_dt_rmse = np.round(np.sqrt(np.mean((obs_dt_unnorm - r14_dt_unnorm) ** 2)), 2)

rfreg_dt_cor  = np.round(np.corrcoef(obs_dt_unnorm, rfreg_dt_unnorm)[0,1], 2)
rfreg_dt_rmse = np.round(np.sqrt(np.mean((obs_dt_unnorm - rfreg_dt_unnorm) ** 2)), 2)

mlpreg_dt_cor  = np.round(np.corrcoef(obs_dt_unnorm, mlpreg_dt_unnorm)[0,1], 2)
mlpreg_dt_rmse = np.round(np.sqrt(np.mean((obs_dt_unnorm - mlpreg_dt_unnorm) ** 2)), 2)

In [None]:
fig = plt.figure(figsize=[32,8])

# obs
line1 = plt.plot(x, obs_dt_unnorm, 'black')
plt.setp(line1,linewidth=2.0, marker='o', markersize=6.0)

# r14
line2 = plt.plot(x, r14_dt_unnorm*200, 'blue')
plt.setp(line2,linewidth=2.0, marker='o', markersize=6.0)

# rf
line3 = plt.plot(x, rfreg_dt_unnorm, 'orangered')
plt.setp(line3,linewidth=2.0, marker='o', markersize=6.0)

# nn
line4 = plt.plot(x, mlpreg_dt_unnorm, 'green')
plt.setp(line4,linewidth=2.0, marker='o', markersize=6.0)

# y-scale
#plt.yscale('log')

# set x-, y-axis
#plt.xticks(np.arange(0,39,5), np.arange(1980,2019,5))
#plt.tick_params(labelsize=7,direction='in',length=3,width=0.4,color='black')
plt.xticks(fontsize=24)
plt.yticks(fontsize=24)

# zero line
plt.axhline(0,color='black',linewidth=1.0)

# x-, y-label
plt.xlabel('Time', fontsize=24)
plt.ylabel('Lightning stroke density', fontsize=24)
plt.title('', y=0.99, fontsize=32)

# legend
plt.legend(
    ['OBS', 
     '200x R14 (Cor='+str(r14_dt_cor)+', RMSE='+str(r14_dt_rmse)+')', 
     'RF (Cor='+str(rfreg_dt_cor)+', RMSE='+str(rfreg_dt_rmse)+')',
     'NN (Cor='+str(mlpreg_dt_cor)+', RMSE='+str(mlpreg_dt_rmse)+')'], 
    loc='upper left', 
    prop={'size':24}, 
    ncol=2
    )

# set plot area
plt.subplots_adjust(bottom=0.1, top=0.93, left=0.1, right=0.96)

# save
# plt.savefig(main_dir+'fig/result_cnn.png', dpi=300)
# plt.close()


In [None]:
y_train_r14_unnorm = R14.predict(
    dataset_raw['cape'].where( ~np.isnan(mask) ).isel(Time=idx_train), 
    dataset_raw['pcp'].where( ~np.isnan(mask) ).isel(Time=idx_train), 
    r14coef
)

In [None]:
y_train_rfreg_norm = rfreg.predict(X_train[feature_name])
y_train['y_train_rfreg_norm'] = y_train_rfreg_norm
y_train_rfreg_norm = y_train['y_train_rfreg_norm'].to_xarray()
y_train_rfreg_unnorm = xr.apply_ufunc(unnormalize,
    y_train_rfreg_norm.groupby('Time.month'),
    dataset_raw['F'].groupby('Time.month').mean('Time'),
    dataset_raw['F'].groupby('Time.month').std('Time'),
    dask='allowed'
                                   )

In [None]:
y_train_mlpreg_norm = mlpreg.predict(X_train[feature_name])
y_train['y_train_mlpreg_norm'] = y_train_mlpreg_norm
y_train_mlpreg_norm = y_train['y_train_mlpreg_norm'].to_xarray()
y_train_mlpreg_unnorm = xr.apply_ufunc(unnormalize,
    y_train_mlpreg_norm.groupby('Time.month'),
    dataset_raw['F'].groupby('Time.month').mean('Time'),
    dataset_raw['F'].groupby('Time.month').std('Time'),
    dask='allowed'
                                   )

In [None]:
def plot_line_graph(xvarstr,yvarstr,xvar_1,yvar_1,xvar_2,yvar_2,xvar_3,yvar_3,xvar_4,yvar_4,bins,xyminmax,legends):

    plt.rcParams.update({'font.size': 36})
    
    logbins = np.logspace(np.log10(xyminmax[0]),np.log10(xyminmax[1]),bins)
    
    var1_w = plt.hist(xvar_1[:], weights=yvar_1[:], range=[xyminmax[0], xyminmax[1]], bins=logbins);
    var1_h = plt.hist(xvar_1[:], range=[xyminmax[0], xyminmax[1]], bins=logbins);

    var2_w = plt.hist(xvar_2[:], weights=yvar_2[:], range=[xyminmax[0], xyminmax[1]], bins=logbins);
    var2_h = plt.hist(xvar_2[:], range=[xyminmax[0], xyminmax[1]], bins=logbins);
    
    var3_w = plt.hist(xvar_3[:], weights=yvar_3[:], range=[xyminmax[0], xyminmax[1]], bins=logbins);
    var3_h = plt.hist(xvar_3[:], range=[xyminmax[0], xyminmax[1]], bins=logbins);
    
    var4_w = plt.hist(xvar_4[:], weights=yvar_4[:], range=[xyminmax[0], xyminmax[1]], bins=logbins);
    var4_h = plt.hist(xvar_4[:], range=[xyminmax[0], xyminmax[1]], bins=logbins);
    
    plt.clf()

    fig0 = plt.figure(figsize=(10,10))
    ax0  = fig0.add_subplot(111)

    var1_wh = var1_w[0] / var1_h[0]
    var2_wh = var2_w[0] / var2_h[0]
    var3_wh = var3_w[0] / var3_h[0]
    var4_wh = var4_w[0] / var4_h[0]

    plt.plot(var1_w[1][0:bins-1], var1_wh[:], 'k-', linewidth=3)
    plt.plot(var2_w[1][0:bins-1], var2_wh[:], 'b-', linewidth=3)
    plt.plot(var3_w[1][0:bins-1], var3_wh[:], 'r-', linewidth=3)
    plt.plot(var4_w[1][0:bins-1], var4_wh[:], 'g-', linewidth=3)

    plt.legend(legends,
               loc='upper left')
    plt.xlim(xyminmax[0], xyminmax[1])
    plt.ylim(xyminmax[2], xyminmax[3])
    plt.xlabel(xvarstr)
    plt.ylabel(yvarstr)
    plt.xscale('log')
    plt.yscale('log')
    plt.title('')

    plt.rcParams.update({'font.size': 10})
#plt.show()

In [None]:
xvar_1 = dataset_raw['cape'].where( ~np.isnan(mask) ).isel(Time=idx_train).values.ravel()
yvar_1 = dataset_raw['F'].where( ~np.isnan(mask) ).isel(Time=idx_train).values.ravel()

xvar_2 = dataset_raw['cape'].where( ~np.isnan(mask) ).isel(Time=idx_train).values.ravel()
yvar_2 = y_train_r14_unnorm.values.ravel()

xvar_3 = dataset_raw['cape'].where( ~np.isnan(mask) ).isel(Time=idx_train).values.ravel()
yvar_3 = y_train_rfreg_unnorm.values.ravel()

xvar_4 = dataset_raw['cape'].where( ~np.isnan(mask) ).isel(Time=idx_train).values.ravel()
yvar_4 = y_train_mlpreg_unnorm.values.ravel()

In [None]:
bins         = 20
xyminmax     = [0.1,2000,0.001,100]

xvarstr      = 'CAPE$^{1/2}$ [m s$^{-1}$]'
yvarstr      = 'F [km$^{-2}$ yr$^{-1}$]'

legends      = ['OBS','R14','RF','NN']

labelformat    = '%3.1f'

plot_line_graph(xvarstr,yvarstr,
                xvar_1,yvar_1,
                xvar_2,yvar_2,
                xvar_3,yvar_3,
                xvar_4,yvar_4,
                bins,
                xyminmax,
                legends)


In [None]:
xvar_1 = dataset_raw['pcp'].where( ~np.isnan(mask) ).isel(Time=idx_train).values.ravel()
yvar_1 = dataset_raw['F'].where( ~np.isnan(mask) ).isel(Time=idx_train).values.ravel()

xvar_2 = dataset_raw['pcp'].where( ~np.isnan(mask) ).isel(Time=idx_train).values.ravel()
yvar_2 = y_train_r14_unnorm.values.ravel()

xvar_3 = dataset_raw['pcp'].where( ~np.isnan(mask) ).isel(Time=idx_train).values.ravel()
yvar_3 = y_train_rfreg_unnorm.values.ravel()

xvar_4 = dataset_raw['pcp'].where( ~np.isnan(mask) ).isel(Time=idx_train).values.ravel()
yvar_4 = y_train_mlpreg_unnorm.values.ravel()

In [None]:
plt.rcParams.update({'font.size': 36})
bins         = 20
xyminmax     = [0.0001,3,0.001,100]

xvarstr      = 'Precip [mm hr$^{-1}$]'
yvarstr      = 'F [km$^{-2}$ yr$^{-1}$]'

legends      = ['OBS','R14','RF','NN']

labelformat    = '%3.1f'

plot_line_graph(xvarstr,yvarstr,
                xvar_1,yvar_1,
                xvar_2,yvar_2,
                xvar_3,yvar_3,
                xvar_4,yvar_4,
                bins,
                xyminmax,
                legends)


In [None]:
xvar_1 = dataset_raw['cape'].where( ~np.isnan(mask) ).isel(Time=idx_test).values.ravel()
yvar_1 = dataset_raw['F'].where( ~np.isnan(mask) ).isel(Time=idx_test).values.ravel()

xvar_2 = dataset_raw['cape'].where( ~np.isnan(mask) ).isel(Time=idx_test).values.ravel()
yvar_2 = y_predict_r14_unnorm.values.ravel()

xvar_3 = dataset_raw['cape'].where( ~np.isnan(mask) ).isel(Time=idx_test).values.ravel()
yvar_3 = y_predict_rfreg_unnorm.values.ravel()

xvar_4 = dataset_raw['cape'].where( ~np.isnan(mask) ).isel(Time=idx_test).values.ravel()
yvar_4 = y_predict_mlpreg_unnorm.values.ravel()

In [None]:
bins         = 20
xyminmax     = [0.1,2000,0.001,100]

xvarstr      = 'CAPE$^{1/2}$ [m s$^{-1}$]'
yvarstr      = 'F [km$^{-2}$ yr$^{-1}$]'

legends      = ['OBS','R14','RF','NN']

labelformat    = '%3.1f'

plot_line_graph(xvarstr,yvarstr,
                xvar_1,yvar_1,
                xvar_2,yvar_2,
                xvar_3,yvar_3,
                xvar_4,yvar_4,
                bins,
                xyminmax,
                legends)


In [None]:
xvar_1 = dataset_raw['pcp'].where( ~np.isnan(mask) ).isel(Time=idx_test).values.ravel()
yvar_1 = dataset_raw['F'].where( ~np.isnan(mask) ).isel(Time=idx_test).values.ravel()

xvar_2 = dataset_raw['pcp'].where( ~np.isnan(mask) ).isel(Time=idx_test).values.ravel()
yvar_2 = y_predict_r14_unnorm.values.ravel()

xvar_3 = dataset_raw['pcp'].where( ~np.isnan(mask) ).isel(Time=idx_test).values.ravel()
yvar_3 = y_predict_rfreg_unnorm.values.ravel()

xvar_4 = dataset_raw['pcp'].where( ~np.isnan(mask) ).isel(Time=idx_test).values.ravel()
yvar_4 = y_predict_mlpreg_unnorm.values.ravel()

In [None]:
plt.rcParams.update({'font.size': 36})
bins         = 20
xyminmax     = [0.0001,3,0.001,100]

xvarstr      = 'Precip [mm hr$^{-1}$]'
yvarstr      = 'F [km$^{-2}$ yr$^{-1}$]'

legends      = ['OBS','R14','RF','NN']

labelformat    = '%3.1f'

plot_line_graph(xvarstr,yvarstr,
                xvar_1,yvar_1,
                xvar_2,yvar_2,
                xvar_3,yvar_3,
                xvar_4,yvar_4,
                bins,
                xyminmax,
                legends)


In [None]:
obs_unnorm    = y_predict_truth_unnorm['F'].mean(dim={'lat','lon'}).resample(Time='1D').mean().values.ravel()
r14_unnorm    = y_predict_r14_unnorm.mean(dim={'lat','lon'}).resample(Time='1D').mean().values.ravel()
rfreg_unnorm  = y_predict_rfreg_unnorm.mean(dim={'lat','lon'}).resample(Time='1D').mean().values.ravel()
mlpreg_unnorm = y_predict_mlpreg_unnorm.mean(dim={'lat','lon'}).resample(Time='1D').mean().values.ravel()

In [None]:
r14_cor  = np.round(ma.corrcoef(ma.masked_invalid(obs_unnorm), ma.masked_invalid(r14_unnorm))[0,1], 2)
r14_rmse = np.round(np.sqrt(ma.mean((ma.masked_invalid(obs_unnorm) - ma.masked_invalid(r14_unnorm)) ** 2)), 2)

rfreg_cor  = np.round(ma.corrcoef(ma.masked_invalid(obs_unnorm), ma.masked_invalid(rfreg_unnorm))[0,1], 2)
rfreg_rmse = np.round(np.sqrt(ma.mean((ma.masked_invalid(obs_unnorm) - ma.masked_invalid(rfreg_unnorm)) ** 2)), 2)

mlpreg_cor  = np.round(ma.corrcoef(ma.masked_invalid(obs_unnorm), ma.masked_invalid(mlpreg_unnorm))[0,1], 2)
mlpreg_rmse = np.round(np.sqrt(ma.mean((ma.masked_invalid(obs_unnorm) - ma.masked_invalid(mlpreg_unnorm)) ** 2)), 2)

In [None]:
print(r14_cor, r14_rmse)
print(rfreg_cor, rfreg_rmse)
print(mlpreg_cor, mlpreg_rmse)

In [None]:
obs_dt_unnorm    = ydt_predict_truth_unnorm['F'].mean(dim={'lat','lon'}).resample(Time='1D').mean().values.ravel()
r14_dt_unnorm    = ydt_predict_r14_unnorm.mean(dim={'lat','lon'}).resample(Time='1D').mean().values.ravel()
rfreg_dt_unnorm  = ydt_predict_rfreg_unnorm.mean(dim={'lat','lon'}).resample(Time='1D').mean().values.ravel()
mlpreg_dt_unnorm = ydt_predict_mlpreg_unnorm.mean(dim={'lat','lon'}).resample(Time='1D').mean().values.ravel()

In [None]:
r14_dt_cor  = np.round(ma.corrcoef(ma.masked_invalid(obs_dt_unnorm), ma.masked_invalid(r14_dt_unnorm))[0,1], 2)
r14_dt_rmse = np.round(np.sqrt(ma.mean((ma.masked_invalid(obs_dt_unnorm) - ma.masked_invalid(r14_dt_unnorm)) ** 2)), 2)

rfreg_dt_cor  = np.round(ma.corrcoef(ma.masked_invalid(obs_dt_unnorm), ma.masked_invalid(rfreg_dt_unnorm))[0,1], 2)
rfreg_dt_rmse = np.round(np.sqrt(ma.mean((ma.masked_invalid(obs_dt_unnorm) - ma.masked_invalid(rfreg_dt_unnorm)) ** 2)), 2)

mlpreg_dt_cor  = np.round(ma.corrcoef(ma.masked_invalid(obs_dt_unnorm), ma.masked_invalid(mlpreg_dt_unnorm))[0,1], 2)
mlpreg_dt_rmse = np.round(np.sqrt(ma.mean((ma.masked_invalid(obs_dt_unnorm) - ma.masked_invalid(mlpreg_dt_unnorm)) ** 2)), 2)

In [None]:
print(r14_dt_cor, r14_dt_rmse)
print(rfreg_dt_cor, rfreg_dt_rmse)
print(mlpreg_dt_cor, mlpreg_dt_rmse)