In [None]:
import pandas as pd
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import matplotlib.ticker as ticker
import ewstools
from ewstools.models import simulate_ricker
import datetime
import time
from pylab import cm
from sklearn import metrics

In [None]:
print(tf.__version__)

In [None]:
print("Num GPUs available:", len(tf.config.experimental.list_physical_devices('GPU')))

In [None]:
model_resnet = tf.keras.models.load_model('resnet_300ep_1024batch.h5')
model_lstm = tf.keras.models.load_model('lstm_300ep_1024batch.h5')

In [None]:
def reshape_inputs(df1, df2, n_features=2):
    # Ensure df1 and df2 are in the correct shape and convert them to NumPy arrays
    arr1 = df1.values.T.reshape(-1, df1.shape[0], 1)  # Shape: (n_samples, n_time_steps, 1)
    arr2 = df2.values.T.reshape(-1, df2.shape[0], 1)  # Shape: (n_samples, n_time_steps, 1)
    
    # Stack the arrays along the last dimension to combine the features
    X = np.concatenate((arr1, arr2), axis=2)  # Shape: (n_samples, n_time_steps, n_features)
    
    return X

In [None]:
# Test files
# Use sde_simulation.ipynb to generated samples, rename the data as follows and save them in the folder 'test_data'
#df_x_ts_test = pd.read_csv('./test_data/test_x_ts.csv')
#df_x_sm_test = pd.read_csv('./test_data/test_x_sm.csv')
#df_x_var_test = pd.read_csv('./test_data/test_x_var.csv')
#df_x_ac_test = pd.read_csv('./test_data/test_x_ac.csv')
#df_tw_ts_test = pd.read_csv('./test_data/test_tw_ts.csv')
#df_tw_sm_test = pd.read_csv('./test_data/test_tw_sm.csv')
#df_tw_res_test = pd.read_csv('./test_data/test_tw_res.csv')
#df_tw_var_test = pd.read_csv('./test_data/test_tw_var.csv')
#df_tw_ac_test = pd.read_csv('./test_data/test_tw_ac.csv')
#df_pt_test = pd.read_csv('./test_data/test_pt.csv')
#df_lb_test = pd.read_csv('./test_data/test_lb.csv')

# We provide some samples for demonstration
df_x_ts_test = pd.read_csv('./demo_data/test_x_ts_demo.csv')
df_x_sm_test = pd.read_csv('./demo_data/test_x_sm_demo.csv')
df_x_var_test = pd.read_csv('./demo_data/test_x_var_demo.csv')
df_x_ac_test = pd.read_csv('./demo_data/test_x_ac_demo.csv')
df_tw_ts_test = pd.read_csv('./demo_data/test_tw_ts_demo.csv')
df_tw_sm_test = pd.read_csv('./demo_data/test_tw_sm_demo.csv')
df_tw_res_test = pd.read_csv('./demo_data/test_tw_res_demo.csv')
df_tw_var_test = pd.read_csv('./demo_data/test_tw_var_demo.csv')
df_tw_ac_test = pd.read_csv('./demo_data/test_tw_ac_demo.csv')
df_pt_test = pd.read_csv('./demo_data/test_pt_demo.csv')
df_lb_test = pd.read_csv('./demo_data/test_lb_demo.csv')

In [None]:
df_x_ts_test.columns = range(df_x_ts_test.shape[1])
df_x_sm_test.columns = range(df_x_sm_test.shape[1])
df_x_var_test.columns = range(df_x_var_test.shape[1])
df_x_ac_test.columns = range(df_x_ac_test.shape[1])
df_tw_ts_test.columns = range(df_tw_ts_test.shape[1])
df_tw_sm_test.columns = range(df_tw_sm_test.shape[1])
df_tw_res_test.columns = range(df_tw_res_test.shape[1])
df_tw_var_test.columns = range(df_tw_var_test.shape[1])
df_tw_ac_test.columns = range(df_tw_ac_test.shape[1])

In [None]:
pos_list = []
neg_list = []

In [None]:
# Choose positive samples and negative samples for testing
num_sample = 100
i, j = 0, 0
for index in range(len(df_lb_test)):
    lb = df_lb_test['label'][index]
    if lb == 1 and len(pos_list) < num_sample:
        pos_list.append(index)
        i += 1
    elif lb == 0 and len(neg_list) < num_sample:
        neg_list.append(index)
        j += 1

In [None]:
prob_list_1_resnet = []
prob_list_0_resnet = []
prob_list_1_lstm = []
prob_list_0_lstm = []

In [None]:
# Positive
for index in pos_list:
    trans = df_pt_test['trans'][index]
    inputs_test = reshape_inputs(df_tw_res_test[index][trans-500:trans], df_tw_sm_test[index][trans-500:trans])
    predict_resnet = model_resnet.predict(inputs_test, verbose=0)
    predict_lstm = model_lstm.predict(inputs_test, verbose=0)
    prob_list_1_resnet.append(predict_resnet[0][0])
    prob_list_1_lstm.append(predict_lstm[0][0])

In [None]:
# Negative
for index in neg_list:
    inputs_test = reshape_inputs(df_tw_res_test[index][-501:-1], df_tw_sm_test[index][-501:-1])
    predict_resnet = model_resnet.predict(inputs_test, verbose=0)
    predict_lstm = model_lstm.predict(inputs_test, verbose=0)
    prob_list_0_resnet.append(predict_resnet[0][0])
    prob_list_0_lstm.append(predict_lstm[0][0])

In [None]:
print(str(sum([i > 0.5 for i in prob_list_1_resnet])) + ' out of ' + str(num_sample) + ' positive samples are correctly predicted by ResNet.')
print(str(sum([i > 0.5 for i in prob_list_1_lstm])) + ' out of ' + str(num_sample) + ' positive samples are correctly predicted by LSTM.')
print(str(sum([i < 0.5 for i in prob_list_0_resnet])) + ' out of ' + str(num_sample) + ' negative samples are correctly predicted by ResNet.')
print(str(sum([i < 0.5 for i in prob_list_0_lstm])) + ' out of ' + str(num_sample) + ' negative samples are correctly predicted by LSTM.')

In [None]:
prob_list_resnet = {}
prob_list_lstm = {}

In [None]:
# Generate outputs by DL classifiers
# (This could take A WHILE!)
for j in range(df_tw_res_test.shape[1]):
    trans = df_pt_test['trans'][j]
    if trans != -1:
        last = trans  # trans/trans-100/trans-200/trans-400
    else:
        last = df_tw_res_test.shape[0]
    end = last-199
    prob_list_resnet[j] = []
    prob_list_lstm[j] = []
    while end <= last:
        inputs_test = reshape_inputs(df_tw_res_test[j][end-500:end], df_tw_sm_test[j][end-500:end])
        predict_resnet = model_resnet.predict(inputs_test, verbose=0)
        predict_lstm = model_lstm.predict(inputs_test, verbose=0)
        prob_list_resnet[j].append(predict_resnet[0][0])
        prob_list_lstm[j].append(predict_lstm[0][0])
        end += 1

In [None]:
df_prob_list_resnet = pd.DataFrame(prob_list_resnet)
df_prob_list_lstm = pd.DataFrame(prob_list_lstm)

### Demonstration of applying DL models on theoretical data

In [None]:
demo_1_1 = 0  # index of selected positive sample in df_tw_res_test or df_tw_sm_test
trans = df_pt_test['trans'][demo_1_1]
end = 0  # not inclusive
prob_list_resnet_demo_1_1 = []
prob_list_lstm_demo_1_1 = []
while end <= trans:
    if end < 500:
        prob_list_resnet_demo_1_1.append(np.nan)
        prob_list_lstm_demo_1_1.append(np.nan)
    else:
        inputs_test = reshape_inputs(df_tw_res_test[demo_1_1][end-500:end], df_tw_sm_test[demo_1_1][end-500:end])
        predict_resnet = model_resnet.predict(inputs_test, verbose=0)
        predict_lstm = model_lstm.predict(inputs_test, verbose=0)
        prob_list_resnet_demo_1_1.append(predict_resnet[0][0])
        prob_list_lstm_demo_1_1.append(predict_lstm[0][0])
    end += 1

In [None]:
demo_1_2 = 1  # index of another selected positive sample
trans = df_pt_test['trans'][demo_1_2]
end = 0  # not inclusive
prob_list_resnet_demo_1_2 = []
prob_list_lstm_demo_1_2 = []
while end <= trans:
    if end < 500:
        prob_list_resnet_demo_1_2.append(np.nan)
        prob_list_lstm_demo_1_2.append(np.nan)
    else:
        inputs_test = reshape_inputs(df_tw_res_test[demo_1_2][end-500:end], df_tw_sm_test[demo_1_2][end-500:end])
        predict_resnet = model_resnet.predict(inputs_test, verbose=0)
        predict_lstm = model_lstm.predict(inputs_test, verbose=0)
        prob_list_resnet_demo_1_2.append(predict_resnet[0][0])
        prob_list_lstm_demo_1_2.append(predict_lstm[0][0])
    end += 1

In [None]:
demo_0_1 = 100  # index of selected negative sample in df_tw_res_test or df_tw_sm_test
trans = 4001
end = 0  # not inclusive
prob_list_resnet_demo_0_1 = []
prob_list_lstm_demo_0_1 = []
while end <= trans:
    if end < 500:
        prob_list_resnet_demo_0_1.append(np.nan)
        prob_list_lstm_demo_0_1.append(np.nan)
    else:
        inputs_test = reshape_inputs(df_tw_res_test[demo_0_1][end-500:end], df_tw_sm_test[demo_0_1][end-500:end])
        predict_resnet = model_resnet.predict(inputs_test, verbose=0)
        predict_lstm = model_lstm.predict(inputs_test, verbose=0)
        prob_list_resnet_demo_0_1.append(predict_resnet[0][0])
        prob_list_lstm_demo_0_1.append(predict_lstm[0][0])
    end += 1

In [None]:
demo_0_2 = 101  # index of another selected negative sample
trans = 4001
end = 0  # not inclusive
prob_list_resnet_demo_0_2 = []
prob_list_lstm_demo_0_2 = []
while end <= trans:
    if end < 500:
        prob_list_resnet_demo_0_2.append(np.nan)
        prob_list_lstm_demo_0_2.append(np.nan)
    else:
        inputs_test = reshape_inputs(df_tw_res_test[demo_0_2][end-500:end], df_tw_sm_test[demo_0_2][end-500:end])
        predict_resnet = model_resnet.predict(inputs_test, verbose=0)
        predict_lstm = model_lstm.predict(inputs_test, verbose=0)
        prob_list_resnet_demo_0_2.append(predict_resnet[0][0])
        prob_list_lstm_demo_0_2.append(predict_lstm[0][0])
    end += 1

In [None]:
# Detrend DL probabilities
resnet_demo_1_1 = ewstools.TimeSeries(data=prob_list_resnet_demo_1_1)
lstm_demo_1_1 = ewstools.TimeSeries(data=prob_list_lstm_demo_1_1)
resnet_demo_1_1.detrend(method='Gaussian', bandwidth=20)
lstm_demo_1_1.detrend(method='Gaussian', bandwidth=20)
resnet_demo_1_2 = ewstools.TimeSeries(data=prob_list_resnet_demo_1_2)
lstm_demo_1_2 = ewstools.TimeSeries(data=prob_list_lstm_demo_1_2)
resnet_demo_1_2.detrend(method='Gaussian', bandwidth=20)
lstm_demo_1_2.detrend(method='Gaussian', bandwidth=20)
resnet_demo_0_1 = ewstools.TimeSeries(data=prob_list_resnet_demo_0_1)
lstm_demo_0_1 = ewstools.TimeSeries(data=prob_list_lstm_demo_0_1)
resnet_demo_0_1.detrend(method='Gaussian', bandwidth=20)
lstm_demo_0_1.detrend(method='Gaussian', bandwidth=20)
resnet_demo_0_2 = ewstools.TimeSeries(data=prob_list_resnet_demo_0_2)
lstm_demo_0_2 = ewstools.TimeSeries(data=prob_list_lstm_demo_0_2)
resnet_demo_0_2.detrend(method='Gaussian', bandwidth=20)
lstm_demo_0_2.detrend(method='Gaussian', bandwidth=20)

In [None]:
fig, ax = plt.subplots(nrows=3, ncols=4, figsize=(36,15))

ax[0,0].plot(df_x_ts_test[demo_1_1], linewidth=2, color='xkcd:light periwinkle')  # plt.cm.Purples(64)
ax[0,0].plot(df_x_sm_test[demo_1_1], linewidth=3, color='xkcd:dark periwinkle')  # plt.cm.Purples(192)
ax[0,0].set_xlim(left=0, right=4000)
ax[0,0].set_ylim(bottom=0.96, top=1.001)
ax[0,0].set_xticklabels([])
ax[0,0].set_ylabel('Pro-vaccine %', fontsize=36)
ax[0,0].axvline(x=df_pt_test['trans'][demo_1_1], linestyle='--', linewidth=2, color='grey')
ax[0,0].text(2560, 1.0018, 'Transition', size=28, color='grey')
ax[0,0].text(0, 1.05, 'A', transform=ax[0,0].transAxes, size=40, weight='bold')

ax[0,1].plot(df_x_ts_test[demo_1_2], linewidth=2, color='xkcd:light periwinkle')
ax[0,1].plot(df_x_sm_test[demo_1_2], linewidth=3, color='xkcd:dark periwinkle')
ax[0,1].set_xlim(left=0, right=4000)
ax[0,1].set_ylim(bottom=0.96, top=1.001)
ax[0,1].set_xticklabels([])
ax[0,1].axvline(x=df_pt_test['trans'][demo_1_2], linestyle='--', linewidth=2, color='grey')
ax[0,1].text(0, 1.05, 'B', transform=ax[0,1].transAxes, size=40, weight='bold')

ax[0,2].plot(df_x_ts_test[demo_0_1], linewidth=2, color='xkcd:light periwinkle')
ax[0,2].plot(df_x_sm_test[demo_0_1], linewidth=3, color='xkcd:dark periwinkle')
ax[0,2].set_xlim(left=0, right=4000)
ax[0,2].set_ylim(bottom=0.96, top=1.001)
ax[0,2].set_xticklabels([])
ax[0,2].text(0, 1.05, 'C', transform=ax[0,2].transAxes, size=40, weight='bold')

ax[0,3].plot(df_x_ts_test[demo_0_2], linewidth=2, color='xkcd:light periwinkle')
ax[0,3].plot(df_x_sm_test[demo_0_2], linewidth=3, color='xkcd:dark periwinkle')
ax[0,3].set_xlim(left=0, right=4000)
ax[0,3].set_ylim(bottom=0.96, top=1.001)
ax[0,3].set_xticklabels([])
ax[0,3].text(0, 1.05, 'D', transform=ax[0,3].transAxes, size=40, weight='bold')

ax[1,0].plot(df_x_var_test[demo_1_1], linewidth=2, color='mediumseagreen', label='Var')
ax[1,0].set_xlim(left=0, right=4000)
ax[1,0].set_ylim(bottom=-1e-7, top=1e-5)
ax[1,0].set_ylabel('Variance', fontsize=36)
ax[1,0].yaxis.get_offset_text().set_fontsize(24)
ax[1,0].set_xticklabels([])
ax[1,0].axvline(x=df_pt_test['trans'][demo_1_1], linestyle='--', linewidth=2, color='grey')
ax[1,0].annotate('', xy=(0, 0.05), xytext=(0.05, 0.05), arrowprops=dict(arrowstyle="<->"), xycoords='axes fraction')
ax[1,0].text(0, 1.11, 'E', transform=ax[1,0].transAxes, size=40, weight='bold')

ax[1,1].plot(df_x_var_test[demo_1_2], linewidth=2, color='mediumseagreen', label='Var')
ax[1,1].set_xlim(left=0, right=4000)
ax[1,1].set_ylim(bottom=-1e-7, top=1e-5)
ax[1,1].yaxis.get_offset_text().set_fontsize(24)
ax[1,1].set_xticklabels([])
ax[1,1].axvline(x=df_pt_test['trans'][demo_1_2], linestyle='--', linewidth=2, color='grey')
ax[1,1].annotate('', xy=(0, 0.05), xytext=(0.05, 0.05), arrowprops=dict(arrowstyle="<->"), xycoords='axes fraction')
ax[1,1].text(0, 1.11, 'F', transform=ax[1,1].transAxes, size=40, weight='bold')

ax[1,2].plot(df_x_var_test[demo_0_1], linewidth=2, color='mediumseagreen', label='Var')
ax[1,2].set_xlim(left=0, right=4000)
ax[1,2].set_ylim(bottom=-1e-7, top=1e-5)
ax[1,2].yaxis.get_offset_text().set_fontsize(24)
ax[1,2].set_xticklabels([])
ax[1,2].annotate('', xy=(0, 0.05), xytext=(0.05, 0.05), arrowprops=dict(arrowstyle="<->"), xycoords='axes fraction')
ax[1,2].text(0, 1.11, 'G', transform=ax[1,2].transAxes, size=40, weight='bold')

ax[1,3].plot(df_x_var_test[demo_0_2], linewidth=2, color='mediumseagreen', label='Var')
ax[1,3].set_xlim(left=0, right=4000)
ax[1,3].set_ylim(bottom=-1e-7, top=1e-5)
ax[1,3].yaxis.get_offset_text().set_fontsize(24)
ax[1,3].set_xticklabels([])
ax[1,3].annotate('', xy=(0, 0.05), xytext=(0.05, 0.05), arrowprops=dict(arrowstyle="<->"), xycoords='axes fraction')
ax[1,3].text(0, 1.11, 'H', transform=ax[1,3].transAxes, size=40, weight='bold')

ax[2,0].plot(df_x_ac_test[demo_1_1], linewidth=2, color='coral', label='AC')
ax[2,0].set_xlim(left=0, right=4000)
ax[2,0].set_ylim(bottom=0.75, top=1.005)
ax[2,0].set_xlabel('Time', fontsize=36)
ax[2,0].set_ylabel('Lag-1 AC', fontsize=36)
ax[2,0].axvline(x=df_pt_test['trans'][demo_1_1], linestyle='--', linewidth=2, color='grey')
ax[2,0].annotate('', xy=(0, 0.05), xytext=(0.05, 0.05), arrowprops=dict(arrowstyle="<->"), xycoords='axes fraction')
ax[2,0].text(0, 1.05, 'I', transform=ax[2,0].transAxes, size=40, weight='bold')

ax[2,1].plot(df_x_ac_test[demo_1_2], linewidth=2, color='coral', label='AC')
ax[2,1].set_xlim(left=0, right=4000)
ax[2,1].set_ylim(bottom=0.75, top=1.005)
ax[2,1].set_xlabel('Time', fontsize=36)
ax[2,1].axvline(x=df_pt_test['trans'][demo_1_2], linestyle='--', linewidth=2, color='grey')
ax[2,1].annotate('', xy=(0, 0.05), xytext=(0.05, 0.05), arrowprops=dict(arrowstyle="<->"), xycoords='axes fraction')
ax[2,1].text(0, 1.05, 'J', transform=ax[2,1].transAxes, size=40, weight='bold')

ax[2,2].plot(df_x_ac_test[demo_0_1], linewidth=2, color='coral', label='AC')
ax[2,2].set_xlim(left=0, right=4000)
ax[2,2].set_ylim(bottom=0.75, top=1.005)
ax[2,2].set_xlabel('Time', fontsize=36)
ax[2,2].annotate('', xy=(0, 0.05), xytext=(0.05, 0.05), arrowprops=dict(arrowstyle="<->"), xycoords='axes fraction')
ax[2,2].text(0, 1.05, 'K', transform=ax[2,2].transAxes, size=40, weight='bold')

ax[2,3].plot(df_x_ac_test[demo_0_2], linewidth=2, color='coral', label='AC')
ax[2,3].set_xlim(left=0, right=4000)
ax[2,3].set_ylim(bottom=0.75, top=1.005)
ax[2,3].set_xlabel('Time', fontsize=36)
ax[2,3].annotate('', xy=(0, 0.05), xytext=(0.05, 0.05), arrowprops=dict(arrowstyle="<->"), xycoords='axes fraction')
ax[2,3].text(0, 1.05, 'L', transform=ax[2,3].transAxes, size=40, weight='bold')

for ax in ax.flatten():
    ax.tick_params(axis='both', which='major', labelsize=30)

fig.tight_layout()
#plt.savefig('dl_demo_x.png', dpi=300)
#plt.savefig('dl_demo_x.eps', format='eps', bbox_inches='tight')

In [None]:
fig, ax = plt.subplots(nrows=4, ncols=4, figsize=(36,20))

ax[0,0].plot(df_tw_ts_test[demo_1_1], linewidth=2, color='xkcd:light periwinkle')  # plt.cm.Purples(64)
ax[0,0].plot(df_tw_sm_test[demo_1_1], linewidth=3, color='xkcd:dark periwinkle')  # plt.cm.Purples(192)
ax[0,0].set_xlim(left=0, right=4000)
ax[0,0].set_ylim(bottom=-0.5, top=25)
ax[0,0].set_xticklabels([])
ax[0,0].set_ylabel('Pro-vaccine\nTweets', fontsize=36)
ax[0,0].axvline(x=df_pt_test['trans'][demo_1_1], linestyle='--', linewidth=2, color='grey')
ax[0,0].text(2560, 25.5, 'Transition', size=28, color='grey')
ax[0,0].text(0, 1.05, 'A', transform=ax[0,0].transAxes, size=40, weight='bold')

ax[0,1].plot(df_tw_ts_test[demo_1_2], linewidth=2, color='xkcd:light periwinkle')
ax[0,1].plot(df_tw_sm_test[demo_1_2], linewidth=3, color='xkcd:dark periwinkle')
ax[0,1].set_xlim(left=0, right=4000)
ax[0,1].set_ylim(bottom=-0.5, top=25)
ax[0,1].set_xticklabels([])
ax[0,1].axvline(x=df_pt_test['trans'][demo_1_2], linestyle='--', linewidth=2, color='grey')
ax[0,1].text(0, 1.05, 'B', transform=ax[0,1].transAxes, size=40, weight='bold')

ax[0,2].plot(df_tw_ts_test[demo_0_1], linewidth=2, color='xkcd:light periwinkle')
ax[0,2].plot(df_tw_sm_test[demo_0_1], linewidth=3, color='xkcd:dark periwinkle')
ax[0,2].set_xlim(left=0, right=4000)
ax[0,2].set_ylim(bottom=-0.5, top=25)
ax[0,2].set_xticklabels([])
ax[0,2].text(0, 1.05, 'C', transform=ax[0,2].transAxes, size=40, weight='bold')

ax[0,3].plot(df_tw_ts_test[demo_0_2], linewidth=2, color='xkcd:light periwinkle')
ax[0,3].plot(df_tw_sm_test[demo_0_2], linewidth=3, color='xkcd:dark periwinkle')
ax[0,3].set_xlim(left=0, right=4000)
ax[0,3].set_ylim(bottom=-0.5, top=25)
ax[0,3].set_xticklabels([])
ax[0,3].text(0, 1.05, 'D', transform=ax[0,3].transAxes, size=40, weight='bold')

ax[1,0].plot(df_tw_var_test[demo_1_1], linewidth=2, color='mediumseagreen', label='Var')
ax[1,0].set_xlim(left=0, right=4000)
ax[1,0].set_ylim(bottom=-0.01, top=2)
ax[1,0].set_ylabel('Variance', fontsize=36)
ax[1,0].yaxis.get_offset_text().set_fontsize(24)
ax[1,0].set_xticklabels([])
ax[1,0].axvline(x=df_pt_test['trans'][demo_1_1], linestyle='--', linewidth=2, color='grey')
ax[1,0].annotate('', xy=(0, 0.05), xytext=(0.05, 0.05), arrowprops=dict(arrowstyle="<->"), xycoords='axes fraction')
ax[1,0].text(0, 1.05, 'E', transform=ax[1,0].transAxes, size=40, weight='bold')

ax[1,1].plot(df_tw_var_test[demo_1_2], linewidth=2, color='mediumseagreen', label='Var')
ax[1,1].set_xlim(left=0, right=4000)
ax[1,1].set_ylim(bottom=-0.1, top=20)
ax[1,1].yaxis.get_offset_text().set_fontsize(24)
ax[1,1].set_xticklabels([])
ax[1,1].axvline(x=df_pt_test['trans'][demo_1_2], linestyle='--', linewidth=2, color='grey')
ax[1,1].annotate('', xy=(0, 0.05), xytext=(0.05, 0.05), arrowprops=dict(arrowstyle="<->"), xycoords='axes fraction')
ax[1,1].text(0, 1.05, 'F', transform=ax[1,1].transAxes, size=40, weight='bold')

ax[1,2].plot(df_tw_var_test[demo_0_1], linewidth=2, color='mediumseagreen', label='Var')
ax[1,2].set_xlim(left=0, right=4000)
ax[1,2].set_ylim(bottom=-0.01, top=2)
ax[1,2].yaxis.get_offset_text().set_fontsize(24)
ax[1,2].set_xticklabels([])
ax[1,2].annotate('', xy=(0, 0.05), xytext=(0.05, 0.05), arrowprops=dict(arrowstyle="<->"), xycoords='axes fraction')
ax[1,2].text(0, 1.05, 'G', transform=ax[1,2].transAxes, size=40, weight='bold')

ax[1,3].plot(df_tw_var_test[demo_0_2], linewidth=2, color='mediumseagreen', label='Var')
ax[1,3].set_xlim(left=0, right=4000)
ax[1,3].set_ylim(bottom=-0.01, top=2)
ax[1,3].yaxis.get_offset_text().set_fontsize(24)
ax[1,3].set_xticklabels([])
ax[1,3].annotate('', xy=(0, 0.05), xytext=(0.05, 0.05), arrowprops=dict(arrowstyle="<->"), xycoords='axes fraction')
ax[1,3].text(0, 1.05, 'H', transform=ax[1,3].transAxes, size=40, weight='bold')

ax[2,0].plot(df_tw_ac_test[demo_1_1], linewidth=2, color='coral', label='AC')
ax[2,0].set_xlim(left=0, right=4000)
ax[2,0].set_ylim(bottom=-0.5, top=0.5)
ax[2,0].set_xticklabels([])
ax[2,0].set_ylabel('Lag-1 AC', fontsize=36)
ax[2,0].axvline(x=df_pt_test['trans'][demo_1_1], linestyle='--', linewidth=2, color='grey')
ax[2,0].annotate('', xy=(0, 0.05), xytext=(0.05, 0.05), arrowprops=dict(arrowstyle="<->"), xycoords='axes fraction')
ax[2,0].text(0, 1.05, 'I', transform=ax[2,0].transAxes, size=40, weight='bold')

ax[2,1].plot(df_tw_ac_test[demo_1_2], linewidth=2, color='coral', label='AC')
ax[2,1].set_xlim(left=0, right=4000)
ax[2,1].set_ylim(bottom=-0.5, top=0.5)
ax[2,1].set_xticklabels([])
ax[2,1].axvline(x=df_pt_test['trans'][demo_1_2], linestyle='--', linewidth=2, color='grey')
ax[2,1].annotate('', xy=(0, 0.05), xytext=(0.05, 0.05), arrowprops=dict(arrowstyle="<->"), xycoords='axes fraction')
ax[2,1].text(0, 1.05, 'J', transform=ax[2,1].transAxes, size=40, weight='bold')

ax[2,2].plot(df_tw_ac_test[demo_0_1], linewidth=2, color='coral', label='AC')
ax[2,2].set_xlim(left=0, right=4000)
ax[2,2].set_ylim(bottom=-0.5, top=0.5)
ax[2,2].set_xticklabels([])
ax[2,2].annotate('', xy=(0, 0.05), xytext=(0.05, 0.05), arrowprops=dict(arrowstyle="<->"), xycoords='axes fraction')
ax[2,2].text(0, 1.05, 'K', transform=ax[2,2].transAxes, size=40, weight='bold')

ax[2,3].plot(df_tw_ac_test[demo_0_2], linewidth=2, color='coral', label='AC')
ax[2,3].set_xlim(left=0, right=4000)
ax[2,3].set_ylim(bottom=-0.5, top=0.5)
ax[2,3].set_xticklabels([])
ax[2,3].annotate('', xy=(0, 0.05), xytext=(0.05, 0.05), arrowprops=dict(arrowstyle="<->"), xycoords='axes fraction')
ax[2,3].text(0, 1.05, 'L', transform=ax[2,3].transAxes, size=40, weight='bold')

ax[3,0].plot(resnet_demo_1_1['smoothing'], linewidth=2, color='cornflowerblue', label='ResNet')
ax[3,0].plot(lstm_demo_1_1['smoothing'], linewidth=2, color='crimson', label='LSTM')
ax[3,0].set_xlim(left=0, right=4000)
ax[3,0].set_ylim(bottom=-0.05, top=1.05)
ax[3,0].set_xlabel('Time', fontsize=36)
ax[3,0].set_ylabel('Probability', fontsize=36)
ax[3,0].text(3125, 0.2, 'ResNet', size=28, color='cornflowerblue')
ax[3,0].text(3125, 0.94, 'LSTM', size=28, color='crimson')
ax[3,0].axvline(x=df_pt_test['trans'][demo_1_1], linestyle='--', linewidth=2, color='grey')
ax[3,0].annotate('', xy=(0, 0.05), xytext=(530/4000, 0.05), arrowprops=dict(arrowstyle="<->"), xycoords='axes fraction')
ax[3,0].text(0, 1.05, 'M', transform=ax[3,0].transAxes, size=40, weight='bold')

ax[3,1].plot(resnet_demo_1_2['smoothing'], linewidth=2, color='cornflowerblue', label='ResNet')
ax[3,1].plot(lstm_demo_1_2['smoothing'], linewidth=2, color='crimson', label='LSTM')
ax[3,1].set_xlim(left=0, right=4000)
ax[3,1].set_ylim(bottom=-0.05, top=1.05)
ax[3,1].set_xlabel('Time', fontsize=36)
ax[3,1].axvline(x=df_pt_test['trans'][demo_1_2], linestyle='--', linewidth=2, color='grey')
ax[3,1].annotate('', xy=(0, 0.05), xytext=(530/4000, 0.05), arrowprops=dict(arrowstyle="<->"), xycoords='axes fraction')
ax[3,1].text(0, 1.05, 'N', transform=ax[3,1].transAxes, size=40, weight='bold')

ax[3,2].plot(resnet_demo_0_1['smoothing'], linewidth=2, color='cornflowerblue', label='ResNet')
ax[3,2].plot(lstm_demo_0_1['smoothing'], linewidth=2, color='crimson', label='LSTM')
ax[3,2].set_xlim(left=0, right=4000)
ax[3,2].set_ylim(bottom=-0.05, top=1.05)
ax[3,2].set_xlabel('Time', fontsize=36)
ax[3,2].annotate('', xy=(0, 0.05), xytext=(530/4000, 0.05), arrowprops=dict(arrowstyle="<->"), xycoords='axes fraction')
ax[3,2].text(0, 1.05, 'O', transform=ax[3,2].transAxes, size=40, weight='bold')

ax[3,3].plot(resnet_demo_0_2['smoothing'], linewidth=2, color='cornflowerblue', label='ResNet')
ax[3,3].plot(lstm_demo_0_2['smoothing'], linewidth=2, color='crimson', label='LSTM')
ax[3,3].set_xlim(left=0, right=4000)
ax[3,3].set_ylim(bottom=-0.05, top=1.05)
ax[3,3].set_xlabel('Time', fontsize=36)
ax[3,3].annotate('', xy=(0, 0.05), xytext=(530/4000, 0.05), arrowprops=dict(arrowstyle="<->"), xycoords='axes fraction')
ax[3,3].text(0, 1.05, 'P', transform=ax[3,3].transAxes, size=40, weight='bold')

for ax in ax.flatten():
    ax.tick_params(axis='both', which='major', labelsize=30)

fig.tight_layout()
#plt.savefig('dl_demo_tw.png', dpi=300)
#plt.savefig('dl_demo_tw.eps', format='eps', bbox_inches='tight')

## ROC curves

In [None]:
# Results of DL classifiers are given by 'df_prob_list_resnet' and 'df_prob_list_lstm' from above
# We provide some samples for demonstration
df_resnet = pd.read_csv('./demo_data/resnet_test_demo.csv')
df_lstm = pd.read_csv('./demo_data/lstm_test_demo.csv')
df_resnet_pre100 = pd.read_csv('./demo_data/resnet_test_pre100_demo.csv')
df_lstm_pre100 = pd.read_csv('./demo_data/lstm_test_pre100_demo.csv')
df_resnet_pre200 = pd.read_csv('./demo_data/resnet_test_pre200_demo.csv')
df_lstm_pre200 = pd.read_csv('./demo_data/lstm_test_pre200_demo.csv')
df_resnet_pre400 = pd.read_csv('./demo_data/resnet_test_pre400_demo.csv')
df_lstm_pre400 = pd.read_csv('./demo_data/lstm_test_pre400_demo.csv')

df_resnet.columns = range(df_resnet.shape[1])
df_lstm.columns = range(df_lstm.shape[1])
df_resnet_pre100.columns = range(df_resnet_pre100.shape[1])
df_lstm_pre100.columns = range(df_lstm_pre100.shape[1])
df_resnet_pre200.columns = range(df_resnet_pre200.shape[1])
df_lstm_pre200.columns = range(df_lstm_pre200.shape[1])
df_resnet_pre400.columns = range(df_resnet_pre400.shape[1])
df_lstm_pre400.columns = range(df_lstm_pre400.shape[1])

In [None]:
# Results of EWS (Kendall tau) are given by sde_simulation.ipynb
df_ktau_x = pd.read_csv('./test_data/test_ktau_x.csv')
df_ktau_x_pre100 = pd.read_csv('./test_data/test_ktau_x_pre100.csv')
df_ktau_x_pre200 = pd.read_csv('./test_data/test_ktau_x_pre200.csv')
df_ktau_x_pre400 = pd.read_csv('./test_data/test_ktau_x_pre400.csv')
df_ktau_tw = pd.read_csv('./test_data/test_ktau_tw.csv')
df_ktau_tw_pre100 = pd.read_csv('./test_data/test_ktau_tw_pre100.csv')
df_ktau_tw_pre200 = pd.read_csv('./test_data/test_ktau_tw_pre200.csv')
df_ktau_tw_pre400 = pd.read_csv('./test_data/test_ktau_tw_pre400.csv')

In [None]:
score_resnet = []
score_lstm = []
score_resnet_pre100 = []
score_lstm_pre100 = []
score_resnet_pre200 = []
score_lstm_pre200 = []
score_resnet_pre400 = []
score_lstm_pre400 = []

In [None]:
# Detrend
for j in range(df_resnet.shape[1]):
    ts_resnet = ewstools.TimeSeries(data=df_resnet[j])
    ts_lstm = ewstools.TimeSeries(data=df_lstm[j])
    ts_resnet.detrend(method='Gaussian', bandwidth=20)
    ts_lstm.detrend(method='Gaussian', bandwidth=20)
    score_resnet.append(ts_resnet.state['smoothing'][199])
    score_lstm.append(ts_lstm.state['smoothing'][199])
    
    ts_resnet_pre100 = ewstools.TimeSeries(data=df_resnet_pre100[j])
    ts_lstm_pre100 = ewstools.TimeSeries(data=df_lstm_pre100[j])
    ts_resnet_pre100.detrend(method='Gaussian', bandwidth=20)
    ts_lstm_pre100.detrend(method='Gaussian', bandwidth=20)
    score_resnet_pre100.append(ts_resnet_pre100.state['smoothing'][199])
    score_lstm_pre100.append(ts_lstm_pre100.state['smoothing'][199])
    
    ts_resnet_pre200 = ewstools.TimeSeries(data=df_resnet_pre200[j])
    ts_lstm_pre200 = ewstools.TimeSeries(data=df_lstm_pre200[j])
    ts_resnet_pre200.detrend(method='Gaussian', bandwidth=20)
    ts_lstm_pre200.detrend(method='Gaussian', bandwidth=20)
    score_resnet_pre200.append(ts_resnet_pre200.state['smoothing'][199])
    score_lstm_pre200.append(ts_lstm_pre200.state['smoothing'][199])

    ts_resnet_pre400 = ewstools.TimeSeries(data=df_resnet_pre400[j])
    ts_lstm_pre400 = ewstools.TimeSeries(data=df_lstm_pre400[j])
    ts_resnet_pre400.detrend(method='Gaussian', bandwidth=20)
    ts_lstm_pre400.detrend(method='Gaussian', bandwidth=20)
    score_resnet_pre400.append(ts_resnet_pre400.state['smoothing'][199])
    score_lstm_pre400.append(ts_lstm_pre400.state['smoothing'][199])

In [None]:
y_true = [1 for i in range(500)] + [0 for i in range(500)]

fpr_resnet, tpr_resnet, _ = metrics.roc_curve(y_true, score_resnet)
fpr_lstm, tpr_lstm, _ = metrics.roc_curve(y_true, score_lstm)
fpr_resnet_pre100, tpr_resnet_pre100, _ = metrics.roc_curve(y_true, score_resnet_pre100)
fpr_lstm_pre100, tpr_lstm_pre100, _ = metrics.roc_curve(y_true, score_lstm_pre100)
fpr_resnet_pre200, tpr_resnet_pre200, _ = metrics.roc_curve(y_true, score_resnet_pre200)
fpr_lstm_pre200, tpr_lstm_pre200, _ = metrics.roc_curve(y_true, score_lstm_pre200)
fpr_resnet_pre400, tpr_resnet_pre400, _ = metrics.roc_curve(y_true, score_resnet_pre400)
fpr_lstm_pre400, tpr_lstm_pre400, _ = metrics.roc_curve(y_true, score_lstm_pre400)

auc_resnet = metrics.auc(fpr_resnet, tpr_resnet)
auc_resnet_pre100 = metrics.auc(fpr_resnet_pre100, tpr_resnet_pre100)
auc_resnet_pre200 = metrics.auc(fpr_resnet_pre200, tpr_resnet_pre200)
auc_resnet_pre400 = metrics.auc(fpr_resnet_pre400, tpr_resnet_pre400)
auc_lstm = metrics.auc(fpr_lstm, tpr_lstm)
auc_lstm_pre100 = metrics.auc(fpr_lstm_pre100, tpr_lstm_pre100)
auc_lstm_pre200 = metrics.auc(fpr_lstm_pre200, tpr_lstm_pre200)
auc_lstm_pre400 = metrics.auc(fpr_lstm_pre400, tpr_lstm_pre400)

fpr_ktau_x_var, tpr_ktau_x_var, _ = metrics.roc_curve(y_true, df_ktau_x['var'])
fpr_ktau_x_ac, tpr_ktau_x_ac, _ = metrics.roc_curve(y_true, df_ktau_x['ac'])
fpr_ktau_x_var_pre100, tpr_ktau_x_var_pre100, _ = metrics.roc_curve(y_true, df_ktau_x_pre100['var'])
fpr_ktau_x_ac_pre100, tpr_ktau_x_ac_pre100, _ = metrics.roc_curve(y_true, df_ktau_x_pre100['ac'])
fpr_ktau_x_var_pre200, tpr_ktau_x_var_pre200, _ = metrics.roc_curve(y_true, df_ktau_x_pre200['var'])
fpr_ktau_x_ac_pre200, tpr_ktau_x_ac_pre200, _ = metrics.roc_curve(y_true, df_ktau_x_pre200['ac'])
fpr_ktau_x_var_pre400, tpr_ktau_x_var_pre400, _ = metrics.roc_curve(y_true, df_ktau_x_pre400['var'])
fpr_ktau_x_ac_pre400, tpr_ktau_x_ac_pre400, _ = metrics.roc_curve(y_true, df_ktau_x_pre400['ac'])

auc_ktau_x_var = metrics.auc(fpr_ktau_x_var, tpr_ktau_x_var)
auc_ktau_x_var_pre100 = metrics.auc(fpr_ktau_x_var_pre100, tpr_ktau_x_var_pre100)
auc_ktau_x_var_pre200 = metrics.auc(fpr_ktau_x_var_pre200, tpr_ktau_x_var_pre200)
auc_ktau_x_var_pre400 = metrics.auc(fpr_ktau_x_var_pre400, tpr_ktau_x_var_pre400)
auc_ktau_x_ac = metrics.auc(fpr_ktau_x_ac, tpr_ktau_x_ac)
auc_ktau_x_ac_pre100 = metrics.auc(fpr_ktau_x_ac_pre100, tpr_ktau_x_ac_pre100)
auc_ktau_x_ac_pre200 = metrics.auc(fpr_ktau_x_ac_pre200, tpr_ktau_x_ac_pre200)
auc_ktau_x_ac_pre400 = metrics.auc(fpr_ktau_x_ac_pre400, tpr_ktau_x_ac_pre400)

fpr_ktau_tw_var, tpr_ktau_tw_var, _ = metrics.roc_curve(y_true, df_ktau_tw['var'])
fpr_ktau_tw_ac, tpr_ktau_tw_ac, _ = metrics.roc_curve(y_true, df_ktau_tw['ac'])
fpr_ktau_tw_var_pre100, tpr_ktau_tw_var_pre100, _ = metrics.roc_curve(y_true, df_ktau_tw_pre100['var'])
fpr_ktau_tw_ac_pre100, tpr_ktau_tw_ac_pre100, _ = metrics.roc_curve(y_true, df_ktau_tw_pre100['ac'])
fpr_ktau_tw_var_pre200, tpr_ktau_tw_var_pre200, _ = metrics.roc_curve(y_true, df_ktau_tw_pre200['var'])
fpr_ktau_tw_ac_pre200, tpr_ktau_tw_ac_pre200, _ = metrics.roc_curve(y_true, df_ktau_tw_pre200['ac'])
fpr_ktau_tw_var_pre400, tpr_ktau_tw_var_pre400, _ = metrics.roc_curve(y_true, df_ktau_tw_pre400['var'])
fpr_ktau_tw_ac_pre400, tpr_ktau_tw_ac_pre400, _ = metrics.roc_curve(y_true, df_ktau_tw_pre400['ac'])

auc_ktau_tw_var = metrics.auc(fpr_ktau_tw_var, tpr_ktau_tw_var)
auc_ktau_tw_var_pre100 = metrics.auc(fpr_ktau_tw_var_pre100, tpr_ktau_tw_var_pre100)
auc_ktau_tw_var_pre200 = metrics.auc(fpr_ktau_tw_var_pre200, tpr_ktau_tw_var_pre200)
auc_ktau_tw_var_pre400 = metrics.auc(fpr_ktau_tw_var_pre400, tpr_ktau_tw_var_pre400)
auc_ktau_tw_ac = metrics.auc(fpr_ktau_tw_ac, tpr_ktau_tw_ac)
auc_ktau_tw_ac_pre100 = metrics.auc(fpr_ktau_tw_ac_pre100, tpr_ktau_tw_ac_pre100)
auc_ktau_tw_ac_pre200 = metrics.auc(fpr_ktau_tw_ac_pre200, tpr_ktau_tw_ac_pre200)
auc_ktau_tw_ac_pre400 = metrics.auc(fpr_ktau_tw_ac_pre400, tpr_ktau_tw_ac_pre400)

In [None]:
fig, ax = plt.subplots(nrows=3, ncols=4, figsize=(36,27))

ax[0,0].plot(fpr_ktau_x_var, tpr_ktau_x_var, color='mediumseagreen', linewidth=3)
ax[0,0].plot(fpr_ktau_x_ac, tpr_ktau_x_ac, color='coral', linewidth=3)
ax[0,0].plot([0, 1], [0, 1], linestyle='--', color='grey', linewidth=3)
ax[0,0].set_ylabel('True Positive Rate', fontsize=36)
ax[0,0].text(0.37, 0.61, 'Var($x$)', size=28, color='mediumseagreen')
ax[0,0].text(0.2, 0.77, 'AC($x$)', size=28, color='coral')
ax[0,0].text(0.6, 0.18, r'$\mathregular{AUC_{Var}} = $'+r'{:.2f}'.format(auc_ktau_x_var), size=24, color='black')
ax[0,0].text(0.6, 0.1, r'$\mathregular{AUC_{AC}} = $'+r'{:.2f}'.format(auc_ktau_x_ac), size=24, color='black')
ax[0,0].text(0, 1.025, 'A', transform=ax[0,0].transAxes, size=40, weight='bold')

ax[0,1].plot(fpr_ktau_x_var_pre100, tpr_ktau_x_var_pre100, color='mediumseagreen', linewidth=3)
ax[0,1].plot(fpr_ktau_x_ac_pre100, tpr_ktau_x_ac_pre100, color='coral', linewidth=3)
ax[0,1].plot([0, 1], [0, 1], linestyle='--', color='grey', linewidth=3)
ax[0,1].text(0.6, 0.18, r'$\mathregular{AUC_{Var}} = $'+r'{:.2f}'.format(auc_ktau_x_var_pre100), size=24, color='black')
ax[0,1].text(0.6, 0.1, r'$\mathregular{AUC_{AC}} = $'+r'{:.2f}'.format(auc_ktau_x_ac_pre100), size=24, color='black')
ax[0,1].text(0, 1.025, 'B', transform=ax[0,1].transAxes, size=40, weight='bold')

ax[0,2].plot(fpr_ktau_x_var_pre200, tpr_ktau_x_var_pre200, color='mediumseagreen', linewidth=3)
ax[0,2].plot(fpr_ktau_x_ac_pre200, tpr_ktau_x_ac_pre200, color='coral', linewidth=3)
ax[0,2].plot([0, 1], [0, 1], linestyle='--', color='grey', linewidth=3)
ax[0,2].text(0.6, 0.18, r'$\mathregular{AUC_{Var}} = $'+r'{:.2f}'.format(auc_ktau_x_var_pre200), size=24, color='black')
ax[0,2].text(0.6, 0.1, r'$\mathregular{AUC_{AC}} = $'+r'{:.2f}'.format(auc_ktau_x_ac_pre200), size=24, color='black')
ax[0,2].text(0, 1.025, 'C', transform=ax[0,2].transAxes, size=40, weight='bold')

ax[0,3].plot(fpr_ktau_x_var_pre400, tpr_ktau_x_var_pre400, color='mediumseagreen', linewidth=3)
ax[0,3].plot(fpr_ktau_x_ac_pre400, tpr_ktau_x_ac_pre400, color='coral', linewidth=3)
ax[0,3].plot([0, 1], [0, 1], linestyle='--', color='grey', linewidth=3)
ax[0,3].text(0.6, 0.18, r'$\mathregular{AUC_{Var}} = $'+r'{:.2f}'.format(auc_ktau_x_var_pre400), size=24, color='black')
ax[0,3].text(0.6, 0.1, r'$\mathregular{AUC_{AC}} = $'+r'{:.2f}'.format(auc_ktau_x_ac_pre400), size=24, color='black')
ax[0,3].text(0, 1.025, 'D', transform=ax[0,3].transAxes, size=40, weight='bold')

ax[1,0].plot(fpr_ktau_tw_var, tpr_ktau_tw_var, color='mediumseagreen', linewidth=3)
ax[1,0].plot(fpr_ktau_tw_ac, tpr_ktau_tw_ac, color='coral', linewidth=3)
ax[1,0].plot([0, 1], [0, 1], linestyle='--', color='grey', linewidth=3)
ax[1,0].set_ylabel('True Positive Rate', fontsize=36)
ax[1,0].text(0.1, 0.93, 'Var($T_p$)', size=28, color='mediumseagreen')
ax[1,0].text(0.6, 0.46, 'AC($T_p$)', size=28, color='coral')
ax[1,0].text(0.6, 0.18, r'$\mathregular{AUC_{Var}} = $'+r'{:.2f}'.format(auc_ktau_tw_var), size=24, color='black')
ax[1,0].text(0.6, 0.1, r'$\mathregular{AUC_{AC}} = $'+r'{:.2f}'.format(auc_ktau_tw_ac), size=24, color='black')
ax[1,0].text(0, 1.025, 'E', transform=ax[1,0].transAxes, size=40, weight='bold')

ax[1,1].plot(fpr_ktau_tw_var_pre100, tpr_ktau_tw_var_pre100, color='mediumseagreen', linewidth=3)
ax[1,1].plot(fpr_ktau_tw_ac_pre100, tpr_ktau_tw_ac_pre100, color='coral', linewidth=3)
ax[1,1].plot([0, 1], [0, 1], linestyle='--', color='grey', linewidth=3)
ax[1,1].text(0.6, 0.18, r'$\mathregular{AUC_{Var}} = $'+r'{:.2f}'.format(auc_ktau_tw_var_pre100), size=24, color='black')
ax[1,1].text(0.6, 0.1, r'$\mathregular{AUC_{AC}} = $'+r'{:.2f}'.format(auc_ktau_tw_ac_pre100), size=24, color='black')
ax[1,1].text(0, 1.025, 'F', transform=ax[1,1].transAxes, size=40, weight='bold')

ax[1,2].plot(fpr_ktau_tw_var_pre200, tpr_ktau_tw_var_pre200, color='mediumseagreen', linewidth=3)
ax[1,2].plot(fpr_ktau_tw_ac_pre200, tpr_ktau_tw_ac_pre200, color='coral', linewidth=3)
ax[1,2].plot([0, 1], [0, 1], linestyle='--', color='grey', linewidth=3)
ax[1,2].text(0.6, 0.18, r'$\mathregular{AUC_{Var}} = $'+r'{:.2f}'.format(auc_ktau_tw_var_pre200), size=24, color='black')
ax[1,2].text(0.6, 0.1, r'$\mathregular{AUC_{AC}} = $'+r'{:.2f}'.format(auc_ktau_tw_ac_pre200), size=24, color='black')
ax[1,2].text(0, 1.025, 'G', transform=ax[1,2].transAxes, size=40, weight='bold')

ax[1,3].plot(fpr_ktau_tw_var_pre400, tpr_ktau_tw_var_pre400, color='mediumseagreen', linewidth=3)
ax[1,3].plot(fpr_ktau_tw_ac_pre400, tpr_ktau_tw_ac_pre400, color='coral', linewidth=3)
ax[1,3].plot([0, 1], [0, 1], linestyle='--', color='grey', linewidth=3)
ax[1,3].text(0.6, 0.18, r'$\mathregular{AUC_{Var}} = $'+r'{:.2f}'.format(auc_ktau_tw_var_pre400), size=24, color='black')
ax[1,3].text(0.6, 0.1, r'$\mathregular{AUC_{AC}} = $'+r'{:.2f}'.format(auc_ktau_tw_ac_pre400), size=24, color='black')
ax[1,3].text(0, 1.025, 'H', transform=ax[1,3].transAxes, size=40, weight='bold')

ax[2,0].plot(fpr_resnet, tpr_resnet, color='cornflowerblue', linewidth=3)
ax[2,0].plot(fpr_lstm, tpr_lstm, color='crimson', linewidth=3)
ax[2,0].plot([0, 1], [0, 1], linestyle='--', color='grey', linewidth=3)
ax[2,0].set_xlabel('False Positive Rate', fontsize=36)
ax[2,0].set_ylabel('True Positive Rate', fontsize=36)
ax[2,0].text(0.2, 0.91, 'ResNet', size=28, color='cornflowerblue')
ax[2,0].text(-0.04, 0.99, 'LSTM', size=28, color='crimson')
ax[2,0].text(0.6, 0.18, r'$\mathregular{AUC_{ResNet}} = $'+r'{:.2f}'.format(auc_resnet), size=24, color='black')
ax[2,0].text(0.6, 0.1, r'$\mathregular{AUC_{LSTM}} = $'+r'{:.2f}'.format(auc_lstm), size=24, color='black')
ax[2,0].text(0, 1.025, 'I', transform=ax[2,0].transAxes, size=40, weight='bold')

ax[2,1].plot(fpr_resnet_pre100, tpr_resnet_pre100, color='cornflowerblue', linewidth=3)
ax[2,1].plot(fpr_lstm_pre100, tpr_lstm_pre100, color='crimson', linewidth=3)
ax[2,1].plot([0, 1], [0, 1], linestyle='--', color='grey', linewidth=3)
ax[2,1].set_xlabel('False Positive Rate', fontsize=36)
ax[2,1].text(0.6, 0.18, r'$\mathregular{AUC_{ResNet}} = $'+r'{:.2f}'.format(auc_resnet_pre100), size=24, color='black')
ax[2,1].text(0.6, 0.1, r'$\mathregular{AUC_{LSTM}} = $'+r'{:.2f}'.format(auc_lstm_pre100), size=24, color='black')
ax[2,1].text(0, 1.025, 'J', transform=ax[2,1].transAxes, size=40, weight='bold')

ax[2,2].plot(fpr_resnet_pre200, tpr_resnet_pre200, color='cornflowerblue', linewidth=3)
ax[2,2].plot(fpr_lstm_pre200, tpr_lstm_pre200, color='crimson', linewidth=3)
ax[2,2].plot([0, 1], [0, 1], linestyle='--', color='grey', linewidth=3)
ax[2,2].set_xlabel('False Positive Rate', fontsize=36)
ax[2,2].text(0.6, 0.18, r'$\mathregular{AUC_{ResNet}} = $'+r'{:.2f}'.format(auc_resnet_pre200), size=24, color='black')
ax[2,2].text(0.6, 0.1, r'$\mathregular{AUC_{LSTM}} = $'+r'{:.2f}'.format(auc_lstm_pre200), size=24, color='black')
ax[2,2].text(0, 1.025, 'K', transform=ax[2,2].transAxes, size=40, weight='bold')

ax[2,3].plot(fpr_resnet_pre400, tpr_resnet_pre400, color='cornflowerblue', linewidth=3)
ax[2,3].plot(fpr_lstm_pre400, tpr_lstm_pre400, color='crimson', linewidth=3)
ax[2,3].plot([0, 1], [0, 1], linestyle='--', color='grey', linewidth=3)
ax[2,3].set_xlabel('False Positive Rate', fontsize=36)
ax[2,3].text(0.6, 0.18, r'$\mathregular{AUC_{ResNet}} = $'+r'{:.2f}'.format(auc_resnet_pre400), size=24, color='black')
ax[2,3].text(0.6, 0.1, r'$\mathregular{AUC_{LSTM}} = $'+r'{:.2f}'.format(auc_lstm_pre400), size=24, color='black')
ax[2,3].text(0, 1.025, 'L', transform=ax[2,3].transAxes, size=40, weight='bold')

for ax in ax.flatten():
    ax.tick_params(axis='both', which='major', labelsize=30)

fig.tight_layout()

#plt.savefig('roc_300ep.png', dpi=300)
#plt.savefig('roc_300ep.eps', format='eps', bbox_inches='tight')