# Tutorial 3: Seasonality

We will explore the seasonality component of NeuralProphet. It allows to model to capture seasonal effects in the data, for example effects which repeat the same time every year or the same day of every week.

We start with the same code as in the previous tutorial on trends.

In [1]:
import pandas as pd
from neuralprophet import NeuralProphet, set_log_level

# Hide all logging messages unless they are errors
set_log_level("ERROR")

# Load the dataset from the CSV file using pandas
df = pd.read_csv("https://github.com/ourownstory/neuralprophet-data/raw/main/kaggle-energy/datasets/tutorial01.csv")

# Reduce the number of epochs to 5 for faster training
EPOCHS = 5

# Model and prediction
m = NeuralProphet(
    epochs=EPOCHS,
    # Disable trend changepoints
    n_changepoints=0,
    # Disable seasonality components
    yearly_seasonality=False,
    weekly_seasonality=False,
    daily_seasonality=False,
)
metrics = m.fit(df)
forecast = m.predict(df)
m.plot(forecast)

Finding best initial lr:   0%|          | 0/229 [00:00<?, ?it/s]

Training: 0it [00:00, ?it/s]

Predicting: 46it [00:00, ?it/s]

FigureWidgetResampler({
    'data': [{'fill': 'none',
              'line': {'color': 'rgba(45, 146, 255, 1.0)', 'width': 2},
              'mode': 'lines',
              'name': '<b style="color:sandybrown">[R]</b> yhat1 <i style="color:#fc9944">~1D</i>',
              'type': 'scatter',
              'uid': 'ca539807-26a0-455b-a4d7-1e82c53c2066',
              'x': array([datetime.datetime(2014, 12, 31, 0, 0),
                          datetime.datetime(2015, 1, 1, 0, 0),
                          datetime.datetime(2015, 1, 2, 0, 0), ...,
                          datetime.datetime(2018, 12, 28, 0, 0),
                          datetime.datetime(2018, 12, 29, 0, 0),
                          datetime.datetime(2018, 12, 31, 0, 0)], dtype=object),
              'y': array([72.06973267, 72.05545044, 72.04116821, ..., 51.24927521, 51.23500061,
                          51.20643616])},
             {'marker': {'color': 'black', 'size': 4},
              'mode': 'markers',
              'n

In [2]:
m.plot_components(forecast)

FigureWidgetResampler({
    'data': [{'line': {'color': '#2d92ff', 'width': 2},
              'mode': 'lines',
              'name': '<b style="color:sandybrown">[R]</b> Trend <i style="color:#fc9944">~1D</i>',
              'showlegend': False,
              'type': 'scatter',
              'uid': '20a0341a-336c-4f24-8788-0aac3edbf90b',
              'x': array([datetime.datetime(2014, 12, 31, 0, 0),
                          datetime.datetime(2015, 1, 1, 0, 0),
                          datetime.datetime(2015, 1, 2, 0, 0), ...,
                          datetime.datetime(2018, 12, 28, 0, 0),
                          datetime.datetime(2018, 12, 29, 0, 0),
                          datetime.datetime(2018, 12, 31, 0, 0)], dtype=object),
              'xaxis': 'x',
              'y': array([72.06973267, 72.05545044, 72.04116821, ..., 51.24927521, 51.23500061,
                          51.20643616]),
              'yaxis': 'y'}],
    'layout': {'autosize': True,
               'font': {'

Let us enable the seasonality step by step again and see what effects it has on the model.

In [3]:
# Model and prediction
m = NeuralProphet(
    epochs=EPOCHS,
    # Disable trend changepoints
    n_changepoints=0,
    # Disable seasonality components, except yearly
    yearly_seasonality=True,
    weekly_seasonality=False,
    daily_seasonality=False,
)
metrics = m.fit(df)
forecast = m.predict(df)
m.plot(forecast)

Finding best initial lr:   0%|          | 0/229 [00:00<?, ?it/s]

Training: 0it [00:00, ?it/s]

Predicting: 46it [00:00, ?it/s]

FigureWidgetResampler({
    'data': [{'fill': 'none',
              'line': {'color': 'rgba(45, 146, 255, 1.0)', 'width': 2},
              'mode': 'lines',
              'name': '<b style="color:sandybrown">[R]</b> yhat1 <i style="color:#fc9944">~1D</i>',
              'type': 'scatter',
              'uid': '1ecb85f8-0e17-4bea-b315-67efec8c3079',
              'x': array([datetime.datetime(2014, 12, 31, 0, 0),
                          datetime.datetime(2015, 1, 1, 0, 0),
                          datetime.datetime(2015, 1, 2, 0, 0), ...,
                          datetime.datetime(2018, 12, 28, 0, 0),
                          datetime.datetime(2018, 12, 30, 0, 0),
                          datetime.datetime(2018, 12, 31, 0, 0)], dtype=object),
              'y': array([55.95669556, 55.75299072, 55.57995605, ..., 66.74468231, 66.19564819,
                          65.96237946])},
             {'marker': {'color': 'black', 'size': 4},
              'mode': 'markers',
              'n

In [4]:
m.plot_components(forecast)

FigureWidgetResampler({
    'data': [{'line': {'color': '#2d92ff', 'width': 2},
              'mode': 'lines',
              'name': '<b style="color:sandybrown">[R]</b> Trend <i style="color:#fc9944">~1D</i>',
              'showlegend': False,
              'type': 'scatter',
              'uid': '196bc2d7-c982-4290-95e4-35dec3a875c3',
              'x': array([datetime.datetime(2014, 12, 31, 0, 0),
                          datetime.datetime(2015, 1, 1, 0, 0),
                          datetime.datetime(2015, 1, 2, 0, 0), ...,
                          datetime.datetime(2018, 12, 28, 0, 0),
                          datetime.datetime(2018, 12, 30, 0, 0),
                          datetime.datetime(2018, 12, 31, 0, 0)], dtype=object),
              'xaxis': 'x',
              'y': array([52.1873703 , 52.1942215 , 52.20107269, ..., 62.17250824, 62.18621063,
                          62.1930542 ]),
              'yaxis': 'y'},
             {'line': {'color': '#2d92ff', 'width': 2},
   

`plot_paramters` also allows you to specify which components to plot, so we will focus only on seasonality for now.

In [5]:
m.plot_parameters(components=["seasonality"])

FigureWidgetResampler({
    'data': [{'fill': 'none',
              'line': {'color': '#2d92ff', 'width': 2},
              'mode': 'lines',
              'name': 'yearly',
              'type': 'scatter',
              'uid': '8565ec10-08ae-41a3-938c-539ae9d4445f',
              'x': array([datetime.datetime(2017, 1, 1, 0, 0),
                          datetime.datetime(2017, 1, 2, 0, 0),
                          datetime.datetime(2017, 1, 3, 0, 0), ...,
                          datetime.datetime(2017, 12, 29, 0, 0),
                          datetime.datetime(2017, 12, 30, 0, 0),
                          datetime.datetime(2017, 12, 31, 0, 0)], dtype=object),
              'xaxis': 'x',
              'y': array([3.46493933, 3.30069512, 3.16808875, ..., 4.20806715, 3.94671509,
                          3.71386276]),
              'yaxis': 'y'}],
    'layout': {'autosize': True,
               'font': {'size': 10},
               'height': 210,
               'hovermode': 'x unified'

In [6]:
# Model and prediction
m = NeuralProphet(
    epochs=EPOCHS,
    # Disable trend changepoints
    n_changepoints=0,
    # Enable all seasonality components (default for NeuralProphet)
    yearly_seasonality=True,
    weekly_seasonality=True,
    daily_seasonality=True,
)
metrics = m.fit(df)
forecast = m.predict(df)
m.plot(forecast)

Finding best initial lr:   0%|          | 0/229 [00:00<?, ?it/s]

Training: 0it [00:00, ?it/s]

Predicting: 46it [00:00, ?it/s]

FigureWidgetResampler({
    'data': [{'fill': 'none',
              'line': {'color': 'rgba(45, 146, 255, 1.0)', 'width': 2},
              'mode': 'lines',
              'name': '<b style="color:sandybrown">[R]</b> yhat1 <i style="color:#fc9944">~1D</i>',
              'type': 'scatter',
              'uid': '03555f8e-0be0-4eac-aed8-aa0b51670a4d',
              'x': array([datetime.datetime(2014, 12, 31, 0, 0),
                          datetime.datetime(2015, 1, 1, 0, 0),
                          datetime.datetime(2015, 1, 2, 0, 0), ...,
                          datetime.datetime(2018, 12, 28, 0, 0),
                          datetime.datetime(2018, 12, 30, 0, 0),
                          datetime.datetime(2018, 12, 31, 0, 0)], dtype=object),
              'y': array([58.91551208, 59.04046631, 58.51623535, ..., 68.79338074, 60.60398865,
                          67.48673248])},
             {'marker': {'color': 'black', 'size': 4},
              'mode': 'markers',
              'n

In [7]:
m.plot_components(forecast, components=["seasonality"])

FigureWidgetResampler({
    'data': [{'line': {'color': '#2d92ff', 'width': 2},
              'mode': 'lines',
              'name': ('<b style="color:sandybrown">[R' ... ' style="color:#fc9944">~1D</i>'),
              'showlegend': False,
              'type': 'scatter',
              'uid': 'eb34e47f-54a3-4d5c-b34e-65ed9d2c055b',
              'x': array([datetime.datetime(2014, 12, 31, 0, 0),
                          datetime.datetime(2015, 1, 1, 0, 0),
                          datetime.datetime(2015, 1, 2, 0, 0), ...,
                          datetime.datetime(2018, 12, 28, 0, 0),
                          datetime.datetime(2018, 12, 29, 0, 0),
                          datetime.datetime(2018, 12, 31, 0, 0)], dtype=object),
              'xaxis': 'x',
              'y': array([3.36803365, 3.33818197, 3.32414889, ..., 3.5558598 , 3.47673321,
                          3.36803365]),
              'yaxis': 'y'},
             {'line': {'color': '#2d92ff', 'width': 2},
              

Daily seasonality does not make sense for this dataset. We see this clearly in the plotted as it only shows a flat line. For the weekly seasonality best zoom in or check out the plotted paramters below.

In [8]:
m.plot_parameters(components=["seasonality"])

FigureWidgetResampler({
    'data': [{'fill': 'none',
              'line': {'color': '#2d92ff', 'width': 2},
              'mode': 'lines',
              'name': 'yearly',
              'type': 'scatter',
              'uid': '1306f969-4777-4887-97d5-54fcab02bf86',
              'x': array([datetime.datetime(2017, 1, 1, 0, 0),
                          datetime.datetime(2017, 1, 2, 0, 0),
                          datetime.datetime(2017, 1, 3, 0, 0), ...,
                          datetime.datetime(2017, 12, 29, 0, 0),
                          datetime.datetime(2017, 12, 30, 0, 0),
                          datetime.datetime(2017, 12, 31, 0, 0)], dtype=object),
              'xaxis': 'x',
              'y': array([3.32922256, 3.32287637, 3.33134178, ..., 3.45953307, 3.40107235,
                          3.35906096]),
              'yaxis': 'y'},
             {'fill': 'none',
              'line': {'color': '#2d92ff', 'width': 2},
              'mode': 'lines',
              'name': '