In [44]:
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
import numpy as np
import pandas as pd
import ipywidgets as widgets
from ipywidgets import interact
import seaborn as sns

from utils.load_data import load_data
from utils.preprocessing import preprocess_data
from models.BaseModel import BaseModel
from postprocessing.arima import postprocess_arima, postprocess_arima_auto

%load_ext autoreload
%autoreload 2
%reload_ext autoreload

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [45]:
SEED = 0
EPS = 1e-6
VALID_PROPORTION = 0.2
CONFIDENCE_LEVEL = 0.95

In [46]:
data = load_data()

X, y, countries, y_mean, y_std = preprocess_data(data=data, epsilon=EPS)
countries = countries.to_frame().assign(date=X['date'])

In [47]:
n_valid = int(len(X) * VALID_PROPORTION)

X_train, y_train = X[:-n_valid], y[:-n_valid]
X_valid, y_valid = X[-n_valid:], y[-n_valid:]
country_train, country_valid = countries.iloc[:-n_valid], countries.iloc[-n_valid:]


In [48]:
model = BaseModel(seed=SEED)

model.fit(X_train, y_train)

# Predict on training and validation data
y_pred_train = model.predict(X_train)
y_pred_valid = model.predict(X_valid)

# Calculate Mean Squared Error
mse_train = mean_squared_error(y_train, y_pred_train)
mse_valid = mean_squared_error(y_valid, y_pred_valid)
print(f"Training MSE: {mse_train:.4f}")
print(f"Validation MSE: {mse_valid:.4f}")

Training MSE: 0.0027
Validation MSE: 0.7334


In [None]:
# Associate the result by country
y_pred_valid_country = pd.DataFrame({'date': country_valid['date'].values, 'country': country_valid['country'].values, 'y_pred': y_pred_valid, 'y_true': y_valid, "y_pred_init": y_pred_valid, "y_pred_avg": y_pred_valid})
y_pred_train_country = pd.DataFrame({'date': country_train['date'].values, 'country': country_train['country'].values, 'y_pred': y_pred_train, 'y_true': y_train, "y_pred_init": y_pred_train, "y_pred_avg": y_pred_train})

# predictions = pd.concat([y_pred_train_country, y_pred_valid_country], axis=0)
# predictions_melted = predictions.melt(
#     id_vars=["date", "country"], value_vars=["y_pred", "y_true"],
#     var_name="Type", value_name="Value"
# )
# adjusted_predictions = predictions.copy()

# Apply the post-processing function
#adjusted_predictions = postprocess_arima_auto(y_pred_train_country, y_pred_valid_country)

p = 10 # p = the number of autoregressive terms
d = 1 # d = the number of differences needed for stationarity
q = 10 # q = the number of lags to be used in the model
adjusted_predictions = postprocess_arima(y_pred_train_country, y_pred_valid_country, p, d, q)

# Update y_pred_avg
adjusted_predictions['y_pred_avg'] = (adjusted_predictions['y_pred'] + adjusted_predictions['y_pred_init']) / 2

# Change "y_pred" to "y_pred_arima"
adjusted_predictions = adjusted_predictions.rename(columns={"y_pred": "y_pred_arima"})
adjusted_predictions = adjusted_predictions.rename(columns={"y_pred_init": "y_pred_NN"})

# Evaluate adjusted predictions
valid_adjusted = adjusted_predictions[adjusted_predictions['set'] == 'validation']
mse_valid_adjusted = mean_squared_error(valid_adjusted['y_true'], valid_adjusted['y_pred_arima'])
print(f"Adjusted Validation MSE: {mse_valid_adjusted:.4f}")

# Melting the dataframe for better plotting
predictions_melted = adjusted_predictions.melt(
    id_vars=["date", "country"], value_vars=["y_true", "y_pred_NN", "y_pred_arima", "y_pred_avg"],
    var_name="Type", value_name="Value"
)

  warn('Non-stationary starting autoregressive parameters'
  warn('Non-invertible starting MA parameters found.'
  warn('Non-stationary starting autoregressive parameters'
  warn('Non-invertible starting MA parameters found.'
  warn('Non-stationary starting autoregressive parameters'
  warn('Non-invertible starting MA parameters found.'
  warn('Non-stationary starting autoregressive parameters'
  warn('Non-invertible starting MA parameters found.'
  warn('Non-stationary starting autoregressive parameters'
  warn('Non-invertible starting MA parameters found.'
  warn('Non-stationary starting autoregressive parameters'
  warn('Non-invertible starting MA parameters found.'
  warn('Non-stationary starting autoregressive parameters'
  warn('Non-invertible starting MA parameters found.'


Adjusted Validation MSE: 0.2829




In [57]:
# Function to plot data with confidence intervals for the selected country
def plot_by_country_with_confidence(selected_country):
    filtered_data = predictions_melted[predictions_melted["country"] == selected_country]
    cutoff_date = filtered_data['date'].quantile(1-VALID_PROPORTION)

    plt.figure(figsize=(12, 6))
    
    # Plot predictions and true values
    sns.lineplot(
        data=filtered_data,
        x="date", y="Value", hue="Type", style="Type", markers=True, dashes=False
    )

    # Add a vertical line to indicate where validation starts
    plt.axvline(x=cutoff_date, color='red', linestyle='--', label=f'Validation Start ({(1-VALID_PROPORTION)*100:.0f}%)')
    
    # Enhancing the plot
    plt.title(f"Prediction vs True Values for {selected_country}")
    plt.xlabel("Date")
    plt.ylabel("Values")
    plt.legend(title="Legend")
    plt.grid(True)
    plt.tight_layout()
    plt.show()

# Create a dropdown widget for selecting the country
countries = adjusted_predictions["country"].unique()
dropdown = widgets.Dropdown(
    options=countries,
    value=countries[0],
    description='Country:'
)

# Use the interact function to link the dropdown with the updated plot function
interact(plot_by_country_with_confidence, selected_country=dropdown)

interactive(children=(Dropdown(description='Country:', options=('Switzerland', 'Germany', 'United Kingdom', 'Jâ€¦

<function __main__.plot_by_country_with_confidence(selected_country)>