[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ourownstory/neural_prophet/blob/main/tutorials/feature-use/sparse_autoregression_yosemite_temps.ipynb)

# Sparse Autoregression
Here we fit NeuralProphet to data with 5-minute resolution (daily temperatures at Yosemite).
This is a continuation of the tutorial notebook `Tutorial 4: Auto regression`. While tutorial 4 covers autoregression in general, this notebook focus on sparsity.

In [1]:
if "google.colab" in str(get_ipython()):
    !pip install git+https://github.com/ourownstory/neural_prophet.git # may take a while
    #!pip install neuralprophet # much faster, but may not have the latest upgrades/bugfixes

import pandas as pd
from neuralprophet import NeuralProphet, set_log_level

set_log_level("ERROR")

In [2]:
data_location = "https://raw.githubusercontent.com/ourownstory/neuralprophet-data/main/datasets/"
df = pd.read_csv(data_location + "yosemite_temps.csv")

## Sparsifying the AR coefficients
The autoregression component of NeuralProphet is defined as a AR-Net ([paper](https://arxiv.org/abs/1911.12436), [github](https://github.com/ourownstory/AR-Net)).
Thus, we can set `ar_reg` to a number greater zero, if we like to induce sparsity in the AR coefficients.

However, fitting a model with multiple components and regularizations can be harder to fit and in some cases you may need to take manual control over the training hyperparameters.


We will start by setting regularization to 0.1

In [3]:
m = NeuralProphet(
    n_lags=6 * 12,
    n_forecasts=3 * 12,
    n_changepoints=0,
    weekly_seasonality=False,
    daily_seasonality=False,
    learning_rate=0.01,
    ar_reg=0.1,
)
metrics = m.fit(df, freq="5min")  # validate_each_epoch=True, plot_live_loss=True

Training: 0it [00:00, ?it/s]

In [4]:
m.plot_parameters()

FigureWidgetResampler({
    'data': [{'fill': 'none',
              'line': {'color': '#2d92ff', 'width': 2},
              'mode': 'lines',
              'name': 'Trend',
              'type': 'scatter',
              'uid': '2a70c815-8e0c-41c0-86ac-2f151a84d32d',
              'x': array([datetime.datetime(2017, 5, 1, 0, 0),
                          datetime.datetime(2017, 7, 5, 0, 0)], dtype=object),
              'xaxis': 'x',
              'y': array([10.997025, 25.83351 ], dtype=float32),
              'yaxis': 'y'},
             {'marker': {'color': '#2d92ff'},
              'name': 'AR',
              'type': 'bar',
              'uid': '412ca9cd-c718-4e80-be9c-5751ce11c3fa',
              'x': array([72, 71, 70, 69, 68, 67, 66, 65, 64, 63, 62, 61, 60, 59, 58, 57, 56, 55,
                          54, 53, 52, 51, 50, 49, 48, 47, 46, 45, 44, 43, 42, 41, 40, 39, 38, 37,
                          36, 35, 34, 33, 32, 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19,
            

In [5]:
m = m.highlight_nth_step_ahead_of_each_forecast(1)
m.plot_parameters()

FigureWidgetResampler({
    'data': [{'fill': 'none',
              'line': {'color': '#2d92ff', 'width': 2},
              'mode': 'lines',
              'name': 'Trend',
              'type': 'scatter',
              'uid': 'ec57ef9e-cc4f-4abc-9cb7-9e45e653a428',
              'x': array([datetime.datetime(2017, 5, 1, 0, 0),
                          datetime.datetime(2017, 7, 5, 0, 0)], dtype=object),
              'xaxis': 'x',
              'y': array([10.997025, 25.83351 ], dtype=float32),
              'yaxis': 'y'},
             {'marker': {'color': '#2d92ff'},
              'name': 'AR',
              'type': 'bar',
              'uid': 'b418ddb1-d384-4911-a0b6-7d4a0500c6ca',
              'width': 0.8,
              'x': array([72, 71, 70, 69, 68, 67, 66, 65, 64, 63, 62, 61, 60, 59, 58, 57, 56, 55,
                          54, 53, 52, 51, 50, 49, 48, 47, 46, 45, 44, 43, 42, 41, 40, 39, 38, 37,
                          36, 35, 34, 33, 32, 31, 30, 29, 28, 27, 26, 25, 24, 23, 

In [6]:
m = m.highlight_nth_step_ahead_of_each_forecast(36)
m.plot_parameters()

FigureWidgetResampler({
    'data': [{'fill': 'none',
              'line': {'color': '#2d92ff', 'width': 2},
              'mode': 'lines',
              'name': 'Trend',
              'type': 'scatter',
              'uid': '091a0278-a29e-49eb-95de-fd1185c73445',
              'x': array([datetime.datetime(2017, 5, 1, 0, 0),
                          datetime.datetime(2017, 7, 5, 0, 0)], dtype=object),
              'xaxis': 'x',
              'y': array([10.997025, 25.83351 ], dtype=float32),
              'yaxis': 'y'},
             {'marker': {'color': '#2d92ff'},
              'name': 'AR',
              'type': 'bar',
              'uid': 'd92256a8-f362-47d4-b27c-f8a5db7e3e94',
              'width': 0.8,
              'x': array([72, 71, 70, 69, 68, 67, 66, 65, 64, 63, 62, 61, 60, 59, 58, 57, 56, 55,
                          54, 53, 52, 51, 50, 49, 48, 47, 46, 45, 44, 43, 42, 41, 40, 39, 38, 37,
                          36, 35, 34, 33, 32, 31, 30, 29, 28, 27, 26, 25, 24, 23, 

## Further reducing the non-zero AR-coefficents
By setting the ar_reg higher, we can further reduce the number of non-zero weights.
Here we set it to 1

In [7]:
m = NeuralProphet(
    n_lags=6 * 12,
    n_forecasts=3 * 12,
    n_changepoints=0,
    daily_seasonality=False,
    weekly_seasonality=False,
    learning_rate=0.01,
    ar_reg=1,
)
metrics = m.fit(df, freq="5min")

Training: 0it [00:00, ?it/s]

In [8]:
m.plot_parameters()

FigureWidgetResampler({
    'data': [{'fill': 'none',
              'line': {'color': '#2d92ff', 'width': 2},
              'mode': 'lines',
              'name': 'Trend',
              'type': 'scatter',
              'uid': '25d47ee9-50af-41ad-b576-54aa6fb38271',
              'x': array([datetime.datetime(2017, 5, 1, 0, 0),
                          datetime.datetime(2017, 7, 5, 0, 0)], dtype=object),
              'xaxis': 'x',
              'y': array([11.002428, 25.84529 ], dtype=float32),
              'yaxis': 'y'},
             {'marker': {'color': '#2d92ff'},
              'name': 'AR',
              'type': 'bar',
              'uid': '4812c53c-b6d0-4247-847c-b5a5e6e6d862',
              'x': array([72, 71, 70, 69, 68, 67, 66, 65, 64, 63, 62, 61, 60, 59, 58, 57, 56, 55,
                          54, 53, 52, 51, 50, 49, 48, 47, 46, 45, 44, 43, 42, 41, 40, 39, 38, 37,
                          36, 35, 34, 33, 32, 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19,
            

In [9]:
m = m.highlight_nth_step_ahead_of_each_forecast(1)
m.plot_parameters()

FigureWidgetResampler({
    'data': [{'fill': 'none',
              'line': {'color': '#2d92ff', 'width': 2},
              'mode': 'lines',
              'name': 'Trend',
              'type': 'scatter',
              'uid': '7080c836-8b93-4b7b-9f81-1c0a8c34247f',
              'x': array([datetime.datetime(2017, 5, 1, 0, 0),
                          datetime.datetime(2017, 7, 5, 0, 0)], dtype=object),
              'xaxis': 'x',
              'y': array([11.002428, 25.84529 ], dtype=float32),
              'yaxis': 'y'},
             {'marker': {'color': '#2d92ff'},
              'name': 'AR',
              'type': 'bar',
              'uid': 'cfb80f54-ef42-4e79-b5e4-1adc0ded4847',
              'width': 0.8,
              'x': array([72, 71, 70, 69, 68, 67, 66, 65, 64, 63, 62, 61, 60, 59, 58, 57, 56, 55,
                          54, 53, 52, 51, 50, 49, 48, 47, 46, 45, 44, 43, 42, 41, 40, 39, 38, 37,
                          36, 35, 34, 33, 32, 31, 30, 29, 28, 27, 26, 25, 24, 23, 

In [10]:
m = m.highlight_nth_step_ahead_of_each_forecast(36)
m.plot_parameters()

FigureWidgetResampler({
    'data': [{'fill': 'none',
              'line': {'color': '#2d92ff', 'width': 2},
              'mode': 'lines',
              'name': 'Trend',
              'type': 'scatter',
              'uid': 'fb81768f-acca-48c6-9088-7ce350fa2086',
              'x': array([datetime.datetime(2017, 5, 1, 0, 0),
                          datetime.datetime(2017, 7, 5, 0, 0)], dtype=object),
              'xaxis': 'x',
              'y': array([11.002428, 25.84529 ], dtype=float32),
              'yaxis': 'y'},
             {'marker': {'color': '#2d92ff'},
              'name': 'AR',
              'type': 'bar',
              'uid': '50337491-2c3a-4fb3-b3d1-7f3060ec1826',
              'width': 0.8,
              'x': array([72, 71, 70, 69, 68, 67, 66, 65, 64, 63, 62, 61, 60, 59, 58, 57, 56, 55,
                          54, 53, 52, 51, 50, 49, 48, 47, 46, 45, 44, 43, 42, 41, 40, 39, 38, 37,
                          36, 35, 34, 33, 32, 31, 30, 29, 28, 27, 26, 25, 24, 23, 

## Extreme sparsity
The higher we set `ar_reg`, the fewer non-zero weiths are fitted by the model. Here we set it to 10, which should lead to a single non-zero lag.

Note: Extreme values can lead to training instability.

In [11]:
m = NeuralProphet(
    n_lags=6 * 12,
    n_forecasts=3 * 12,
    n_changepoints=0,
    daily_seasonality=False,
    weekly_seasonality=False,
    learning_rate=0.01,
    ar_reg=10,
)
metrics = m.fit(df, freq="5min")

Training: 0it [00:00, ?it/s]

In [12]:
m.plot_parameters()

FigureWidgetResampler({
    'data': [{'fill': 'none',
              'line': {'color': '#2d92ff', 'width': 2},
              'mode': 'lines',
              'name': 'Trend',
              'type': 'scatter',
              'uid': '48999b8f-2a05-4637-a184-b66b4745174c',
              'x': array([datetime.datetime(2017, 5, 1, 0, 0),
                          datetime.datetime(2017, 7, 5, 0, 0)], dtype=object),
              'xaxis': 'x',
              'y': array([10.963005, 25.862192], dtype=float32),
              'yaxis': 'y'},
             {'marker': {'color': '#2d92ff'},
              'name': 'AR',
              'type': 'bar',
              'uid': 'f8a351cc-9172-4fd4-a174-1356923c5c46',
              'x': array([72, 71, 70, 69, 68, 67, 66, 65, 64, 63, 62, 61, 60, 59, 58, 57, 56, 55,
                          54, 53, 52, 51, 50, 49, 48, 47, 46, 45, 44, 43, 42, 41, 40, 39, 38, 37,
                          36, 35, 34, 33, 32, 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19,
            

In [13]:
m = m.highlight_nth_step_ahead_of_each_forecast(1)
m.plot_parameters()

FigureWidgetResampler({
    'data': [{'fill': 'none',
              'line': {'color': '#2d92ff', 'width': 2},
              'mode': 'lines',
              'name': 'Trend',
              'type': 'scatter',
              'uid': '9937157d-0e59-4b4e-9291-c3bc1147cf9b',
              'x': array([datetime.datetime(2017, 5, 1, 0, 0),
                          datetime.datetime(2017, 7, 5, 0, 0)], dtype=object),
              'xaxis': 'x',
              'y': array([10.963005, 25.862192], dtype=float32),
              'yaxis': 'y'},
             {'marker': {'color': '#2d92ff'},
              'name': 'AR',
              'type': 'bar',
              'uid': 'a31a7368-8cd8-4f2b-ae3b-d2b3a85cfb44',
              'width': 0.8,
              'x': array([72, 71, 70, 69, 68, 67, 66, 65, 64, 63, 62, 61, 60, 59, 58, 57, 56, 55,
                          54, 53, 52, 51, 50, 49, 48, 47, 46, 45, 44, 43, 42, 41, 40, 39, 38, 37,
                          36, 35, 34, 33, 32, 31, 30, 29, 28, 27, 26, 25, 24, 23, 

In [14]:
m = m.highlight_nth_step_ahead_of_each_forecast(36)
m.plot_parameters()

FigureWidgetResampler({
    'data': [{'fill': 'none',
              'line': {'color': '#2d92ff', 'width': 2},
              'mode': 'lines',
              'name': 'Trend',
              'type': 'scatter',
              'uid': '27227443-ba98-41d7-93e5-900d12a87f70',
              'x': array([datetime.datetime(2017, 5, 1, 0, 0),
                          datetime.datetime(2017, 7, 5, 0, 0)], dtype=object),
              'xaxis': 'x',
              'y': array([10.963005, 25.862192], dtype=float32),
              'yaxis': 'y'},
             {'marker': {'color': '#2d92ff'},
              'name': 'AR',
              'type': 'bar',
              'uid': '7701bc56-6dd9-4df1-9763-0e14b09ae4a2',
              'width': 0.8,
              'x': array([72, 71, 70, 69, 68, 67, 66, 65, 64, 63, 62, 61, 60, 59, 58, 57, 56, 55,
                          54, 53, 52, 51, 50, 49, 48, 47, 46, 45, 44, 43, 42, 41, 40, 39, 38, 37,
                          36, 35, 34, 33, 32, 31, 30, 29, 28, 27, 26, 25, 24, 23, 