In [None]:
import numpy as np
import scipy as sp
import dask
import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt
import time

# Data Processing

## Read WWLLN data

In [None]:
wwlln_dataset = xr.open_mfdataset('/home/disk/eos12/wycheng/data/WWLLN/Global/WWLLN_20*.nc', parallel=True)

In [None]:
wwlln_dataset

In [None]:
Fdata = wwlln_dataset.F.resample(Time='1D').sum()

In [None]:
Fdata

In [None]:
Fdata[:,:,:].mean(dim='Time').plot()

## Read GEFS data

In [None]:
gefs_dataset = xr.open_mfdataset('/home/disk/eos12/wycheng/data/GEFS/GEFS*.nc', parallel=True)\
                 .rename({'X': 'lon','Y': 'lat'})
gefs_dataset = gefs_dataset.assign_coords(lon=(((gefs_dataset.lon + 180) % 360) - 180))

In [None]:
gefs_dataset

# Create CP data

In [None]:
def multiply(a, b):
    func = lambda x, y: x * y 
    return xr.apply_ufunc(func, a, b, dask="parallelized")

In [None]:
cp_data = multiply(gefs_dataset.cape, gefs_dataset.pr)

In [None]:
cp_data

# TK18: Fig 1

## Generate cp_data_SMLmean dataset

In [None]:
cp_data_SMLmean = cp_data.mean(dim={'S','M','L'})

In [None]:
cp_data_SMLmean

In [None]:
start = time.time()
cp_data_SMLmean.persist()
end = time.time()
print(end - start)

In [None]:
cp_data_SMLmean

In [None]:
lono = xr.DataArray(np.linspace(-179.5,179.5,360), dims='lon')
lato = xr.DataArray(np.linspace(-89.5,89.5,180), dims='lat')

with dask.config.set(**{'array.slicing.split_large_chunks': True}):
    cp_data_SMLmean = cp_data_SMLmean.interp(lon=lono,lat=lato,method='linear')

In [None]:
start = time.time()
cp_data_SMLmean.to_netcdf('/home/disk/eos12/wycheng/data/metadata/cp_data_SMLmean.nc',mode='w')
end = time.time()
print(end - start)

## Read cp_data_SMLmean dataset

In [None]:
cp_data_SMLmean = xr.open_dataarray('/home/wei/data/metadata/cp_data_SMLmean.nc')

In [None]:
cp_data_SMLmean.persist()
cp_data_SMLmean

In [None]:
start = time.time()
cp_data_SMLmean.plot()
end = time.time()
print(end - start)

## Plotting

In [None]:
import cartopy.crs as ccrs
import regionmask
import geopandas as gpd

import matplotlib.ticker as mticker
from mpl_toolkits.axes_grid1 import make_axes_locatable

In [None]:
def plot_map(figsize,data,cmap,vmin=None,vmax=None,title=None,unit=None):
    
    plt.rcParams.update({'font.size': 48})
    
    xlim    = (-125,-65)
    ylim    = (25,50)
    
    pcm = xr.plot.pcolormesh(data,"lon","lat",
                             figsize=figsize,
                             xlim=xlim,
                             ylim=ylim,
                             cmap=cmap,
                             vmin=vmin,
                             vmax=vmax,
                             add_colorbar=True,
                            )

    plt.title(title)
    plt.xlabel('')
    plt.ylabel('')

In [None]:
def plot_timeseries(figsize,data,cmap,vmin=None,vmax=None,title=None,unit=None):
    
    fig   = plt.figure(figsize=figsize)


### Set CONUS filter

In [None]:
PATH_TO_SHAPEFILE = '/home/wei/data/WorldCountriesBoundaries/99bfd9e7-bb42-4728-87b5-07f8c8ac631c2020328-1-1vef4ev.lu5nk.shp'
countries = gpd.read_file(PATH_TO_SHAPEFILE)
indexes = np.arange(250).tolist()
countries_mask_poly = regionmask.Regions_cls(name = 'COUNTRY', numbers = indexes, names = countries.CNTRY_NAME[indexes], abbrevs = countries.CNTRY_NAME[indexes], outlines = list(countries.geometry.values[i] for i in range(0,countries.shape[0])))

In [None]:
mask = countries_mask_poly.mask(Fdata.isel(Time = 0), lat_name='lat', lon_name='lon')
mask = mask.where( (mask==232) & (mask.lat<49.35) & (mask.lat>24.74)  & (mask.lon>-124.78) & (mask.lon<-66.95) )

## Figure 1a

In [None]:
data1a = Fdata.mean(dim='Time').where( ~np.isnan(mask) ).persist()

In [None]:
figsize = (48,16)
cmap    = plt.get_cmap('jet')
vmin    = 0
vmax    = 200
title   = 'Daily Avg Number of Strokes'
unit    = ''

plot_map(figsize,data1a,cmap,vmin=vmin,vmax=vmax,title=title,unit=unit)
plt.savefig('TK18_Fig1a.png')

## Figure 1b

In [None]:
data1b = cp_data_SMLmean.where( ~np.isnan(mask) ).persist()

In [None]:
figsize = (48,16)
cmap    = plt.get_cmap('jet')
vmin    = 0
vmax    = 0.04
title   = 'Daily Avg CP'
unit    = ''

plot_map(figsize,data=data1b,cmap=cmap,vmin=vmin,vmax=vmax,title=title,unit=unit)
plt.savefig('TK18_Fig1b.png')

## Figure 1c

In [None]:
data1c = Fdata.where( ~np.isnan(mask) ).sum(dim={'lat','lon'}).isel(Time=slice(365,730)).persist()

In [None]:
figsize = (48,16)

xr.plot.line(data1c,
             figsize=figsize,
             linewidth=4,
            )
plt.savefig('TK18_Fig1c.png')

## Figure 1d

In [None]:
data1d = Fdata.where( (Fdata>0) & (~np.isnan(mask)) ).count(dim={'lat','lon'}).persist()

In [None]:
figsize = (48,16)

xr.plot.line(data1d.isel(Time=slice(365,730)),
             figsize=figsize,
             linewidth=4,
            )
plt.savefig('TK18_Fig1d.png')