# Forecasting BART hourly ridership (mulivariate)

In [1]:
import logging
import pyro
import pyro.distributions as dist
import pyro.poutine as poutine
import torch
from pyro.infer.autoguide import init_to_sample
from pyro.contrib.examples.bart import load_bart_od
from pyro.contrib.forecast import Forecaster, ForecastingModel, eval_crps
from pyro.ops.tensor_utils import periodic_cumsum, periodic_repeat
from pyro.ops.stats import quantile
import matplotlib.pyplot as plt

from bart_multi import Model

%matplotlib inline
pyro.enable_validation(True)
logging.getLogger("pyro").setLevel(logging.DEBUG)
logging.getLogger("pyro").handlers[0].setLevel(logging.DEBUG)

In [2]:
dataset = load_bart_od()
print(dataset.keys())
print(dataset["counts"].shape)
print(" ".join(dataset["stations"]))

data = dataset["counts"].permute(1, 2, 0).unsqueeze(-1).log1p().contiguous()
print(dataset["counts"].shape, data.shape)
covariates = torch.zeros(data.size(-2), 0)  # empty

dict_keys(['stations', 'start_date', 'counts'])
torch.Size([78888, 50, 50])
12TH 16TH 19TH 24TH ANTC ASHB BALB BAYF BERY CAST CIVC COLM COLS CONC DALY DBRK DELN DUBL EMBR FRMT FTVL GLEN HAYW LAFY LAKE MCAR MLBR MLPT MONT NBRK NCON OAKL ORIN PCTR PHIL PITT PLZA POWL RICH ROCK SANL SBRN SFIA SHAY SSAN UCTY WARM WCRK WDUB WOAK
torch.Size([78888, 50, 50]) torch.Size([50, 50, 78888, 1])


In [3]:
T0 = 0
T1 = 24 * 90
T2 = T1 + 24 * 14
station = "EMBR"

In [4]:
def init_loc_fn(site):
    if site["name"].endswith("stability"):
        return torch.full(site["fn"].shape(), 1.9)
    if site["name"].endswith("dof"):
        return torch.full(site["fn"].shape(), 10.0)
    return init_to_sample(site)

def create_plates(zero_data, covariates):
    num_origins, num_destins, duration, one = zero_data.shape
    origin_plate = pyro.plate("origin", num_origins, subsample_size=10, dim=-3)
    # We reuse the subsample so that both plates are subsampled identically.
    with origin_plate as subsample:
        pass
    destin_plate = pyro.plate("destin", num_destins, subsample=subsample, dim=-2)
    return origin_plate, destin_plate

def train(dist_type, num_steps=1001):
    pyro.clear_param_store()
    pyro.set_rng_seed(20200313)
    forecaster = Forecaster(Model(dist_type), data[..., :T1, :], covariates[:T1, :],
                            init_loc_fn=init_loc_fn, # create_plates=create_plates,
                            learning_rate=0.1, num_steps=num_steps, log_every=100)
    for name, value in forecaster.guide.median().items():
        if value.numel() == 1:
            print("{} = {:0.4g}".format(name, value.item()))
    return forecaster
            
def plot(forecaster):
    samples = forecaster(data[..., :T1, :], covariates[:T2, :], num_samples=100)
    samples.clamp_(min=0).expm1()
    p10, p50, p90 = quantile(samples[..., 0], (0.1, 0.5, 0.9)).squeeze(-1)
    crps = eval_crps(samples, data[..., T1:T2, :])
    print(samples.shape, p10.shape)

    fig, axes = plt.subplots(8, 1, figsize=(9, 10), sharex=True)
    plt.subplots_adjust(hspace=0)
    j = dataset["stations"].index(station)
    axes[0].set_title("# hourly arrivals to {} (CRPS = {:0.3g})"
                      .format(station, crps))
    for i, ax in enumerate(axes):
        ax.fill_between(torch.arange(T1, T2), p10[i, j], p90[i, j], color="red", alpha=0.3)
        ax.plot(torch.arange(T1, T2), p50[i, j], 'r-', lw=1, label='forecast')
        ax.plot(torch.arange(T1 - 24 * 7, T2),
                data[i, j, T1 - 24 * 7: T2, 0], 'k-', lw=1, label='truth')
        ax.set_ylabel("from {}".format(dataset["stations"][i]))
    ax.set_xlabel("Hour after 2011-01-01")
    ax.set_xlim(T1 - 24 * 7, T2)
    axes[0].legend(loc="upper left");

In [None]:
%%time
forecaster = train("normal")
plot(forecaster)

INFO 	 step    0 loss = 1.46462e+13


In [None]:
%%time
forecaster = train("stable")
plot(forecaster)

In [None]:
%%time
forecaster = train("studentt")
plot(forecaster)