[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ourownstory/neural_prophet/blob/main/tutorials/feature-use/sparse_autoregression_yosemite_temps.ipynb)

# Sparse Autoregression
Here we fit NeuralProphet to data with 5-minute resolution (daily temperatures at Yosemite).
This is a continuation of the tutorial notebook `Tutorial 4: Auto regression`. While tutorial 4 covers autoregression in general, this notebook focus on sparsity.

In [1]:
if "google.colab" in str(get_ipython()):
    !pip install git+https://github.com/ourownstory/neural_prophet.git # may take a while
    #!pip install neuralprophet # much faster, but may not have the latest upgrades/bugfixes

import pandas as pd
from neuralprophet import NeuralProphet, set_log_level

set_log_level("ERROR")

In [4]:
data_location = "https://raw.githubusercontent.com/ourownstory/neuralprophet-data/main/datasets/"
df = pd.read_csv(data_location + "yosemite_temps.csv")

## Sparsifying the AR coefficients
The autoregression component of NeuralProphet is defined as a AR-Net ([paper](https://arxiv.org/abs/1911.12436), [github](https://github.com/ourownstory/AR-Net)).
Thus, we can set `ar_reg` to a number greater zero, if we like to induce sparsity in the AR coefficients.

However, fitting a model with multiple components and regularizations can be harder to fit and in some cases you may need to take manual control over the training hyperparameters.


We will start by setting regularization to 0.1

In [5]:
m = NeuralProphet(
    n_lags=6 * 12,
    n_forecasts=3 * 12,
    n_changepoints=0,
    weekly_seasonality=False,
    daily_seasonality=False,
    learning_rate=0.01,
    ar_reg=0.1,
)
metrics = m.fit(df, freq="5min")  # validate_each_epoch=True, plot_live_loss=True

Training: 0it [00:00, ?it/s]

In [6]:
m.plot_parameters()

FigureWidgetResampler({
    'data': [{'fill': 'none',
              'line': {'color': '#2d92ff', 'width': 2},
              'mode': 'lines',
              'name': 'Trend',
              'type': 'scatter',
              'uid': '0f5250b2-2708-4e95-a686-63d8cdcf3ee9',
              'x': array([datetime.datetime(2017, 5, 1, 0, 0),
                          datetime.datetime(2017, 7, 5, 0, 0)], dtype=object),
              'xaxis': 'x',
              'y': array([11.002438, 25.826551], dtype=float32),
              'yaxis': 'y'},
             {'marker': {'color': '#2d92ff'},
              'name': 'AR',
              'type': 'bar',
              'uid': '851d68e5-33e5-4682-983f-91a08f948a2e',
              'x': array([72, 71, 70, 69, 68, 67, 66, 65, 64, 63, 62, 61, 60, 59, 58, 57, 56, 55,
                          54, 53, 52, 51, 50, 49, 48, 47, 46, 45, 44, 43, 42, 41, 40, 39, 38, 37,
                          36, 35, 34, 33, 32, 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19,
            

In [7]:
m = m.highlight_nth_step_ahead_of_each_forecast(1)
m.plot_parameters()

FigureWidgetResampler({
    'data': [{'fill': 'none',
              'line': {'color': '#2d92ff', 'width': 2},
              'mode': 'lines',
              'name': 'Trend',
              'type': 'scatter',
              'uid': 'a5a33bd3-1ac8-4970-9af8-0b98bc36f4f1',
              'x': array([datetime.datetime(2017, 5, 1, 0, 0),
                          datetime.datetime(2017, 7, 5, 0, 0)], dtype=object),
              'xaxis': 'x',
              'y': array([11.002438, 25.826551], dtype=float32),
              'yaxis': 'y'},
             {'marker': {'color': '#2d92ff'},
              'name': 'AR',
              'type': 'bar',
              'uid': '8228f5a4-88b2-4030-ae60-01b8006f43e6',
              'width': 0.8,
              'x': array([72, 71, 70, 69, 68, 67, 66, 65, 64, 63, 62, 61, 60, 59, 58, 57, 56, 55,
                          54, 53, 52, 51, 50, 49, 48, 47, 46, 45, 44, 43, 42, 41, 40, 39, 38, 37,
                          36, 35, 34, 33, 32, 31, 30, 29, 28, 27, 26, 25, 24, 23, 

In [8]:
m = m.highlight_nth_step_ahead_of_each_forecast(36)
m.plot_parameters()

FigureWidgetResampler({
    'data': [{'fill': 'none',
              'line': {'color': '#2d92ff', 'width': 2},
              'mode': 'lines',
              'name': 'Trend',
              'type': 'scatter',
              'uid': '1102bc55-84be-4bd3-8fde-001e8c2b5294',
              'x': array([datetime.datetime(2017, 5, 1, 0, 0),
                          datetime.datetime(2017, 7, 5, 0, 0)], dtype=object),
              'xaxis': 'x',
              'y': array([11.002438, 25.826551], dtype=float32),
              'yaxis': 'y'},
             {'marker': {'color': '#2d92ff'},
              'name': 'AR',
              'type': 'bar',
              'uid': '26965d19-7cf3-43ea-91c7-cfedc50006e2',
              'width': 0.8,
              'x': array([72, 71, 70, 69, 68, 67, 66, 65, 64, 63, 62, 61, 60, 59, 58, 57, 56, 55,
                          54, 53, 52, 51, 50, 49, 48, 47, 46, 45, 44, 43, 42, 41, 40, 39, 38, 37,
                          36, 35, 34, 33, 32, 31, 30, 29, 28, 27, 26, 25, 24, 23, 

## Further reducing the non-zero AR-coefficents
By setting the ar_reg higher, we can further reduce the number of non-zero weights.
Here we set it to 1

In [9]:
m = NeuralProphet(
    n_lags=6 * 12,
    n_forecasts=3 * 12,
    n_changepoints=0,
    daily_seasonality=False,
    weekly_seasonality=False,
    learning_rate=0.01,
    ar_reg=1,
)
metrics = m.fit(df, freq="5min")

Training: 0it [00:00, ?it/s]

In [10]:
m.plot_parameters()

FigureWidgetResampler({
    'data': [{'fill': 'none',
              'line': {'color': '#2d92ff', 'width': 2},
              'mode': 'lines',
              'name': 'Trend',
              'type': 'scatter',
              'uid': '011ee705-ef46-4658-9457-b969a2a49f8a',
              'x': array([datetime.datetime(2017, 5, 1, 0, 0),
                          datetime.datetime(2017, 7, 5, 0, 0)], dtype=object),
              'xaxis': 'x',
              'y': array([10.987474, 25.836412], dtype=float32),
              'yaxis': 'y'},
             {'marker': {'color': '#2d92ff'},
              'name': 'AR',
              'type': 'bar',
              'uid': '3d105a3f-a7e3-4bfb-86a1-77406726ad99',
              'x': array([72, 71, 70, 69, 68, 67, 66, 65, 64, 63, 62, 61, 60, 59, 58, 57, 56, 55,
                          54, 53, 52, 51, 50, 49, 48, 47, 46, 45, 44, 43, 42, 41, 40, 39, 38, 37,
                          36, 35, 34, 33, 32, 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19,
            

In [11]:
m = m.highlight_nth_step_ahead_of_each_forecast(1)
m.plot_parameters()

FigureWidgetResampler({
    'data': [{'fill': 'none',
              'line': {'color': '#2d92ff', 'width': 2},
              'mode': 'lines',
              'name': 'Trend',
              'type': 'scatter',
              'uid': 'eb074249-93f0-465b-9eb4-0e780ad10c24',
              'x': array([datetime.datetime(2017, 5, 1, 0, 0),
                          datetime.datetime(2017, 7, 5, 0, 0)], dtype=object),
              'xaxis': 'x',
              'y': array([10.987474, 25.836412], dtype=float32),
              'yaxis': 'y'},
             {'marker': {'color': '#2d92ff'},
              'name': 'AR',
              'type': 'bar',
              'uid': '6d3400bf-3366-4450-9cda-6b8213f6757a',
              'width': 0.8,
              'x': array([72, 71, 70, 69, 68, 67, 66, 65, 64, 63, 62, 61, 60, 59, 58, 57, 56, 55,
                          54, 53, 52, 51, 50, 49, 48, 47, 46, 45, 44, 43, 42, 41, 40, 39, 38, 37,
                          36, 35, 34, 33, 32, 31, 30, 29, 28, 27, 26, 25, 24, 23, 

In [12]:
m = m.highlight_nth_step_ahead_of_each_forecast(36)
m.plot_parameters()

FigureWidgetResampler({
    'data': [{'fill': 'none',
              'line': {'color': '#2d92ff', 'width': 2},
              'mode': 'lines',
              'name': 'Trend',
              'type': 'scatter',
              'uid': '803c1157-fa21-4623-b948-338e1ac4c9c7',
              'x': array([datetime.datetime(2017, 5, 1, 0, 0),
                          datetime.datetime(2017, 7, 5, 0, 0)], dtype=object),
              'xaxis': 'x',
              'y': array([10.987474, 25.836412], dtype=float32),
              'yaxis': 'y'},
             {'marker': {'color': '#2d92ff'},
              'name': 'AR',
              'type': 'bar',
              'uid': '4f7460ac-a4c8-4003-8ac7-07d281894c33',
              'width': 0.8,
              'x': array([72, 71, 70, 69, 68, 67, 66, 65, 64, 63, 62, 61, 60, 59, 58, 57, 56, 55,
                          54, 53, 52, 51, 50, 49, 48, 47, 46, 45, 44, 43, 42, 41, 40, 39, 38, 37,
                          36, 35, 34, 33, 32, 31, 30, 29, 28, 27, 26, 25, 24, 23, 

## Extreme sparsity
The higher we set `ar_reg`, the fewer non-zero weiths are fitted by the model. Here we set it to 10, which should lead to a single non-zero lag.

Note: Extreme values can lead to training instability.

In [13]:
m = NeuralProphet(
    n_lags=6 * 12,
    n_forecasts=3 * 12,
    n_changepoints=0,
    daily_seasonality=False,
    weekly_seasonality=False,
    learning_rate=0.01,
    ar_reg=10,
)
metrics = m.fit(df, freq="5min")

Training: 0it [00:00, ?it/s]

In [14]:
m.plot_parameters()

FigureWidgetResampler({
    'data': [{'fill': 'none',
              'line': {'color': '#2d92ff', 'width': 2},
              'mode': 'lines',
              'name': 'Trend',
              'type': 'scatter',
              'uid': 'eb7533d7-3c19-461c-ae9b-272a61cc47b2',
              'x': array([datetime.datetime(2017, 5, 1, 0, 0),
                          datetime.datetime(2017, 7, 5, 0, 0)], dtype=object),
              'xaxis': 'x',
              'y': array([10.969181, 25.879461], dtype=float32),
              'yaxis': 'y'},
             {'marker': {'color': '#2d92ff'},
              'name': 'AR',
              'type': 'bar',
              'uid': '60d64105-a431-4f7f-94cc-3dcbf6513866',
              'x': array([72, 71, 70, 69, 68, 67, 66, 65, 64, 63, 62, 61, 60, 59, 58, 57, 56, 55,
                          54, 53, 52, 51, 50, 49, 48, 47, 46, 45, 44, 43, 42, 41, 40, 39, 38, 37,
                          36, 35, 34, 33, 32, 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19,
            

In [15]:
m = m.highlight_nth_step_ahead_of_each_forecast(1)
m.plot_parameters()

FigureWidgetResampler({
    'data': [{'fill': 'none',
              'line': {'color': '#2d92ff', 'width': 2},
              'mode': 'lines',
              'name': 'Trend',
              'type': 'scatter',
              'uid': 'c4f4141f-3b7e-416f-8b0f-908faeac54da',
              'x': array([datetime.datetime(2017, 5, 1, 0, 0),
                          datetime.datetime(2017, 7, 5, 0, 0)], dtype=object),
              'xaxis': 'x',
              'y': array([10.969181, 25.879461], dtype=float32),
              'yaxis': 'y'},
             {'marker': {'color': '#2d92ff'},
              'name': 'AR',
              'type': 'bar',
              'uid': '71069982-f0f5-43aa-b5ea-c4e426713b97',
              'width': 0.8,
              'x': array([72, 71, 70, 69, 68, 67, 66, 65, 64, 63, 62, 61, 60, 59, 58, 57, 56, 55,
                          54, 53, 52, 51, 50, 49, 48, 47, 46, 45, 44, 43, 42, 41, 40, 39, 38, 37,
                          36, 35, 34, 33, 32, 31, 30, 29, 28, 27, 26, 25, 24, 23, 

In [16]:
m = m.highlight_nth_step_ahead_of_each_forecast(36)
m.plot_parameters()

FigureWidgetResampler({
    'data': [{'fill': 'none',
              'line': {'color': '#2d92ff', 'width': 2},
              'mode': 'lines',
              'name': 'Trend',
              'type': 'scatter',
              'uid': '95c9ee78-03db-490b-a6b5-41c8ab8a93d3',
              'x': array([datetime.datetime(2017, 5, 1, 0, 0),
                          datetime.datetime(2017, 7, 5, 0, 0)], dtype=object),
              'xaxis': 'x',
              'y': array([10.969181, 25.879461], dtype=float32),
              'yaxis': 'y'},
             {'marker': {'color': '#2d92ff'},
              'name': 'AR',
              'type': 'bar',
              'uid': 'daae3567-1afa-40b2-9557-6c2bac036513',
              'width': 0.8,
              'x': array([72, 71, 70, 69, 68, 67, 66, 65, 64, 63, 62, 61, 60, 59, 58, 57, 56, 55,
                          54, 53, 52, 51, 50, 49, 48, 47, 46, 45, 44, 43, 42, 41, 40, 39, 38, 37,
                          36, 35, 34, 33, 32, 31, 30, 29, 28, 27, 26, 25, 24, 23, 