-
Notifications
You must be signed in to change notification settings - Fork 222
/
stochastic_volatility.py
137 lines (109 loc) · 4.78 KB
/
stochastic_volatility.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0
"""
Example: Stochastic Volatility
==============================
Generative model:
.. math:: :nowrap:
\\begin{align}
\\sigma & \\sim \\text{Exponential}(50) \\\\
\\nu & \\sim \\text{Exponential}(.1) \\\\
s_i & \\sim \\text{Normal}(s_{i-1}, \\sigma^{- 2}) \\\\
r_i & \\sim \\text{StudentT}(\\nu, 0, \\exp(s_i))
\\end{align}
This example is from PyMC3 [1], which itself is adapted from the original experiment
from [2]. A discussion about translating this in Pyro appears in [3].
We take this example to illustrate how to use the functional interface `hmc`. However,
we recommend readers to use `MCMC` class as in other examples because it is more stable
and has more features supported.
**References:**
1. *Stochastic Volatility Model*, https://docs.pymc.io/notebooks/stochastic_volatility.html
2. *The No-U-Turn Sampler: Adaptively Setting Path Lengths in Hamiltonian Monte Carlo*,
https://arxiv.org/pdf/1111.4246.pdf
3. Pyro forum discussion, https://forum.pyro.ai/t/problems-transforming-a-pymc3-model-to-pyro-mcmc/208/14
.. image:: ../_static/img/examples/stochastic_volatility.png
:align: center
"""
import argparse
import os
import matplotlib
import matplotlib.dates as mdates
import matplotlib.pyplot as plt
import jax.numpy as jnp
import jax.random as random
import numpyro
import numpyro.distributions as dist
from numpyro.examples.datasets import SP500, load_dataset
from numpyro.infer.hmc import hmc
from numpyro.infer.util import initialize_model
from numpyro.util import fori_collect
matplotlib.use("Agg") # noqa: E402
def model(returns):
step_size = numpyro.sample("sigma", dist.Exponential(50.0))
s = numpyro.sample(
"s", dist.GaussianRandomWalk(scale=step_size, num_steps=jnp.shape(returns)[0])
)
nu = numpyro.sample("nu", dist.Exponential(0.1))
return numpyro.sample(
"r", dist.StudentT(df=nu, loc=0.0, scale=jnp.exp(s)), obs=returns
)
def print_results(posterior, dates):
def _print_row(values, row_name=""):
quantiles = jnp.array([0.2, 0.4, 0.5, 0.6, 0.8])
row_name_fmt = "{:>8}"
header_format = row_name_fmt + "{:>12}" * 5
row_format = row_name_fmt + "{:>12.3f}" * 5
columns = ["(p{})".format(int(q * 100)) for q in quantiles]
q_values = jnp.quantile(values, quantiles, axis=0)
print(header_format.format("", *columns))
print(row_format.format(row_name, *q_values))
print("\n")
print("=" * 20, "sigma", "=" * 20)
_print_row(posterior["sigma"])
print("=" * 20, "nu", "=" * 20)
_print_row(posterior["nu"])
print("=" * 20, "volatility", "=" * 20)
for i in range(0, len(dates), 180):
_print_row(jnp.exp(posterior["s"][:, i]), dates[i])
def main(args):
_, fetch = load_dataset(SP500, shuffle=False)
dates, returns = fetch()
init_rng_key, sample_rng_key = random.split(random.PRNGKey(args.rng_seed))
model_info = initialize_model(init_rng_key, model, model_args=(returns,))
init_kernel, sample_kernel = hmc(model_info.potential_fn, algo="NUTS")
hmc_state = init_kernel(
model_info.param_info, args.num_warmup, rng_key=sample_rng_key
)
hmc_states = fori_collect(
args.num_warmup,
args.num_warmup + args.num_samples,
sample_kernel,
hmc_state,
transform=lambda hmc_state: model_info.postprocess_fn(hmc_state.z),
progbar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True,
)
print_results(hmc_states, dates)
fig, ax = plt.subplots(figsize=(8, 6), constrained_layout=True)
dates = mdates.num2date(mdates.datestr2num(dates))
ax.plot(dates, returns, lw=0.5)
# format the ticks
ax.xaxis.set_major_locator(mdates.YearLocator())
ax.xaxis.set_major_formatter(mdates.DateFormatter("%Y"))
ax.xaxis.set_minor_locator(mdates.MonthLocator())
ax.plot(dates, jnp.exp(hmc_states["s"].T), "r", alpha=0.01)
legend = ax.legend(["returns", "volatility"], loc="upper right")
legend.legendHandles[1].set_alpha(0.6)
ax.set(xlabel="time", ylabel="returns", title="Volatility of S&P500 over time")
plt.savefig("stochastic_volatility_plot.pdf")
if __name__ == "__main__":
assert numpyro.__version__.startswith("0.9.2")
parser = argparse.ArgumentParser(description="Stochastic Volatility Model")
parser.add_argument("-n", "--num-samples", nargs="?", default=600, type=int)
parser.add_argument("--num-warmup", nargs="?", default=600, type=int)
parser.add_argument("--device", default="cpu", type=str, help='use "cpu" or "gpu".')
parser.add_argument(
"--rng_seed", default=21, type=int, help="random number generator seed"
)
args = parser.parse_args()
numpyro.set_platform(args.device)
main(args)