## Forecasting with structured VAEs

For background see [(Johnson et al. 2016)](https://arxiv.org/abs/1603.06277).

In [None]:
import pyro
import torch
import matplotlib.pyplot as plt
import pyro.distributions as dist
import pyro.poutine as poutine
from pyro.contrib.examples.bart import load_bart_od
from pyro.contrib.timeseries.stable import LogStableCoxProcess
from pyro.nn import PyroParam, PyroSample
from torch.distributions import constraints

%matplotlib inline
pyro.enable_validation(True)
pyro.set_rng_seed(2020012917)

## Data: 9 years of hourly rider counts among 50 train stations

In [None]:
dataset = load_bart_od()
counts = dataset["counts"]
print(dataset.keys())
print(counts.shape)
print(dataset["start_date"])
print(" ".join(dataset["stations"]))

## Univariate forecasting

In [None]:
station = "EMBR"
station_id = dataset["stations"].index(station)

hourly_data = counts[:, station_id].sum(-1).unsqueeze(-1)
plt.figure(figsize=(9, 2.5))
plt.plot(hourly_data[:4*24*7])
plt.title("Departures from {}".format(station))
plt.xlabel("hour after {}".format(dataset["start_date"][0].strftime("%Y-%m-%d")))
plt.ylabel("# riders")
plt.xlim(0, 4*24*7)
plt.ylim(0, None)
plt.tight_layout()
print(hourly_data.shape)

weekly_data = hourly_data[:len(hourly_data) // (24 * 7) * 24 * 7]
weekly_data = weekly_data.reshape(-1, 24 * 7, 1).sum(-2)
plt.figure(figsize=(9, 2.5))
plt.plot(weekly_data[:4*24*7])
plt.title("Departures from {}".format(station))
plt.xlabel("week after {}".format(dataset["start_date"][0].strftime("%Y-%m-%d")))
plt.ylabel("# riders")
plt.xlim(0, len(weekly_data))
plt.ylim(0, None)
plt.tight_layout()
print(weekly_data.shape)

In [None]:
%%time
pyro.clear_param_store()
process = LogStableCoxProcess("rides", hidden_dim=1, obs_dim=1, max_rate=1e5)
process.model.stability = 1.9  # FIXME remove this
losses = process.fit(weekly_data, learning_rate=0.1)
plt.figure(figsize=(9, 2.5))
plt.plot(losses)
plt.xlabel("SVI step")
plt.ylabel("loss")
plt.tight_layout()

In [None]:
trace = poutine.trace(process.posterior).get_trace(weekly_data)
for name, site in sorted(trace.nodes.items()):
    if site["type"] == "sample":
        if site["value"].numel() == 1:
            print("{} = {}".format(name, site["value"].item()))
        else:
            print("{}.shape = {}".format(name, tuple(site["value"].shape)))

In [None]:
noise = process.detect(weekly_data)

In [None]:
fig, axes = plt.subplots(2, figsize=(9, 4), sharex=True)
axes[0].plot(weekly_data[:4*24*7])
axes[0].set_title("Departures from {}".format(station))
axes[0].set_ylabel("# riders")
axes[0].set_ylim(0, None)
axes[0].set_xlim(0, len(weekly_data))
axes[1].plot(noise["trans"], "b-")
axes[1].plot(noise["obs"], "r-")
axes[1].set_xlabel("week after {}".format(dataset["start_date"][0].strftime("%Y-%m-%d")))
plt.tight_layout()
print(weekly_data.shape)

In [None]:
x = torch.linspace(-80, 8, 1000)
alpha = torch.tensor(0.1)
beta = 1.
loc = alpha.digamma()
scale = alpha.polygamma(1).sqrt()
y = dist.TransformedDistribution(dist.Gamma(alpha, beta),
                                 dist.transforms.ExpTransform().inv).log_prob(x).exp()
y2 = dist.Normal(loc, scale).log_prob(x).exp()
plt.plot(x, y, 'k--')
plt.plot(x, y2, 'r-');