In [None]:
# import
import pandas as pd
import numpy as np

import matplotlib.pyplot as plt

import matplotlib.dates as mdates
myFmt = mdates.DateFormatter('%d %b')
myFmt2 = mdates.DateFormatter('%d %b \n%H:%M')
import datetime
from scipy import signal

In [None]:
!pip install plotly

## Data loading

In [None]:
# settings
shift = 15
window = 45
model_type = "autoreg15" # for autoreg - autoreg{shift]; for multiNN - PC, AsyH, ap_index, BzGSE, all_parameters
model_type2 = "autoreg" # for autoreg - autoreg; for multiNN - PC, AsyH, ap_index, BzGSE, all_parameters
best_model = True # load best model (true) or model from epoch (false)
epoch = ""
mae = ""

In [None]:
# data loading
size=1 # select subset (>1) or whole test_data (1)
param_t = np.load(f"../data/{model_type2}/shift-{shift}-windows-{window}/test_index_timestamp.npy", allow_pickle = True).flatten()
param_ytest = np.load(f"../data/{model_type2}/shift-{shift}-windows-{window}/y_test.npy", allow_pickle = True).flatten()
param_bin = np.load(f"../data/{model_type2}/shift-{shift}-windows-{window}/test_index_bin.npy", allow_pickle = True).flatten()
if best_model:
    param_pred = np.load(f"../prediction/y_pred-{model_type2}-shift-{shift}-windows-{window}.npy, allow_pickle = True").flatten()
else:
    param_pred = np.load(f"../prediction/{model_type2}/shift-{shift}-windows-{window}/y_pred_epoch-{epoch}-mae-{mae}.npy", allow_pickle = True).flatten()

# create dataframe
param_df = pd.DataFrame(data={f"{model_type}_test": param_ytest[:int(len(param_ytest)/size)], f"{model_type}_pred": param_pred[:int(len(param_pred)/size)], f"{model_type}_bin":param_bin[:int(len(param_ytest)/size)], f"{model_type}_time":param_t[:int(len(param_ytest)/size)]})
param_df[f"{model_type}_bin"] = param_df[f"{model_type}_bin"].str[:5]

param_df["index_time_bin"]=param_df[f"{model_type}_time"]+" "+param_df[f"{model_type}_bin"]
param_df=param_df.set_index("index_time_bin")

In [None]:
# create scinti_df
scinti_df = param_df
scinti_df = scinti_df.set_index(pd.to_datetime(scinti_df[f"{model_type}_time"]))

scinti_df['persistence_shift_time'] = scinti_df.index - datetime.timedelta(minutes=shift)
scinti_df['persistence_shift_time'] = scinti_df['persistence_shift_time'].astype(str)
scinti_df['persistence_shift_time'] = scinti_df['persistence_shift_time']+" "+scinti_df[f'{model_type}_bin'].astype(str)

mapping_true = dict(scinti_df[['index_time_bin', f"{model_type}_test"]].values)
scinti_df['persistence'] = scinti_df["persistence_shift_time"].map(mapping_true)

scinti_df["persistance_true"]=np.where(scinti_df[f"{model_type}_test"] >= 0.1, 1, 0)
scinti_df["persistance_pred"]=np.where(scinti_df["persistence"] >= 0.1, 1, 0)

scinti_df

In [None]:
# select bin
scinti_df=scinti_df[scinti_df[f"{model_type}_bin"] == 13725]
scinti_df = scinti_df.set_index(pd.to_datetime(scinti_df.index))

In [None]:
# visualization of true and predicted values

scinti_df = scinti_df.set_index(pd.to_datetime(scinti_df[f"{model_type}_time"]))

plt.rcParams["figure.figsize"] = (10,5)
for i in range(0, 53,1):  
    # can change start time in "xlim"
    plt.xlim(pd.Timestamp('2019-01-24 4:00:00')+datetime.timedelta(days=i*2), pd.Timestamp('2019-01-24 4:00:00')+datetime.timedelta(days=i*2+2))
    plt.scatter(scinti_df.index, scinti_df[f'{model_type}_test'], s=0.8, c="lime", label="True Phi60_Sig1")
    plt.scatter(scinti_df.index, scinti_df[f'{model_type}_pred'], s=0.1, c="dodgerblue", label="Predicted Phi60_Sig1")

    plt.ylabel(r'$\sigma_\phi [rad]$', fontsize=12)
    plt.xlabel("Time")
    plt.axhline(0.1, color='silver', linewidth=2)
    plt.grid(True, linewidth=0.3)
    lgnd=plt.legend()
    lgnd.legend_handles[0].set_sizes([10])
    lgnd.legend_handles[1].set_sizes([10])
    plt.ylim(0, 0.5)
    plt.gca().xaxis.set_major_formatter(myFmt2)
    plt.show()

In [None]:
# Savitzky-golay filter
# predicted values
scinti_df["sav_gol"] = signal.savgol_filter(scinti_df["autoreg_15_pred"], window_length=15, polyorder=0, mode="wrap")
# true values
scinti_df["sav_gol_true"] = signal.savgol_filter(scinti_df["autoreg_15_test"], window_length=60, polyorder=0, mode="wrap")

In [None]:
# Convolution filter with Hann window
win = signal.windows.hann(60)
# predicted values
scinti_df["scipy_conv"] = signal.convolve(scinti_df["autoreg_15_pred"], win, mode='same') / sum(win)
# true values
scinti_df["scipy_conv_true"] = signal.convolve(scinti_df["autoreg_15_test"], win, mode='same') / sum(win)

In [None]:
# Kaiser window function
beta = [2,4,16,32]
def smooth(x,beta):
    """ kaiser window smoothing """
    window_len=31
 # extending the data at beginning and at the end
 # to apply the window at the borders
    s = np.r_[x[window_len-1:0:-1],x,x[-1:-window_len:-1]]
    w = np.kaiser(window_len,beta)
    y = np.convolve(w/w.sum(),s,mode='valid')
    return y[15:len(y)-15]

In [None]:
# Kaiser window
# true values
scinti_df["kaiser_true"] = smooth(scinti_df.autoreg_15_test,1) 
# predicted_values
scinti_df["kaiser"] = smooth(scinti_df.autoreg_15_pred,1) 

In [None]:
plt.rcParams["figure.figsize"] = (10,5)
for i in range(0, 53,1):  
    plt.xlim(pd.Timestamp('2019-01-24 4:00:00')+datetime.timedelta(days=i*1), pd.Timestamp('2019-01-24 4:00:00')+datetime.timedelta(days=i*1+1))
    plt.scatter(scinti_df.index, scinti_df[f'{model_type}_test'], s=1, c="black", label="True Phi60_Sig1")
    plt.scatter(scinti_df.index, scinti_df[f'sav_gol_true'], s=1, c="orange", label="Savitzky-Golay filter (45)")
    plt.scatter(scinti_df.index, scinti_df[f'scipy_conv_true'], s=1, c="mediumvioletred", label="Convolution Hann window (60)")
    plt.scatter(scinti_df.index, scinti_df[f'kaiser_true'], s=1, c="deepskyblue", label="Kaiser window smothing (30)")
   
    plt.ylabel(r'$\sigma_\phi [rad]$', fontsize=12)
    plt.xlabel("Time", fontsize=12)
    plt.axhline(0.1, color='silver', linewidth=2)
    plt.grid(True, linewidth=0.3)
    lgnd=plt.legend()
    lgnd.legend_handles[0].set_sizes([10])
    lgnd.legend_handles[1].set_sizes([10])
    lgnd.legend_handles[2].set_sizes([10])
    lgnd.legend_handles[3].set_sizes([10])
    plt.ylim(0, 0.3)
    plt.gca().xaxis.set_major_formatter(myFmt2)
    plt.show()
