In [1]:
import os

import numpy as np
import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from plotly_resampler import unregister_plotly_resampler

from neuralprophet import NeuralProphet, set_log_level

set_log_level("INFO")

In [2]:
def create_metrics_plot(metrics):
    # Deactivate the resampler since it is not compatible with kaleido (image export)
    unregister_plotly_resampler()

    # Plotly params
    prediction_color = "#2d92ff"
    actual_color = "black"
    line_width = 2
    xaxis_args = {"showline": True, "mirror": True, "linewidth": 1.5, "showgrid": False}
    yaxis_args = {
        "showline": True,
        "mirror": True,
        "linewidth": 1.5,
        "showgrid": False,
        "rangemode": "tozero",
        "type": "log",
    }
    layout_args = {
        "autosize": True,
        "template": "plotly_white",
        "margin": go.layout.Margin(l=0, r=10, b=0, t=30, pad=0),
        "font": dict(size=10),
        "title": dict(font=dict(size=10)),
        "width": 1000,
        "height": 200,
    }

    metric_cols = [col for col in metrics.columns if not ("_val" in col or col == "RegLoss" or col == "epoch")]
    fig = make_subplots(rows=1, cols=len(metric_cols), subplot_titles=metric_cols)
    for i, metric in enumerate(metric_cols):
        fig.add_trace(
            go.Scatter(
                y=metrics[metric],
                name=metric,
                mode="lines",
                line=dict(color=prediction_color, width=line_width),
                legendgroup=metric,
            ),
            row=1,
            col=i + 1,
        )
        if f"{metric}_val" in metrics.columns:
            fig.add_trace(
                go.Scatter(
                    y=metrics[f"{metric}_val"],
                    name=f"{metric}_val",
                    mode="lines",
                    line=dict(color=actual_color, width=line_width),
                    legendgroup=metric,
                ),
                row=1,
                col=i + 1,
            )
        if metric == "Loss":
            fig.add_trace(
                go.Scatter(
                    y=metrics["RegLoss"],
                    name="RegLoss",
                    mode="lines",
                    line=dict(color=actual_color, width=line_width),
                    legendgroup=metric,
                ),
                row=1,
                col=i + 1,
            )
    fig.update_xaxes(xaxis_args)
    fig.update_yaxes(yaxis_args)
    fig.update_layout(layout_args)
    return fig

In [3]:
DIR = "~/github/neural_prophet"
DATA_DIR = os.path.join(DIR, "tests", "test-data")
PEYTON_FILE = os.path.join(DATA_DIR, "wp_log_peyton_manning.csv")
AIR_FILE = os.path.join(DATA_DIR, "air_passengers.csv")
YOS_FILE = os.path.join(DATA_DIR, "yosemite_temps.csv")
ENERGY_PRICE_DAILY_FILE = os.path.join(DATA_DIR, "tutorial04_kaggle_energy_daily_temperature.csv")

In [4]:
df = pd.read_csv(ENERGY_PRICE_DAILY_FILE)
df["temp"] = df["temperature"]
df = df.drop(columns="temperature")
df["ds"] = pd.to_datetime(df["ds"])
df["y"] = pd.to_numeric(df["y"], errors="coerce")

df = df.drop("ds", axis=1)
df["ds"] = pd.date_range(start="2015-01-01 00:00:00", periods=len(df), freq="h")
df["ID"] = "test"

df_id = df[["ds", "y", "temp"]].copy()
df_id["ID"] = "test2"
df_id["y"] = df_id["y"] * 0.3
df_id["temp"] = df_id["temp"] * 0.4
df = pd.concat([df, df_id], ignore_index=True)

# Conditional Seasonality
df["winter"] = np.where(
    df["ds"].dt.month.isin([1]),
    1,
    0,
)
df["summer"] = np.where(df["ds"].dt.month.isin([2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]), 1, 0)
df["winter"] = pd.to_numeric(df["winter"], errors="coerce")
df["summer"] = pd.to_numeric(df["summer"], errors="coerce")

# Normalize Temperature
df["temp"] = (df["temp"] - 65.0) / 50.0

# df
df = df[["ID", "ds", "y", "temp", "winter", "summer"]]

# Split
df_train = df[df["ds"] < "2015-03-01"]
df_test = df[df["ds"] >= "2015-03-01"]

In [5]:
# Hyperparameter
tuned_params = {
    "n_lags": 10,
    "newer_samples_weight": 2.0,
    "n_changepoints": 0,
    "yearly_seasonality": 10,
    "weekly_seasonality": True,
    "daily_seasonality": False,  # due to conditional daily seasonality
    "batch_size": 32,
    "ar_layers": [8, 4],
    "lagged_reg_layers": [8],
    # not tuned
    "n_forecasts": 5,
    # "learning_rate": 0.1,
    "epochs": 10,
    "trend_global_local": "global",
    "season_global_local": "global",
    "drop_missing": True,
    "normalize": "standardize",
}

# Uncertainty Quantification
confidence_lv = 0.98
quantile_list = [round(((1 - confidence_lv) / 2), 2), round((confidence_lv + (1 - confidence_lv) / 2), 2)]
# quantile_list = None
print(f"quantiles: {quantile_list}")

# Check if GPU is available
# use_gpu = torch.cuda.is_available()
use_gpu = False

# Set trainer configuration
trainer_configs = {
    "accelerator": "gpu" if use_gpu else "cpu",
}
print(f"Using {'GPU' if use_gpu else 'CPU'}")

# Model
m = NeuralProphet(**tuned_params, **trainer_configs, quantiles=quantile_list)

# Lagged Regressor
m.add_lagged_regressor(names="temp", n_lags=33, normalize="standardize")

# Conditional Seasonality
m.add_seasonality(name="winter", period=1, fourier_order=6, condition_name="winter")
m.add_seasonality(name="summer", period=1, fourier_order=6, condition_name="summer")

# Holidays
m.add_country_holidays(country_name="US", lower_window=-1, upper_window=1)

quantiles: [0.01, 0.99]
Using CPU


<neuralprophet.forecaster.NeuralProphet at 0x74f53b333610>

In [6]:
# Training & Predict
metrics = m.fit(
    df=df_train,
    validation_df=df_test,
    freq="h",
    early_stopping=False,
    # scheduler="onecyclelr",
    # scheduler_args={
    #     "pct_start": 0.3,
    #     "div_factor": 100.0,
    #     "final_div_factor": 1000.0,
    #     "anneal_strategy": "cos",
    #     "three_phase": False,
    # },
    # scheduler="exponentiallr",
    # scheduler_args={"gamma": 0.8,},
)

INFO - (NP.forecaster.fit) - When Global modeling with local normalization, metrics are displayed in normalized scale.
  converted_ds = pd.to_datetime(ds_col, utc=True).view(dtype=np.int64)

INFO - (NP.df_utils._infer_frequency) - Major frequency h corresponds to 99.929% of the data.
  converted_ds = pd.to_datetime(ds_col, utc=True).view(dtype=np.int64)

  converted_ds = pd.to_datetime(ds_col, utc=True).view(dtype=np.int64)

INFO - (NP.df_utils._infer_frequency) - Defined frequency is equal to major frequency - h
  converted_ds = pd.to_datetime(ds_col, utc=True).view(dtype=np.int64)

INFO - (NP.df_utils._infer_frequency) - Major frequency h corresponds to 99.929% of the data.
  converted_ds = pd.to_datetime(ds_col, utc=True).view(dtype=np.int64)

  converted_ds = pd.to_datetime(ds_col, utc=True).view(dtype=np.int64)

INFO - (NP.df_utils._infer_frequency) - Defined frequency is equal to major frequency - h
  contains_nan = torch.cat([torch.tensor(contains_nan), torch.ones(n_forecasts, d

Training: |          | 0/? [00:00<?, ?it/s]

INFO - (NP.forecaster._train) - No Learning Rate provided. Activating learning rate finder
INFO - (NP.forecaster._train) - Learning rate finder ---- ARGs: {'min_lr': 1e-07, 'max_lr': 10.0, 'num_training': 127, 'early_stop_threshold': None, 'mode': 'exponential'}


Finding best initial lr:   0%|          | 0/127 [00:00<?, ?it/s]



INFO - (NP.utils.smooth_loss_and_suggest) - Learning rate finder ---- default suggestion: 0.012657915866672028
INFO - (NP.utils.smooth_loss_and_suggest) - Learning rate finder ---- steepest: 0.009470610000772239
INFO - (NP.utils.smooth_loss_and_suggest) - Learning rate finder ---- minimum (not used): 1e-07
INFO - (NP.utils.smooth_loss_and_suggest) - Learning rate finder ---- log-avg: 0.010948889651276864
INFO - (NP.utils.smooth_loss_and_suggest) - Learning rate finder ---- returning: 0.010948889651276864
INFO - (NP.utils.smooth_loss_and_suggest) - Learning rate finder ---- LR (start): [1e-07, 1.336547050891115e-07, 1.5451703926941467e-07, 1.7863580192457368e-07, 2.0651929314796276e-07]
INFO - (NP.utils.smooth_loss_and_suggest) - Learning rate finder ---- LR (end): [5.597981979123278, 6.4717781594068, 7.481966305138837, 8.649836012976683, 10.0]
INFO - (NP.utils.smooth_loss_and_suggest) - Learning rate finder ---- LOSS (start): [1.4179426 1.4179426 1.4179426 1.4179426 1.4179426]
INFO -

Training: |          | 0/? [00:00<?, ?it/s]




Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

In [7]:
create_metrics_plot(metrics)

In [8]:
metrics.to_dict("records")[-1]

{'MAE_val': 0.5108088850975037,
 'RMSE_val': 0.5891289114952087,
 'Loss_val': 0.420722633600235,
 'RegLoss_val': 0.0,
 'epoch': 9,
 'MAE': 0.468860924243927,
 'RMSE': 0.6348727941513062,
 'Loss': 0.2833350598812103,
 'RegLoss': 0.0,
 'LR': 0.00014073456986807287}

In [9]:
metrics

Unnamed: 0,MAE_val,RMSE_val,Loss_val,RegLoss_val,epoch,MAE,RMSE,Loss,RegLoss,LR
0,0.51822,0.618144,0.463509,0.0,0,1.071776,1.369143,1.176725,0.0,0.002857
1,0.544104,0.619763,0.485463,0.0,1,0.551133,0.734627,0.364626,0.0,0.00913
2,0.479003,0.555191,0.394251,0.0,2,0.495316,0.666528,0.309311,0.0,0.009187
3,0.516385,0.592248,0.435629,0.0,3,0.481878,0.650643,0.295138,0.0,0.002914
4,0.49294,0.569359,0.405212,0.0,4,0.473522,0.639349,0.288111,0.0,0.001064
5,0.509749,0.587457,0.422531,0.0,5,0.471753,0.637636,0.28606,0.0,0.000888
6,0.51294,0.592517,0.425102,0.0,6,0.470406,0.636353,0.284498,0.0,0.000604
7,0.507922,0.586603,0.418404,0.0,7,0.470709,0.635788,0.283936,0.0,0.000319
8,0.509402,0.588127,0.420135,0.0,8,0.469353,0.635285,0.282983,0.0,0.000142
9,0.510809,0.589129,0.420723,0.0,9,0.468861,0.634873,0.283335,0.0,0.000141


In [10]:
forecast = m.predict(df)


Series.view is deprecated and will be removed in a future version. Use ``astype`` as an alternative to change the dtype.




INFO - (NP.df_utils._infer_frequency) - Major frequency h corresponds to 99.932% of the data.

Series.view is deprecated and will be removed in a future version. Use ``astype`` as an alternative to change the dtype.



Series.view is deprecated and will be removed in a future version. Use ``astype`` as an alternative to change the dtype.


INFO - (NP.df_utils._infer_frequency) - Defined frequency is equal to major frequency - h

Series.view is deprecated and will be removed in a future version. Use ``astype`` as an alternative to change the dtype.


INFO - (NP.df_utils._infer_frequency) - Major frequency h corresponds to 99.932% of the data.

Series.view is deprecated and will be removed in a future version. Use ``astype`` as an alternative to change the dtype.



Series.view is deprecated and will be removed in a future version. Use ``astype`` as an alternative to change the dtype.


INFO - (NP.df_utils._infer_frequency) - Defined frequency is equal to major frequency - h

Series.view

Predicting: |          | 0/? [00:00<?, ?it/s]


To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).




Predicting: |          | 0/? [00:00<?, ?it/s]

In [11]:
m.highlight_nth_step_ahead_of_each_forecast(m.n_forecasts)
m.plot(forecast, df_name="test")

INFO - (NP.forecaster.plot) - Plotting data from ID test

The behavior of DatetimeProperties.to_pydatetime is deprecated, in a future version this will return a Series containing python datetime objects instead of an ndarray. To retain the old behavior, call `np.array` on the result



'H' is deprecated and will be removed in a future version. Please use 'h' instead of 'H'.



'H' is deprecated and will be removed in a future version, please use 'h' instead.



'H' is deprecated and will be removed in a future version. Please use 'h' instead of 'H'.



'H' is deprecated and will be removed in a future version, please use 'h' instead.



'H' is deprecated and will be removed in a future version. Please use 'h' instead of 'H'.



'H' is deprecated and will be removed in a future version, please use 'h' instead.



'H' is deprecated and will be removed in a future version. Please use 'h' instead of 'H'.



'H' is deprecated and will be removed in a future version, please use 'h' instead.


FigureWidgetResampler({
    'data': [{'fillcolor': 'rgba(45, 146, 255, 0.2)',
              'line': {'color': 'rgba(45, 146, 255, 0.2)', 'width': 1},
              'mode': 'lines',
              'name': '<b style="color:sandybrown">[R]</b> yhat5 1.0% <i style="color:#fc9944">~1h</i>',
              'type': 'scatter',
              'uid': '1e485c1d-dae9-439f-97e8-f8960bf19265',
              'x': array([datetime.datetime(2015, 1, 2, 13, 0),
                          datetime.datetime(2015, 1, 2, 14, 0),
                          datetime.datetime(2015, 1, 2, 15, 0), ...,
                          datetime.datetime(2015, 3, 2, 17, 0),
                          datetime.datetime(2015, 3, 2, 18, 0),
                          datetime.datetime(2015, 3, 2, 20, 0)], dtype=object),
              'y': array([-8.401451, -8.331238, -7.641697, ..., 35.0834  , 31.378742, 26.125694],
                         dtype=float32)},
             {'fill': 'tonexty',
              'fillcolor': 'rgba(45, 146, 

In [12]:
m.plot_components(forecast, df_name="test")

INFO - (NP.forecaster.plot_components) - Plotting data from ID test

The behavior of DatetimeProperties.to_pydatetime is deprecated, in a future version this will return a Series containing python datetime objects instead of an ndarray. To retain the old behavior, call `np.array` on the result



'H' is deprecated and will be removed in a future version. Please use 'h' instead of 'H'.



'H' is deprecated and will be removed in a future version, please use 'h' instead.



The behavior of DatetimeProperties.to_pydatetime is deprecated, in a future version this will return a Series containing python datetime objects instead of an ndarray. To retain the old behavior, call `np.array` on the result



'H' is deprecated and will be removed in a future version. Please use 'h' instead of 'H'.



'H' is deprecated and will be removed in a future version, please use 'h' instead.



The behavior of DatetimeProperties.to_pydatetime is deprecated, in a future version this will return a Series conta

IndexError: index -1 is out of bounds for axis 0 with size 0

In [None]:
m.plot_parameters()


The behavior of DatetimeProperties.to_pydatetime is deprecated, in a future version this will return a Series containing python datetime objects instead of an ndarray. To retain the old behavior, call `np.array` on the result



The behavior of DatetimeProperties.to_pydatetime is deprecated, in a future version this will return a Series containing python datetime objects instead of an ndarray. To retain the old behavior, call `np.array` on the result



The behavior of DatetimeProperties.to_pydatetime is deprecated, in a future version this will return a Series containing python datetime objects instead of an ndarray. To retain the old behavior, call `np.array` on the result



'H' is deprecated and will be removed in a future version, please use 'h' instead.




FigureWidgetResampler({
    'data': [{'fill': 'none',
              'line': {'color': '#2d92ff', 'width': 2},
              'mode': 'lines',
              'name': 'Trend',
              'type': 'scatter',
              'uid': 'a87ad6c0-8302-4d80-a053-3de55909d7d9',
              'x': array([datetime.datetime(2015, 1, 1, 0, 0),
                          datetime.datetime(2015, 2, 28, 23, 0)], dtype=object),
              'xaxis': 'x',
              'y': array([44.171093, 46.755657], dtype=float32),
              'yaxis': 'y'},
             {'fill': 'none',
              'line': {'color': '#2d92ff', 'width': 2},
              'mode': 'lines',
              'name': 'yearly',
              'type': 'scatter',
              'uid': '687528c8-278a-4dba-a432-7580315842b3',
              'x': array([datetime.datetime(2017, 1, 1, 0, 0),
                          datetime.datetime(2017, 1, 2, 0, 0),
                          datetime.datetime(2017, 1, 3, 0, 0), ...,
                          datet