# GluonTS - Simple model fit and evaluation

This example shows how to fit a model and evaluate its predictions. 

- Based on `model_evaluate.py` from GluonTS examples

In [1]:
# imports 
from gluonts.dataset.repository.datasets import get_dataset, dataset_recipes
from gluonts.evaluation import Evaluator
from gluonts.evaluation.backtest import make_evaluation_predictions
from gluonts.model.simple_feedforward import SimpleFeedForwardEstimator
from gluonts.trainer import Trainer

if __name__ == "__main__":
    print("Available datasets:")
    print(dataset_recipes.keys())
    
    # get m4_weekly
    dataset = get_dataset("m4_hourly", regenerate=False)
    
# define estimator 
    estimator = SimpleFeedForwardEstimator(
        prediction_length=dataset.metadata.prediction_length,
        freq=dataset.metadata.freq,
        trainer=Trainer(epochs=10, num_batches_per_epoch=10)
    )


INFO:root:Using CPU


Available datasets:
odict_keys(['constant', 'exchange_rate', 'solar-energy', 'electricity', 'traffic', 'm4_hourly', 'm4_daily', 'm4_weekly', 'm4_monthly', 'm4_quarterly', 'm4_yearly'])


INFO:root:using dataset already processed in path C:\Users\TM\.mxnet\gluon-ts\datasets\m4_hourly.
INFO:root:Using CPU


In [2]:
%%time

# train estimator
predictor=estimator.train(dataset.train)

INFO:root:Start model training
INFO:root:Number of parameters in SimpleFeedForwardTrainingNetwork: 1963
INFO:root:Epoch[0] Learning rate is 0.001
100%|██████████| 10/10 [00:00<00:00, 58.86it/s, avg_epoch_loss=6.77]
INFO:root:Epoch[0] Elapsed time 0.176 seconds
INFO:root:Epoch[0] Evaluation metric 'epoch_loss'=6.765577
INFO:root:Epoch[1] Learning rate is 0.001
100%|██████████| 10/10 [00:00<00:00, 65.83it/s, avg_epoch_loss=6.43]
INFO:root:Epoch[1] Elapsed time 0.159 seconds
INFO:root:Epoch[1] Evaluation metric 'epoch_loss'=6.429865
INFO:root:Epoch[2] Learning rate is 0.001
100%|██████████| 10/10 [00:00<00:00, 66.71it/s, avg_epoch_loss=5.01]
INFO:root:Epoch[2] Elapsed time 0.157 seconds
INFO:root:Epoch[2] Evaluation metric 'epoch_loss'=5.011007
INFO:root:Epoch[3] Learning rate is 0.001
100%|██████████| 10/10 [00:00<00:00, 57.84it/s, avg_epoch_loss=4.29]
INFO:root:Epoch[3] Elapsed time 0.191 seconds
INFO:root:Epoch[3] Evaluation metric 'epoch_loss'=4.291459
INFO:root:Epoch[4] Learning rate

Wall time: 1.79 s


In [3]:
%%time

# make predictions with model
forecast_it, ts_it = make_evaluation_predictions(
    dataset.test, predictor=predictor, num_eval_samples=100
)


Wall time: 998 µs


In [4]:
%%time

# evalate forecasts
evaluator = Evaluator(quantiles=[0.5])
agg_metrics, item_metrics = evaluator(
    iter(ts_it),
    iter(forecast_it),
    num_series=len(dataset.test)
)

Running evaluation: 100%|██████████| 414/414 [00:03<00:00, 106.99it/s]


Wall time: 3.9 s


In [5]:
# show forecast metrics 
import json
print(json.dumps(agg_metrics, indent=4))

{
    "MSE": 6154841.6253044605,
    "abs_error": 8651116.674880981,
    "abs_target_sum": 145558863.59960938,
    "abs_target_mean": 7324.822041043147,
    "seasonal_error": 336.9046924038302,
    "MASE": 2.941287154279122,
    "sMAPE": 0.19852873327463827,
    "MSIS": 117.65148549229245,
    "QuantileLoss[0.5]": 8651116.619938374,
    "Coverage[0.5]": 0.43523550724637655,
    "RMSE": 2480.895327357537,
    "NRMSE": 0.3386970104469905,
    "ND": 0.05943380197497088,
    "wQuantileLoss[0.5]": 0.05943380159751117,
    "mean_wQuantileLoss": 0.05943380159751117,
    "MAE_Coverage": 0.06476449275362345
}


In [6]:
# time-series individual metrics
item_metrics.head()

Unnamed: 0,item_id,MSE,abs_error,abs_target_sum,abs_target_mean,seasonal_error,MASE,sMAPE,MSIS,QuantileLoss[0.5],Coverage[0.5]
0,,1530.751953,1509.054565,31644.0,659.25,42.371302,0.741979,0.046207,29.679178,1509.054565,0.708333
1,,119231.614583,15254.0,124149.0,2586.4375,165.107988,1.92475,0.119965,76.990016,15253.999146,1.0
2,,29918.789062,6537.868164,65030.0,1354.791667,78.889053,1.726546,0.09494,69.061846,6537.868225,0.208333
3,,183152.75,16456.787109,235783.0,4912.145833,258.982249,1.323835,0.068058,52.953396,16456.787109,0.458333
4,,88222.541667,10659.682617,131088.0,2731.0,200.494083,1.107647,0.076674,44.305888,10659.682251,0.645833
