In [1]:
import pandas as pd
import os
from pathlib import Path
import xgboost as xgb
from sklearn.linear_model import LinearRegression
from sklearn.tree import DecisionTreeRegressor
from utilsforecast.losses import mae
from mlforecast import MLForecast
#from src.data.feature_engineering import date_features, lags

In [2]:
PROJECT_ROOT = Path.cwd().parent
train_path = os.path.join(PROJECT_ROOT, 'data', 'preprocessed', 'train.csv')  
test_path = os.path.join(PROJECT_ROOT, 'data', 'preprocessed', 'test.csv')  

In [None]:
print(f'>>Loading data...')
train = pd.read_csv(train_path, parse_dates=['ds'])
test = pd.read_csv(test_path, parse_dates=['ds'])
print(f'✅ data Loaded succefully: Train shape: {train.shape}, Test shape: {test.shape}')


>>Loading data...
✅ Train shape: (131470, 7), Test shape: (13896, 6)


In [17]:
models = {
    'lreg': LinearRegression(),
    'dt': DecisionTreeRegressor(),
    'xgb': xgb.XGBRegressor()
}


In [23]:
ml = MLForecast( models=models,
                 freq='h',
                 lags=[1, 24],
                 date_features=['dayofweek', 'hour'],)

print(f'Training models using cross validation...')
cv_df = ml.cross_validation(
                h=20,
                df=train[:2000],
                n_windows=2,
                step_size=2,
                refit=True,
                static_features=[]
)

print(f'✅ Models trained successfully!')
print(f'✅ cross validation results', cv_df.head())

Training models using cross validation...
✅ Models trained successfully!
✅ cross validation results   unique_id                  ds              cutoff        y          lreg  \
0         A 2002-03-24 11:00:00 2002-03-24 10:00:00  27106.0  27075.545280   
1         A 2002-03-24 12:00:00 2002-03-24 10:00:00  26736.0  26591.228522   
2         A 2002-03-24 13:00:00 2002-03-24 10:00:00  26424.0  25947.203023   
3         A 2002-03-24 14:00:00 2002-03-24 10:00:00  25905.0  25480.186936   
4         A 2002-03-24 15:00:00 2002-03-24 10:00:00  25640.0  25423.347949   

        dt           xgb  
0  27361.0  26795.908203  
1  27403.0  27130.167969  
2  28156.0  27544.505859  
3  28700.0  27133.662109  
4  28434.0  26808.230469  
