In [198]:
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
import numpy as np
import pandas as pd
import ipywidgets as widgets
from ipywidgets import interact
import seaborn as sns

from utils.load_data import load_data
from utils.preprocessing import preprocess_data
from models.BaseModel import BaseModel
from postprocessing.arima import postprocess_arima

%load_ext autoreload
%autoreload 2
%reload_ext autoreload

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [199]:
SEED = 0
EPS = 1e-6
VALID_PROPORTION = 0.1
CONFIDENCE_LEVEL = 0.95

In [200]:
data = load_data()

X, y, countries, y_mean, y_std = preprocess_data(data=data, epsilon=EPS)
countries = countries.to_frame().assign(date=X['date'])

X

Unnamed: 0,date,Expense_average,Research_and_development_average,Capital_expenditure_average,Business_average,Cost_average,Tax_average,Financial_capital_average,Investment_average,Gross_domestic_product_average,...,Artificial_intelligence_average,International_Financial_Reporting_Standards_average,Employment_average,country_Canada,country_Germany,country_Japan,country_Korea,country_Switzerland,country_United Kingdom,country_United States
2,-1.771126,2.782184,-2.417180,-1.696290,1.426370,-1.290762,1.327271,-0.604001,2.530535,-3.273764,...,-0.043561,-1.710910,0.900544,-0.415202,-0.415202,-0.365754,-0.412192,2.404072,-0.415202,-0.415202
1227,-1.771126,-0.446089,2.210188,-1.696290,0.516251,0.636960,0.313333,-0.604001,0.875985,3.790688,...,-0.336456,-1.710910,1.839678,-0.415202,-0.415202,-0.365754,2.421629,-0.415202,-0.415202,-0.415202
492,-1.771126,-1.009755,2.751049,-1.696290,1.718909,-1.037115,-0.773029,1.073548,2.594171,2.276877,...,-0.277877,0.745671,2.102636,-0.415202,-0.415202,-0.365754,-0.412192,-0.415202,2.404072,-0.415202
1472,-1.771126,-0.036149,3.231815,0.589877,1.653900,-0.682008,1.399695,0.082269,1.703260,0.763066,...,-0.453614,-0.598496,1.689417,-0.415202,-0.415202,-0.365754,-0.412192,-0.415202,-0.415202,2.404072
247,-1.771126,1.398638,0.707796,0.793092,1.231345,-1.595140,1.822169,-0.604001,1.957806,0.618893,...,-0.395035,0.652970,0.862979,-0.415202,2.404072,-0.365754,-0.412192,-0.415202,-0.415202,-0.415202
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
732,1.715275,-0.138634,1.248657,0.132644,0.776285,1.093526,2.727470,-0.451496,1.575987,3.574430,...,4.759909,-0.274042,1.802113,-0.415202,-0.415202,-0.365754,-0.412192,-0.415202,2.404072,-0.415202
487,1.715275,0.783729,-0.313831,1.605952,0.256216,2.311035,2.425703,-0.604001,1.639624,0.979325,...,4.818488,-0.459444,-0.865028,-0.415202,2.404072,-0.365754,-0.412192,-0.415202,-0.415202,-0.415202
242,1.715275,1.808578,-1.395553,1.656756,0.841293,2.361764,3.270651,-0.604001,2.530535,1.988532,...,5.111383,-0.181340,1.464025,-0.415202,-0.415202,-0.365754,-0.412192,2.404072,-0.415202,-0.415202
1222,1.715275,2.116032,-0.361907,1.555148,1.114329,2.006658,3.149944,-0.222740,1.130532,1.700187,...,5.111383,0.421217,0.862979,2.404072,-0.415202,-0.365754,-0.412192,-0.415202,-0.415202,-0.415202


In [201]:
n_valid = int(len(X) * VALID_PROPORTION)

X_train, y_train = X[:-n_valid], y[:-n_valid]
X_valid, y_valid = X[-n_valid:], y[-n_valid:]
country_train, country_valid = countries.iloc[:-n_valid], countries.iloc[-n_valid:]


In [202]:
model = BaseModel(seed=SEED)

model.fit(X_train, y_train)

# Predict on training and validation data
y_pred_train = model.predict(X_train)
y_pred_valid = model.predict(X_valid)

# Calculate Mean Squared Error
mse_train = mean_squared_error(y_train, y_pred_train)
mse_valid = mean_squared_error(y_valid, y_pred_valid)
print(f"Training MSE: {mse_train:.4f}")
print(f"Validation MSE: {mse_valid:.4f}")

Training MSE: 0.0046
Validation MSE: 0.4993


In [None]:
# Associate the result by country
y_pred_valid_country = pd.DataFrame({'date': country_valid['date'].values, 'country': country_valid['country'].values, 'y_pred': y_pred_valid, 'y_true': y_valid})
y_pred_train_country = pd.DataFrame({'date': country_train['date'].values, 'country': country_train['country'].values, 'y_pred': y_pred_train, 'y_true': y_train})

print(y_pred_train_country)

# Apply the post-processing function
adjusted_predictions = postprocess_arima(y_pred_train_country, y_pred_valid_country)

# Evaluate adjusted predictions
valid_adjusted = adjusted_predictions[adjusted_predictions['set'] == 'validation']
mse_valid_adjusted = mean_squared_error(valid_adjusted['y_true'], valid_adjusted['y_pred'])
print(f"Adjusted Validation MSE: {mse_valid_adjusted:.4f}")

# Melting the dataframe for better plotting
predictions_melted = adjusted_predictions.melt(
    id_vars=["date", "country"], value_vars=["y_pred", "y_true"],
    var_name="Type", value_name="Value"
)

          date         country    y_pred    y_true
2    -1.771126     Switzerland -0.667305 -0.636078
1227 -1.771126           Korea -1.826280 -1.837559
492  -1.771126  United Kingdom -1.208107 -1.228650
1472 -1.771126   United States -0.488987 -0.539413
247  -1.771126         Germany -1.238617 -1.219147
...        ...             ...       ...       ...
953   1.366397           Japan -0.101467 -0.146885
1688  1.366397   United States  1.946929  2.057429
463   1.366397         Germany  1.387534  1.365879
1198  1.366397          Canada  0.883712  0.938508
1443  1.366397           Korea  0.218995  0.173998

[495 rows x 4 columns]


TypeError: postprocess_arima() missing 3 required positional arguments: 'p', 'd', and 'q'

In [197]:
# Function to plot data with confidence intervals for the selected country
def plot_by_country_with_confidence(selected_country):
    filtered_data = predictions_melted[predictions_melted["country"] == selected_country]
    cutoff_date = filtered_data['date'].quantile(1-VALID_PROPORTION)

    plt.figure(figsize=(12, 6))
    
    # Plot predictions and true values
    sns.lineplot(
        data=filtered_data,
        x="date", y="Value", hue="Type", style="Type", markers=True, dashes=False
    )

    # Add a vertical line to indicate where validation starts
    plt.axvline(x=cutoff_date, color='red', linestyle='--', label=f'Validation Start ({(1-VALID_PROPORTION)*100:.0f}%)')
    
    # Enhancing the plot
    plt.title(f"Prediction vs True Values for {selected_country}")
    plt.xlabel("Date")
    plt.ylabel("Values")
    plt.legend(title="Legend")
    plt.grid(True)
    plt.tight_layout()
    plt.show()

# Create a dropdown widget for selecting the country
countries = adjusted_predictions["country"].unique()
dropdown = widgets.Dropdown(
    options=countries,
    value=countries[0],
    description='Country:'
)

# Use the interact function to link the dropdown with the updated plot function
interact(plot_by_country_with_confidence, selected_country=dropdown)

interactive(children=(Dropdown(description='Country:', options=('Switzerland', 'Germany', 'United Kingdom', 'J…

<function __main__.plot_by_country_with_confidence(selected_country)>