In [1]:
import sys,os,importlib,gc, re, string
import xarray as xr
import numpy as np
import pandas as pd

%matplotlib inline
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages

os.chdir('/home/peter/Projects/tc_emulator/results')

sys.path.append('../scripts')
import _weather_pattern_class; importlib.reload(_weather_pattern_class)

atl = _weather_pattern_class.weather_patterns(source='ERA5', working_directory='./')
atl.load_input('ERA5_VWS-MSLP_noTC3')
years = np.array(range(1982,2021))
atl.set_split(years=years)
nrows,ncols = 5,4
tag = 'SOM_pcaInit%sx%s_v1' % (nrows,ncols)
atl.define_plot_environment(pre_mapping='mapping_raw', clustering=tag, post_mapping='mapping_sammon_1982-2020', nrows=nrows, ncols=ncols)
atl.stats_TC(file='tracks/tracks_ibtracks.csv', overwrite=False)

{'SOM': <minisom.MiniSom object at 0x7fb35768dfd0>}
.//ERA5_VWS-MSLP_noTC3/mapping_raw_1982-2020/SOM_pcaInit5x4_v1/mapping_sammon_1982-2020/grid_5x4


In [2]:
years = np.arange(1982,2021,1)
# get sst
sst_MDR = xr.load_dataset('/home/peter/Projects/data/SST/OISST_sst_MDR_1981-2019_daily.nc')['sst']
sst_MDR = sst_MDR[np.isin(sst_MDR.time.dt.year,years)]
sst_MDR = sst_MDR[np.isin(sst_MDR.time.dt.month,atl._months['mon'])]
sst_MDR = sst_MDR.assign_coords(time=np.array([str(d)[:10] for d in sst_MDR.time.values], np.datetime64))

# prepare tracks:
# here ssts are added. this will be needed in the wind component
atl._tracks = atl._tracks.loc[np.isin(atl._tracks.year,years)]
times = np.array([str(d)[:10] for d in atl._tracks.time.values], np.datetime64)
atl._tracks['time'] = np.array([str(d)[:10] for d in atl._tracks.time],np.datetime64)
atl._tracks['sst'] = sst_MDR.loc[times].values
atl._tracks['weather_0'] = atl._tracks['label_lag0']
tracks = atl._tracks.loc[np.isfinite(atl._tracks.weather_0)]
tracks = tracks.loc[tracks.distance > 0, ['weather_0','sst','wind','genesis','storm','ACE','year','storm_day','wind_before','month']]

# prepare gensis input
# this is a dataframe with an entry for each day
# this is required to get genesis probabilities
weather_sst = pd.DataFrame()
weather_sst['time'] =  np.array([str(d)[:10] for d in  atl._vector_time.values], np.datetime64)
weather_sst['year'] = atl._vector_time.dt.year
weather_sst['weather_0'] = atl._clust_labels
weather_sst['weather_1'] = np.roll(atl._clust_labels,1)
weather_sst['weather_2'] = np.roll(atl._clust_labels,2)
weather_sst['weather_3'] = np.roll(atl._clust_labels,3)
weather_sst = weather_sst.loc[np.isin(atl._vector_time.dt.year,years)]

genesis = weather_sst.copy()
genesis['genesis'] = [atl._tracks.loc[atl._tracks.time==np.datetime64(tt),'genesis'].sum() for tt in genesis.time]
genesis['sst'] = sst_MDR.sel(time=weather_sst.time.values)

weather_sst['sst'] = sst_MDR.sel(time=weather_sst.time.values)

genesis['day_in_season'] = 0
weather_sst['day_in_season'] = 0
for year in np.unique(weather_sst.time.dt.year):
    tttmmmppp = weather_sst.loc[(weather_sst.time.dt.year==year),'day_in_season']
    weather_sst.loc[(weather_sst.time.dt.year==year),'day_in_season'] = np.arange(tttmmmppp.shape[0])
    genesis.loc[(genesis.time.dt.year==year),'day_in_season'] = np.arange(tttmmmppp.shape[0])

weather_sst = weather_sst.loc[(weather_sst.day_in_season>=3) & np.isin(weather_sst.year,years)]
genesis = genesis.loc[(genesis.day_in_season>=3) & np.isin(genesis.year,years)]

In [28]:
np.unique(tracks.storm).shape
# tracks.genesis.sum()

(454,)

In [3]:
# train test split by decades
train_test = pd.DataFrame()
train_test['year'] = list(range(1982,2021))
train_test['1982-1990'] = 'train'
train_test.loc[np.isin(train_test.year,np.arange(1982,1990+1)), '1982-1990'] = 'test'
train_test['1991-2000'] = 'train'
train_test.loc[np.isin(train_test.year,np.arange(1991,2000+1)), '1991-2000'] = 'test'
train_test['2001-2010'] = 'train'
train_test.loc[np.isin(train_test.year,np.arange(2001,2010+1)), '2001-2010'] = 'test'
train_test['2011-2020'] = 'train'
train_test.loc[np.isin(train_test.year,np.arange(2011,2020+1)), '2011-2020'] = 'test'

In [4]:
comp_names = {
    'gWeaLag2Weight' : '',
    'gWeaLag2' : 'equal weight',
    'gWea' : 'no lag',
    'gnnWeaSST' : 'NN weather + SST',
    'sLWeaNeigh' : '',
    'sLWea' : 'no neighbors',
    'sLAll' : 'random',
    'wS100nnQrSST' : '',
    'wS100nn' : '100 nn',
    'wS50nn' : '50 nn',
    'wS20nn' : '20 nn',
    'wS100nnNoSST' : 'no SST',
    'wS100nnQrSSTnoHist' : 'no history',
    'wS100nnQrSSTnoWeather' : 'no weather',
    'g' : 'formation',
    'sL' : 'duration',
    'wS': 'intensification'
}

def siggi(s):
    if np.isnan(s): return ''
    if s < 0.1: return '*'
    #if s < 0.1: return '*'
    return ''

def nicer_plot(fig, ax, out_file, ylim=None, upper_left='', upper_right='', edgeC='w', text=''):
    ax.annotate(upper_left, xy=(0.03,0.95), xycoords='figure fraction', ha='left', va='top', fontweight='bold', fontsize=12, backgroundcolor='w')
    ax.annotate(upper_right, xy=(0.97,0.97), xycoords='figure fraction', ha='right', va='top', fontweight='bold', fontsize=12, color=edgeC, backgroundcolor='w')
    ax.annotate(text, xy=(0.03,0.97), xycoords='axes fraction', ha='left', va='top', backgroundcolor='none')
    if ylim is None:
        ylim = ax.get_ylim()
    ax.set_ylim(ylim)
    plt.gcf().patch.set_linewidth(3)
    plt.savefig(out_file, bbox_inches='tight', dpi=200, edgecolor=edgeC)


In [7]:
# choose components
alphabet = iter(list(string.ascii_uppercase))
alphabet_sL = iter(list(string.ascii_uppercase))
alphabet_wS = iter(list(string.ascii_uppercase))
version = iter(range(1,100))
comps_todo = [
    {'g':'gWeaLag2Weight', 'sL':'sLWeaNeigh', 'wS':'wS100nnQrSST', 'Emu':'Emu0', 'name':'main','l':next(alphabet),'c':'c', 'v':'main', 'vc':''},
    ]
for i,g in enumerate(['gWea', 'gWeaLag2', 'gnnWeaSST']):
    comps_todo.append({'g':g, 'sL':'sLWeaNeigh', 'wS':'wS100nnQrSST', 'Emu':'Emu0', 'name':'formation: '+comp_names[g],'l':next(alphabet),'c':'m', 'v':'v%s' %(next(version)), 'vc':'vG%s' %(i)})
for i,sL in enumerate(['sLAll','sLWea']):
    comps_todo.append({'g':'gWeaLag2Weight', 'sL':sL, 'wS':'wS100nnQrSST', 'Emu':'Emu0', 'name':'duration: '+comp_names[sL],'l':next(alphabet),'c':'orange', 'v':'v%s' %(next(version)), 'vc':'vD%s' %(i)})
for i,wS in enumerate(['wS100nn','wS20nn','wS100nnNoSST','wS100nnQrSSTnoWeather','wS100nnQrSSTnoHist']):
    # for i,wS in enumerate(['wS20nn','wS50nn','wS100nn'][::-1]):
    comps_todo.append({'g':'gWeaLag2Weight', 'sL':'sLWeaNeigh', 'wS':wS, 'Emu':'Emu0', 'name':'intensification: '+comp_names[wS],'l':next(alphabet),'c':'r', 'v':'v%s' %(next(version)), 'vc':'vI%s' %(i)})
N = 1000
overwrite = False
validations = {}
for dt in comps_todo:
    tag = '_'.join([dt[k] for k in ['g','sL','wS','Emu']])
    print(tag)
    import _emulator; importlib.reload(_emulator); from _emulator import *
    for k,v in {k:v for k,v in dt.items() if k in ['g','sL','wS','Emu']}.items():
        exec("import %s; importlib.reload(%s); from %s import *" % tuple(['components.'+k+'.'+v]*3))
    for test_period in [tt for tt in train_test.columns if tt != 'year']:
        train_years = train_test.loc[train_test[test_period]=='train', 'year'].values
        test_years = train_test.loc[train_test[test_period]=='test', 'year'].values
        train_folder = atl._dir_lvl4 + '/emulator/' + str(test_period)+'/'
        # genesis
        comp_file = train_folder+'/_comp_g_'+dt['g']+'/genesis_obj.pkl'
        if os.path.isfile(comp_file) and overwrite == False:
            genesis_obj = pickle.load(open(comp_file, 'rb'))
        else:
            genesis_obj = genesis_pred(dir=train_folder+'/_comp_g_'+dt['g']+'/', df=genesis.loc[np.isin(genesis.time.dt.year,train_years)])
            genesis_obj.fit(atl)
            genesis_obj.save()
            # print(genesis_obj._probs)
        # stormLength
        comp_file = train_folder+'/_comp_sL_'+dt['sL']+'/end_obj.pkl'
        if os.path.isfile(comp_file) and overwrite == False:
            stormL_obj = pickle.load(open(comp_file, 'rb'))
        else:
            stormL_obj = storm_length_estimator(dir=train_folder+'/_comp_sL_'+dt['sL']+'/', atl=atl, tracks=tracks.loc[np.isin(tracks.year,train_years)])
            stormL_obj.save()
            stormL_obj.plot_simulated_storm_length(atl=atl, tracks=tracks.loc[np.isin(tracks.year,train_years)])
        # windSpeed
        comp_file = train_folder+'/_comp_wS_'+dt['wS']+'/wind_obj.pkl'
        if os.path.isfile(comp_file) and overwrite == False:
            wind_obj = pickle.load(open(comp_file, 'rb'))
            # quantiles, wind_quR_params = sst_vs_wind_quantile_regression(tracks.loc[np.isin(tracks.year,train_years)], plot_dir=train_folder+'/_comp_wS_'+dt['wS']+'/', sst_var='sst')

        else:
            wind_obj = wind_estimator(dir=train_folder+'/_comp_wS_'+dt['wS']+'/', df=tracks.loc[np.isin(tracks.year,train_years)])
            wind_obj.get_analogue_pdfs(atl=atl)
            wind_obj.load_pdfs()
            wind_obj.save()
        exec("import %s; importlib.reload(%s); from %s import *" % tuple(['components.wS._helping_functions']*3))
        quantiles, wind_quR_params = sst_vs_wind_quantile_regression(tracks.loc[np.isin(tracks.year,train_years)], plot_dir=train_folder+'/_comp_wS_'+dt['wS']+'/', sst_var='sst')
        # print(wind_obj._lr)
        # wind_obj.plot_pdfs()
        emu = storm_emulator(dir=train_folder, tag=tag, emulate_season_function=emulate_season_function)
        atl._vector_time.values = np.array([str(d)[:10] for d in atl._vector_time.values], np.datetime64)
        # emu.prepare_input(atl, sst_tropics, sst_MDR_rel, years, fileName = atl._dir_lvl4 + '/emulator/weather_sst_input.csv', overwrite=overwrite)
        emu._weather_sst = weather_sst
        emu.emulate_seasons_serial(genesis_obj, wind_obj, stormL_obj, years=test_years, N=N, overwrite=overwrite)
    emu = storm_emulator(dir=atl._dir_lvl4 + '/emulator/xValid/', tag=tag, emulate_season_function=None)
    emu._seasons = {}
    for test_period in [tt for tt in train_test.columns if tt != 'year']:
        test_years = train_test.loc[train_test[test_period]=='test', 'year'].values
        for test_year in test_years:
            with open(atl._dir_lvl4 + '/emulator/'+test_period+'/'+tag+'/sim/'+tag+'_'+str(test_year)+'_N'+str(N)+'.pkl', 'rb') as infile:
                emu._seasons[test_year] = pickle.load(infile)
    emu._N = N
    emu._weather_sst = weather_sst
    # emu.get_simu_tracks(overwrite=True)
    # emu.get_other_stats_for_tracks(tracks)
    emu.get_stats_seasonal_simu(overwrite=False)
    emu.get_stats_seasonal_obs(tracks, train_test.year.values)
    emu._sets = [{'years':[years], 'label':'xValid', 'color':'c'}]
    emu._indicator_dict = {
        'genesis' : 'storm formations',
        'storm_days' : 'storm days in season',
        'wind' : 'max. daily wind speed',
        'wind' : 'acc. daily max. wind speeds',
        'ACE' : 'ACE',
        'Hur' : 'hurricanes',
        'MajHur' : 'major hurricanes',
        'stoMaxWind' : 'max. intensity of storm [kts]',
        'stoLen' : 'storm duration',
        'stoD' : 'day of storm',
        'dWind' : 'change in storm intensity [kts]',
        'wind_before' : 'intensity on the day before',
        'sst' : 'SST',
    }
    if tag == 'gWeaLag2Weight_sLWeaNeigh_wS100nnQrSST_Emu0':
        # fig 3
        axes_ = []
        fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(8,5), sharex=True)
        for indicator,letter,ylim,ax in zip(['genesis','storm_days','MajHur','ACE', 'Hur'],['A','B','C','D','E'],[(0,30),(0,200),(0,10),(0,350),(0,16)], axes.flatten()):
            ax = emu.vali_year_to_year_variability(indicator, show_legend=False, ax=ax)
            ax.annotate(letter, xy=(0.05,0.95), xycoords='axes fraction', ha='left', va='top', fontweight='bold', fontsize=12, backgroundcolor='w')
            corr_ = 'corr: %.2f%s' %(emu._validation[indicator]['pearson_median']['coef'],siggi(emu._validation[indicator]['pearson_median']['pval']))
            ax.annotate(corr_, xy=(0.95,0.95), xycoords='axes fraction', ha='right', va='top', fontweight='bold', fontsize=12, backgroundcolor='w')
        plt.tight_layout()
        plt.savefig(emu._dir_plot+'fig3.png', dpi=300)


    if dt['sL'] != 'sLWeaNeigh' or tag == 'gWeaLag2Weight_sLWeaNeigh_wS100nnQrSST_Emu0':
        letter_ = next(alphabet_sL)

        storms = [[np.float(w) for w in tracks.loc[(tracks.storm==storm), 'wind']] for storm in np.unique(tracks.loc[np.isin(tracks.year,years),'storm'])]
        seasons = [winds for season in emu._seasons.values() for storms in season for winds in storms.values()]
        
        # HIST #
        out_file = emu._dir_plot+'hist_duration_N'+str(emu._N)+'.png'
        if os.path.isfile(out_file.replace('.png','.pkl')) and False:
            fig, ax = pickle.load(open(out_file.replace('.png','.pkl'),'rb'))
        else:
            obs = np.array([len(winds) for winds in storms])
            simu = np.array([len(winds) for winds in seasons])
            fig, ax, out_file = emu.vali_distr_tests(obs=obs, simu=simu, out_file=out_file, indicator='duration', bins=np.arange(0,30,1))
            pickle.dump((fig,ax), open(out_file.replace('.png','.pkl'), 'wb'))
        ax.set_xlabel('storm duration [days]')
        nicer_plot(fig,ax,out_file.replace('.png','_sens.png'), upper_left=letter_, upper_right=dt['name'], edgeC=dt['c'])


    version_text = '\n'.join(['%s: %s' %(comp_names[k],comp_names[dt[k]]) for k in ['g','sL','wS']])
    for indicator,ylim in zip(['genesis','storm_days','Hur','MajHur','ACE'],[(0,35),(0,220),(0,18),(0,12),(0,360)]):
        # Year to Year and correlation #
        fig,ax,out_file = emu.vali_year_to_year_variability(indicator, show_legend=False)
        vali = emu._validation[indicator]
        text = 'pearson corr.: %s%s' %(vali['pearson_mean']['coef'].round(2), siggi(vali['pearson_mean']['pval']))
        text += '\nspearman corr.: %s%s' %(vali['spearman_mean']['coef'].round(2), siggi(vali['spearman_mean']['pval']))
        nicer_plot(fig,ax,out_file.replace('.png','_sens.png'), ylim=ylim, text=text)
    for indicator,ylim in zip(['genesis','storm_days','MajHur','ACE', 'Hur'],[(-16,16),(-120,120),(-7,7),(-240,240),(-10,10)]):
        # trend in residuals #
        fig,ax,out_file = emu.vali_residuals_and_long_term_trend(indicator)
        vali = emu._validation[indicator]
        text = 'MannKendall: %s%s' %(vali['MK_mean']['trend'], siggi(vali['MK_mean']['pval']))
        text += '\nLinear trend: %s%s' %(vali['trend_mean']['slope'].round(2), siggi(vali['trend_mean']['pval']))
        nicer_plot(fig,ax,out_file.replace('.png','_sens.png'), ylim=ylim, text=text)

        # RMSD
        emu.vali_RMSD(indicator)

    # residuals vs SST 
    for indicator,ylim in zip(['genesis','storm_days','MajHur','ACE', 'Hur'],[(-16,16),(-120,120),(-7,7),(-240,240),(-10,10)]):
        fig,ax,out_file = emu.vali_residuals_IND_vs_SST(indicator, sst_MDR.groupby('time.year').mean('time'))
        vali = emu._validation[indicator]
        text = 'MannKendall: %s%s' %(vali['MK_vs_SST_median']['trend'], siggi(vali['MK_vs_SST_median']['pval']))
        text += '\nLinear regression: %s%s' %(vali['trend_vs_SST_median']['slope'].round(2), siggi(vali['trend_vs_SST_median']['pval']))
        nicer_plot(fig,ax,out_file.replace('.png','_sens.png'), ylim=ylim, text=text)

    if dt['wS'] != 'wS100nnQrSST' or tag == 'gWeaLag2Weight_sLWeaNeigh_wS100nnQrSST_Emu0':
        #####################
        # deviations in KNN #
        #####################
        for var,ylim,xlim in zip(['wind_before','sst'],[(-100,40),(-0.5,0.5)],[(0,160),(27,29)]):
            plt.close('all')
            fig,ax = plt.subplots(nrows=1, figsize=(4,3))
            dists = xr.open_dataset(train_folder+'/_comp_wS_'+dt['wS']+'/distances.nc')['distances']          
            if var in dists.dims:
                for weather in [15,6,1,12]:
                    y = dists.sel({'q':[17,50,83],'d':var}).squeeze()
                    if 'weather_0' in dists.dims:
                        y = y.sel({'weather_0':weather})
                    for k in [k for k in y.dims if k not in [var,'q']]:
                        y = y.mean(k)
                    ax.fill_between(dists[var], y.loc[:,17], y.loc[:,83], alpha=0.5)
                    ax.plot(dists[var], y.loc[:,50], label='w%s' %(weather))
                ax.set_ylabel('bias in \n'+emu._indicator_dict[var])
                ax.set_xlabel(emu._indicator_dict[var])
                ax.legend()
                ax.set_xlim(xlim)
                out_file = train_folder+'/_comp_wS_'+dt['wS']+'/dist_weather_'+var+'.png'
                nicer_plot(fig,ax,out_file.replace('.png','_sens.png'), ylim=ylim)
                # nicer_plot(fig,ax,out_file.replace('.png','_sens.png'), ylim=ylim, upper_left=letter_, upper_right=dt['vc'], edgeC=dt['c'])
    
    validations[tag] = emu._validation

In [None]:
validations['gWeaLag2Weight_sLWeaNeigh_wS100nnQrSST_Emu0']['MajHur']

In [None]:
validations['gWeaLag2Weight_sLWeaNeigh_wS100nnNoSST_Emu0']['ACE']

## old


In [None]:
    # # residuals vs ACE 
    # fig,ax,out_file = emu.vali_residuals_IND_vs_X('ACE', emu._obs['ACE'], 'ACE')
    # nicer_plot(fig,ax,out_file.replace('.png','_sens.png'), upper_left=dt['l'], upper_right=dt['name'], edgeC=dt['c'])
    # if dt['wS'] != 'wS100nnQrSST' or tag == 'gWeaLag2Weight_sLWeaNeigh_wS100nnQrSST_Emu0':
    #     #####################
    #     # deviations in KNN #
    #     #####################
    #     for var,ylim,xlim in zip(['wind_before','sst'],[(-100,40),(-0.5,0.5)],[(0,160),(27,29)]):
    #         plt.close('all')
    #         fig,ax = plt.subplots(nrows=1, figsize=(4,3))
    #         dists = xr.open_dataset(train_folder+'/_comp_wS_'+dt['wS']+'/distances.nc')['distances']
    #         if var in dists.dims:
    #             for weather in [0,12,6,10,15]:
    #                 y = dists.sel({'q':[17,50,83],'d':var}).squeeze()
    #                 if 'weather_0' in dists.dims:
    #                     y = y.sel({'weather_0':weather})
    #                 for k in [k for k in y.dims if k not in [var,'q']]:
    #                     y = y.mean(k)
    #                 ax.fill_between(dists[var], y.loc[:,17], y.loc[:,83], alpha=0.5)
    #                 ax.plot(dists[var], y.loc[:,50], label='w%s' %(weather))
    #             ax.set_ylabel('bias in '+emu._indicator_dict[var])
    #             ax.set_xlabel(emu._indicator_dict[var])
    #             ax.legend()
    #             ax.set_xlim(xlim)
    #             out_file = train_folder+'/_comp_wS_'+dt['wS']+'/dist_weather_'+var+'.png'
    #             nicer_plot(fig,ax,out_file.replace('.png','_sens.png'), ylim=ylim, upper_left=dt['l'], upper_right=dt['name'], edgeC='k')
    #             # nicer_plot(fig,ax,out_file.replace('.png','_sens.png'), ylim=ylim, upper_left=letter_, upper_right=dt['vc'], edgeC=dt['c'])
    
    # # residuals vs SSTs  #
    # for indicator,lim in zip(['wind','ACE','Hur','MajHur','genesis'],[(-6000,6000),(-200,200),(-10,10),(-6,6),(-15,15)]):
    #     fig,ax,out_file = emu.vali_residuals_IND_vs_SST(indicator, sst_MDR.groupby('time.year').mean('time'))
    #     nicer_plot(fig,ax,out_file.replace('.png','_sens.png'), upper_left=dt['l'], upper_right=dt['name'], edgeC=dt['c'])

    
    
    print(emu._validation['ACE']['trend_median'], emu._validation['ACE']['trend_mean'])

    if dt['wS'] != 'wS100nnQrSST' or tag == 'gWeaLag2Weight_sLWeaNeigh_wS100nnQrSST_Emu0':
        print('wind plots', dt['wS'])
        letter_ = next(alphabet_wS)

        storms = [[np.float(w) for w in tracks.loc[(tracks.storm==storm), 'wind']] for storm in np.unique(tracks.loc[np.isin(tracks.year,years),'storm'])]
        seasons = [winds for season in emu._seasons.values() for storms in season for winds in storms.values()]

        # max. wind vs storm length #stoMaxWind_vs_stoLen_N1000_sens.png
        out_file = emu._dir_plot+'hstoMaxWind_vs_stoLen_N%s.png' %(emu._N)
        if os.path.isfile(out_file.replace('.png','.pkl')):
            fig, ax, im_simu, im_obs = pickle.load(open(out_file.replace('.png','.pkl'),'rb'))
        else:
            stoLen = np.array([len(winds) for winds in storms])
            maxWind = np.array([max(winds) for winds in storms])
            obs = pd.DataFrame(np.vstack((stoLen,maxWind)).T, columns=('stoLen','stoMaxWind'))
            stoLen = np.array([len(winds) for winds in seasons])
            maxWind = np.array([max(winds) for winds in seasons])
            simu = pd.DataFrame(np.vstack((stoLen,maxWind)).T)
            fig, ax, im_simu, im_obs, _ = emu.vali_2D_distr_plot(obs,simu, ranges=[0,14,0,120], bw_method=0.2, nBins=50)
            pickle.dump((fig,ax,im_simu,im_obs), open(out_file.replace('.png','.pkl'), 'wb'))
        nicer_plot(fig,ax,out_file.replace('.png','_sens.png'), upper_left=letter_, upper_right=dt['name'], edgeC=dt['c'])
        # fig, ax, out_file = emu.vali_2D_distr_KL_divergence(obs,simu, ranges=[np.arange(0.5,14.5,1),np.arange(15,115,10)])

        # daily wind speed vs storm day #
        out_file = emu._dir_plot+'stoD_vs_wind_N%s.png' %(emu._N)
        if os.path.isfile(out_file.replace('.png','.pkl')):
            fig, ax, im_simu, im_obs = pickle.load(open(out_file.replace('.png','.pkl'),'rb'))
        else:
            stoDay = np.array([i for winds in storms for i in range(len(winds))])
            wind = np.array([w for winds in storms for w in winds]).flatten()
            obs = pd.DataFrame(np.vstack((stoDay,wind)).T, columns=('stoD','wind'))
            stoDay = np.array([i for winds in seasons for i in range(len(winds))]).flatten()
            wind = np.array([w for winds in seasons for w in winds]).flatten()
            simu = pd.DataFrame(np.vstack((stoDay,wind)).T)
            fig, ax, im_simu, im_obs, _ = emu.vali_2D_distr_plot(obs,simu, ranges=[0,10,0,100], bw_method=0.2, nBins=50)
            pickle.dump((fig,ax,im_simu,im_obs), open(out_file.replace('.png','.pkl'), 'wb'))
        nicer_plot(fig,ax,out_file.replace('.png','_sens.png'), upper_left=letter_, upper_right=dt['name'], edgeC=dt['c'])
        # fig, ax, out_file = emu.vali_2D_distr_KL_divergence(obs,simu, ranges=[np.arange(0.5,10.5,1),np.arange(5,105,10)])

        # daily change in wind speed vs storm day #
        out_file = emu._dir_plot+'stoD_vs_dWind_N%s.png' %(emu._N)
        if os.path.isfile(out_file.replace('.png','.pkl')):
            fig, ax, im_simu, im_obs = pickle.load(open(out_file.replace('.png','.pkl'),'rb'))
        else:
            stoDay = np.array([i+2 for winds in storms for i in range(len(winds)-1)])
            wind = np.array([w-winds[i-1] for winds in storms for i,w in enumerate(winds[1:])]).flatten()
            obs = pd.DataFrame(np.vstack((stoDay,wind)).T, columns=('stoD','dWind'))
            stoDay = np.array([i+2 for winds in seasons for i in range(len(winds)-1)]).flatten()
            wind = np.array([w-winds[i-1] for winds in seasons for i,w in enumerate(winds[1:])]).flatten()
            simu = pd.DataFrame(np.vstack((stoDay,wind)).T)
            fig, ax, im_simu, im_obs, _ = emu.vali_2D_distr_plot(obs,simu, ranges=[0,10,-50,50], bw_method=0.2, nBins=50)
            pickle.dump((fig,ax,im_simu,im_obs), open(out_file.replace('.png','.pkl'), 'wb'))
        nicer_plot(fig,ax,out_file.replace('.png','_sens.png'), upper_left=letter_, upper_right=dt['name'], edgeC=dt['c'])
        # fig, ax, out_file = emu.vali_2D_distr_KL_divergence(obs,simu, ranges=[np.arange(0.5,10.5,1),np.arange(-55,65,10)])

        # # outliers #
        # fig,ax,out_file = emu.vali_outliers(indicator)
        # nicer_plot(fig,ax,out_file.replace('.png','_sens.png'), upper_left=dt['l'], ylim=(0,111), edgeC=dt['c'])


In [None]:
import sys,os,importlib,gc, re, string
import xarray as xr
import numpy as np
import pandas as pd

%matplotlib inline
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages

os.chdir('/home/peter/Projects/tc_emulator/results')

sys.path.append('../scripts')
import _weather_pattern_class; importlib.reload(_weather_pattern_class)

atl = _weather_pattern_class.weather_patterns(source='ERA5', working_directory='./')
atl.load_input('ERA5_VWS-MSLP_noTC3')
years = np.array(range(1982,2021))
atl.set_split(years=years)
nrows,ncols = 5,4
tag = 'SOM_pcaInit%sx%s_v1' % (nrows,ncols)
atl.define_plot_environment(pre_mapping='mapping_raw', clustering=tag, post_mapping='mapping_sammon_1982-2020', nrows=nrows, ncols=ncols)
atl.stats_TC(file='tracks/tracks_ibtracks.csv', overwrite=False)

years = np.arange(1982,2021,1)
# get sst
sst_MDR = xr.load_dataset('/home/peter/Projects/data/SST/OISST_sst_MDR_1981-2019_daily.nc')['sst']
sst_MDR = sst_MDR[np.isin(sst_MDR.time.dt.year,years)]
sst_MDR = sst_MDR[np.isin(sst_MDR.time.dt.month,atl._months['mon'])]
sst_MDR = sst_MDR.assign_coords(time=np.array([str(d)[:10] for d in sst_MDR.time.values], np.datetime64))

# prepare tracks:
# here ssts are added. this will be needed in the wind component
atl._tracks = atl._tracks.loc[np.isin(atl._tracks.year,years)]
times = np.array([str(d)[:10] for d in atl._tracks.time.values], np.datetime64)
atl._tracks['time'] = np.array([str(d)[:10] for d in atl._tracks.time],np.datetime64)
atl._tracks['sst'] = sst_MDR.loc[times].values
atl._tracks['weather_0'] = atl._tracks['label_lag0']
tracks = atl._tracks.loc[np.isfinite(atl._tracks.weather_0)]
tracks = tracks.loc[tracks.distance > 0, ['weather_0','sst','wind','genesis','storm','ACE','year','storm_day','wind_before','month']]

# prepare gensis input
# this is a dataframe with an entry for each day
# this is required to get genesis probabilities
weather_sst = pd.DataFrame()
weather_sst['time'] =  np.array([str(d)[:10] for d in  atl._vector_time.values], np.datetime64)
weather_sst['year'] = atl._vector_time.dt.year
weather_sst['weather_0'] = atl._clust_labels
weather_sst['weather_1'] = np.roll(atl._clust_labels,1)
weather_sst['weather_2'] = np.roll(atl._clust_labels,2)
weather_sst['weather_3'] = np.roll(atl._clust_labels,3)
weather_sst = weather_sst.loc[np.isin(atl._vector_time.dt.year,years)]

genesis = weather_sst.copy()
genesis['genesis'] = [atl._tracks.loc[atl._tracks.time==np.datetime64(tt),'genesis'].sum() for tt in genesis.time]
genesis['sst'] = sst_MDR.sel(time=weather_sst.time.values)

weather_sst['sst'] = sst_MDR.sel(time=weather_sst.time.values)

genesis['day_in_season'] = 0
weather_sst['day_in_season'] = 0
for year in np.unique(weather_sst.time.dt.year):
    tttmmmppp = weather_sst.loc[(weather_sst.time.dt.year==year),'day_in_season']
    weather_sst.loc[(weather_sst.time.dt.year==year),'day_in_season'] = np.arange(tttmmmppp.shape[0])
    genesis.loc[(genesis.time.dt.year==year),'day_in_season'] = np.arange(tttmmmppp.shape[0])

weather_sst = weather_sst.loc[(weather_sst.day_in_season>=3) & np.isin(weather_sst.year,years)]
genesis = genesis.loc[(genesis.day_in_season>=3) & np.isin(genesis.year,years)]

# train test split by decades
train_test = pd.DataFrame()
train_test['year'] = list(range(1982,2021))
train_test['1982-1990'] = 'train'
train_test.loc[np.isin(train_test.year,np.arange(1982,1990+1)), '1982-1990'] = 'test'
train_test['1991-2000'] = 'train'
train_test.loc[np.isin(train_test.year,np.arange(1991,2000+1)), '1991-2000'] = 'test'
train_test['2001-2010'] = 'train'
train_test.loc[np.isin(train_test.year,np.arange(2001,2010+1)), '2001-2010'] = 'test'
train_test['2011-2020'] = 'train'
train_test.loc[np.isin(train_test.year,np.arange(2011,2020+1)), '2011-2020'] = 'test'

comp_names = {
    'gWeaLag2Weight' : '',
    'gWeaLag2' : 'equal weight',
    'gWea' : 'no lag',
    'gnnWeaSST' : 'NN weather + SST',
    'sLWeaNeigh' : '',
    'sLWea' : 'no neighbors',
    'sLAll' : 'random',
    'wS100nnQrSST' : '',
    'wS100nn' : '100 nn',
    'wS50nn' : '50 nn',
    'wS20nn' : '20 nn',
    'wS100nnNoSST' : 'no SST',
    'wS100nnQrSSTnoHist' : 'no history',
    'wS100nnQrSSTnoWeather' : 'no weather',
    'g' : 'formation',
    'sL' : 'duration',
    'wS': 'intensification'
}

def siggi(s):
    if np.isnan(s): return ''
    if s < 0.1: return '*'
    #if s < 0.1: return '*'
    return ''

def nicer_plot(fig, ax, out_file, ylim=None, upper_left='', upper_right='', edgeC='w', text=''):
    ax.annotate(upper_left, xy=(0.03,0.95), xycoords='figure fraction', ha='left', va='top', fontweight='bold', fontsize=12, backgroundcolor='w')
    ax.annotate(upper_right, xy=(0.97,0.97), xycoords='figure fraction', ha='right', va='top', fontweight='bold', fontsize=12, color=edgeC, backgroundcolor='w')
    ax.annotate(text, xy=(0.03,0.97), xycoords='axes fraction', ha='left', va='top', backgroundcolor='none')
    if ylim is None:
        ylim = ax.get_ylim()
    ax.set_ylim(ylim)
    plt.gcf().patch.set_linewidth(3)
    plt.savefig(out_file, bbox_inches='tight', dpi=200, edgecolor=edgeC)


# choose components
alphabet = iter(list(string.ascii_uppercase))
alphabet_sL = iter(list(string.ascii_uppercase))
alphabet_wS = iter(list(string.ascii_uppercase))
version = iter(range(1,100))
comps_todo = [
    {'g':'gWeaLag2Weight', 'sL':'sLWea', 'wS':'wS100nnQrSST', 'Emu':'Emu0', 'name':'main','l':next(alphabet),'c':'c', 'v':'main', 'vc':''},
    ]
# for i,g in enumerate(['gWea', 'gWeaLag2', 'gnnWeaSST']):
    # comps_todo.append({'g':g, 'sL':'sLWea', 'wS':'wS100nnQrSST', 'Emu':'Emu0', 'name':'formation: '+comp_names[g],'l':next(alphabet),'c':'m', 'v':'v%s' %(next(version)), 'vc':'vG%s' %(i)})
for i,sL in enumerate(['sLAll','sLWeaNeigh']):
    comps_todo.append({'g':'gWeaLag2Weight', 'sL':sL, 'wS':'wS100nnQrSST', 'Emu':'Emu0', 'name':'duration: '+comp_names[sL],'l':next(alphabet),'c':'orange', 'v':'v%s' %(next(version)), 'vc':'vD%s' %(i)})
for i,wS in enumerate(['wS100nn','wS20nn','wS100nnNoSST','wS100nnQrSSTnoWeather','wS100nnQrSSTnoHist']):
    # for i,wS in enumerate(['wS20nn','wS50nn','wS100nn'][::-1]):
    comps_todo.append({'g':'gWeaLag2Weight', 'sL':'sLWea', 'wS':wS, 'Emu':'Emu0', 'name':'intensification: '+comp_names[wS],'l':next(alphabet),'c':'r', 'v':'v%s' %(next(version)), 'vc':'vI%s' %(i)})
N = 1000
overwrite = False
validations = {}
for dt in comps_todo:
    tag = '_'.join([dt[k] for k in ['g','sL','wS','Emu']])
    print(tag)
    import _emulator; importlib.reload(_emulator); from _emulator import *
    for k,v in {k:v for k,v in dt.items() if k in ['g','sL','wS','Emu']}.items():
        exec("import %s; importlib.reload(%s); from %s import *" % tuple(['components.'+k+'.'+v]*3))
    for test_period in [tt for tt in train_test.columns if tt != 'year']:
        train_years = train_test.loc[train_test[test_period]=='train', 'year'].values
        test_years = train_test.loc[train_test[test_period]=='test', 'year'].values
        train_folder = atl._dir_lvl4 + '/emulator/' + str(test_period)+'/'
        # genesis
        comp_file = train_folder+'/_comp_g_'+dt['g']+'/genesis_obj.pkl'
        if os.path.isfile(comp_file) and overwrite == False:
            genesis_obj = pickle.load(open(comp_file, 'rb'))
        else:
            genesis_obj = genesis_pred(dir=train_folder+'/_comp_g_'+dt['g']+'/', df=genesis.loc[np.isin(genesis.time.dt.year,train_years)])
            genesis_obj.fit(atl)
            genesis_obj.save()
            # print(genesis_obj._probs)
        # stormLength
        comp_file = train_folder+'/_comp_sL_'+dt['sL']+'/end_obj.pkl'
        if os.path.isfile(comp_file) and overwrite == False:
            stormL_obj = pickle.load(open(comp_file, 'rb'))
        else:
            stormL_obj = storm_length_estimator(dir=train_folder+'/_comp_sL_'+dt['sL']+'/', atl=atl, tracks=tracks.loc[np.isin(tracks.year,train_years)])
            stormL_obj.save()
            stormL_obj.plot_simulated_storm_length(atl=atl, tracks=tracks.loc[np.isin(tracks.year,train_years)])
        # windSpeed
        comp_file = train_folder+'/_comp_wS_'+dt['wS']+'/wind_obj.pkl'
        if os.path.isfile(comp_file) and overwrite == False:
            wind_obj = pickle.load(open(comp_file, 'rb'))
        else:
            wind_obj = wind_estimator(dir=train_folder+'/_comp_wS_'+dt['wS']+'/', df=tracks.loc[np.isin(tracks.year,train_years)])
            wind_obj.get_analogue_pdfs(atl=atl)
            wind_obj.load_pdfs()
            wind_obj.save()
        exec("import %s; importlib.reload(%s); from %s import *" % tuple(['components.wS._helping_functions']*3))
        quantiles, wind_quR_params = sst_vs_wind_quantile_regression(tracks.loc[np.isin(tracks.year,train_years)], plot_dir=train_folder+'/_comp_wS_'+dt['wS']+'/', sst_var='sst')
        # print(wind_obj._lr)
        # wind_obj.plot_pdfs()
        emu = storm_emulator(dir=train_folder, tag=tag, emulate_season_function=emulate_season_function)
        atl._vector_time.values = np.array([str(d)[:10] for d in atl._vector_time.values], np.datetime64)
        # emu.prepare_input(atl, sst_tropics, sst_MDR_rel, years, fileName = atl._dir_lvl4 + '/emulator/weather_sst_input.csv', overwrite=overwrite)
        emu._weather_sst = weather_sst
        emu.emulate_seasons_serial(genesis_obj, wind_obj, stormL_obj, years=test_years, N=N, overwrite=overwrite)
    emu = storm_emulator(dir=atl._dir_lvl4 + '/emulator/xValid/', tag=tag, emulate_season_function=None)
    emu._seasons = {}
    for test_period in [tt for tt in train_test.columns if tt != 'year']:
        test_years = train_test.loc[train_test[test_period]=='test', 'year'].values
        for test_year in test_years:
            with open(atl._dir_lvl4 + '/emulator/'+test_period+'/'+tag+'/sim/'+tag+'_'+str(test_year)+'_N'+str(N)+'.pkl', 'rb') as infile:
                emu._seasons[test_year] = pickle.load(infile)
    emu._N = N
    emu._weather_sst = weather_sst
    # emu.get_simu_tracks(overwrite=True)
    # emu.get_other_stats_for_tracks(tracks)
    emu.get_stats_seasonal_simu(overwrite=False)
    emu.get_stats_seasonal_obs(tracks, train_test.year.values)
    emu._sets = [{'years':[years], 'label':'xValid', 'color':'c'}]
    emu._indicator_dict = {
        'genesis' : 'storm formations',
        'storm_days' : 'storm days in season',
        'wind' : 'max. daily wind speed',
        'wind' : 'acc. daily max. wind speeds',
        'ACE' : 'ACE',
        'Hur' : 'hurricanes',
        'MajHur' : 'major hurricanes',
        'stoMaxWind' : 'max. intensity of storm [kts]',
        'stoLen' : 'storm duration',
        'stoD' : 'day of storm',
        'dWind' : 'change in storm intensity [kts]',
        'wind_before' : 'intensity on the day before',
        'sst' : 'SST',
    }
    if tag == 'gWeaLag2Weight_sLWeaNeigh_wS100nnQrSST_Emu0':
        # fig 3
        axes_ = []
        fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(8,5), sharex=True)
        for indicator,letter,ylim,ax in zip(['genesis','storm_days','MajHur','ACE', 'Hur'],['A','B','C','D','E'],[(0,30),(0,200),(0,10),(0,350),(0,16)], axes.flatten()):
            ax = emu.vali_year_to_year_variability(indicator, show_legend=False, ax=ax)
            ax.annotate(letter, xy=(0.05,0.95), xycoords='axes fraction', ha='left', va='top', fontweight='bold', fontsize=12, backgroundcolor='w')
            corr_ = 'corr: %.2f%s' %(emu._validation[indicator]['pearson_median']['coef'],siggi(emu._validation[indicator]['pearson_median']['pval']))
            ax.annotate(corr_, xy=(0.95,0.95), xycoords='axes fraction', ha='right', va='top', fontweight='bold', fontsize=12, backgroundcolor='w')
        plt.tight_layout()
        plt.savefig(emu._dir_plot+'fig3.png', dpi=300)


    if dt['sL'] != 'sLWeaNeigh' or tag == 'gWeaLag2Weight_sLWeaNeigh_wS100nnQrSST_Emu0':
        letter_ = next(alphabet_sL)

        storms = [[np.float(w) for w in tracks.loc[(tracks.storm==storm), 'wind']] for storm in np.unique(tracks.loc[np.isin(tracks.year,years),'storm'])]
        seasons = [winds for season in emu._seasons.values() for storms in season for winds in storms.values()]
        
        # HIST #
        out_file = emu._dir_plot+'hist_duration_N'+str(emu._N)+'.png'
        if os.path.isfile(out_file.replace('.png','.pkl')) and False:
            fig, ax = pickle.load(open(out_file.replace('.png','.pkl'),'rb'))
        else:
            obs = np.array([len(winds) for winds in storms])
            simu = np.array([len(winds) for winds in seasons])
            fig, ax, out_file = emu.vali_distr_tests(obs=obs, simu=simu, out_file=out_file, indicator='duration', bins=np.arange(0,30,1))
            pickle.dump((fig,ax), open(out_file.replace('.png','.pkl'), 'wb'))
        ax.set_xlabel('storm duration [days]')
        nicer_plot(fig,ax,out_file.replace('.png','_sens.png'), upper_left=letter_, upper_right=dt['name'], edgeC=dt['c'])


    version_text = '\n'.join(['%s: %s' %(comp_names[k],comp_names[dt[k]]) for k in ['g','sL','wS']])
    for indicator,ylim in zip(['genesis','storm_days','Hur','MajHur','ACE'],[(0,35),(0,220),(0,18),(0,12),(0,360)]):
        # Year to Year and correlation #
        fig,ax,out_file = emu.vali_year_to_year_variability(indicator, show_legend=False)
        vali = emu._validation[indicator]
        text = 'pearson corr.: %s%s' %(vali['pearson_mean']['coef'].round(2), siggi(vali['pearson_mean']['pval']))
        text += '\nspearman corr.: %s%s' %(vali['spearman_mean']['coef'].round(2), siggi(vali['spearman_mean']['pval']))
        nicer_plot(fig,ax,out_file.replace('.png','_sens.png'), ylim=ylim, text=text)
    for indicator,ylim in zip(['genesis','storm_days','MajHur','ACE', 'Hur'],[(-16,16),(-120,120),(-7,7),(-240,240),(-10,10)]):
        # trend in residuals #
        fig,ax,out_file = emu.vali_residuals_and_long_term_trend(indicator)
        vali = emu._validation[indicator]
        text = 'MannKendall: %s%s' %(vali['MK_mean']['trend'], siggi(vali['MK_mean']['pval']))
        text += '\nLinear trend: %s%s' %(vali['trend_mean']['slope'].round(2), siggi(vali['trend_mean']['pval']))
        nicer_plot(fig,ax,out_file.replace('.png','_sens.png'), ylim=ylim, text=text)

        # RMSD
        emu.vali_RMSD(indicator)

    # residuals vs SST 
    for indicator,ylim in zip(['genesis','storm_days','MajHur','ACE', 'Hur'],[(-16,16),(-120,120),(-7,7),(-240,240),(-10,10)]):
        fig,ax,out_file = emu.vali_residuals_IND_vs_SST(indicator, sst_MDR.groupby('time.year').mean('time'))
        vali = emu._validation[indicator]
        text = 'MannKendall: %s%s' %(vali['MK_vs_SST_median']['trend'], siggi(vali['MK_vs_SST_median']['pval']))
        text += '\nLinear regression: %s%s' %(vali['trend_vs_SST_median']['slope'].round(2), siggi(vali['trend_vs_SST_median']['pval']))
        nicer_plot(fig,ax,out_file.replace('.png','_sens.png'), ylim=ylim, text=text)

    if dt['wS'] != 'wS100nnQrSST' or tag == 'gWeaLag2Weight_sLWeaNeigh_wS100nnQrSST_Emu0':
        #####################
        # deviations in KNN #
        #####################
        for var,ylim,xlim in zip(['wind_before','sst'],[(-100,40),(-0.5,0.5)],[(0,160),(27,29)]):
            plt.close('all')
            fig,ax = plt.subplots(nrows=1, figsize=(4,3))
            dists = xr.open_dataset(train_folder+'/_comp_wS_'+dt['wS']+'/distances.nc')['distances']          
            if var in dists.dims:
                for weather in [15,6,1,12]:
                    y = dists.sel({'q':[17,50,83],'d':var}).squeeze()
                    if 'weather_0' in dists.dims:
                        y = y.sel({'weather_0':weather})
                    for k in [k for k in y.dims if k not in [var,'q']]:
                        y = y.mean(k)
                    ax.fill_between(dists[var], y.loc[:,17], y.loc[:,83], alpha=0.5)
                    ax.plot(dists[var], y.loc[:,50], label='w%s' %(weather))
                ax.set_ylabel('bias in \n'+emu._indicator_dict[var])
                ax.set_xlabel(emu._indicator_dict[var])
                ax.legend()
                ax.set_xlim(xlim)
                out_file = train_folder+'/_comp_wS_'+dt['wS']+'/dist_weather_'+var+'.png'
                nicer_plot(fig,ax,out_file.replace('.png','_sens.png'), ylim=ylim)
                # nicer_plot(fig,ax,out_file.replace('.png','_sens.png'), ylim=ylim, upper_left=letter_, upper_right=dt['vc'], edgeC=dt['c'])
    
    validations[tag] = emu._validation