In [None]:
import numpy as np
import scipy as sp
import dask
import pandas as pd
import xarray as xr
import matplotlib
from cycler import cycler
import matplotlib.pyplot as plt
import time

# Data Processing

## Read WWLLN data

- Read in the WWLLN data 
- Variable(s):
    - F (Lightning flash rate): The number of lightning strokes observed by WWLLN in each grid cell(# of strokes / grid / 3 hr).

In [None]:
wwlln_dataset = xr.open_mfdataset('/home/disk/eos12/wycheng/data/WWLLN/Global/WWLLN_20*.nc', 
                                  chunks={'Time':'auto','lat':'auto','lon':'auto'},
                                  parallel=True,
                                 )

- Select the CONUS area

In [None]:
wwlln_dataset = wwlln_dataset.sel(lon=slice(-125,-65),lat=slice(20,50))

- Change the temporal resolution from 3 hours to 1 day 
    - Method: Sum up all observed strokes in 1 day

In [None]:
Fdata = wwlln_dataset.F.resample(Time='1D').sum()

- Saving the F data as a dask array for later

In [None]:
Fdata.persist()

In [None]:
Fdata.mean(dim='Time').plot()

## Read GEFS data

- Read in the Hindcast dataset from GEFS model - CAPE: convective available potential energy (J/kg); PR: precipitation (mm)
- Rename the coordinates from ('X', 'Y') to ('lon', 'lat')
- The dimensions:
    - S: Start Time (forecast_reference_time): ordered from (0000 6 Jan 2010) to (0000 28 Dec 2016) by 7 (days)
    - M: Ensemble Member (realization): ordered from (0) to (10) by 1.0
    - L: Forecast Lead Time (forecast_period): ordered from (0.5 days) to (34.5 days) by 1.0 (days)
    - lon: The longitude; Notice that the range of this coordinate is from (0) to (360)
    - lat: The latitude; Notice that the order of this coordinate is from (90) to (-90)

In [None]:
gefs_dataset = xr.open_mfdataset('/home/disk/eos12/wycheng/data/GEFS/GEFS*.nc',
                                 chunks={'S':'auto','M':'auto','L':'auto','X':'auto','Y':'auto'},
                                 parallel=True,
                                )\
                 .rename({'X': 'lon','Y': 'lat'})

- Select the CONUS area

In [None]:
gefs_dataset = gefs_dataset.sel(lon=slice(235,295),lat=slice(50,20))

In [None]:
gefs_dataset

- Reassign the longitude coordinate from (0, 360) to (-180, 180)
- Reverse the latitude coordinate from (60, 20) to (20, 60)

In [None]:
with dask.config.set(**{'array.slicing.split_large_chunks': True}):
    gefs_dataset = gefs_dataset.assign_coords(lon=(((gefs_dataset.lon + 180) % 360) - 180)).reindex(lat=gefs_dataset.lat[::-1])

In [None]:
gefs_dataset

- Interpolate the data from integer grid point to half degree grid point to match the F data from WWLLN

In [None]:
lono = xr.DataArray(np.linspace(-124.5,-65.5,60), dims='lon')
lato = xr.DataArray(np.linspace(20.5,49.5,30), dims='lat')

with dask.config.set(**{'array.slicing.split_large_chunks': True}):
    gefs_dataset = gefs_dataset.interp(lon=lono,lat=lato,method='linear')

In [None]:
gefs_dataset

# Create CP data

In [None]:
cp_data = gefs_dataset.cape * gefs_dataset.pr

In [None]:
cp_data

# TK18: Fig 1

## Generate daily avg CP map

In [None]:
cp_data_SMLmean = cp_data.mean(dim={'S','M','L'})

In [None]:
cp_data_SMLmean

In [None]:
start = time.time()
cp_data_SMLmean.persist()
end = time.time()
print(end - start)

In [None]:
start = time.time()
cp_data_SMLmean.to_netcdf('/home/disk/eos12/wycheng/data/metadata/cp_data_SMLmean.nc',mode='w')
end = time.time()
print(end - start)

## Generate CP forecast data

- Turn the Start time and Lead time coordinates ('S', 'L') to new Forecast time coordinate (fct)

In [None]:
new_coor = (cp_data.S + cp_data.L).stack(fct=('S', 'L')).reset_index('fct',drop=True)

In [None]:
new_coor

In [None]:
cp_forecast = cp_data.stack(fct=('S', 'L')).reset_index('fct',drop=True)

In [None]:
cp_forecast

In [None]:
cp_forecast = cp_forecast.assign_coords(fct=new_coor)

In [None]:
cp_forecast

In [None]:
cp_forecast_MXYmean = cp_forecast.mean(dim={'M', 'lat', 'lon'})

In [None]:
cp_forecast_MXYmean

In [None]:
cp_forecast_MXYTmean = cp_forecast_mean.groupby(test_mean.fct).mean(dim='fct')

In [None]:
cp_forecast_MXYTmean

In [None]:
cp_data1c = cp_forecast_MXYTmean.sel(fct=slice("2011-01-01", "2011-12-31"))

In [None]:
cp_data1c.persist()

In [None]:
F_data1c = Fdata.where( ~np.isnan(mask) ).sum(dim={'lat','lon'}).sel(Time=slice("2011-01-01", "2011-12-31")).persist()

In [None]:
F_data1c

In [None]:
da1 = F_data1c.expand_dims('y')

In [None]:
da2 = cp_data1c.expand_dims('y')

## Read cp_data_SMLmean dataset

In [None]:
#cp_data_SMLmean = xr.open_dataarray('/home/wei/data/metadata/cp_data_SMLmean.nc').persist()

## Plotting

In [None]:
import cartopy.crs as ccrs
import regionmask
import geopandas as gpd

import matplotlib.ticker as mticker
from mpl_toolkits.axes_grid1 import make_axes_locatable

In [None]:
def plot_map(figsize,data,cmap,vmin=None,vmax=None,title=None,unit=None):
    
    plt.rcParams.update({'font.size': 48})
    
    xlim    = (-125,-65)
    ylim    = (25,50)
    
    pcm = xr.plot.pcolormesh(data,"lon","lat",
                             figsize=figsize,
                             xlim=xlim,
                             ylim=ylim,
                             cmap=cmap,
                             vmin=vmin,
                             vmax=vmax,
                             add_colorbar=True,
                            )

    plt.title(title)
    plt.xlabel('')
    plt.ylabel('')

### Set CONUS filter

In [None]:
PATH_TO_SHAPEFILE = '/home/disk/eos12/wycheng/data/WorldCountriesBoundaries/99bfd9e7-bb42-4728-87b5-07f8c8ac631c2020328-1-1vef4ev.lu5nk.shp'
countries = gpd.read_file(PATH_TO_SHAPEFILE)
indexes = np.arange(250).tolist()
countries_mask_poly = regionmask.Regions(name = 'COUNTRY', numbers = indexes, names = countries.CNTRY_NAME[indexes], abbrevs = countries.CNTRY_NAME[indexes], outlines = list(countries.geometry.values[i] for i in range(0,countries.shape[0])))

In [None]:
mask = countries_mask_poly.mask(Fdata.isel(Time = 0), lat_name='lat', lon_name='lon')
mask = mask.where( (mask==232) & (mask.lat<49.35) & (mask.lat>24.74)  & (mask.lon>-124.78) & (mask.lon<-66.95) )

In [None]:
mask

## Figure 1a

In [None]:
data1a = Fdata.mean(dim='Time').where( ~np.isnan(mask) ).persist()

In [None]:
figsize = (48,16)
cmap    = plt.get_cmap('jet')
vmin    = 0
vmax    = 200
title   = 'Daily Avg Number of Strokes'
unit    = ''

plot_map(figsize,data1a,cmap,vmin=vmin,vmax=vmax,title=title,unit=unit)
plt.savefig('TK18_Fig1a.png')

## Figure 1b

In [None]:
data1b = cp_data_SMLmean.where( ~np.isnan(mask) ).persist()

In [None]:
figsize = (48,16)
cmap    = plt.get_cmap('jet')
vmin    = 0
vmax    = 0.04
title   = 'Daily Avg CP'
unit    = ''

plot_map(figsize,data=data1b,cmap=cmap,vmin=vmin,vmax=vmax,title=title,unit=unit)
plt.savefig('TK18_Fig1b.png')

## Figure 1c

In [None]:
data1c = Fdata.where( ~np.isnan(mask) ).sum(dim={'lat','lon'}).isel(Time=slice(365,730)).persist()

In [None]:
figsize = (48,16)

xr.plot.line(data1c,
             figsize=figsize,
             linewidth=4,
             color='gray',
            )
plt.savefig('TK18_Fig1c.png')

## Figure 1d

In [None]:
data1d = Fdata.where( (Fdata>0) & (~np.isnan(mask)) ).count(dim={'lat','lon'}).isel(Time=slice(365,730)).persist()

In [None]:
figsize = (48,16)

xr.plot.line(data1d,
             figsize=figsize,
             linewidth=4,
             color='gray',
            )
plt.savefig('TK18_Fig1d.png')

# Testing

In [None]:
new_time_coor = data1c.Time
new_time_coor = new_time_coor.reset_index('Time',drop=True)

In [None]:
new_time_coor

In [None]:
da1 = data1c.expand_dims('y')
da2 = test_1c.

In [None]:
test_1c.reset_index('fct',drop=True)

In [None]:
test = xr.concat([da1,da2],'y')

In [None]:
test

In [None]:
matplotlib.rcParams['axes.prop_cycle'] = matplotlib.cycler(color=["grey", "b", "g"]) 

xr.plot.line(test,
             hue='y',
             figsize=figsize,
             linewidth=4,
             
            )
plt.legend(['WWLLN','GEFS'])

In [None]:
cp_data

In [None]:
new_coor = (cp_data.S + cp_data.L).stack(fct=('S', 'L')).reset_index('fct',drop=True)

In [None]:
new_coor

In [None]:
test = cp_data.stack(fct=('S', 'L')).reset_index('fct',drop=True)

In [None]:
test

In [None]:
test.assign_coords(fct=new_coor)

In [None]:
test_mean = test.mean(dim={'M', 'lat', 'lon'}).assign_coords(fct=new_coor)

In [None]:
test_mean

In [None]:
test_1c = test_mean.groupby(test_mean.fct).mean(dim='fct')

In [None]:
test_1c = test_1c.isel(fct=slice(360,725))

In [None]:
data1c