In [1]:
import os
import pathlib
import torch

import numpy as np
import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from plotly_resampler import unregister_plotly_resampler

from neuralprophet import NeuralProphet, set_random_seed, set_log_level

set_log_level("INFO")

In [2]:
def create_metrics_plot(metrics):
    # Deactivate the resampler since it is not compatible with kaleido (image export)
    unregister_plotly_resampler()

    # Plotly params
    prediction_color = "#2d92ff"
    actual_color = "black"
    line_width = 2
    xaxis_args = {"showline": True, "mirror": True, "linewidth": 1.5, "showgrid": False}
    yaxis_args = {
        "showline": True,
        "mirror": True,
        "linewidth": 1.5,
        "showgrid": False,
        "rangemode": "tozero",
        "type": "log",
    }
    layout_args = {
        "autosize": True,
        "template": "plotly_white",
        "margin": go.layout.Margin(l=0, r=10, b=0, t=30, pad=0),
        "font": dict(size=10),
        "title": dict(font=dict(size=10)),
        "width": 1000,
        "height": 200,
    }

    metric_cols = [col for col in metrics.columns if not ("_val" in col or col == "RegLoss" or col == "epoch")]
    fig = make_subplots(rows=1, cols=len(metric_cols), subplot_titles=metric_cols)
    for i, metric in enumerate(metric_cols):
        fig.add_trace(
            go.Scatter(
                y=metrics[metric],
                name=metric,
                mode="lines",
                line=dict(color=prediction_color, width=line_width),
                legendgroup=metric,
            ),
            row=1,
            col=i + 1,
        )
        if f"{metric}_val" in metrics.columns:
            fig.add_trace(
                go.Scatter(
                    y=metrics[f"{metric}_val"],
                    name=f"{metric}_val",
                    mode="lines",
                    line=dict(color=actual_color, width=line_width),
                    legendgroup=metric,
                ),
                row=1,
                col=i + 1,
            )
        if metric == "Loss":
            fig.add_trace(
                go.Scatter(
                    y=metrics["RegLoss"],
                    name="RegLoss",
                    mode="lines",
                    line=dict(color=actual_color, width=line_width),
                    legendgroup=metric,
                ),
                row=1,
                col=i + 1,
            )
    fig.update_xaxes(xaxis_args)
    fig.update_yaxes(yaxis_args)
    fig.update_layout(layout_args)
    return fig

In [3]:
DIR = "~/github/neural_prophet"
DATA_DIR = os.path.join(DIR, "tests", "test-data")
PEYTON_FILE = os.path.join(DATA_DIR, "wp_log_peyton_manning.csv")
AIR_FILE = os.path.join(DATA_DIR, "air_passengers.csv")
YOS_FILE = os.path.join(DATA_DIR, "yosemite_temps.csv")
ENERGY_PRICE_DAILY_FILE = os.path.join(DATA_DIR, "tutorial04_kaggle_energy_daily_temperature.csv")

In [4]:
df = pd.read_csv(ENERGY_PRICE_DAILY_FILE)
df["temp"] = df["temperature"]
df = df.drop(columns="temperature")
df["ds"] = pd.to_datetime(df["ds"])
df["y"] = pd.to_numeric(df["y"], errors="coerce")

df = df.drop("ds", axis=1)
df["ds"] = pd.date_range(start="2015-01-01 00:00:00", periods=len(df), freq="h")
df["ID"] = "test"

df_id = df[["ds", "y", "temp"]].copy()
df_id["ID"] = "test2"
df_id["y"] = df_id["y"] * 0.3
df_id["temp"] = df_id["temp"] * 0.4
df = pd.concat([df, df_id], ignore_index=True)

# Conditional Seasonality
df["winter"] = np.where(
    df["ds"].dt.month.isin([1]),
    1,
    0,
)
df["summer"] = np.where(df["ds"].dt.month.isin([2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]), 1, 0)
df["winter"] = pd.to_numeric(df["winter"], errors="coerce")
df["summer"] = pd.to_numeric(df["summer"], errors="coerce")

# Normalize Temperature
df["temp"] = (df["temp"] - 65.0) / 50.0

# df
df = df[["ID", "ds", "y", "temp", "winter", "summer"]]

# Split
df_train = df[df["ds"] < "2015-03-01"]
df_test = df[df["ds"] >= "2015-03-01"]

In [5]:
# Hyperparameter
tuned_params = {
    "n_lags": 10,
    "newer_samples_weight": 2.0,
    "n_changepoints": 0,
    "yearly_seasonality": 10,
    "weekly_seasonality": True,
    "daily_seasonality": False,  # due to conditional daily seasonality
    "batch_size": 32,
    "ar_layers": [8, 4],
    "lagged_reg_layers": [8],
    # not tuned
    "n_forecasts": 5,
    # "learning_rate": 0.1,
    "epochs": 10,
    "trend_global_local": "global",
    "season_global_local": "global",
    "drop_missing": True,
    "normalize": "standardize",
}

# Uncertainty Quantification
confidence_lv = 0.98
quantile_list = [round(((1 - confidence_lv) / 2), 2), round((confidence_lv + (1 - confidence_lv) / 2), 2)]
# quantile_list = None
print(f"quantiles: {quantile_list}")

# Check if GPU is available
# use_gpu = torch.cuda.is_available()
use_gpu = False

# Set trainer configuration
trainer_configs = {
    "accelerator": "gpu" if use_gpu else "cpu",
}
print(f"Using {'GPU' if use_gpu else 'CPU'}")

# Model
m = NeuralProphet(**tuned_params, **trainer_configs, quantiles=quantile_list)

# Lagged Regressor
m.add_lagged_regressor(names="temp", n_lags=33, normalize="standardize")

# Conditional Seasonality
m.add_seasonality(name="winter", period=1, fourier_order=6, condition_name="winter")
m.add_seasonality(name="summer", period=1, fourier_order=6, condition_name="summer")

# Holidays
m.add_country_holidays(country_name="US", lower_window=-1, upper_window=1)

quantiles: [0.01, 0.99]
Using CPU


<neuralprophet.forecaster.NeuralProphet at 0x7bc282101f30>

In [6]:
# Training & Predict
metrics = m.fit(
    df=df_train,
    validation_df=df_test,
    freq="h",
    early_stopping=False,
    # scheduler="onecyclelr",
    # scheduler_args={
    #     "pct_start": 0.3,
    #     "div_factor": 100.0,
    #     "final_div_factor": 1000.0,
    #     "anneal_strategy": "cos",
    #     "three_phase": False,
    # },
    # scheduler="exponentiallr",
    # scheduler_args={"gamma": 0.8,},
)

INFO - (NP.forecaster.fit) - When Global modeling with local normalization, metrics are displayed in normalized scale.
  converted_ds = pd.to_datetime(ds_col, utc=True).view(dtype=np.int64)

INFO - (NP.df_utils._infer_frequency) - Major frequency h corresponds to 99.929% of the data.
  converted_ds = pd.to_datetime(ds_col, utc=True).view(dtype=np.int64)

  converted_ds = pd.to_datetime(ds_col, utc=True).view(dtype=np.int64)

INFO - (NP.df_utils._infer_frequency) - Defined frequency is equal to major frequency - h
  converted_ds = pd.to_datetime(ds_col, utc=True).view(dtype=np.int64)

INFO - (NP.df_utils._infer_frequency) - Major frequency h corresponds to 99.929% of the data.
  converted_ds = pd.to_datetime(ds_col, utc=True).view(dtype=np.int64)

  converted_ds = pd.to_datetime(ds_col, utc=True).view(dtype=np.int64)

INFO - (NP.df_utils._infer_frequency) - Defined frequency is equal to major frequency - h
  contains_nan = torch.cat([torch.tensor(contains_nan), torch.ones(n_forecasts, d

Training: |          | 0/? [00:00<?, ?it/s]

INFO - (NP.forecaster._train) - No Learning Rate provided. Activating learning rate finder
INFO - (NP.forecaster._train) - Learning rate finder ---- ARGs: {'min_lr': 1e-08, 'max_lr': 10.0, 'num_training': 168, 'early_stop_threshold': None, 'mode': 'exponential'}


Finding best initial lr:   0%|          | 0/168 [00:00<?, ?it/s]



INFO - (NP.utils.smooth_loss_and_suggest) - Learning rate finder ---- default suggestion: 0.011312834366320213
INFO - (NP.utils.smooth_loss_and_suggest) - Learning rate finder ---- steepest: 0.011312834366320213
INFO - (NP.utils.smooth_loss_and_suggest) - Learning rate finder ---- minimum: 0.8483428982440717
INFO - (NP.utils.smooth_loss_and_suggest) - Learning rate finder ---- log-avg: 0.04770582696143933
INFO - (NP.utils.smooth_loss_and_suggest) - Learning rate finder ---- returning: 0.04770582696143933
INFO - (NP.utils.smooth_loss_and_suggest) - Learning rate finder ---- LR (start): [1e-08, 1.2798022139979537e-08, 1.447819046860873e-08, 1.637893706954064e-08, 1.8529220216409524e-08]
INFO - (NP.utils.smooth_loss_and_suggest) - Learning rate finder ---- LR (end): [6.105402296585325, 6.906940492102072, 7.813707376518102, 8.839517733744358, 10.0]
INFO - (NP.utils.smooth_loss_and_suggest) - Learning rate finder ---- LOSS (start): [4.03609179 4.03609179 4.03609179 4.03609179 4.03609179]


Training: |          | 0/? [00:00<?, ?it/s]




Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

In [7]:
create_metrics_plot(metrics)

In [8]:
metrics.to_dict("records")[-1]

{'MAE_val': 0.688785970211029,
 'RMSE_val': 0.8063586950302124,
 'Loss_val': 0.5771589875221252,
 'RegLoss_val': 0.0,
 'epoch': 9,
 'MAE': 0.45611318945884705,
 'RMSE': 0.6116937398910522,
 'Loss': 0.26712149381637573,
 'RegLoss': 0.0,
 'LR': 0.0006131998379714787}

In [9]:
metrics

Unnamed: 0,MAE_val,RMSE_val,Loss_val,RegLoss_val,epoch,MAE,RMSE,Loss,RegLoss,LR
0,0.710609,0.819329,0.622614,0.0,0,1.144053,1.562842,1.368656,0.0,0.012448
1,0.836989,0.946932,0.733583,0.0,1,0.532974,0.702971,0.343495,0.0,0.039781
2,0.588745,0.704277,0.49729,0.0,2,0.495191,0.658636,0.304316,0.0,0.040028
3,0.699847,0.818369,0.594933,0.0,3,0.475755,0.632438,0.283402,0.0,0.012695
4,0.70467,0.828111,0.594259,0.0,4,0.460465,0.615198,0.271323,0.0,0.004634
5,0.648891,0.755905,0.53024,0.0,5,0.458983,0.614499,0.270039,0.0,0.003871
6,0.715093,0.839661,0.608006,0.0,6,0.459262,0.614916,0.269727,0.0,0.002631
7,0.689927,0.807763,0.577745,0.0,7,0.455545,0.609835,0.266736,0.0,0.001389
8,0.646029,0.753095,0.530006,0.0,8,0.45711,0.61197,0.267934,0.0,0.000618
9,0.688786,0.806359,0.577159,0.0,9,0.456113,0.611694,0.267121,0.0,0.000613


In [10]:
forecast = m.predict(df)


Series.view is deprecated and will be removed in a future version. Use ``astype`` as an alternative to change the dtype.


INFO - (NP.df_utils._infer_frequency) - Major frequency h corresponds to 99.932% of the data.

Series.view is deprecated and will be removed in a future version. Use ``astype`` as an alternative to change the dtype.



Series.view is deprecated and will be removed in a future version. Use ``astype`` as an alternative to change the dtype.


INFO - (NP.df_utils._infer_frequency) - Defined frequency is equal to major frequency - h

Series.view is deprecated and will be removed in a future version. Use ``astype`` as an alternative to change the dtype.


INFO - (NP.df_utils._infer_frequency) - Major frequency h corresponds to 99.932% of the data.

Series.view is deprecated and will be removed in a future version. Use ``astype`` as an alternative to change the dtype.



Series.view is deprecated and will be removed in a future version. Use ``astype`` as an alternative t

Predicting: |          | 0/? [00:00<?, ?it/s]


To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).




Predicting: |          | 0/? [00:00<?, ?it/s]

In [11]:
m.highlight_nth_step_ahead_of_each_forecast(m.n_forecasts)
m.plot(forecast, df_name="test")

INFO - (NP.forecaster.plot) - Plotting data from ID test

The behavior of DatetimeProperties.to_pydatetime is deprecated, in a future version this will return a Series containing python datetime objects instead of an ndarray. To retain the old behavior, call `np.array` on the result



'H' is deprecated and will be removed in a future version. Please use 'h' instead of 'H'.



'H' is deprecated and will be removed in a future version, please use 'h' instead.



'H' is deprecated and will be removed in a future version. Please use 'h' instead of 'H'.



'H' is deprecated and will be removed in a future version, please use 'h' instead.



'H' is deprecated and will be removed in a future version. Please use 'h' instead of 'H'.



'H' is deprecated and will be removed in a future version, please use 'h' instead.



'H' is deprecated and will be removed in a future version. Please use 'h' instead of 'H'.



'H' is deprecated and will be removed in a future version, please use 'h' instead.


FigureWidgetResampler({
    'data': [{'fillcolor': 'rgba(45, 146, 255, 0.2)',
              'line': {'color': 'rgba(45, 146, 255, 0.2)', 'width': 1},
              'mode': 'lines',
              'name': '<b style="color:sandybrown">[R]</b> yhat5 1.0% <i style="color:#fc9944">~1h</i>',
              'type': 'scatter',
              'uid': 'cceaf554-f88b-47ac-b077-bd98eebd51bd',
              'x': array([datetime.datetime(2015, 1, 2, 13, 0),
                          datetime.datetime(2015, 1, 2, 14, 0),
                          datetime.datetime(2015, 1, 2, 15, 0), ...,
                          datetime.datetime(2015, 3, 2, 17, 0),
                          datetime.datetime(2015, 3, 2, 19, 0),
                          datetime.datetime(2015, 3, 2, 20, 0)], dtype=object),
              'y': array([21.843597, 25.104948, 33.001038, ..., 46.6899  , 41.747295, 48.700737],
                         dtype=float32)},
             {'fill': 'tonexty',
              'fillcolor': 'rgba(45, 146, 

In [12]:
m.plot_components(forecast, df_name="test")

INFO - (NP.forecaster.plot_components) - Plotting data from ID test

The behavior of DatetimeProperties.to_pydatetime is deprecated, in a future version this will return a Series containing python datetime objects instead of an ndarray. To retain the old behavior, call `np.array` on the result



'H' is deprecated and will be removed in a future version. Please use 'h' instead of 'H'.



'H' is deprecated and will be removed in a future version, please use 'h' instead.



The behavior of DatetimeProperties.to_pydatetime is deprecated, in a future version this will return a Series containing python datetime objects instead of an ndarray. To retain the old behavior, call `np.array` on the result



'H' is deprecated and will be removed in a future version. Please use 'h' instead of 'H'.



'H' is deprecated and will be removed in a future version, please use 'h' instead.



The behavior of DatetimeProperties.to_pydatetime is deprecated, in a future version this will return a Series conta

IndexError: index -1 is out of bounds for axis 0 with size 0

In [None]:
m.plot_parameters()


The behavior of DatetimeProperties.to_pydatetime is deprecated, in a future version this will return a Series containing python datetime objects instead of an ndarray. To retain the old behavior, call `np.array` on the result



The behavior of DatetimeProperties.to_pydatetime is deprecated, in a future version this will return a Series containing python datetime objects instead of an ndarray. To retain the old behavior, call `np.array` on the result



The behavior of DatetimeProperties.to_pydatetime is deprecated, in a future version this will return a Series containing python datetime objects instead of an ndarray. To retain the old behavior, call `np.array` on the result



'H' is deprecated and will be removed in a future version, please use 'h' instead.




FigureWidgetResampler({
    'data': [{'fill': 'none',
              'line': {'color': '#2d92ff', 'width': 2},
              'mode': 'lines',
              'name': 'Trend',
              'type': 'scatter',
              'uid': '6d93394c-496d-4c52-8c79-e1a55e9bff0d',
              'x': array([datetime.datetime(2015, 1, 1, 0, 0),
                          datetime.datetime(2015, 2, 28, 23, 0)], dtype=object),
              'xaxis': 'x',
              'y': array([35.735615, 26.46712 ], dtype=float32),
              'yaxis': 'y'},
             {'fill': 'none',
              'line': {'color': '#2d92ff', 'width': 2},
              'mode': 'lines',
              'name': 'yearly',
              'type': 'scatter',
              'uid': '91e577de-4754-46f9-a832-ff20851064d6',
              'x': array([datetime.datetime(2017, 1, 1, 0, 0),
                          datetime.datetime(2017, 1, 2, 0, 0),
                          datetime.datetime(2017, 1, 3, 0, 0), ...,
                          datet