## Forecasting with structured VAEs

For background see [(Johnson et al. 2016)](https://arxiv.org/abs/1603.06277).

In [None]:
import pyro
import torch
import matplotlib.pyplot as plt
import pyro.distributions as dist
from pyro.contrib.examples.bart import load_bart_od
from pyro.contrib.timeseries.stable import LogStableCoxProcess
from pyro.nn import PyroParam, PyroSample
from torch.distributions import constraints

%matplotlib inline
pyro.enable_validation(True)
pyro.set_rng_seed(2020012917)

## Data: 9 years of hourly rider counts among 50 train stations

In [None]:
dataset = load_bart_od()
counts = dataset["counts"]
print(dataset.keys())
print(counts.shape)
print(dataset["start_date"])
print(" ".join(dataset["stations"]))

## Univariate forecasting

In [None]:
station = "EMBR"
station_id = dataset["stations"].index(station)
data = counts[:, station_id].sum(-1).unsqueeze(-1)
plt.figure(figsize=(9, 2.5))
plt.plot(data[:4*24*7])
plt.title("Departures from {}".format(station))
plt.xlabel("hour after {}".format(dataset["start_date"][0].strftime("%Y-%m-%d")))
plt.ylabel("# riders")
plt.tight_layout()

In [None]:
pyro.clear_param_store()
process = LogStableCoxProcess("foo", hidden_dim=1, obs_dim=1, max_rate=1e5)
process.fit(data, learning_rate=1e-3)

In [None]:
x = torch.linspace(-80, 8, 1000)
alpha = torch.tensor(0.1)
beta = 1.
loc = alpha.digamma()
scale = alpha.polygamma(1).sqrt()
y = dist.TransformedDistribution(dist.Gamma(alpha, beta),
                                 dist.transforms.ExpTransform().inv).log_prob(x).exp()
y2 = dist.Normal(loc, scale).log_prob(x).exp()
plt.plot(x, y, 'k--')
plt.plot(x, y2, 'r-');