## Code for applying EWS and DL on empirical data

In [None]:
import pandas as pd
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import matplotlib.ticker as ticker
import ewstools
from ewstools.models import simulate_ricker
import datetime
import time
from pylab import cm
from sklearn import metrics

In [None]:
print(tf.__version__)

In [None]:
print("Num GPUs available:", len(tf.config.experimental.list_physical_devices('GPU')))

In [None]:
model_resnet = tf.keras.models.load_model('./trained_models/resnet_300ep_1024batch.h5')
model_lstm = tf.keras.models.load_model('./trained_models/lstm_300ep_1024batch.h5')

In [None]:
def reshape_inputs(df1, df2, n_features=2):
    # Ensure df1 and df2 are in the correct shape and convert them to NumPy arrays
    arr1 = df1.values.T.reshape(-1, df1.shape[0], 1)  # Shape: (n_samples, n_time_steps, 1)
    arr2 = df2.values.T.reshape(-1, df2.shape[0], 1)  # Shape: (n_samples, n_time_steps, 1)
    
    # Stack the arrays along the last dimension to combine the features
    X = np.concatenate((arr1, arr2), axis=2)  # Shape: (n_samples, n_time_steps, n_features)
    
    return X

### Obtain EWS

In [None]:
ca_count = pd.read_csv('./twitter_data/ca14_count.csv')
ny_count = pd.read_csv('./twitter_data/ny18_count.csv')
wa_count = pd.read_csv('./twitter_data/wa18_count.csv')
bc_count = pd.read_csv('./twitter_data/bc14_count.csv')
sg_count = pd.read_csv('./twitter_data/sg_count.csv')
mx_count = pd.read_csv('./twitter_data/mx_count.csv')
ar_count = pd.read_csv('./twitter_data/ar_count.csv')
br_count = pd.read_csv('./twitter_data/br_count.csv')

ca_count['Time'] = pd.to_datetime(ca_count['Time'])  # 2011-01-01 to 2015-03-29
ny_count['Time'] = pd.to_datetime(ny_count['Time'])  # 2015-05-01 to 2019-05-12
wa_count['Time'] = pd.to_datetime(ny_count['Time'])  # 2015-05-01 to 2019-03-15
bc_count['Time'] = pd.to_datetime(bc_count['Time'])  # 2010-11-01 to 2014-09-27
sg_count['Time'] = pd.to_datetime(sg_count['Time'])  # 2014-01-01 to 2017-12-31
mx_count['Time'] = pd.to_datetime(mx_count['Time'])  # 2014-01-01 to 2017-12-31
ar_count['Time'] = pd.to_datetime(ar_count['Time'])  # 2014-01-01 to 2017-12-31
br_count['Time'] = pd.to_datetime(br_count['Time'])  # 2014-01-01 to 2017-12-31

In [None]:
# trans denotes the 'index'
ts_ca = ewstools.TimeSeries(data=ca_count['Count'], transition=1456)  # first case: 2014-12-27
ts_ny = ewstools.TimeSeries(data=ny_count['Count'], transition=1248)  # 2018-09-30
ts_wa = ewstools.TimeSeries(data=wa_count['Count'], transition=1338)  # 2018-12-29
ts_bc = ewstools.TimeSeries(data=bc_count['Count'], transition=1209)  # 2014-02-22
ts_sg = ewstools.TimeSeries(data=sg_count['Count'])
ts_mx = ewstools.TimeSeries(data=mx_count['Count'])
ts_ar = ewstools.TimeSeries(data=ar_count['Count'])
ts_br = ewstools.TimeSeries(data=br_count['Count'])

ts_ca.detrend(method='Gaussian', bandwidth=100)
ts_ny.detrend(method='Gaussian', bandwidth=100)
ts_wa.detrend(method='Gaussian', bandwidth=100)
ts_bc.detrend(method='Gaussian', bandwidth=100)
ts_sg.detrend(method='Gaussian', bandwidth=100)
ts_mx.detrend(method='Gaussian', bandwidth=100)
ts_ar.detrend(method='Gaussian', bandwidth=100)
ts_br.detrend(method='Gaussian', bandwidth=100)

# EWS
ts_ca.compute_var(rolling_window=200)
ts_ca.compute_auto(rolling_window=200, lag=1)
ts_ca.compute_cv(rolling_window=200)
ts_ny.compute_var(rolling_window=200)
ts_ny.compute_auto(rolling_window=200, lag=1)
ts_ny.compute_cv(rolling_window=200)
ts_wa.compute_var(rolling_window=200)
ts_wa.compute_auto(rolling_window=200, lag=1)
ts_wa.compute_cv(rolling_window=200)
ts_bc.compute_var(rolling_window=200)
ts_bc.compute_auto(rolling_window=200, lag=1)
ts_bc.compute_cv(rolling_window=200)
ts_sg.compute_var(rolling_window=200)
ts_sg.compute_auto(rolling_window=200, lag=1)
ts_sg.compute_cv(rolling_window=200)
ts_mx.compute_var(rolling_window=200)
ts_mx.compute_auto(rolling_window=200, lag=1)
ts_mx.compute_cv(rolling_window=200)
ts_ar.compute_var(rolling_window=200)
ts_ar.compute_auto(rolling_window=200, lag=1)
ts_ar.compute_cv(rolling_window=200)
ts_br.compute_var(rolling_window=200)
ts_br.compute_auto(rolling_window=200, lag=1)
ts_br.compute_cv(rolling_window=200)

### Obtain DL output

In [None]:
# CA 2014
prob_list_ca_resnet = []
prob_list_ca_lstm = []

trans = 1456
end = 0
ts_res_ca = ts_ca.state['residuals']/np.mean(np.abs(ts_ca.state['residuals']))
while end <= trans:
    if end < 500:
        prob_list_ca_resnet.append(np.nan)
        prob_list_ca_lstm.append(np.nan)
    else:
        inputs_test = reshape_inputs(ts_res_ca[end-500:end], ts_ca.state['smoothing'][end-500:end])
        predict_resnet = model_resnet.predict(inputs_test, verbose=0)
        predict_lstm = model_lstm.predict(inputs_test, verbose=0)
        prob_list_ca_resnet.append(predict_resnet[0][0])
        prob_list_ca_lstm.append(predict_lstm[0][0])
    end += 1

In [None]:
# NY 2018
prob_list_ny_resnet = []
prob_list_ny_lstm = []

trans = 1248
end = 0
ts_res_ny = ts_ny.state['residuals']/np.mean(np.abs(ts_ny.state['residuals']))
while end <= trans:
    if end < 500:
        prob_list_ny_resnet.append(np.nan)
        prob_list_ny_lstm.append(np.nan)
    else:
        inputs_test = reshape_inputs(ts_res_ny[end-500:end], ts_ny.state['smoothing'][end-500:end])
        predict_resnet = model_resnet.predict(inputs_test, verbose=0)
        predict_lstm = model_lstm.predict(inputs_test, verbose=0)
        prob_list_ny_resnet.append(predict_resnet[0][0])
        prob_list_ny_lstm.append(predict_lstm[0][0])
    end += 1

In [None]:
# WA 2018
prob_list_wa_resnet = []
prob_list_wa_lstm = []

trans = 1338
end = 0
ts_res_wa = ts_wa.state['residuals']/np.mean(np.abs(ts_wa.state['residuals']))
while end <= trans:
    if end < 500:
        prob_list_wa_resnet.append(np.nan)
        prob_list_wa_lstm.append(np.nan)
    else:
        inputs_test = reshape_inputs(ts_res_wa[end-500:end], ts_wa.state['smoothing'][end-500:end])
        predict_resnet = model_resnet.predict(inputs_test, verbose=0)
        predict_lstm = model_lstm.predict(inputs_test, verbose=0)
        prob_list_wa_resnet.append(predict_resnet[0][0])
        prob_list_wa_lstm.append(predict_lstm[0][0])
    end += 1

In [None]:
# BC 2014
prob_list_bc_resnet = []
prob_list_bc_lstm = []

trans = 1209
end = 0
ts_res_bc = ts_bc.state['residuals']/np.mean(np.abs(ts_bc.state['residuals']))
while end <= trans:
    if end < 500:
        prob_list_bc_resnet.append(np.nan)
        prob_list_bc_lstm.append(np.nan)
    else:
        inputs_test = reshape_inputs(ts_res_bc[end-500:end], ts_bc.state['smoothing'][end-500:end])
        predict_resnet = model_resnet.predict(inputs_test, verbose=0)
        predict_lstm = model_lstm.predict(inputs_test, verbose=0)
        prob_list_bc_resnet.append(predict_resnet[0][0])
        prob_list_bc_lstm.append(predict_lstm[0][0])
    end += 1

In [None]:
# SG
prob_list_sg_resnet = []
prob_list_sg_lstm = []

trans = 1461
end = 0
ts_res_sg = ts_sg.state['residuals']/np.mean(np.abs(ts_sg.state['residuals']))
while end <= trans:
    if end < 500:
        prob_list_sg_resnet.append(np.nan)
        prob_list_sg_lstm.append(np.nan)
    else:
        inputs_test = reshape_inputs(ts_res_sg[end-500:end], ts_sg.state['smoothing'][end-500:end])
        predict_resnet = model_resnet.predict(inputs_test, verbose=0)
        predict_lstm = model_lstm.predict(inputs_test, verbose=0)
        prob_list_sg_resnet.append(predict_resnet[0][0])
        prob_list_sg_lstm.append(predict_lstm[0][0])
    end += 1

In [None]:
# MX
prob_list_mx_resnet = []
prob_list_mx_lstm = []

trans = 1461
end = 0
ts_res_mx = ts_mx.state['residuals']/np.mean(np.abs(ts_mx.state['residuals']))
while end <= trans:
    if end < 500:
        prob_list_mx_resnet.append(np.nan)
        prob_list_mx_lstm.append(np.nan)
    else:
        inputs_test = reshape_inputs(ts_res_mx[end-500:end], ts_mx.state['smoothing'][end-500:end])
        predict_resnet = model_resnet.predict(inputs_test, verbose=0)
        predict_lstm = model_lstm.predict(inputs_test, verbose=0)
        prob_list_mx_resnet.append(predict_resnet[0][0])
        prob_list_mx_lstm.append(predict_lstm[0][0])
    end += 1

In [None]:
# AR
prob_list_ar_resnet = []
prob_list_ar_lstm = []

trans = 1461
end = 0
ts_res_ar = ts_ar.state['residuals']/np.mean(np.abs(ts_ar.state['residuals']))
while end <= trans:
    if end < 500:
        prob_list_ar_resnet.append(np.nan)
        prob_list_ar_lstm.append(np.nan)
    else:
        inputs_test = reshape_inputs(ts_res_ar[end-500:end], ts_ar.state['smoothing'][end-500:end])
        predict_resnet = model_resnet.predict(inputs_test, verbose=0)
        predict_lstm = model_lstm.predict(inputs_test, verbose=0)
        prob_list_ar_resnet.append(predict_resnet[0][0])
        prob_list_ar_lstm.append(predict_lstm[0][0])
    end += 1

In [None]:
# BR
prob_list_br_resnet = []
prob_list_br_lstm = []

trans = 1461
end = 0
ts_res_br = ts_br.state['residuals']/np.mean(np.abs(ts_br.state['residuals']))
while end <= trans:
    if end < 500:
        prob_list_br_resnet.append(np.nan)
        prob_list_br_lstm.append(np.nan)
    else:
        inputs_test = reshape_inputs(ts_res_br[end-500:end], ts_br.state['smoothing'][end-500:end])
        predict_resnet = model_resnet.predict(inputs_test, verbose=0)
        predict_lstm = model_lstm.predict(inputs_test, verbose=0)
        prob_list_br_resnet.append(predict_resnet[0][0])
        prob_list_br_lstm.append(predict_lstm[0][0])
    end += 1

In [None]:
df_test_ca_resnet = pd.DataFrame({'prob':prob_list_ca_resnet})
df_test_ca_lstm = pd.DataFrame({'prob':prob_list_ca_lstm})
df_test_ny_resnet = pd.DataFrame({'prob':prob_list_ny_resnet})
df_test_ny_lstm = pd.DataFrame({'prob':prob_list_ny_lstm})
df_test_wa_resnet = pd.DataFrame({'prob':prob_list_wa_resnet})
df_test_wa_lstm = pd.DataFrame({'prob':prob_list_wa_lstm})
df_test_bc_resnet = pd.DataFrame({'prob':prob_list_bc_resnet})
df_test_bc_lstm = pd.DataFrame({'prob':prob_list_bc_lstm})
df_test_sg_resnet = pd.DataFrame({'prob':prob_list_sg_resnet})
df_test_sg_lstm = pd.DataFrame({'prob':prob_list_sg_lstm})
df_test_mx_resnet = pd.DataFrame({'prob':prob_list_mx_resnet})
df_test_mx_lstm = pd.DataFrame({'prob':prob_list_mx_lstm})
df_test_ar_resnet = pd.DataFrame({'prob':prob_list_ar_resnet})
df_test_ar_lstm = pd.DataFrame({'prob':prob_list_ar_lstm})
df_test_br_resnet = pd.DataFrame({'prob':prob_list_br_resnet})
df_test_br_lstm = pd.DataFrame({'prob':prob_list_br_lstm})

ts_test_ca_resnet = ewstools.TimeSeries(data=df_test_ca_resnet['prob'])
ts_test_ca_lstm = ewstools.TimeSeries(data=df_test_ca_lstm['prob'])
ts_test_ny_resnet = ewstools.TimeSeries(data=df_test_ny_resnet['prob'])
ts_test_ny_lstm = ewstools.TimeSeries(data=df_test_ny_lstm['prob'])
ts_test_wa_resnet = ewstools.TimeSeries(data=df_test_wa_resnet['prob'])
ts_test_wa_lstm = ewstools.TimeSeries(data=df_test_wa_lstm['prob'])
ts_test_bc_resnet = ewstools.TimeSeries(data=df_test_bc_resnet['prob'])
ts_test_bc_lstm = ewstools.TimeSeries(data=df_test_bc_lstm['prob'])
ts_test_sg_resnet = ewstools.TimeSeries(data=df_test_sg_resnet['prob'])
ts_test_sg_lstm = ewstools.TimeSeries(data=df_test_sg_lstm['prob'])
ts_test_mx_resnet = ewstools.TimeSeries(data=df_test_mx_resnet['prob'])
ts_test_mx_lstm = ewstools.TimeSeries(data=df_test_mx_lstm['prob'])
ts_test_ar_resnet = ewstools.TimeSeries(data=df_test_ar_resnet['prob'])
ts_test_ar_lstm = ewstools.TimeSeries(data=df_test_ar_lstm['prob'])
ts_test_br_resnet = ewstools.TimeSeries(data=df_test_br_resnet['prob'])
ts_test_br_lstm = ewstools.TimeSeries(data=df_test_br_lstm['prob'])

ts_test_ca_resnet.detrend(method='Gaussian', bandwidth=20)
ts_test_ca_lstm.detrend(method='Gaussian', bandwidth=20)
ts_test_ny_resnet.detrend(method='Gaussian', bandwidth=20)
ts_test_ny_lstm.detrend(method='Gaussian', bandwidth=20)
ts_test_wa_resnet.detrend(method='Gaussian', bandwidth=20)
ts_test_wa_lstm.detrend(method='Gaussian', bandwidth=20)
ts_test_bc_resnet.detrend(method='Gaussian', bandwidth=20)
ts_test_bc_lstm.detrend(method='Gaussian', bandwidth=20)
ts_test_sg_resnet.detrend(method='Gaussian', bandwidth=20)
ts_test_sg_lstm.detrend(method='Gaussian', bandwidth=20)
ts_test_mx_resnet.detrend(method='Gaussian', bandwidth=20)
ts_test_mx_lstm.detrend(method='Gaussian', bandwidth=20)
ts_test_ar_resnet.detrend(method='Gaussian', bandwidth=20)
ts_test_ar_lstm.detrend(method='Gaussian', bandwidth=20)
ts_test_br_resnet.detrend(method='Gaussian', bandwidth=20)
ts_test_br_lstm.detrend(method='Gaussian', bandwidth=20)

In [None]:
fig2, ax2 = plt.subplots(nrows=4, ncols=4, figsize=(36,20))

ax2[0,0].plot(ca_count['Time'], ts_ca.state['state'], linewidth=2, color='xkcd:light periwinkle')
ax2[0,0].plot(ca_count['Time'], ts_ca.state['smoothing'], linewidth=3, color='xkcd:dark periwinkle')
ax2[0,1].plot(ny_count['Time'], ts_ny.state['state'], linewidth=2, color='xkcd:light periwinkle')
ax2[0,1].plot(ny_count['Time'], ts_ny.state['smoothing'], linewidth=3, color='xkcd:dark periwinkle')
ax2[0,2].plot(wa_count['Time'], ts_wa.state['state'], linewidth=2, color='xkcd:light periwinkle')
ax2[0,2].plot(wa_count['Time'], ts_wa.state['smoothing'], linewidth=3, color='xkcd:dark periwinkle')
ax2[0,3].plot(bc_count['Time'], ts_bc.state['state'], linewidth=2, color='xkcd:light periwinkle')
ax2[0,3].plot(bc_count['Time'], ts_bc.state['smoothing'], linewidth=3, color='xkcd:dark periwinkle')

ax2[0,0].axvline(x=pd.to_datetime('2014-12-27'), linestyle='--', linewidth=2, color='grey')
ax2[0,1].axvline(x=pd.to_datetime('2018-09-30'), linestyle='--', linewidth=2, color='grey')
ax2[0,2].axvline(x=pd.to_datetime('2018-12-29'), linestyle='--', linewidth=2, color='grey')
ax2[0,3].axvline(x=pd.to_datetime('2014-02-22'), linestyle='--', linewidth=2, color='grey')

ax2[0,0].text(np.array(ca_count['Time'])[30], 132, 'California, United States', size=30, color='black')
ax2[0,1].text(np.array(ny_count['Time'])[30], 132, 'New York, United States', size=30, color='black')
ax2[0,2].text(np.array(wa_count['Time'])[30], 132, 'Washington, United States', size=30, color='black')
ax2[0,3].text(np.array(bc_count['Time'])[30], 132, 'British Columbia, Canada', size=30, color='black')

ax2[0,0].text(pd.to_datetime('2013-12-15'), 153, 'Outbreak', size=28, color='grey')
ax2[0,0].annotate('', xy=(np.array(ca_count['Time'])[-1], 1.05), xytext=(pd.to_datetime('2014-12-27'), 1.05), xycoords=('data','axes fraction'), arrowprops=dict(arrowstyle="->"))
ax2[0,1].annotate('', xy=(np.array(ny_count['Time'])[-1], 1.05), xytext=(pd.to_datetime('2018-09-30'), 1.05), xycoords=('data','axes fraction'), arrowprops=dict(arrowstyle="->"))
ax2[0,2].annotate('', xy=(np.array(wa_count['Time'])[-1], 1.05), xytext=(pd.to_datetime('2018-12-29'), 1.05), xycoords=('data','axes fraction'), arrowprops=dict(arrowstyle="->"))
ax2[0,3].annotate('', xy=(np.array(bc_count['Time'])[-1], 1.05), xytext=(pd.to_datetime('2014-02-22'), 1.05), xycoords=('data','axes fraction'), arrowprops=dict(arrowstyle="->"))

ax2[1,0].plot(ca_count['Time'], ts_ca.ews['variance'], linewidth=2, color='mediumseagreen', label='Var')
ax2[1,0].annotate('', xy=(0, 0.05), xytext=(200/1549, 0.05), arrowprops=dict(arrowstyle="<->"), xycoords='axes fraction')
ax2[1,0].axvline(x=pd.to_datetime('2014-12-27'), linestyle='--', linewidth=2, color='grey')

ax2[1,1].plot(ny_count['Time'], ts_ny.ews['variance'], linewidth=2, color='mediumseagreen', label='Var')
ax2[1,1].annotate('', xy=(0, 0.05), xytext=(200/1473, 0.05), arrowprops=dict(arrowstyle="<->"), xycoords='axes fraction')
ax2[1,1].axvline(x=pd.to_datetime('2018-09-30'), linestyle='--', linewidth=2, color='grey')

ax2[1,2].plot(wa_count['Time'], ts_wa.ews['variance'], linewidth=2, color='mediumseagreen', label='Var')
ax2[1,2].annotate('', xy=(0, 0.05), xytext=(200/1415, 0.05), arrowprops=dict(arrowstyle="<->"), xycoords='axes fraction')
ax2[1,2].axvline(x=pd.to_datetime('2018-12-29'), linestyle='--', linewidth=2, color='grey')

ax2[1,3].plot(bc_count['Time'], ts_bc.ews['variance'], linewidth=2, color='mediumseagreen', label='Var')
ax2[1,3].annotate('', xy=(0, 0.05), xytext=(200/1427, 0.05), arrowprops=dict(arrowstyle="<->"), xycoords='axes fraction')
ax2[1,3].axvline(x=pd.to_datetime('2014-02-22'), linestyle='--', linewidth=2, color='grey')

ax2[2,0].plot(ca_count['Time'], ts_ca.ews['ac1'], linewidth=2, color='coral', label='AC')
ax2[2,0].annotate('', xy=(0, 0.05), xytext=(200/1549, 0.05), arrowprops=dict(arrowstyle="<->"), xycoords='axes fraction')
ax2[2,0].axvline(x=pd.to_datetime('2014-12-27'), linestyle='--', linewidth=2, color='grey')

ax2[2,1].plot(ny_count['Time'], ts_ny.ews['ac1'], linewidth=2, color='coral', label='AC')
ax2[2,1].annotate('', xy=(0, 0.05), xytext=(200/1473, 0.05), arrowprops=dict(arrowstyle="<->"), xycoords='axes fraction')
ax2[2,1].axvline(x=pd.to_datetime('2018-09-30'), linestyle='--', linewidth=2, color='grey')

ax2[2,2].plot(wa_count['Time'], ts_wa.ews['ac1'], linewidth=2, color='coral', label='AC')
ax2[2,2].annotate('', xy=(0, 0.05), xytext=(200/1415, 0.05), arrowprops=dict(arrowstyle="<->"), xycoords='axes fraction')
ax2[2,2].axvline(x=pd.to_datetime('2018-12-29'), linestyle='--', linewidth=2, color='grey')

ax2[2,3].plot(bc_count['Time'], ts_bc.ews['ac1'], linewidth=2, color='coral', label='AC')
ax2[2,3].annotate('', xy=(0, 0.05), xytext=(200/1427, 0.05), arrowprops=dict(arrowstyle="<->"), xycoords='axes fraction')
ax2[2,3].axvline(x=pd.to_datetime('2014-02-22'), linestyle='--', linewidth=2, color='grey')

ax2[3,0].plot(ts_test_ca_resnet['time'], ts_test_ca_resnet['smoothing'], linewidth=2, color='cornflowerblue', label='ResNet')
ax2[3,0].plot(ts_test_ca_resnet['time'], ts_test_ca_lstm['smoothing'], linewidth=2, color='crimson', label='LSTM')
ax2[3,0].annotate('', xy=(0, 0.05), xytext=(530/1549, 0.05), arrowprops=dict(arrowstyle="<->"), xycoords='axes fraction')
ax2[3,0].text(pd.to_datetime('2013-01-15'), 0.87, 'ResNet', size=28, color='cornflowerblue')
ax2[3,0].text(pd.to_datetime('2014-04-01'), 0.87, 'LSTM', size=28, color='crimson')
ax2[3,0].axvline(x=pd.to_datetime('2014-12-27'), linestyle='--', linewidth=2, color='grey')

ax2[3,1].plot(ts_test_ny_resnet['time'], ts_test_ny_resnet['smoothing'], linewidth=2, color='cornflowerblue', label='ResNet')
ax2[3,1].plot(ts_test_ny_resnet['time'], ts_test_ny_lstm['smoothing'], linewidth=2, color='crimson', label='LSTM')
ax2[3,1].annotate('', xy=(0, 0.05), xytext=(530/1473, 0.05), arrowprops=dict(arrowstyle="<->"), xycoords='axes fraction')
#ax2[3,1].legend(loc='upper left', fontsize=18)
ax2[3,1].axvline(x=pd.to_datetime('2018-09-30'), linestyle='--', linewidth=2, color='grey')

ax2[3,2].plot(ts_test_wa_resnet['time'], ts_test_wa_resnet['smoothing'], linewidth=2, color='cornflowerblue', label='ResNet')
ax2[3,2].plot(ts_test_wa_resnet['time'], ts_test_wa_lstm['smoothing'], linewidth=2, color='crimson', label='LSTM')
ax2[3,2].annotate('', xy=(0, 0.05), xytext=(530/1415, 0.05), arrowprops=dict(arrowstyle="<->"), xycoords='axes fraction')
ax2[3,2].axvline(x=pd.to_datetime('2018-12-29'), linestyle='--', linewidth=2, color='grey')

ax2[3,3].plot(ts_test_bc_resnet['time'], ts_test_bc_resnet['smoothing'], linewidth=2, color='cornflowerblue', label='ResNet')
ax2[3,3].plot(ts_test_bc_resnet['time'], ts_test_bc_lstm['smoothing'], linewidth=2, color='crimson', label='LSTM')
ax2[3,3].annotate('', xy=(0, 0.05), xytext=(530/1427, 0.05), arrowprops=dict(arrowstyle="<->"), xycoords='axes fraction')
ax2[3,3].axvline(x=pd.to_datetime('2014-02-22'), linestyle='--', linewidth=2, color='grey')

ax2[0,0].set_xlim(left=np.array(ca_count['Time'])[0], right=np.array(ca_count['Time'])[-1])
ax2[1,0].set_xlim(left=np.array(ca_count['Time'])[0], right=np.array(ca_count['Time'])[-1])
ax2[2,0].set_xlim(left=np.array(ca_count['Time'])[0], right=np.array(ca_count['Time'])[-1])
ax2[3,0].set_xlim(left=np.array(ca_count['Time'])[0], right=np.array(ca_count['Time'])[-1])
ax2[0,1].set_xlim(left=np.array(ny_count['Time'])[0], right=np.array(ny_count['Time'])[-1])
ax2[1,1].set_xlim(left=np.array(ny_count['Time'])[0], right=np.array(ny_count['Time'])[-1])
ax2[2,1].set_xlim(left=np.array(ny_count['Time'])[0], right=np.array(ny_count['Time'])[-1])
ax2[3,1].set_xlim(left=np.array(ny_count['Time'])[0], right=np.array(ny_count['Time'])[-1])
ax2[0,2].set_xlim(left=np.array(wa_count['Time'])[0], right=np.array(wa_count['Time'])[-1])
ax2[1,2].set_xlim(left=np.array(wa_count['Time'])[0], right=np.array(wa_count['Time'])[-1])
ax2[2,2].set_xlim(left=np.array(wa_count['Time'])[0], right=np.array(wa_count['Time'])[-1])
ax2[3,2].set_xlim(left=np.array(wa_count['Time'])[0], right=np.array(wa_count['Time'])[-1])
ax2[0,3].set_xlim(left=np.array(bc_count['Time'])[0], right=np.array(bc_count['Time'])[-1])
ax2[1,3].set_xlim(left=np.array(bc_count['Time'])[0], right=np.array(bc_count['Time'])[-1])
ax2[2,3].set_xlim(left=np.array(bc_count['Time'])[0], right=np.array(bc_count['Time'])[-1])
ax2[3,3].set_xlim(left=np.array(bc_count['Time'])[0], right=np.array(bc_count['Time'])[-1])

ax2[0,0].set_ylim(bottom=-3, top=150)
ax2[0,1].set_ylim(bottom=-3, top=150)
ax2[0,2].set_ylim(bottom=-3, top=150)
ax2[0,3].set_ylim(bottom=-3, top=150)
ax2[1,0].set_ylim(bottom=-6, top=600)
ax2[1,1].set_ylim(bottom=-6, top=600)
ax2[1,2].set_ylim(bottom=-6, top=600)
ax2[1,3].set_ylim(bottom=-0.6, top=60)
ax2[2,0].set_ylim(bottom=-0.2, top=1)
ax2[2,1].set_ylim(bottom=-0.2, top=1)
ax2[2,2].set_ylim(bottom=-0.2, top=1)
ax2[2,3].set_ylim(bottom=-0.2, top=1)
ax2[3,0].set_ylim(bottom=-0.05, top=1.05)
ax2[3,1].set_ylim(bottom=-0.05, top=1.05)
ax2[3,2].set_ylim(bottom=-0.05, top=1.05)
ax2[3,3].set_ylim(bottom=-0.05, top=1.05)

ax2[3,0].set_xlabel('Time', fontsize=36)
ax2[3,1].set_xlabel('Time', fontsize=36)
ax2[3,2].set_xlabel('Time', fontsize=36)
ax2[3,3].set_xlabel('Time', fontsize=36)
ax2[0,0].set_ylabel('Pro-vaccine\nTweets', fontsize=36)
ax2[1,0].set_ylabel('Variance', fontsize=36)
ax2[2,0].set_ylabel('Lag-1 AC', fontsize=36)
ax2[3,0].set_ylabel('Probability', fontsize=36)

for ax in ax2.flatten():
    ax.tick_params(axis='both', which='major', labelsize=30)
    ax.xaxis.set_major_locator(mdates.YearLocator())
    ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y'))

ax2[0,0].set_xticklabels([])
ax2[0,1].set_xticklabels([])
ax2[0,2].set_xticklabels([])
ax2[0,3].set_xticklabels([])
ax2[1,0].set_xticklabels([])
ax2[1,1].set_xticklabels([])
ax2[1,2].set_xticklabels([])
ax2[1,3].set_xticklabels([])
ax2[2,0].set_xticklabels([])
ax2[2,1].set_xticklabels([])
ax2[2,2].set_xticklabels([])
ax2[2,3].set_xticklabels([])

ax2[0,0].text(0, 1.05, 'A', transform=ax2[0,0].transAxes, size=40, weight='bold')
ax2[0,1].text(0, 1.05, 'B', transform=ax2[0,1].transAxes, size=40, weight='bold')
ax2[0,2].text(0, 1.05, 'C', transform=ax2[0,2].transAxes, size=40, weight='bold')
ax2[0,3].text(0, 1.05, 'D', transform=ax2[0,3].transAxes, size=40, weight='bold')
ax2[1,0].text(0, 1.05, 'E', transform=ax2[1,0].transAxes, size=40, weight='bold')
ax2[1,1].text(0, 1.05, 'F', transform=ax2[1,1].transAxes, size=40, weight='bold')
ax2[1,2].text(0, 1.05, 'G', transform=ax2[1,2].transAxes, size=40, weight='bold')
ax2[1,3].text(0, 1.05, 'H', transform=ax2[1,3].transAxes, size=40, weight='bold')
ax2[2,0].text(0, 1.05, 'I', transform=ax2[2,0].transAxes, size=40, weight='bold')
ax2[2,1].text(0, 1.05, 'J', transform=ax2[2,1].transAxes, size=40, weight='bold')
ax2[2,2].text(0, 1.05, 'K', transform=ax2[2,2].transAxes, size=40, weight='bold')
ax2[2,3].text(0, 1.05, 'L', transform=ax2[2,3].transAxes, size=40, weight='bold')
ax2[3,0].text(0, 1.05, 'M', transform=ax2[3,0].transAxes, size=40, weight='bold')
ax2[3,1].text(0, 1.05, 'N', transform=ax2[3,1].transAxes, size=40, weight='bold')
ax2[3,2].text(0, 1.05, 'O', transform=ax2[3,2].transAxes, size=40, weight='bold')
ax2[3,3].text(0, 1.05, 'P', transform=ax2[3,3].transAxes, size=40, weight='bold')

fig2.tight_layout()
plt.savefig('empirical_outbreaks.png', dpi=300)
plt.savefig('empirical_outbreaks.eps', format='eps', bbox_inches='tight')

In [None]:
fig2, ax2 = plt.subplots(nrows=4, ncols=4, figsize=(36,20))

ax2[0,0].plot(sg_count['Time'], ts_sg.state['state'], linewidth=2, color='xkcd:light periwinkle')
ax2[0,0].plot(sg_count['Time'], ts_sg.state['smoothing'], linewidth=3, color='xkcd:dark periwinkle')
ax2[0,1].plot(mx_count['Time'], ts_mx.state['state'], linewidth=2, color='xkcd:light periwinkle')
ax2[0,1].plot(mx_count['Time'], ts_mx.state['smoothing'], linewidth=3, color='xkcd:dark periwinkle')
ax2[0,2].plot(ar_count['Time'], ts_ar.state['state'], linewidth=2, color='xkcd:light periwinkle')
ax2[0,2].plot(ar_count['Time'], ts_ar.state['smoothing'], linewidth=3, color='xkcd:dark periwinkle')
ax2[0,3].plot(br_count['Time'], ts_br.state['state'], linewidth=2, color='xkcd:light periwinkle')
ax2[0,3].plot(br_count['Time'], ts_br.state['smoothing'], linewidth=3, color='xkcd:dark periwinkle')

ax2[0,0].text(np.array(sg_count['Time'])[30], 44, 'Singapore', size=30, color='black')
ax2[0,1].text(np.array(mx_count['Time'])[30], 44, 'Mexico', size=30, color='black')
ax2[0,2].text(np.array(ar_count['Time'])[30], 44, 'Argentina', size=30, color='black')
ax2[0,3].text(np.array(br_count['Time'])[30], 44, 'Brazil', size=30, color='black')

ax2[1,0].plot(sg_count['Time'], ts_sg.ews['variance'], linewidth=2, color='mediumseagreen', label='Var')
ax2[1,0].annotate('', xy=(0, 0.05), xytext=(200/1461, 0.05), arrowprops=dict(arrowstyle="<->"), xycoords='axes fraction')

ax2[1,1].plot(mx_count['Time'], ts_mx.ews['variance'], linewidth=2, color='mediumseagreen', label='Var')
ax2[1,1].annotate('', xy=(0, 0.05), xytext=(200/1461, 0.05), arrowprops=dict(arrowstyle="<->"), xycoords='axes fraction')

ax2[1,2].plot(ar_count['Time'], ts_ar.ews['variance'], linewidth=2, color='mediumseagreen', label='Var')
ax2[1,2].annotate('', xy=(0, 0.05), xytext=(200/1461, 0.05), arrowprops=dict(arrowstyle="<->"), xycoords='axes fraction')

ax2[1,3].plot(br_count['Time'], ts_br.ews['variance'], linewidth=2, color='mediumseagreen', label='Var')
ax2[1,3].annotate('', xy=(0, 0.05), xytext=(200/1461, 0.05), arrowprops=dict(arrowstyle="<->"), xycoords='axes fraction')

ax2[2,0].plot(sg_count['Time'], ts_sg.ews['ac1'], linewidth=2, color='coral', label='AC')
ax2[2,0].annotate('', xy=(0, 0.05), xytext=(200/1461, 0.05), arrowprops=dict(arrowstyle="<->"), xycoords='axes fraction')

ax2[2,1].plot(mx_count['Time'], ts_mx.ews['ac1'], linewidth=2, color='coral', label='AC')
ax2[2,1].annotate('', xy=(0, 0.05), xytext=(200/1461, 0.05), arrowprops=dict(arrowstyle="<->"), xycoords='axes fraction')

ax2[2,2].plot(ar_count['Time'], ts_ar.ews['ac1'], linewidth=2, color='coral', label='AC')
ax2[2,2].annotate('', xy=(0, 0.05), xytext=(200/1461, 0.05), arrowprops=dict(arrowstyle="<->"), xycoords='axes fraction')

ax2[2,3].plot(br_count['Time'], ts_br.ews['ac1'], linewidth=2, color='coral', label='AC')
ax2[2,3].annotate('', xy=(0, 0.05), xytext=(200/1461, 0.05), arrowprops=dict(arrowstyle="<->"), xycoords='axes fraction')

ax2[3,0].plot(ts_test_sg_resnet['time'], ts_test_sg_resnet['smoothing'], linewidth=2, color='cornflowerblue', label='ResNet')
ax2[3,0].plot(ts_test_sg_resnet['time'], ts_test_sg_lstm['smoothing'], linewidth=2, color='crimson', label='LSTM')
ax2[3,0].annotate('', xy=(0, 0.05), xytext=(530/1461, 0.05), arrowprops=dict(arrowstyle="<->"), xycoords='axes fraction')
ax2[3,0].text(pd.to_datetime('2016-06-01'), 0.05, 'ResNet', size=28, color='cornflowerblue')
ax2[3,0].text(pd.to_datetime('2017-04-01'), 0.05, 'LSTM', size=28, color='crimson')

ax2[3,1].plot(ts_test_mx_resnet['time'], ts_test_mx_resnet['smoothing'], linewidth=2, color='cornflowerblue', label='ResNet')
ax2[3,1].plot(ts_test_mx_resnet['time'], ts_test_mx_lstm['smoothing'], linewidth=2, color='crimson', label='LSTM')
ax2[3,1].annotate('', xy=(0, 0.05), xytext=(530/1461, 0.05), arrowprops=dict(arrowstyle="<->"), xycoords='axes fraction')

ax2[3,2].plot(ts_test_ar_resnet['time'], ts_test_ar_resnet['smoothing'], linewidth=2, color='cornflowerblue', label='ResNet')
ax2[3,2].plot(ts_test_ar_resnet['time'], ts_test_ar_lstm['smoothing'], linewidth=2, color='crimson', label='LSTM')
ax2[3,2].annotate('', xy=(0, 0.05), xytext=(530/1461, 0.05), arrowprops=dict(arrowstyle="<->"), xycoords='axes fraction')

ax2[3,3].plot(ts_test_br_resnet['time'], ts_test_br_resnet['smoothing'], linewidth=2, color='cornflowerblue', label='ResNet')
ax2[3,3].plot(ts_test_br_resnet['time'], ts_test_br_lstm['smoothing'], linewidth=2, color='crimson', label='LSTM')
ax2[3,3].annotate('', xy=(0, 0.05), xytext=(530/1461, 0.05), arrowprops=dict(arrowstyle="<->"), xycoords='axes fraction')

ax2[0,0].set_xlim(left=np.array(sg_count['Time'])[0], right=np.array(sg_count['Time'])[-1])
ax2[1,0].set_xlim(left=np.array(sg_count['Time'])[0], right=np.array(sg_count['Time'])[-1])
ax2[2,0].set_xlim(left=np.array(sg_count['Time'])[0], right=np.array(sg_count['Time'])[-1])
ax2[3,0].set_xlim(left=np.array(sg_count['Time'])[0], right=np.array(sg_count['Time'])[-1])
ax2[0,1].set_xlim(left=np.array(mx_count['Time'])[0], right=np.array(mx_count['Time'])[-1])
ax2[1,1].set_xlim(left=np.array(mx_count['Time'])[0], right=np.array(mx_count['Time'])[-1])
ax2[2,1].set_xlim(left=np.array(mx_count['Time'])[0], right=np.array(mx_count['Time'])[-1])
ax2[3,1].set_xlim(left=np.array(mx_count['Time'])[0], right=np.array(mx_count['Time'])[-1])
ax2[0,2].set_xlim(left=np.array(ar_count['Time'])[0], right=np.array(ar_count['Time'])[-1])
ax2[1,2].set_xlim(left=np.array(ar_count['Time'])[0], right=np.array(ar_count['Time'])[-1])
ax2[2,2].set_xlim(left=np.array(ar_count['Time'])[0], right=np.array(ar_count['Time'])[-1])
ax2[3,2].set_xlim(left=np.array(ar_count['Time'])[0], right=np.array(ar_count['Time'])[-1])
ax2[0,3].set_xlim(left=np.array(br_count['Time'])[0], right=np.array(br_count['Time'])[-1])
ax2[1,3].set_xlim(left=np.array(br_count['Time'])[0], right=np.array(br_count['Time'])[-1])
ax2[2,3].set_xlim(left=np.array(br_count['Time'])[0], right=np.array(br_count['Time'])[-1])
ax2[3,3].set_xlim(left=np.array(br_count['Time'])[0], right=np.array(br_count['Time'])[-1])

ax2[0,0].set_ylim(bottom=-1, top=50)
ax2[0,1].set_ylim(bottom=-1, top=50)
ax2[0,2].set_ylim(bottom=-1, top=50)
ax2[0,3].set_ylim(bottom=-1, top=50)
ax2[1,0].set_ylim(bottom=-0.4, top=40)
ax2[1,1].set_ylim(bottom=-0.4, top=40)
ax2[1,2].set_ylim(bottom=-0.4, top=40)
ax2[1,3].set_ylim(bottom=-0.4, top=40)
ax2[2,0].set_ylim(bottom=-0.2, top=1)
ax2[2,1].set_ylim(bottom=-0.2, top=1)
ax2[2,2].set_ylim(bottom=-0.2, top=1)
ax2[2,3].set_ylim(bottom=-0.2, top=1)
ax2[3,0].set_ylim(bottom=-0.05, top=1.05)
ax2[3,1].set_ylim(bottom=-0.05, top=1.05)
ax2[3,2].set_ylim(bottom=-0.05, top=1.05)
ax2[3,3].set_ylim(bottom=-0.05, top=1.05)

ax2[3,0].set_xlabel('Time', fontsize=36)
ax2[3,1].set_xlabel('Time', fontsize=36)
ax2[3,2].set_xlabel('Time', fontsize=36)
ax2[3,3].set_xlabel('Time', fontsize=36)
ax2[0,0].set_ylabel('Pro-vaccine\nTweets', fontsize=36)
ax2[1,0].set_ylabel('Variance', fontsize=36)
ax2[2,0].set_ylabel('Lag-1 AC', fontsize=36)
ax2[3,0].set_ylabel('Probability', fontsize=36)

for ax in ax2.flatten():
    ax.tick_params(axis='both', which='major', labelsize=30)
    ax.xaxis.set_major_locator(mdates.YearLocator())
    ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y'))

ax2[0,0].set_xticklabels([])
ax2[0,1].set_xticklabels([])
ax2[0,2].set_xticklabels([])
ax2[0,3].set_xticklabels([])
ax2[1,0].set_xticklabels([])
ax2[1,1].set_xticklabels([])
ax2[1,2].set_xticklabels([])
ax2[1,3].set_xticklabels([])
ax2[2,0].set_xticklabels([])
ax2[2,1].set_xticklabels([])
ax2[2,2].set_xticklabels([])
ax2[2,3].set_xticklabels([])

ax2[0,0].text(0, 1.05, 'A', transform=ax2[0,0].transAxes, size=40, weight='bold')
ax2[0,1].text(0, 1.05, 'B', transform=ax2[0,1].transAxes, size=40, weight='bold')
ax2[0,2].text(0, 1.05, 'C', transform=ax2[0,2].transAxes, size=40, weight='bold')
ax2[0,3].text(0, 1.05, 'D', transform=ax2[0,3].transAxes, size=40, weight='bold')
ax2[1,0].text(0, 1.05, 'E', transform=ax2[1,0].transAxes, size=40, weight='bold')
ax2[1,1].text(0, 1.05, 'F', transform=ax2[1,1].transAxes, size=40, weight='bold')
ax2[1,2].text(0, 1.05, 'G', transform=ax2[1,2].transAxes, size=40, weight='bold')
ax2[1,3].text(0, 1.05, 'H', transform=ax2[1,3].transAxes, size=40, weight='bold')
ax2[2,0].text(0, 1.05, 'I', transform=ax2[2,0].transAxes, size=40, weight='bold')
ax2[2,1].text(0, 1.05, 'J', transform=ax2[2,1].transAxes, size=40, weight='bold')
ax2[2,2].text(0, 1.05, 'K', transform=ax2[2,2].transAxes, size=40, weight='bold')
ax2[2,3].text(0, 1.05, 'L', transform=ax2[2,3].transAxes, size=40, weight='bold')
ax2[3,0].text(0, 1.05, 'M', transform=ax2[3,0].transAxes, size=40, weight='bold')
ax2[3,1].text(0, 1.05, 'N', transform=ax2[3,1].transAxes, size=40, weight='bold')
ax2[3,2].text(0, 1.05, 'O', transform=ax2[3,2].transAxes, size=40, weight='bold')
ax2[3,3].text(0, 1.05, 'P', transform=ax2[3,3].transAxes, size=40, weight='bold')

fig2.tight_layout()
plt.savefig('empirical_neutral.png', dpi=300)
plt.savefig('empirical_neutral.eps', format='eps', bbox_inches='tight')