In [None]:
import sys
print(sys.executable)

In [None]:
import numpy as np
import scipy as sp
import dask
import pandas as pd
import xarray as xr
import matplotlib
from cycler import cycler
import matplotlib.pyplot as plt
import time
from datetime import datetime, timedelta

# Data Processing

### Set CONUS filter

In [None]:
import regionmask
import geopandas as gpd

In [None]:
PATH_TO_SHAPEFILE = '/home/disk/eos12/wycheng/data/WorldCountriesBoundaries/99bfd9e7-bb42-4728-87b5-07f8c8ac631c2020328-1-1vef4ev.lu5nk.shp'
countries = gpd.read_file(PATH_TO_SHAPEFILE)
indexes = np.arange(250).tolist()
countries_mask_poly = regionmask.Regions(name = 'COUNTRY', numbers = indexes, names = countries.CNTRY_NAME[indexes], abbrevs = countries.CNTRY_NAME[indexes], outlines = list(countries.geometry.values[i] for i in range(0,countries.shape[0])))

## Read WWLLN data

- Variable(s):
    - F (Lightning flash rate): The number of lightning strokes observed by WWLLN in each grid cell(# of strokes / grid / 3 hr).

In [None]:
WWLLN_dataset = xr.open_mfdataset('/home/disk/eos12/wycheng/data/WWLLN/Global/WWLLN_20*.nc', 
                                  chunks={'Time':'auto','lat':'auto','lon':'auto'},
                                  parallel=True,
                                 )

Select the CONUS area

In [None]:
WWLLN_dataset = WWLLN_dataset.sel(lon=slice(-125,-65),lat=slice(20,50))

Change the temporal resolution from 3 hours to 1 day 
- Method: Sum up all observed strokes in 1 day

In [None]:
F_data = WWLLN_dataset.F.resample(Time='1D').sum()

US country code = 232

In [None]:
mask = countries_mask_poly.mask(F_data.isel(Time = 0), lat_name='lat', lon_name='lon')
mask = mask.where( (mask==232) & (mask.lat<49.35) & (mask.lat>24.74)  & (mask.lon>-124.78) & (mask.lon<-66.95) )

Persist F_data array for later use

In [None]:
F_data = F_data.where( ~np.isnan(mask) ).persist()
F_data

In [None]:
F_data.mean(dim='Time').plot()

## Read GEFS data

Read in the Hindcast dataset from GEFS model and Rename the coordinates from ('X', 'Y') to ('lon', 'lat')
- Variables:
    - CAPE: convective available potential energy (J/kg)
    - PR: precipitation (mm)
- The dimensions:
    - S: Start Time (forecast_reference_time): ordered from (0000 6 Jan 2010) to (0000 28 Dec 2016) by 7 (days)
    - M: Ensemble Member (realization): ordered from (0) to (10) by 1.0
    - L: Forecast Lead Time (forecast_period): ordered from (0.5 days) to (34.5 days) by 1.0 (days)
    - lon: The longitude; Notice that the range of this coordinate is from (0) to (360)
    - lat: The latitude; Notice that the order of this coordinate is from (90) to (-90)

In [None]:
GEFS_dataset = xr.open_mfdataset('/home/disk/eos12/wycheng/data/GEFS/GEFS*.nc',
                                 chunks={'S':'auto','M':'auto','L':'auto','X':'auto','Y':'auto'},
                                 parallel=True,
                                )\
                 .rename({'X': 'lon','Y': 'lat'})

Select the CONUS area

In [None]:
GEFS_dataset = GEFS_dataset.sel(lon=slice(225,300),lat=slice(60,20))

In [None]:
GEFS_dataset

- Reassign the longitude coordinate from (0, 360) to (-180, 180)
- Reverse the latitude coordinate from (60, 20) to (20, 60)

In [None]:
with dask.config.set(**{'array.slicing.split_large_chunks': True}):
    GEFS_dataset = GEFS_dataset.assign_coords(lon=(((GEFS_dataset.lon + 180) % 360) - 180)).reindex(lat=GEFS_dataset.lat[::-1])

In [None]:
GEFS_dataset

Interpolate the data from integer grid point to half degree grid point to match the F data from WWLLN

In [None]:
lono = xr.DataArray(np.linspace(-134.5,-60.5,75), dims='lon')
lato = xr.DataArray(np.linspace(20.5,59.5,40), dims='lat')

with dask.config.set(**{'array.slicing.split_large_chunks': True}):
    GEFS_dataset = GEFS_dataset.interp(lon=lono,lat=lato,method='linear')

In [None]:
GEFS_dataset

In [None]:
#GEFS_dataset['cape'].to_netcdf(path='/home/disk/eos12/wycheng/data/US/GEFS/GEFS_cape_dataset.nc', mode='w')
GEFS_dataset['pr'].to_netcdf(path='/home/disk/eos12/wycheng/data/US/GEFS/GEFS_pr_dataset.nc', mode='w')

#GEFS_dataset = xr.open_mfdataset('/home/disk/eos12/wycheng/data/US/GEFS/GEFS_*_dataset.nc')

# Create CP data

In [None]:
CP_data_raw = (GEFS_dataset.cape * GEFS_dataset.pr).where( ~np.isnan(mask) )
CP_data_raw.name = 'CP_raw'
CP_data_raw

In [None]:
CP_data_raw.mean(dim={'S','M','L'}).plot()

# TK18: Fig 1

Turn the Start time and Lead time coordinates ('S', 'L') into forecast time coordinate (FCT) for later use

In [None]:
FCT_coor = xr.DataArray(
                        data=(CP_data_raw.S + CP_data_raw.L).stack(FCT={'S','L'}),
                        dims=['FCT'],
                        coords=dict(
                                    FCT=(['FCT'], (CP_data_raw.S + CP_data_raw.L).stack(FCT={'S','L'})),
                                   ),
                        attrs=None,
                       )
FCT_coor.name='FCT'
FCT_coor

## Generate CP forecast data

### Calculate the conversion coefficient (Ccpf) that convert CP to lightning flash rates. 

In TK18, the conversion coefficient is only a function of lead time and has no temporal and spatial variation. So for each lead time, we sum up all lightning strokes observed over different S (Start time), M (Ensemble members), X (longitude), and Y (latitude), and repeat the same for CP, and the ratio between the two is the conversion coefficient.

In [None]:
F_data_mean = F_data.mean(dim={'Time','lat','lon'}).values
F_data_mean

In [None]:
CP_data_raw_SMXYmean = CP_data_raw.mean(dim={'S','M','lat','lon'}).values
CP_data_raw_SMXYmean

In [None]:
Ccpf = xr.DataArray(
    data=F_data_mean / CP_data_raw_SMXYmean,
    dims=['L'],
    coords=dict(
        L=(['L'], CP_data_raw.L.values),
    ),
    attrs=CP_data_raw.L.attrs
    ,
)

In [None]:
Ccpf.name='Ccpf'
plt.plot(Ccpf.values)
plt.ylabel('Conversion coefficient')
plt.xlabel('Lead time (days)')

In [None]:
CP_data = Ccpf*CP_data_raw
CP_data.name='CP'
CP_data.persist()
CP_data.mean(dim={'S','M','L'}).plot()

### Calculate the CP threshold values that determine whether there is lightning event

Similar to finding the conversion coefficient, for each lead time, we calculate the total lightning events over different S, M, X, and Y, and then try to find the corresponding CP value that will have the same number of events (where CP is greater than this value). 

We first create a binary classification dataset that stores whether a lightning event is observed by WWLLN. When isL=1, more than 1 lightning strokes are observed by WWLLN; When isL=0, no lightning stroke is observed by WWLLN.

In [None]:
isL_data = F_data.where( (F_data<1) | (np.isnan(F_data)) , 1).persist()
isL_data.name = 'isL'
isL_data.persist()

In [None]:
isL_data.mean(dim={'Time'}).plot()

Calculate the daily average of number of lightning events for normalization

In [None]:
isL_data_TXYsum = isL_data.sum(dim={'Time','lat','lon'}).values
isL_data_TXYsum

This means that on average, there are approximately 455367/2557~=178 lightning events per day across CONUS. This will be used as the basis for normalization of binary CP classification.

In [None]:
"""
# Should've used 'DataArray.rank' instead of 'np.sort'
thrs = np.zeros((35))
for iL in range(35):
    print(iL)
    CP_data_sorted = np.sort(CP_data.isel(L=iL), axis=None)
    thrs[iL] = CP_data_sorted[::-1][np.sum(np.isnan(CP_data_sorted))+int(np.floor(isL_data_TXYsum)*(11*365/2557))]
"""

In [None]:
#np.save('/home/disk/eos12/wycheng/data/metadata/thrs.npy', thrs)
thrs = np.load('/home/disk/eos12/wycheng/data/metadata/thrs.npy')

In [None]:
plt.plot(thrs)
plt.ylabel('CP threshold')
plt.xlabel('Lead time (days)')

In [None]:
CP_thrs = xr.DataArray(
    data=thrs,
    dims=['L'],
    coords=dict(
        L=(['L'], CP_data_raw.L.values),
    ),
    attrs=CP_data_raw.L.attrs
    ,
)
CP_thrs.name = 'cp_thrs'
CP_thrs

In [None]:
#isCP_data = CP_data.where( (CP_data>40) | (np.isnan(CP_data)) , 0).where( (CP_data==0) | (np.isnan(CP_data)), 1)
isCP_data = xr.where(CP_data>CP_thrs, 1, 0).where(~np.isnan(CP_data))
isCP_data.name = 'isCP'
isCP_data.persist()

Just to double check, should be pretty close 178

In [None]:
isCP_data_SMmean_XYsum = isCP_data.isel(L=slice(0,7)).mean(dim={'S','M'}).sum(dim={'lat','lon'}).values
isCP_data_SMmean_XYsum

In [None]:
isCP_data.isel(L=slice(0,7)).mean(dim={'S','M','L'}).plot()

## Plotting

In [None]:
import cartopy.crs as ccrs
import matplotlib.ticker as mticker
from mpl_toolkits.axes_grid1 import make_axes_locatable

In [None]:
def plot_map(figsize,data,cmap,vmin=None,vmax=None,title=None,unit=None):
    
    plt.rcParams.update({'font.size': 48})
    
    xlim    = (-125,-65)
    ylim    = (25,50)
    
    pcm = xr.plot.pcolormesh(data,"lon","lat",
                             figsize=figsize,
                             xlim=xlim,
                             ylim=ylim,
                             cmap=cmap,
                             vmin=vmin,
                             vmax=vmax,
                             add_colorbar=True,
                            )

    plt.title(title)
    plt.xlabel('')
    plt.ylabel('')
    
    plt.rcParams.update({'font.size': 10})

## Figure 1

### Fig 1a

In [None]:
data1a = F_data.mean(dim='Time').where( ~np.isnan(mask) ).persist()

In [None]:
figsize = (48,16)
cmap    = plt.get_cmap('jet')
vmin    = 0
vmax    = 200
title   = 'Daily Avg Number of Strokes'
unit    = ''

plot_map(figsize,data1a,cmap,vmin=vmin,vmax=vmax,title=title,unit=unit)
#plt.savefig('TK18_Fig1a.png')

### Fig 1b

In [None]:
data1b = CP_data.isel(L=slice(0,7)).mean(dim={'S','M','L'}).where( ~np.isnan(mask) ).persist()

In [None]:
figsize = (48,16)
cmap    = plt.get_cmap('jet')
vmin    = 0
vmax    = 200
title   = 'Daily Avg CP'
unit    = ''

plot_map(figsize,data=data1b,cmap=cmap,vmin=vmin,vmax=vmax,title=title,unit=unit)
#plt.savefig('TK18_Fig1b.png')

### Fig 1c

In [None]:
 _, index = np.unique(FCT_coor['FCT'], return_index=True)
fig1cd_coor = FCT_coor.isel(FCT=index).sel(FCT=slice("2011-01-01", "2011-12-31"))
fig1cd_coor

In [None]:
F_data_1c = F_data.sel(Time=slice("2011-01-01", "2011-12-31")).sum(dim={'lat','lon'}).persist()
F_data_1c.name = 'F_data_1c'
F_data_1c

In [None]:
CP_data_1c = xr.DataArray(
                          data=CP_data.isel(L=slice(0,7)).stack(FCT=('S', 'L')).sum(dim={'lat','lon'}).mean('M'),
                          dims=['FCT'],
                          coords=dict(
                                      FCT=(['FCT'], (CP_data.isel(L=slice(0,7)).S + CP_data.isel(L=slice(0,7)).L).stack(FCT=('S', 'L'))),
                                     ),
                          attrs=None,
                         ).sel(FCT=slice("2011-01-01", "2011-12-31")).persist()
CP_data_1c.name= 'CP_data_1c'
CP_data_1c

In [None]:
figsize = (48,16)
matplotlib.rcParams['axes.linewidth'] = 4
plt.rcParams.update({'font.size': 48})

fig = plt.figure(figsize=figsize)
ax  = fig.add_subplot(111)

ax.plot(fig1cd_coor, F_data_1c, label='WWLLN', color='gray', linewidth=4)
ax.plot(fig1cd_coor, CP_data_1c, label='GEFS', color='blue', linewidth=4)

ax.set_xlim(14975,15340)
ax.set_ylim(0,200000)

plt.legend()
#plt.savefig('TK18_Fig1c.png')

matplotlib.rcParams['axes.linewidth'] = 1
plt.rcParams.update({'font.size': 10})

### Fig 1d

In [None]:
isL_data_1d = isL_data.sel(Time=slice("2011-01-01", "2011-12-31")).sum(dim={'lat','lon'}).persist()
isL_data_1d

In [None]:
isCP_data_1d = xr.DataArray(
                            data=isCP_data.isel(L=slice(0,7)).stack(FCT=('S', 'L')).sum(dim={'lat','lon'}).mean('M'),
                            dims=['FCT'],
                            coords=dict(
                                        FCT=(['FCT'], (isCP_data.isel(L=slice(0,7)).S + isCP_data.isel(L=slice(0,7)).L).stack(FCT=('S', 'L'))),
                                       ),
                            attrs=None,
                           ).sel(FCT=slice("2011-01-01", "2011-12-31")).persist()
isCP_data_1d.name= 'isCP_data_1d'
isCP_data_1d

In [None]:
figsize = (48,16)
plt.rcParams.update({'font.size': 48})
matplotlib.rcParams['axes.linewidth'] = 4

fig = plt.figure(figsize=figsize)
ax  = fig.add_subplot(111)

ax.plot(fig1cd_coor, isL_data_1d, label='WWLLN', color='gray', linewidth=4)
ax.plot(fig1cd_coor, isCP_data_1d, label='GEFS', color='blue', linewidth=4)

ax.set_xlim(14975,15340)
ax.set_ylim(0,600)

plt.legend()
#plt.savefig('TK18_Fig1d.png')

matplotlib.rcParams['axes.linewidth'] = 1
plt.rcParams.update({'font.size': 10})

## Figure 2

In [None]:
fig2_coor = np.linspace(1,35,35)

### Fig 2a

In [None]:
"""
F_data_FCT = F_data.interp(Time=FCT_coor - np.timedelta64(12,'h'))
F_data_FCT = F_data_FCT.compute()
F_data_FCT
"""

In [None]:
"""
CP_data_FCT = xr.DataArray(
                          data=CP_data.stack(FCT=('S', 'L')).mean('M'),
                          dims=['lat','lon','FCT'],
                          coords=dict(
                                      lat=CP_data['lat'],
                                      lon=CP_data['lon'],
                                      FCT=(['FCT'], FCT_coor),
                                     ),
                          attrs=None,
                         ).compute()
CP_data_FCT
"""

In [None]:
"""
fig2a_TOT = np.zeros((35))
for iL in range(35):
    fig2a_TOT[iL] = xr.corr(
                            F_data_FCT.isel(FCT=slice(iL,12775,35)).stack(grid=('lat','lon')).mean('FCT').compute().rank('grid'), 
                            CP_data_FCT.isel(FCT=slice(iL,12775,35)).stack(grid=('lat','lon')).mean('FCT').compute().rank('grid'),  
                            dim='grid'
                           ).values
"""

In [None]:
"""
F_data_FCT_Manom  = F_data_FCT.groupby('FCT.month') - F_data_FCT.groupby('FCT.month').mean('FCT')
CP_data_FCT_Manom = CP_data_FCT.groupby('FCT.month') - CP_data_FCT.groupby('FCT.month').mean('FCT')
                                                                                                                                                    
fig2a_AMA = np.zeros((35))
for iL in range(35):
    fig2a_AMA[iL] = xr.corr(
                            F_data_FCT_Manom.isel(FCT=slice(iL,12775,35)).stack(grid=('lat','lon')).compute().rank('grid'), 
                            CP_data_FCT_Manom.isel(FCT=slice(iL,12775,35)).stack(grid=('lat','lon')).compute().rank('grid'),  
                            dim='grid'
                           ).mean('FCT').values
"""

In [None]:
"""
F_data_FCT_Danom  = F_data_FCT.groupby('FCT.day') - F_data_FCT.groupby('FCT.day').mean('FCT')
CP_data_FCT_Danom = CP_data_FCT.groupby('FCT.day') - CP_data_FCT.groupby('FCT.day').mean('FCT')
                                                                                                                                                    
fig2a_ADA = np.zeros((35))
for iL in range(35):
    fig2a_ADA[iL] = xr.corr(
                            F_data_FCT_Danom.isel(FCT=slice(iL,12775,35)).stack(grid=('lat','lon')).compute().rank('grid'), 
                            CP_data_FCT_Danom.isel(FCT=slice(iL,12775,35)).stack(grid=('lat','lon')).compute().rank('grid'),  
                            dim='grid'
                           ).mean('FCT').values
"""

In [None]:
#np.save('/home/disk/eos12/wycheng/data/metadata/fig2a_TOT.npy', fig2a_TOT)
#np.save('/home/disk/eos12/wycheng/data/metadata/fig2a_AMA.npy', fig2a_AMA)
#np.save('/home/disk/eos12/wycheng/data/metadata/fig2a_ADA.npy', fig2a_ADA)

fig2a_TOT = np.load('/home/disk/eos12/wycheng/data/metadata/fig2a_TOT.npy')
fig2a_AMA = np.load('/home/disk/eos12/wycheng/data/metadata/fig2a_AMA.npy')
fig2a_ADA = np.load('/home/disk/eos12/wycheng/data/metadata/fig2a_ADA.npy')

In [None]:
figsize = (24,16)
plt.rcParams.update({'font.size': 48})
matplotlib.rcParams['axes.linewidth'] = 4

fig = plt.figure(figsize=figsize)
ax  = fig.add_subplot(111)

ax.plot(fig2_coor, fig2a_TOT, label='Total', color='blue', linewidth=4)
ax.plot(fig2_coor, fig2a_AMA, label='Anomalies from Monthly Average', color='red', linewidth=4)
ax.plot(fig2_coor, fig2a_ADA, label='Anomalies from Daily Average', color='orange', linewidth=4)

ax.set_xlim(0,35)
ax.set_ylim(-0.2,1)
ax.set_xlabel('Forecast time')
ax.set_ylabel('R')
ax.set_title('Forecast rank correlation (Number of strokes)')

plt.legend(loc='lower left',fontsize=36)
#plt.savefig('TK18_Fig1d.png')

matplotlib.rcParams['axes.linewidth'] = 1
plt.rcParams.update({'font.size': 10})

### Fig 2b

In [None]:
#"""
isL_data_FCT = isL_data.interp(Time=FCT_coor - np.timedelta64(12,'h'))
isL_data_FCT = isL_data_FCT.compute()
isL_data_FCT
#"""

In [None]:
#"""
isCP_data_FCT = xr.DataArray(
                          data=isCP_data.stack(FCT=('S', 'L')).mean('M'),
                          dims=['lat','lon','FCT'],
                          coords=dict(
                                      lat=isCP_data['lat'],
                                      lon=isCP_data['lon'],
                                      FCT=(['FCT'], FCT_coor),
                                     ),
                          attrs=None,
                         ).compute()
isCP_data_FCT
#"""

In [None]:
#"""
fig2b_TOT = np.zeros((35))
for iL in range(35):
    fig2b_TOT[iL] = xr.corr(
                            isL_data_FCT.isel(FCT=slice(iL,12775,35)).stack(grid=('lat','lon')).compute().rank('grid'), 
                            isCP_data_FCT.isel(FCT=slice(iL,12775,35)).stack(grid=('lat','lon')).compute().rank('grid'),  
                            dim='grid'
                           ).mean('FCT').values
#"""

In [None]:
"""
isL_data_FCT_Manom  = isL_data_FCT.groupby('FCT.month') - isL_data_FCT.groupby('FCT.month').mean('FCT')
isCP_data_FCT_Manom = isCP_data_FCT.groupby('FCT.month') - isCP_data_FCT.groupby('FCT.month').mean('FCT')
                                                                                                                                                    
fig2b_AMA = np.zeros((35))
for iL in range(35):
    fig2b_AMA[iL] = xr.corr(
                            isL_data_FCT_Manom.isel(FCT=slice(iL,12775,35)).stack(grid=('lat','lon')).compute().rank('grid'), 
                            isCP_data_FCT_Manom.isel(FCT=slice(iL,12775,35)).stack(grid=('lat','lon')).compute().rank('grid'),  
                            dim='grid'
                           ).mean('FCT').values
"""

In [None]:
"""
isL_data_FCT_Danom  = isL_data_FCT.groupby('FCT.day') - isL_data_FCT.groupby('FCT.day').mean('FCT')
isCP_data_FCT_Danom = isCP_data_FCT.groupby('FCT.day') - isCP_data_FCT.groupby('FCT.day').mean('FCT')
                                                                                                                                                    
fig2b_ADA = np.zeros((35))
for iL in range(35):
    fig2b_ADA[iL] = xr.corr(
                            isL_data_FCT_Danom.isel(FCT=slice(iL,12775,35)).stack(grid=('lat','lon')).compute().rank('grid'), 
                            isCP_data_FCT_Danom.isel(FCT=slice(iL,12775,35)).stack(grid=('lat','lon')).compute().rank('grid'),  
                            dim='grid'
                           ).mean('FCT').values
"""

In [None]:
#np.save('/home/disk/eos12/wycheng/data/metadata/fig2b_TOT.npy', fig2b_TOT)
#np.save('/home/disk/eos12/wycheng/data/metadata/fig2b_AMA.npy', fig2b_AMA)
#np.save('/home/disk/eos12/wycheng/data/metadata/fig2b_ADA.npy', fig2b_ADA)

#fig2b_TOT = np.load('/home/disk/eos12/wycheng/data/metadata/fig2b_TOT.npy')
#fig2b_AMA = np.load('/home/disk/eos12/wycheng/data/metadata/fig2b_AMA.npy')
#fig2b_ADA = np.load('/home/disk/eos12/wycheng/data/metadata/fig2b_ADA.npy')

In [None]:
figsize = (24,16)
plt.rcParams.update({'font.size': 48})
matplotlib.rcParams['axes.linewidth'] = 4

fig = plt.figure(figsize=figsize)
ax  = fig.add_subplot(111)

ax.plot(fig2_coor, fig2b_TOT, label='Total', color='blue', linewidth=4)
ax.plot(fig2_coor, fig2b_AMA, label='Anomalies from Monthly Average', color='red', linewidth=4)
ax.plot(fig2_coor, fig2b_ADA, label='Anomalies from Daily Average', color='orange', linewidth=4)

ax.set_xlim(0,35)
ax.set_ylim(-0.2,1)
ax.set_xlabel('Forecast time')
ax.set_ylabel('R')
ax.set_title('Forecast rank correlation (Binary-class)')

plt.legend(loc='lower left',fontsize=36)
#plt.savefig('TK18_Fig1d.png')

matplotlib.rcParams['axes.linewidth'] = 1
plt.rcParams.update({'font.size': 10})

# Random Forest Regressor

In [None]:
from sklearn.ensemble import RandomForestClassifier
from sklearn.ensemble import RandomForestRegressor
from sklearn.multioutput import MultiOutputRegressor
from sklearn.model_selection import validation_curve
from sklearn.model_selection import learning_curve
from sklearn.model_selection import train_test_split
from imblearn.under_sampling import RandomUnderSampler

from sklearn.metrics import r2_score
from sklearn.metrics import accuracy_score, precision_score, average_precision_score, recall_score, f1_score, roc_auc_score, confusion_matrix
from sklearn.metrics import plot_roc_curve, plot_precision_recall_curve
from sklearn.metrics import roc_curve, precision_recall_curve
from sklearn.metrics import auc

In [None]:
GEFS_dataset = GEFS_dataset.persist()
display(GEFS_dataset)

In [None]:
feature_name  = ['cape','pr']
output_name   = ['isL']
undersample   = RandomUnderSampler(sampling_strategy=0.33)
rfclf = RandomForestClassifier(n_estimators=4,
                               max_depth=4,
                               random_state=0,
                               n_jobs=4,
                               verbose=0)

In [None]:
AUROCC_RFC = np.zeros((35))
AUPRC_RFC  = np.zeros((35))

for iL in range(35):
    
    print(iL)
    
    F_data_interp   = F_data.interp(Time=(GEFS_dataset['S'] + GEFS_dataset.isel(L=iL)['L']) - np.timedelta64(12,'h'))
    isL_data_interp = isL_data.interp(Time=(GEFS_dataset['S'] + GEFS_dataset.isel(L=iL)['L']) - np.timedelta64(12,'h'))
    
    ML_dataset = xr.Dataset(
                            data_vars=dict(
                                           F=(['FCT', 'lat', 'lon'], F_data_interp),
                                           isL=(['FCT', 'lat', 'lon'], isL_data_interp),
                                           cape=(['FCT', 'M', 'lat', 'lon'], GEFS_dataset['cape'].isel(L=iL)),
                                           pr=(['FCT', 'M', 'lat', 'lon'], GEFS_dataset['pr'].isel(L=iL)),
                                          ),
                            coords=dict(
                                        FCT=(['FCT'], (GEFS_dataset['S'] + GEFS_dataset.isel(L=iL)['L'])),
                                        M=(['M'], GEFS_dataset['M']),
                                        lat=(['lat'], GEFS_dataset['lat']),
                                        lon=(['lon'], GEFS_dataset['lon']),
                                       ),
                            attrs=None,
                           )
    
    (ML_dataset,) = xr.broadcast(ML_dataset)
    
    dataframe = ML_dataset.to_dataframe().dropna(axis=0)

    X = dataframe[feature_name]
    y = dataframe[output_name]
    
    X_train_raw, X_test, y_train_raw, y_test = train_test_split(X, y, test_size=0.33, random_state=0)
    X_train, y_train = undersample.fit_resample(X_train_raw, y_train_raw)

    y_predict_truth = y_test[output_name].values.ravel()
    
    rfclf.fit(X_train[feature_name], y_train[output_name].values.ravel())
    
    y_predict_rfclf = rfclf.predict(X_test[feature_name])
    
    y_score = rfclf.predict_proba(X_test[feature_name])[:,1]
    precision, recall, thresholds = precision_recall_curve(y_predict_truth, y_score)
    
    AUROCC_RFC[iL] = roc_auc_score(y_predict_truth, y_score)
    AUPRC_RFC[iL] = auc(recall, precision)

In [None]:
AUROCC_R14 = np.zeros((35))
AUPRC_R14  = np.zeros((35))

for iL in range(35):
    
    print(iL)
    
    F_data_interp   = F_data.interp(Time=(GEFS_dataset['S'] + GEFS_dataset.isel(L=iL)['L']) - np.timedelta64(12,'h'))
    isL_data_interp = isL_data.interp(Time=(GEFS_dataset['S'] + GEFS_dataset.isel(L=iL)['L']) - np.timedelta64(12,'h'))
    
    R14_dataset = xr.Dataset(
                             data_vars=dict(
                                            F=(['FCT', 'lat', 'lon'], F_data_interp),
                                            isL=(['FCT', 'lat', 'lon'], isL_data_interp),
                                            cape=(['FCT', 'M', 'lat', 'lon'], GEFS_dataset['cape'].isel(L=iL)),
                                            pr=(['FCT', 'M', 'lat', 'lon'], GEFS_dataset['pr'].isel(L=iL)),
                                            CP=(['FCT', 'M', 'lat', 'lon'], CP_data.isel(L=iL)),
                                            isCP=(['FCT', 'M', 'lat', 'lon'], isCP_data.isel(L=iL)),
                                           ),
                             coords=dict(
                                         FCT=(['FCT'], (GEFS_dataset['S'] + GEFS_dataset.isel(L=iL)['L'])),
                                         M=(['M'], GEFS_dataset['M']),
                                         lat=(['lat'], GEFS_dataset['lat']),
                                         lon=(['lon'], GEFS_dataset['lon']),
                                        ),
                             attrs=None,
                            )
    
    (R14_dataset,) = xr.broadcast(R14_dataset)
    
    dataframe = R14_dataset.to_dataframe().dropna(axis=0)

    X = dataframe[:]
    y = dataframe[output_name]
    
    X_train_raw, X_test, y_train_raw, y_test = train_test_split(X, y, test_size=0.33, random_state=0)

    y_predict_truth = y_test[output_name].values.ravel()
    y_predict_r14   = X_test['isCP']
    
    y_score = X_test['CP']
    precision, recall, thresholds = precision_recall_curve(y_predict_truth, y_score)
    
    AUROCC_R14[iL] = roc_auc_score(y_predict_truth, y_score)
    AUPRC_R14[iL] = auc(recall, precision)

In [None]:
#markers = ['.','v','s','p','*','x','d']
#colors  = 
t = np.arange(1,36,1)
fig, ax = plt.subplots()
sc1 = ax.scatter(AUPRC_RFC, AUROCC_RFC, marker='o', s=20, c=t, cmap='jet')
sc2 = ax.scatter(AUPRC_R14, AUROCC_R14, marker='x', s=20, c=t, cmap='jet')

plt.colorbar(sc1)

ax.set_title('Model skill')
ax.set_xlabel('Area under PR curve')
ax.set_ylabel('Area under ROC curve')
#ax.set_xlim([0.40,0.55])
#ax.set_ylim([0.75,0.95])
#ax.legend(loc='best')

#divider = make_axes_locatable(ax)
#cax = divider.append_axes('right', size='5%', pad=0.05)
#fig.colorbar(sc, cax=cax, orientation='vertical')