Skip to content

Commit

Permalink
Add stochastic volatility model to examples (#143)
Browse files Browse the repository at this point in the history
* stash

* Add stochastic volatility model to examples

* fix lint

* revert dependency change

* Fix distribution test

* sort import

* Pin jax and jaxlib versions to prevent breakage

* address comments
  • Loading branch information
neerajprad authored and fehiepsi committed May 11, 2019
1 parent 7da9574 commit dde3ebf
Show file tree
Hide file tree
Showing 7 changed files with 158 additions and 4 deletions.
2 changes: 2 additions & 0 deletions numpyro/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
Dirichlet,
Exponential,
Gamma,
GaussianRandomWalk,
HalfCauchy,
HalfNormal,
LKJCholesky,
Expand Down Expand Up @@ -48,6 +49,7 @@
'Distribution',
'Exponential',
'Gamma',
'GaussianRandomWalk',
'HalfCauchy',
'HalfNormal',
'LogNormal',
Expand Down
39 changes: 38 additions & 1 deletion numpyro/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,13 @@
from numpyro.distributions.constraints import AbsTransform, AffineTransform, ExpTransform
from numpyro.distributions.distribution import Distribution, TransformedDistribution
from numpyro.distributions.util import (
cumsum,
matrix_to_tril_vec,
multigammaln,
promote_shapes,
signed_stick_breaking_tril,
standard_gamma,
vec_to_tril_matrix
vec_to_tril_matrix,
)


Expand Down Expand Up @@ -491,6 +492,42 @@ def support(self):
return constraints.greater_than(self.scale)


class GaussianRandomWalk(Distribution):
arg_constraints = {'num_steps': constraints.positive, 'scale': constraints.positive}
support = constraints.real
# FIXME: cannot take grad through random.normal with dynamic shape
reparametrized_params = []

def __init__(self, scale, num_steps=1, validate_args=None):
assert np.shape(num_steps) == ()
self.scale = scale
self.num_steps = num_steps
batch_shape, event_shape = np.shape(scale), (num_steps,)
super(GaussianRandomWalk, self).__init__(batch_shape, event_shape, validate_args=validate_args)

def sample(self, key, size=()):
shape = size + self.batch_shape + self.event_shape
walks = random.normal(key, shape=shape)
return cumsum(walks) * np.expand_dims(self.scale, axis=-1)

def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)
init_prob = Normal(0., self.scale).log_prob(value[..., 0])
scale = np.expand_dims(self.scale, -1)
step_probs = Normal(value[..., :-1], scale).log_prob(value[..., 1:])
return init_prob + np.sum(step_probs, axis=-1)

@property
def mean(self):
return np.zeros(self.batch_shape + self.event_shape)

@property
def variance(self):
return np.broadcast_to(np.expand_dims(self.scale, -1) ** 2 * np.arange(1, self.num_steps + 1),
self.batch_shape + self.event_shape)


class StudentT(Distribution):
arg_constraints = {'df': constraints.positive, 'loc': constraints.real, 'scale': constraints.positive}
support = constraints.real
Expand Down
21 changes: 21 additions & 0 deletions numpyro/examples/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@
'https://d2fefpcigoriu7.cloudfront.net/datasets/UCBadmit.csv',
])

SP500 = dset('SP500', [
'https://d2fefpcigoriu7.cloudfront.net/datasets/SP500.csv',
])


def _download(dset):
for url in dset.urls:
Expand Down Expand Up @@ -86,6 +90,21 @@ def train_test_split(file):
'test': (test, player_names)}


def _load_sp500():
_download(SP500)

date, value = [], []
with open(os.path.join(DATA_DIR, 'SP500.csv'), 'r') as f:
csv_reader = csv.DictReader(f, quoting=csv.QUOTE_NONE)
for row in csv_reader:
date.append(row['DATE'])
value.append(float(row['VALUE']))
date = np.stack(date)
value = np.stack(value)

return {'train': (date, value)}


def _load_ucbadmit():
_download(UCBADMIT)

Expand All @@ -111,6 +130,8 @@ def _load(dset):
return _load_mnist()
elif dset == BASEBALL:
return _load_baseball()
elif dset == SP500:
return _load_sp500()
elif dset == UCBADMIT:
return _load_ucbadmit()
raise ValueError('Dataset - {} not found.'.format(dset.name))
Expand Down
86 changes: 86 additions & 0 deletions numpyro/examples/stochastic_volatility.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import argparse

import numpy as onp

import jax.numpy as np
import jax.random as random
from jax.config import config as jax_config

import numpyro.distributions as dist
from numpyro.examples.datasets import SP500, load_dataset
from numpyro.handlers import sample
from numpyro.hmc_util import initialize_model
from numpyro.mcmc import hmc
from numpyro.util import fori_collect


"""
Generative model:
sigma ~ Exponential(50)
nu ~ Exponential(.1)
s_i ~ Normal(s_{i-1}, sigma - 2)
r_i ~ StudentT(nu, 0, exp(-2 s_i))
This example is from PyMC3 [1], which itself is adapted from the original experiment
from [2]. A discussion about translating this in Pyro appears in [3].
For more details, refer to:
1. *Stochastic Volatility Model*, https://docs.pymc.io/notebooks/stochastic_volatility.html
2. *The No-U-Turn Sampler: Adaptively Setting Path Lengths in Hamiltonian Monte Carlo*,
https://arxiv.org/pdf/1111.4246.pdf
3. Forum discussion, https://forum.pyro.ai/t/problems-transforming-a-pymc3-model-to-pyro-mcmc/208/14
"""


def model(returns):
step_size = sample('sigma', dist.Exponential(50.))
s = sample('s', dist.GaussianRandomWalk(scale=step_size, num_steps=np.shape(returns)[0]))
nu = sample('nu', dist.Exponential(.1))
return sample('r', dist.StudentT(df=nu, loc=0., scale=np.exp(-2*s)),
obs=returns)


def print_results(posterior, dates):
def _print_row(values, row_name=''):
quantiles = [0.2, 0.4, 0.5, 0.6, 0.8]
row_name_fmt = '{:>' + str(len(row_name)) + '}'
header_format = row_name_fmt + '{:>12}' * 5
row_format = row_name_fmt + '{:>12.3f}' * 5
columns = ['(p{})'.format(q * 100) for q in quantiles]
q_values = onp.quantile(values, quantiles, axis=0)
print(header_format.format('', *columns))
print(row_format.format(row_name, *q_values))
print('\n')

print('=' * 5, 'sigma', '=' * 5)
_print_row(posterior['sigma'])
print('=' * 5, 'nu', '=' * 5)
_print_row(posterior['nu'])
print('=' * 5, 'volatility', '=' * 5)
for i in range(0, len(dates), 180):
_print_row(np.exp(-2 * posterior['s'][:, i]), dates[i])


def main(args):
jax_config.update('jax_platform_name', args.device)
_, fetch = load_dataset(SP500, shuffle=False)
dates, returns = fetch()
init_rng, sample_rng = random.split(random.PRNGKey(args.rng))
init_params, potential_fn, transform_fn = initialize_model(init_rng, model, (returns,), {})
init_kernel, sample_kernel = hmc(potential_fn, algo='NUTS')
hmc_state = init_kernel(init_params, args.num_warmup_steps, rng=sample_rng)
hmc_states = fori_collect(args.num_samples, sample_kernel, hmc_state,
transform=lambda hmc_state: transform_fn(hmc_state.z))
print_results(hmc_states, dates)


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Stochastic Volatility Model")
parser.add_argument('-n', '--num-samples', nargs='?', default=3000, type=int)
parser.add_argument('--num-warmup-steps', nargs='?', default=1500, type=int)
parser.add_argument('--device', default='cpu', type=str, help='use "cpu" or "gpu".')
parser.add_argument('--rng', default=21, type=int, help='random number generator seed')
args = parser.parse_args()
main(args)
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
author='Uber AI Labs',
author_email='npradhan@uber.com',
install_requires=[
'jax>=0.1.26',
'jaxlib>=0.1.13',
'jax==0.1.26',
'jaxlib==0.1.14',
'tqdm',
],
extras_require={
Expand Down
2 changes: 2 additions & 0 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ def __new__(cls, jax_dist, *params):
T(dist.Exponential, np.array([4., 2.])),
T(dist.Gamma, np.array([1.7]), np.array([[2.], [3.]])),
T(dist.Gamma, np.array([0.5, 1.3]), np.array([[1.], [3.]])),
T(dist.GaussianRandomWalk, 0.1, 10),
T(dist.GaussianRandomWalk, np.array([0.1, 0.3, 0.25]), 10),
T(dist.HalfCauchy, 1.),
T(dist.HalfCauchy, np.array([1., 2.])),
T(dist.HalfNormal, 1.),
Expand Down
8 changes: 7 additions & 1 deletion test/test_example_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import jax.numpy as np
from jax import lax

from numpyro.examples.datasets import BASEBALL, MNIST, load_dataset
from numpyro.examples.datasets import BASEBALL, MNIST, SP500, load_dataset


def test_mnist_data_load():
Expand All @@ -20,3 +20,9 @@ def test_baseball_data_load():
dataset = fetch(0, idx)
assert np.shape(dataset[0]) == (18, 2)
assert np.shape(dataset[1]) == (18,)


def test_sp500_data_load():
_, fetch = load_dataset(SP500, split='train', shuffle=False)
date, value = fetch()
assert np.shape(date) == np.shape(date) == (2427,)

0 comments on commit dde3ebf

Please sign in to comment.