## Forecasting with structured VAEs

For background see [(Johnson et al. 2016)](https://arxiv.org/abs/1603.06277).

In [None]:
import pyro
import torch
import matplotlib.pyplot as plt
import pyro.distributions as dist
from pyro.contrib.examples.bart import load_bart_od
from pyro.contrib.timeseries.stable import LogStableCoxProcess
from torch.distributions import constraints

%matplotlib inline
pyro.enable_validation(True)
pyro.set_rng_seed(2020012917)

## Data: 9 years of hourly rider counts among 50 train stations

In [None]:
dataset = load_bart_od()
counts = dataset["counts"]
print(dataset.keys())
print(counts.shape)
print(dataset["start_date"])
print(" ".join(dataset["stations"]))

## Univariate forecasting

In [None]:
x = torch.linspace(-80, 8, 1000)
alpha = torch.tensor(0.1)
beta = 1.
loc = alpha.digamma()
scale = alpha.polygamma(1).sqrt()
y = dist.TransformedDistribution(dist.Gamma(alpha, beta),
                                 dist.transforms.ExpTransform().inv).log_prob(x).exp()
y2 = dist.Normal(loc, scale).log_prob(x).exp()
plt.plot(x, y, 'k--')
plt.plot(x, y2, 'r-');