In [17]:
# !pip install plotly
# !pip install nbformat
import plotly.express as px
import numpy as np
import pandas as pd
from sklearn.metrics import mean_squared_error
from math import sqrt
from pathlib import Path
import time

In [18]:
# set the dir path
timesNetDir = "long_term_forecast_wind_12_12_TimesNet_custom_ftMS_sl12_ll5_pl12_dm64_nh8_el3_dl1_df64_fc3_ebtimeF_dtTrue_Exp_0"
dir = Path(f"results/{timesNetDir}")

# Load the time index of testing set
index_df = pd.read_csv('test_index.csv')
index_df['date'] = pd.to_datetime(index_df['date'])

# Load true values
trues = np.load(dir/"true.npy").reshape(-1, 12)
trues_df = pd.DataFrame(trues).join(index_df).rename(columns={i: f"Trues_p{i+1}" for i in range(12)})
trues_df = trues_df.set_index("date")

# Load true values of TimesNet
preds = np.load(dir/"pred.npy").reshape(-1, 12)
preds_df = pd.DataFrame(preds).join(index_df).rename(columns={i: f"TimesNet_p{i+1}" for i in range(12)})
preds_df = preds_df.set_index("date")

# Load true values of LSTM
lstm_df = pd.read_csv("results/LSTM_predictions.csv").rename(columns={f"p{i}": f"LSTM_p{i}" for i in range(1, 13)})
lstm_df['datetime'] = pd.to_datetime(lstm_df['datetime']) + pd.DateOffset(hours=12)
lstm_df = lstm_df.set_index("datetime").join(index_df.set_index("date"), how="inner")

# Join true values, TimesNet and LSTM
outputs = lstm_df.join(preds_df).join(trues_df)

In [19]:
lstm_rmse, timeNet_rmse = {}, {}
for i in range(1, 13):
    timesNet = outputs[f'TimesNet_p{i}']
    lstm = outputs[f'LSTM_p{i}']
    trues = outputs[f"Trues_p{i}"]

    timeNet_rmse[i] = mean_squared_error(trues, timesNet, squared=False)
    lstm_rmse[i] = mean_squared_error(trues, lstm, squared=False)

In [20]:
import os
output_dir = Path("plot/LSTM_TimesNet_comparison")
if not os.path.exists(output_dir):
    os.mkdir(output_dir)

for i in range(1, 13):
    fig = px.line(
        outputs, x=outputs.index, y=[f'LSTM_p{i}', f'TimesNet_p{i}', f"Trues_p{i}"], 
        title=f"Lead Time {i}", labels={'x': 'Time', 'value': 'Wind Speed'},
    )
    time.sleep(1)
    fig.write_html(f"plot/LSTM_TimesNet_comparison/p{i}.html")

df = pd.DataFrame([timeNet_rmse, lstm_rmse], index=["TimesNet", "LSTM"]).T
df.to_csv(output_dir/"rmse.csv")