# Forecasting with Pyro: multivariate, hierarchical, and heavy tailed

This tutorial introduces the [pyro.contrib.forecast](http://docs.pyro.ai/en/latest/contrib.forecast.html) module, a framework for forecasting with Pyro models.

#### Summary
- To create a forecasting model:
  1. Create a subclass of the `ForecastingModel` class.
  2. Implement the `.model(data, covariates)` method using standard Pyro syntax.
  3. Sample ay time-local variables inside the `self.time_plate` context.
  4. Finally call the `self.predict(noise_dist, prediction)` method.
- To train a forecasting model, create a `Forecaster` object.
- To forecast the future, draw samples from a `Forecaster` object conditioned on data and covariates.
- To evaluate results, use the `backtest()` helper or the low-level loss functions.

#### Table of contents

In [None]:
import torch
import pyro
import pyro.distributions as dist
import pyro.poutine as poutine
from pyro.contrib.examples.bart import load_bart_od
from pyro.contrib.forecast import ForecastingModel, Forecaster, backtest
from pyro.infer.reparam import LocScaleReparam
from pyro.ops.tensor_utils import periodic_cumsum, periodic_repeat, periodic_features
from pyro.ops.stats import quantile
import matplotlib.pyplot as plt

%matplotlib inline
%config InlineBackend.rc = {'figure.facecolor': (1, 1, 1, 1)}
%config InlineBackend.figure_formats = ['svg']
assert pyro.__version__.startswith('1.2.1')
pyro.enable_validation(True)
pyro.set_rng_seed(20200221)

In [None]:
dataset = load_bart_od()
print(dataset.keys())
print(dataset["counts"].shape)
print(" ".join(dataset["stations"]))

## Intro to Pyro's forecasting framework

Pyro's forecasting framework consists of:
- a [ForecastingModel]() base class, whose ``.model()`` method can be implemented for custom forecasting models,
- a [Forecaster]() class that trains and forecasts using ``ForecastingModel``s, and
- a [backtest]() helper to evaluate models on a number of metrics.

### The ``ForecastingModel`` class

Consider a simple univariate dataset, say weekly ridership aggregated over all stations.

In [None]:
T, O, D = dataset["counts"].shape
data = dataset["counts"][:T // (24 * 7) * 24 * 7].reshape(T // (24 * 7), -1).sum(-1).log()
data = data.unsqueeze(-1)
plt.figure(figsize=(9, 3))
plt.plot(data)
plt.title("Total weekly ridership")
plt.ylabel("log(# rides)")
plt.xlabel("Week after 2011-01-01")
plt.xlim(0, len(data));

Let's start with a simple log-linear regression model, with no trend or seasonality.

In [None]:
# First we need some boilerplate to create a class and define a .model() method.
class Model(ForecastingModel):
    def model(self, zero_data, covariates):
        data_dim = zero_data.size(-1)
        feature_dim = covariates.size(-1)

        # The first part of the model is a probabilistic program to create a prediction.
        # We use the zero_data as a template for the shape of the prediction.
        bias = pyro.sample("bias", dist.Normal(0, 10).expand([data_dim]).to_event(1))
        weight = pyro.sample("weight", dist.Normal(0, 0.1).expand([feature_dim]).to_event(1))
        prediction = bias + (weight * covariates).sum(-1, keepdim=True)
        # The prediction should have the same shape as zero_data (duration, obs_dim),
        # but may have additional sample dimensions on the left.
        assert prediction.shape[-2:] == zero_data.shape

        # The next part of the model creates a likelihood or noise distribution.
        # Again we'll be Bayesian and write this as a probabilistic program with
        # priors over parameters, and again we'll use zero_data as a noise template.
        noise_scale = pyro.sample("noise_scale", dist.LogNormal(-5, 5).expand([1]).to_event(1))
        noise_dist = dist.Normal(zero_data, noise_scale)

        # The final step is to call the .predict() method.
        self.predict(noise_dist, prediction)

We can now train this model by creating a [Forecaster]() object.

In [None]:
T0 = 0
T2 = data.size(-2)
T1 = T2 - 52

In [None]:
%%time
pyro.clear_param_store()
time = torch.arange(float(T2)) / 365
covariates = torch.stack([time], dim=-1)
forecaster = Forecaster(Model(), data[:T1], covariates[:T1], learning_rate=0.1)

In [None]:
samples = forecaster(data[:T1], covariates, num_samples=1000)
p10, p50, p90 = quantile(samples, (0.1, 0.5, 0.9)).squeeze(-1)
print(samples.shape, p10.shape)

plt.figure(figsize=(9, 3))
plt.fill_between(torch.arange(T1, T2), p10, p90, color="red", alpha=0.3)
plt.plot(torch.arange(T1, T2), p50, 'r-', label='forecast')
plt.plot(data, 'k-', label='truth')
plt.title("Total weekly ridership")
plt.ylabel("log(# rides)")
plt.xlabel("Week after 2011-01-01")
plt.xlim(0, None)
plt.legend(loc="best");

In [None]:
plt.figure(figsize=(9, 3))
plt.fill_between(torch.arange(T1, T2), p10, p90, color="red", alpha=0.3)
plt.plot(torch.arange(T1, T2), p50, 'r-', label='forecast')
plt.plot(torch.arange(T1, T2), data[T1:], 'k-', label='truth')
plt.title("Total weekly ridership")
plt.ylabel("log(# rides)")
plt.xlabel("Week after 2011-01-01")
plt.xlim(T1, None)
plt.legend(loc="best");

We could add a yearly seasonal component simply by adding new covariates.

In [None]:
%%time
pyro.clear_param_store()
time = torch.arange(float(T2)) / 365
covariates = torch.cat([time.unsqueeze(-1),
                        periodic_features(T2, 365.25 / 7)], dim=-1)
forecaster = Forecaster(Model(), data[:T1], covariates[:T1], learning_rate=0.1)

In [None]:
samples = forecaster(data[:T1], covariates, num_samples=1000)
p10, p50, p90 = quantile(samples, (0.1, 0.5, 0.9)).squeeze(-1)

plt.figure(figsize=(9, 3))
plt.fill_between(torch.arange(T1, T2), p10, p90, color="red", alpha=0.3)
plt.plot(torch.arange(T1, T2), p50, 'r-', label='forecast')
plt.plot(data, 'k-', label='truth')
plt.title("Total weekly ridership")
plt.ylabel("log(# rides)")
plt.xlabel("Week after 2011-01-01")
plt.xlim(0, None)
plt.legend(loc="best");

In [None]:
plt.figure(figsize=(9, 3))
plt.fill_between(torch.arange(T1, T2), p10, p90, color="red", alpha=0.3)
plt.plot(torch.arange(T1, T2), p50, 'r-', label='forecast')
plt.plot(torch.arange(T1, T2), data[T1:], 'k-', label='truth')
plt.title("Total weekly ridership")
plt.ylabel("log(# rides)")
plt.xlabel("Week after 2011-01-01")
plt.xlim(T1, None)
plt.legend(loc="best");

### Time-local random variables: `self.time_plate`

So far we've seen the ``ForecastingModel.model()`` method and ``self.predict()``. The last piece of forecasting-specific syntax is the ``self.time_plate`` context for time-local variables. To see how this works, consider changing our global linear trend model above to a local level model.

In [None]:
class Model(ForecastingModel):
    def model(self, zero_data, covariates):
        data_dim = zero_data.size(-1)
        feature_dim = covariates.size(-1)
        bias = pyro.sample("bias", dist.Normal(0, 10).expand([data_dim]).to_event(1))
        weight = pyro.sample("weight", dist.Normal(0, 0.1).expand([feature_dim]).to_event(1))

        # We'll sample a time-global scale parameter outside the time plate,
        # then time-local iid noise inside the time plate.
        drift_scale = pyro.sample("drift_scale", dist.LogNormal(-20, 5).expand([1]).to_event(1))
        with self.time_plate:
            # We'll use a reparameterizer to improve variational fit.
            with poutine.reparam(config={"drift": LocScaleReparam()}):
                drift = pyro.sample("drift", dist.Normal(zero_data, drift_scale).to_event(1))
        motion = drift.cumsum(-2)  # A Brownian motion.
        
        prediction = motion + bias + (weight * covariates).sum(-1, keepdim=True)
        assert prediction.shape[-2:] == zero_data.shape

        # The next part of the model creates a likelihood or noise distribution.
        # Again we'll be Bayesian and write this as a probabilistic program with
        # priors over parameters, and again we'll use zero_data as a noise template.
        noise_scale = pyro.sample("noise_scale", dist.LogNormal(-5, 5).expand([1]).to_event(1))
        noise_dist = dist.Normal(zero_data, noise_scale)

        # The final step is to call the .predict() method.
        self.predict(noise_dist, prediction)            

In [None]:
%%time
pyro.clear_param_store()
time = torch.arange(float(T2)) / 365
covariates = periodic_features(T2, 365.25 / 7)
forecaster = Forecaster(Model(), data[:T1], covariates[:T1], learning_rate=0.1)

In [None]:
samples = forecaster(data[:T1], covariates, num_samples=1000)
p10, p50, p90 = quantile(samples, (0.1, 0.5, 0.9)).squeeze(-1)

plt.figure(figsize=(9, 3))
plt.fill_between(torch.arange(T1, T2), p10, p90, color="red", alpha=0.3)
plt.plot(torch.arange(T1, T2), p50, 'r-', label='forecast')
plt.plot(data, 'k-', label='truth')
plt.title("Total weekly ridership")
plt.ylabel("log(# rides)")
plt.xlabel("Week after 2011-01-01")
plt.xlim(0, None)
plt.legend(loc="best");

In [None]:
plt.figure(figsize=(9, 3))
plt.fill_between(torch.arange(T1, T2), p10, p90, color="red", alpha=0.3)
plt.plot(torch.arange(T1, T2), p50, 'r-', label='forecast')
plt.plot(torch.arange(T1, T2), data[T1:], 'k-', label='truth')
plt.title("Total weekly ridership")
plt.ylabel("log(# rides)")
plt.xlabel("Week after 2011-01-01")
plt.xlim(T1, None)
plt.legend(loc="best");

## A bivariate seasonal model

The full dataset has all station->station ridership counts for all of 50 train stations. In this simple example we will model only the aggretate counts to and from a single station, Embarcadero.

In [None]:
i = dataset["stations"].index("EMBR")
arrivals = dataset["counts"][:, :, i].sum(-1)
departures = dataset["counts"][:, i, :].sum(-1)
data = torch.stack([arrivals, departures], dim=-1).log1p()
covariates = torch.zeros(len(data), 0)

In [None]:
class Model(ForecastingModel):
    def model(self, zero_data, covariates):
        period = 24 * 7
        duration, dim = zero_data.shape[-2:]
        assert dim == 2  # Data is bivariate: (arrivals, departures).

        # Sample global parameters.
        noise_scale = pyro.sample("noise_scale",
                                  dist.LogNormal(torch.full((dim,), -3), 1).to_event(1))
        assert noise_scale.shape[-1:] == (dim,)
        trans_timescale = pyro.sample("trans_timescale",
                                      dist.LogNormal(torch.zeros(dim), 1).to_event(1))
        assert trans_timescale.shape[-1:] == (dim,)

        trans_loc = pyro.sample("trans_loc", dist.Cauchy(0, 1 / period))
        trans_loc = trans_loc.unsqueeze(-1).expand(trans_loc.shape + (dim,))
        assert trans_loc.shape[-1:] == (dim,)

        trans_scale = pyro.sample("trans_scale",
                                  dist.LogNormal(torch.zeros(dim), 0.1).to_event(1))
        trans_corr = pyro.sample("trans_corr",
                                 dist.LKJCorrCholesky(dim, torch.ones(())))
        trans_scale_tril = trans_scale.unsqueeze(-1) * trans_corr
        assert trans_scale_tril.shape[-2:] == (dim, dim)

        obs_scale = pyro.sample("obs_scale",
                                dist.LogNormal(torch.zeros(dim), 0.1).to_event(1))
        obs_corr = pyro.sample("obs_corr",
                               dist.LKJCorrCholesky(dim, torch.ones(())))
        obs_scale_tril = obs_scale.unsqueeze(-1) * obs_corr
        assert obs_scale_tril.shape[-2:] == (dim, dim)

        with pyro.plate("season_plate", period,  dim=-1):
            season_init = pyro.sample("season_init",
                                      dist.Normal(torch.zeros(dim), 1).to_event(1))
            assert season_init.shape[-2:] == (period, dim)

        # Sample independent noise at each time step.
        with self.time_plate:
            season_noise = pyro.sample("season_noise",
                                       dist.Normal(0, noise_scale).to_event(1))
            assert season_noise.shape[-2:] == (duration, dim)

        prediction = (periodic_repeat(season_init, duration, dim=-2) +
                      periodic_cumsum(season_noise, period, dim=-2))
        assert prediction.shape[-2:] == (duration, dim)

        init_dist = dist.Normal(torch.zeros(dim), 100).to_event(1)
        trans_mat = trans_timescale.neg().exp().diag_embed()
        trans_dist = dist.MultivariateNormal(trans_loc, scale_tril=trans_scale_tril)
        obs_mat = torch.eye(dim)
        obs_dist = dist.MultivariateNormal(torch.zeros(dim), scale_tril=obs_scale_tril)
        noise_model = dist.GaussianHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist,
                                       duration=duration)
        assert noise_model.event_shape == (duration, dim)

        self.predict(noise_model, prediction)

In [None]:
pyro.clear_param_store()

t_train = 90 * 24
forecaster = Forecaster(Model(), data[:t_train], covariates[:t_train],
                        learning_rate=0.05,
                        num_steps=501,
                        log_every=50)

In [None]:
plt.figure(figsize=(9, 3))
plt.plot(forecaster.losses)
plt.ylabel("ELBO loss")
plt.xlabel("SVI step")
plt.yscale("log")

In [None]:
t_test = 2 * 7 * 24
samples = forecaster(data[:t_train], covariates[:t_train + t_test], num_samples=101)
samples.clamp_(min=0);
p05 = samples.kthvalue(5, dim=0).values
p50 = samples.median(dim=0).values
p95 = samples.kthvalue(95, dim=0).values

In [None]:
plt.figure(figsize=(8, 4))
plt.plot(torch.arange(t_train - t_test // 2, t_train + t_test),
         data[t_train - t_test // 2:t_train + t_test, 0], "k-")
plt.plot(torch.arange(t_train - t_test // 2, t_train + t_test),
         -data[t_train - t_test // 2:t_train + t_test, 1], "k-")
plt.fill_between(torch.arange(t_train, t_train + t_test), p05[..., 0], p95[..., 0],
                 color="pink")
plt.fill_between(torch.arange(t_train, t_train + t_test), -p05[..., 1], -p95[..., 1],
                 color="pink")
plt.plot(torch.arange(t_train, t_train + t_test), p50[..., 0], "r-")
plt.plot(torch.arange(t_train, t_train + t_test), -p50[..., 1], "r-")
plt.ylabel("← departures  arrivals →")
plt.xlabel("hour after 2011-01-01")
plt.title("Forecast ridership at Embarcadero station")
plt.tight_layout()

## Hierarchical modeling using `pyro.plate`

Hierarchical forecasting models use Pyro's standard mechanism for hierarchical modeling: the [pyro.plate](http://docs.pyro.ai/en/latest/primitives.html#pyro.plate) context manager. Let's start with a simple two level hierarchy, say modeling arrivals to each station.

In [None]:
class Model(ForecastingModel):
    def model(self, zero_data, covariates):
        num_stations, duration, one = zero_data.shape

        coefs = pyro.sample("coefs",
                            dist.Normal(0, 1).expand(covariates.shape[-1:]).to_evet(1))
        global_seasonality = (coefs * covariates).sum(-1, keepdim=True)
        with pyro.plate("station", )