-
Notifications
You must be signed in to change notification settings - Fork 223
/
test_pickle.py
239 lines (198 loc) · 7.8 KB
/
test_pickle.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0
import pickle
import numpy as np
from numpy.testing import assert_allclose
import pytest
from jax import random
import jax.numpy as jnp
from jax.tree_util import tree_all, tree_map
import numpyro
from numpyro.contrib.funsor import config_kl
import numpyro.distributions as dist
from numpyro.distributions import constraints
from numpyro.distributions.constraints import (
boolean,
circular,
corr_cholesky,
corr_matrix,
greater_than,
interval,
l1_ball,
lower_cholesky,
nonnegative_integer,
ordered_vector,
positive,
positive_definite,
positive_integer,
positive_ordered_vector,
real,
real_matrix,
real_vector,
scaled_unit_lower_cholesky,
simplex,
softplus_lower_cholesky,
softplus_positive,
sphere,
unit_interval,
)
from numpyro.infer import (
HMC,
HMCECS,
MCMC,
NUTS,
SA,
SVI,
BarkerMH,
DiscreteHMCGibbs,
MixedHMC,
Predictive,
TraceEnum_ELBO,
)
from numpyro.infer.autoguide import AutoDelta, AutoDiagonalNormal, AutoNormal
def normal_model():
numpyro.sample("x", dist.Normal(0, 1))
def bernoulli_model():
numpyro.sample("x", dist.Bernoulli(0.5))
def logistic_regression():
data = jnp.arange(10)
x = numpyro.sample("x", dist.Normal(0, 1))
with numpyro.plate("N", 10, subsample_size=2):
batch = numpyro.subsample(data, 0)
numpyro.sample("obs", dist.Bernoulli(logits=x), obs=batch)
def gmm(data, K):
mix_proportions = numpyro.sample("phi", dist.Dirichlet(jnp.ones(K)))
with numpyro.plate("num_clusters", K, dim=-1):
cluster_means = numpyro.sample("cluster_means", dist.Normal(jnp.arange(K), 1.0))
with numpyro.plate("data", data.shape[0], dim=-1):
assignments = numpyro.sample(
"assignments",
dist.Categorical(mix_proportions),
infer={"enumerate": "parallel"},
)
numpyro.sample("obs", dist.Normal(cluster_means[assignments], 1.0), obs=data)
@pytest.mark.parametrize("kernel", [BarkerMH, HMC, NUTS, SA])
def test_pickle_hmc(kernel):
mcmc = MCMC(kernel(normal_model), num_warmup=10, num_samples=10)
mcmc.run(random.PRNGKey(0))
pickled_mcmc = pickle.loads(pickle.dumps(mcmc))
tree_all(tree_map(assert_allclose, mcmc.get_samples(), pickled_mcmc.get_samples()))
@pytest.mark.parametrize("kernel", [BarkerMH, HMC, NUTS, SA])
def test_pickle_hmc_enumeration(kernel):
K, N = 3, 1000
true_cluster_means = jnp.array([1.0, 5.0, 10.0])
true_mix_proportions = jnp.array([0.1, 0.3, 0.6])
cluster_assignments = dist.Categorical(true_mix_proportions).sample(
random.PRNGKey(0), (N,)
)
data = dist.Normal(true_cluster_means[cluster_assignments], 1.0).sample(
random.PRNGKey(1)
)
mcmc = MCMC(kernel(gmm), num_warmup=10, num_samples=10)
mcmc.run(random.PRNGKey(0), data, K)
pickled_mcmc = pickle.loads(pickle.dumps(mcmc))
tree_all(tree_map(assert_allclose, mcmc.get_samples(), pickled_mcmc.get_samples()))
@pytest.mark.parametrize("kernel", [DiscreteHMCGibbs, MixedHMC])
def test_pickle_discrete_hmc(kernel):
mcmc = MCMC(kernel(HMC(bernoulli_model)), num_warmup=10, num_samples=10)
mcmc.run(random.PRNGKey(0))
pickled_mcmc = pickle.loads(pickle.dumps(mcmc))
tree_all(tree_map(assert_allclose, mcmc.get_samples(), pickled_mcmc.get_samples()))
def test_pickle_hmcecs():
mcmc = MCMC(HMCECS(NUTS(logistic_regression)), num_warmup=10, num_samples=10)
mcmc.run(random.PRNGKey(0))
pickled_mcmc = pickle.loads(pickle.dumps(mcmc))
tree_all(tree_map(assert_allclose, mcmc.get_samples(), pickled_mcmc.get_samples()))
def poisson_regression(x, N):
rate = numpyro.sample("param", dist.Gamma(1.0, 1.0))
batch_size = len(x) if x is not None else None
with numpyro.plate("batch", N, batch_size):
numpyro.sample("x", dist.Poisson(rate), obs=x)
@pytest.mark.parametrize("guide_class", [AutoDelta, AutoDiagonalNormal, AutoNormal])
def test_pickle_autoguide(guide_class):
x = np.random.poisson(1.0, size=(100,))
guide = guide_class(poisson_regression)
optim = numpyro.optim.Adam(1e-2)
svi = SVI(poisson_regression, guide, optim, numpyro.infer.Trace_ELBO())
svi_result = svi.run(random.PRNGKey(1), 3, x, len(x))
pickled_guide = pickle.loads(pickle.dumps(guide))
predictive = Predictive(
poisson_regression,
guide=pickled_guide,
params=svi_result.params,
num_samples=1,
return_sites=["param", "x"],
)
samples = predictive(random.PRNGKey(1), None, 1)
assert set(samples.keys()) == {"param", "x"}
def test_pickle_singleton_constraint():
# some numpyro constraint classes such as constraints._Real, are only accessible
# through their public singleton instance, (such as constraint.real). This test
# ensures that pickling and unpickling singleton instances does not re-create
# additional instances, which is the default behavior of pickle, and which would
# break singleton semantics.
singleton_constraints = (
boolean,
circular,
corr_cholesky,
corr_matrix,
l1_ball,
lower_cholesky,
nonnegative_integer,
ordered_vector,
positive,
positive_definite,
positive_integer,
positive_ordered_vector,
real,
real_matrix,
real_vector,
scaled_unit_lower_cholesky,
simplex,
softplus_lower_cholesky,
softplus_positive,
sphere,
unit_interval,
)
for cnstr in singleton_constraints:
roundtripped_cnstr = pickle.loads(pickle.dumps(cnstr))
# make sure that the unpickled constraint is the original singleton constraint
assert roundtripped_cnstr is cnstr
# Test that it remains possible to pickle newly-created, non-singleton constraints.
# because these constraints are neither singleton nor exposed as top-level variables
# of the numpyro.distributions.constraints module, these objects are not pickled by
# reference, but by value.
int_cstr = interval(1.0, 2.0)
roundtripped_int_cstr = pickle.loads(pickle.dumps(int_cstr))
assert type(roundtripped_int_cstr) is type(int_cstr)
assert int_cstr.lower_bound == roundtripped_int_cstr.lower_bound
assert int_cstr.upper_bound == roundtripped_int_cstr.upper_bound
gt_cstr = greater_than(1.0)
roundtripped_gt_cstr = pickle.loads(pickle.dumps(gt_cstr))
assert type(roundtripped_gt_cstr) is type(gt_cstr)
assert gt_cstr.lower_bound == roundtripped_gt_cstr.lower_bound
def test_mcmc_pickle_post_warmup():
mcmc = MCMC(NUTS(normal_model), num_warmup=10, num_samples=10)
mcmc.run(random.PRNGKey(0))
pickled_mcmc = pickle.loads(pickle.dumps(mcmc))
pickled_mcmc.post_warmup_state = pickled_mcmc.last_state
pickled_mcmc.run(random.PRNGKey(1))
def bernoulli_regression(data):
f = numpyro.sample("beta", dist.Beta(1.0, 1.0))
with numpyro.plate("N", len(data)):
numpyro.sample("obs", dist.Bernoulli(f), obs=data)
def test_beta_bernoulli():
data = jnp.array([1.0] * 8 + [0.0] * 2)
def guide(data):
alpha_q = numpyro.param("alpha_q", 1.0, constraint=constraints.positive)
beta_q = numpyro.param("beta_q", 1.0, constraint=constraints.positive)
numpyro.sample("beta", dist.Beta(alpha_q, beta_q))
pickled_model = pickle.loads(pickle.dumps(config_kl(bernoulli_regression)))
optim = numpyro.optim.Adam(1e-2)
svi = SVI(config_kl(bernoulli_regression), guide, optim, TraceEnum_ELBO())
svi_result = svi.run(random.PRNGKey(0), 3, data)
params = svi_result.params
svi = SVI(pickled_model, guide, optim, TraceEnum_ELBO())
svi_result = svi.run(random.PRNGKey(0), 3, data)
pickled_params = svi_result.params
tree_all(tree_map(assert_allclose, params, pickled_params))