In [None]:
# demo_forecasting_comparison.ipynb

import matplotlib.pyplot as plt
import pandas as pd
from data.data_generator import DataGenerator
from data.google_cluster_data import GoogleClusterData
from clustering.cluster_model import ClusterModel
from forecasting.regression_model import RegressionModel
from forecasting.arima_model import ARIMAModel
from forecasting.prophet_model import ProphetModel
from forecasting.lstm_model import LSTMModel
from forecasting.xgboost_model import XGBoostModel
from forecasting.ensemble_model import EnsembleModel
from utils.utils import calculate_rmse
import warnings
warnings.filterwarnings('ignore')

# Load synthetic or Google data
source = "synthetic"  # or "google"
if source == "synthetic":
    data = DataGenerator(n_samples=1000).generate()
else:
    data = GoogleClusterData("data/google_cluster_sample.csv").load_data()

features = [col for col in data.columns if col != 'demand']

# Cluster data
cluster_model = ClusterModel(n_clusters=4)
data['cluster'], _ = cluster_model.fit_predict(data, features)

# Initialize models
models = {
    "Regression": RegressionModel(),
    "ARIMA": ARIMAModel(),
    "Prophet": ProphetModel(),
    "LSTM": LSTMModel(epochs=5),  # fewer epochs for demo
    "XGBoost": XGBoostModel(),
    "Ensemble": EnsembleModel([RegressionModel(), XGBoostModel(), ARIMAModel()])
}

results = {}

for name, model in models.items():
    print(f"Training {name}...")
    try:
        model.fit(data, features, 'demand')
        preds = model.predict(data, features)
        rmse = calculate_rmse(data['demand'], preds)
        results[name] = {
            "predictions": preds,
            "rmse": rmse
        }
        print(f"{name} RMSE: {rmse:.4f}")
    except Exception as e:
        print(f"Error with {name}: {e}")

# Plot actual vs predictions
plt.figure(figsize=(14, 8))
plt.plot(data['demand'].values, label='Actual', linewidth=2)

for name, res in results.items():
    plt.plot(res["predictions"], label=f'{name} (RMSE: {res["rmse"]:.2f})', alpha=0.7)

plt.legend()
plt.title("Demand Forecasting: Actual vs Predictions")
plt.xlabel("Time")
plt.ylabel("Demand")
plt.show()
