/
jax.py
727 lines (618 loc) · 25.3 KB
/
jax.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
# Copyright 2023 The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import re
import sys
from datetime import datetime
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Sequence, Union
import arviz as az
import jax
import numpy as np
import pytensor.tensor as pt
from arviz.data.base import make_attrs
from jax.experimental.maps import SerialLoop, xmap
from pytensor.compile import SharedVariable, Supervisor, mode
from pytensor.graph.basic import graph_inputs
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.replace import clone_replace
from pytensor.link.jax.dispatch import jax_funcify
from pytensor.raise_op import Assert
from pytensor.tensor import TensorVariable
from pytensor.tensor.random.type import RandomType
from pytensor.tensor.shape import SpecifyShape
from pymc import Model, modelcontext
from pymc.backends.arviz import find_constants, find_observations
from pymc.initial_point import StartDict
from pymc.logprob.utils import CheckParameterValue
from pymc.sampling.mcmc import _init_jitter
from pymc.util import (
RandomSeed,
RandomState,
_get_seeds_per_chain,
get_default_varnames,
)
xla_flags_env = os.getenv("XLA_FLAGS", "")
xla_flags = re.sub(r"--xla_force_host_platform_device_count=.+\s", "", xla_flags_env).split()
os.environ["XLA_FLAGS"] = " ".join([f"--xla_force_host_platform_device_count={100}"] + xla_flags)
__all__ = (
"get_jaxified_graph",
"get_jaxified_logp",
"sample_blackjax_nuts",
"sample_numpyro_nuts",
)
@jax_funcify.register(Assert)
@jax_funcify.register(CheckParameterValue)
@jax_funcify.register(SpecifyShape)
def jax_funcify_Assert(op, **kwargs):
# Jax does not allow assert whose values aren't known during JIT compilation
# within it's JIT-ed code. Hence we need to make a simple pass through
# version of the Assert Op.
# https://github.com/google/jax/issues/2273#issuecomment-589098722
def assert_fn(value, *inps):
return value
return assert_fn
def _replace_shared_variables(graph: List[TensorVariable]) -> List[TensorVariable]:
"""Replace shared variables in graph by their constant values
Raises
------
ValueError
If any shared variable contains default_updates
"""
shared_variables = [var for var in graph_inputs(graph) if isinstance(var, SharedVariable)]
if any(isinstance(var.type, RandomType) for var in shared_variables):
raise ValueError(
"Graph contains shared RandomType variables which cannot be safely replaced"
)
if any(var.default_update is not None for var in shared_variables):
raise ValueError(
"Graph contains shared variables with default_update which cannot "
"be safely replaced."
)
replacements = {var: pt.constant(var.get_value(borrow=True)) for var in shared_variables}
new_graph = clone_replace(graph, replace=replacements)
return new_graph
def get_jaxified_graph(
inputs: Optional[List[TensorVariable]] = None,
outputs: Optional[List[TensorVariable]] = None,
) -> List[TensorVariable]:
"""Compile an PyTensor graph into an optimized JAX function"""
graph = _replace_shared_variables(outputs) if outputs is not None else None
fgraph = FunctionGraph(inputs=inputs, outputs=graph, clone=True)
# We need to add a Supervisor to the fgraph to be able to run the
# JAX sequential optimizer without warnings. We made sure there
# are no mutable input variables, so we only need to check for
# "destroyers". This should be automatically handled by PyTensor
# once https://github.com/pytensor-devs/pytensor/issues/637 is fixed.
fgraph.attach_feature(
Supervisor(
input
for input in fgraph.inputs
if not (hasattr(fgraph, "destroyers") and fgraph.has_destroyers([input]))
)
)
mode.JAX.optimizer.rewrite(fgraph)
# We now jaxify the optimized fgraph
return jax_funcify(fgraph)
def get_jaxified_logp(model: Model, negative_logp=True) -> Callable:
model_logp = model.logp()
if not negative_logp:
model_logp = -model_logp
logp_fn = get_jaxified_graph(inputs=model.value_vars, outputs=[model_logp])
def logp_fn_wrap(x):
return logp_fn(*x)[0]
return logp_fn_wrap
# Adopted from arviz numpyro extractor
def _sample_stats_to_xarray(posterior):
"""Extract sample_stats from NumPyro posterior."""
rename_key = {
"potential_energy": "lp",
"adapt_state.step_size": "step_size",
"num_steps": "n_steps",
"accept_prob": "acceptance_rate",
}
data = {}
for stat, value in posterior.get_extra_fields(group_by_chain=True).items():
if isinstance(value, (dict, tuple)):
continue
name = rename_key.get(stat, stat)
value = value.copy()
data[name] = value
if stat == "num_steps":
data["tree_depth"] = np.log2(value).astype(int) + 1
return data
def _postprocess_samples(
jax_fn: List[TensorVariable],
raw_mcmc_samples: List[TensorVariable],
postprocessing_backend: str,
num_chunks: Optional[int] = None,
) -> List[TensorVariable]:
if num_chunks is not None:
loop = xmap(
jax_fn,
in_axes=["chain", "samples", ...],
out_axes=["chain", "samples", ...],
axis_resources={"samples": SerialLoop(num_chunks)},
)
f = xmap(loop, in_axes=[...], out_axes=[...])
return f(*jax.device_put(raw_mcmc_samples, jax.devices(postprocessing_backend)[0]))
else:
return jax.vmap(jax.vmap(jax_fn))(
*jax.device_put(raw_mcmc_samples, jax.devices(postprocessing_backend)[0])
)
def _blackjax_stats_to_dict(sample_stats, potential_energy) -> Dict:
"""Extract compatible stats from blackjax NUTS sampler
with PyMC/Arviz naming conventions.
Parameters
----------
sample_stats: NUTSInfo
Blackjax NUTSInfo object containing sampler statistics
potential_energy: ArrayLike
Potential energy values of sampled positions.
Returns
-------
Dict[str, ArrayLike]
Dictionary of sampler statistics.
"""
rename_key = {
"is_divergent": "diverging",
"energy": "energy",
"num_trajectory_expansions": "tree_depth",
"num_integration_steps": "n_steps",
"acceptance_rate": "acceptance_rate", # naming here is
"acceptance_probability": "acceptance_rate", # depending on blackjax version
}
converted_stats = {}
converted_stats["lp"] = potential_energy
for old_name, new_name in rename_key.items():
value = getattr(sample_stats, old_name, None)
if value is None:
continue
converted_stats[new_name] = value
return converted_stats
def _get_log_likelihood(
model: Model, samples, backend=None, num_chunks: Optional[int] = None
) -> Dict:
"""Compute log-likelihood for all observations"""
elemwise_logp = model.logp(model.observed_RVs, sum=False)
jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=elemwise_logp)
result = _postprocess_samples(jax_fn, samples, backend, num_chunks=num_chunks)
return {v.name: r for v, r in zip(model.observed_RVs, result)}
def _get_batched_jittered_initial_points(
model: Model,
chains: int,
initvals: Optional[Union[StartDict, Sequence[Optional[StartDict]]]],
random_seed: RandomSeed,
jitter: bool = True,
jitter_max_retries: int = 10,
) -> Union[np.ndarray, List[np.ndarray]]:
"""Get jittered initial point in format expected by NumPyro MCMC kernel
Returns
-------
out: list of ndarrays
list with one item per variable and number of chains as batch dimension.
Each item has shape `(chains, *var.shape)`
"""
initial_points = _init_jitter(
model,
initvals,
seeds=_get_seeds_per_chain(random_seed, chains),
jitter=jitter,
jitter_max_retries=jitter_max_retries,
)
initial_points_values = [list(initial_point.values()) for initial_point in initial_points]
if chains == 1:
return initial_points_values[0]
return [np.stack(init_state) for init_state in zip(*initial_points_values)]
def _update_coords_and_dims(
coords: Dict[str, Any], dims: Dict[str, Any], idata_kwargs: Dict[str, Any]
) -> None:
"""Update 'coords' and 'dims' dicts with values in 'idata_kwargs'."""
if "coords" in idata_kwargs:
coords.update(idata_kwargs.pop("coords"))
if "dims" in idata_kwargs:
dims.update(idata_kwargs.pop("dims"))
@partial(jax.jit, static_argnums=(2, 3, 4, 5, 6))
def _blackjax_inference_loop(
seed,
init_position,
logprob_fn,
draws,
tune,
target_accept,
algorithm=None,
):
import blackjax
if algorithm is None:
algorithm = blackjax.nuts
adapt = blackjax.window_adaptation(
algorithm=algorithm,
logprob_fn=logprob_fn,
num_steps=tune,
target_acceptance_rate=target_accept,
)
last_state, kernel, _ = adapt.run(seed, init_position)
def inference_loop(rng_key, initial_state):
def one_step(state, rng_key):
state, info = kernel(rng_key, state)
return state, (state, info)
keys = jax.random.split(rng_key, draws)
_, (states, infos) = jax.lax.scan(one_step, initial_state, keys)
return states, infos
return inference_loop(seed, last_state)
def sample_blackjax_nuts(
draws: int = 1000,
tune: int = 1000,
chains: int = 4,
target_accept: float = 0.8,
random_seed: Optional[RandomState] = None,
initvals: Optional[Union[StartDict, Sequence[Optional[StartDict]]]] = None,
model: Optional[Model] = None,
var_names: Optional[Sequence[str]] = None,
keep_untransformed: bool = False,
chain_method: str = "parallel",
postprocessing_backend: Optional[str] = None,
postprocessing_chunks: Optional[int] = None,
idata_kwargs: Optional[Dict[str, Any]] = None,
**kwargs,
) -> az.InferenceData:
"""
Draw samples from the posterior using the NUTS method from the ``blackjax`` library.
Parameters
----------
draws : int, default 1000
The number of samples to draw. The number of tuned samples are discarded by
default.
tune : int, default 1000
Number of iterations to tune. Samplers adjust the step sizes, scalings or
similar during tuning. Tuning samples will be drawn in addition to the number
specified in the ``draws`` argument.
chains : int, default 4
The number of chains to sample.
target_accept : float in [0, 1].
The step size is tuned such that we approximate this acceptance rate. Higher
values like 0.9 or 0.95 often work better for problematic posteriors.
random_seed : int, RandomState or Generator, optional
Random seed used by the sampling steps.
initvals: StartDict or Sequence[Optional[StartDict]], optional
Initial values for random variables provided as a dictionary (or sequence of
dictionaries) mapping the random variable (by name or reference) to desired
starting values.
model : Model, optional
Model to sample from. The model needs to have free random variables. When inside
a ``with`` model context, it defaults to that model, otherwise the model must be
passed explicitly.
var_names : sequence of str, optional
Names of variables for which to compute the posterior samples. Defaults to all
variables in the posterior.
keep_untransformed : bool, default False
Include untransformed variables in the posterior samples. Defaults to False.
chain_method : str, default "parallel"
Specify how samples should be drawn. The choices include "parallel", and
"vectorized".
postprocessing_backend : str, optional
Specify how postprocessing should be computed. gpu or cpu
postprocessing_chunks: Optional[int], default None
Specify the number of chunks the postprocessing should be computed in. More
chunks reduces memory usage at the cost of losing some vectorization, None
uses jax.vmap
idata_kwargs : dict, optional
Keyword arguments for :func:`arviz.from_dict`. It also accepts a boolean as
value for the ``log_likelihood`` key to indicate that the pointwise log
likelihood should not be included in the returned object. Values for
``observed_data``, ``constant_data``, ``coords``, and ``dims`` are inferred from
the ``model`` argument if not provided in ``idata_kwargs``. If ``coords`` and
``dims`` are provided, they are used to update the inferred dictionaries.
Returns
-------
InferenceData
ArviZ ``InferenceData`` object that contains the posterior samples, together
with their respective sample stats and pointwise log likeihood values (unless
skipped with ``idata_kwargs``).
"""
import blackjax
model = modelcontext(model)
if var_names is None:
var_names = model.unobserved_value_vars
vars_to_sample = list(get_default_varnames(var_names, include_transformed=keep_untransformed))
coords = {
cname: np.array(cvals) if isinstance(cvals, tuple) else cvals
for cname, cvals in model.coords.items()
if cvals is not None
}
dims = {
var_name: [dim for dim in dims if dim is not None]
for var_name, dims in model.named_vars_to_dims.items()
}
(random_seed,) = _get_seeds_per_chain(random_seed, 1)
tic1 = datetime.now()
print("Compiling...", file=sys.stdout)
init_params = _get_batched_jittered_initial_points(
model=model,
chains=chains,
initvals=initvals,
random_seed=random_seed,
)
if chains == 1:
init_params = [np.stack(init_state) for init_state in zip(init_params)]
logprob_fn = get_jaxified_logp(model)
seed = jax.random.PRNGKey(random_seed)
keys = jax.random.split(seed, chains)
get_posterior_samples = partial(
_blackjax_inference_loop,
logprob_fn=logprob_fn,
tune=tune,
draws=draws,
target_accept=target_accept,
)
tic2 = datetime.now()
print("Compilation time = ", tic2 - tic1, file=sys.stdout)
print("Sampling...", file=sys.stdout)
# Adapted from numpyro
if chain_method == "parallel":
map_fn = jax.pmap
elif chain_method == "vectorized":
map_fn = jax.vmap
else:
raise ValueError(
"Only supporting the following methods to draw chains:" ' "parallel" or "vectorized"'
)
states, stats = map_fn(get_posterior_samples)(keys, init_params)
raw_mcmc_samples = states.position
potential_energy = states.potential_energy
tic3 = datetime.now()
print("Sampling time = ", tic3 - tic2, file=sys.stdout)
print("Transforming variables...", file=sys.stdout)
jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=vars_to_sample)
result = _postprocess_samples(
jax_fn, raw_mcmc_samples, postprocessing_backend, num_chunks=postprocessing_chunks
)
mcmc_samples = {v.name: r for v, r in zip(vars_to_sample, result)}
mcmc_stats = _blackjax_stats_to_dict(stats, potential_energy)
tic4 = datetime.now()
print("Transformation time = ", tic4 - tic3, file=sys.stdout)
if idata_kwargs is None:
idata_kwargs = {}
else:
idata_kwargs = idata_kwargs.copy()
if idata_kwargs.pop("log_likelihood", False):
tic5 = datetime.now()
print("Computing Log Likelihood...", file=sys.stdout)
log_likelihood = _get_log_likelihood(
model,
raw_mcmc_samples,
backend=postprocessing_backend,
num_chunks=postprocessing_chunks,
)
tic6 = datetime.now()
print("Log Likelihood time = ", tic6 - tic5, file=sys.stdout)
else:
log_likelihood = None
attrs = {
"sampling_time": (tic3 - tic2).total_seconds(),
}
posterior = mcmc_samples
# Update 'coords' and 'dims' extracted from the model with user 'idata_kwargs'
# and drop keys 'coords' and 'dims' from 'idata_kwargs' if present.
_update_coords_and_dims(coords=coords, dims=dims, idata_kwargs=idata_kwargs)
# Use 'partial' to set default arguments before passing 'idata_kwargs'
to_trace = partial(
az.from_dict,
log_likelihood=log_likelihood,
observed_data=find_observations(model),
constant_data=find_constants(model),
sample_stats=mcmc_stats,
coords=coords,
dims=dims,
attrs=make_attrs(attrs, library=blackjax),
)
az_trace = to_trace(posterior=posterior, **idata_kwargs)
return az_trace
def _numpyro_nuts_defaults() -> Dict[str, Any]:
"""Defaults parameters for Numpyro NUTS."""
return {
"adapt_step_size": True,
"adapt_mass_matrix": True,
"dense_mass": False,
}
def _update_numpyro_nuts_kwargs(nuts_kwargs: Optional[Dict[str, Any]]) -> Dict[str, Any]:
"""Update default Numpyro NUTS parameters with new values."""
nuts_kwargs_defaults = _numpyro_nuts_defaults()
if nuts_kwargs is not None:
nuts_kwargs_defaults.update(nuts_kwargs)
return nuts_kwargs_defaults
def sample_numpyro_nuts(
draws: int = 1000,
tune: int = 1000,
chains: int = 4,
target_accept: float = 0.8,
random_seed: Optional[RandomState] = None,
initvals: Optional[Union[StartDict, Sequence[Optional[StartDict]]]] = None,
model: Optional[Model] = None,
var_names: Optional[Sequence[str]] = None,
progressbar: bool = True,
keep_untransformed: bool = False,
chain_method: str = "parallel",
postprocessing_backend: Optional[str] = None,
postprocessing_chunks: Optional[int] = None,
idata_kwargs: Optional[Dict] = None,
nuts_kwargs: Optional[Dict] = None,
**kwargs,
) -> az.InferenceData:
"""
Draw samples from the posterior using the NUTS method from the ``numpyro`` library.
Parameters
----------
draws : int, default 1000
The number of samples to draw. The number of tuned samples are discarded by
default.
tune : int, default 1000
Number of iterations to tune. Samplers adjust the step sizes, scalings or
similar during tuning. Tuning samples will be drawn in addition to the number
specified in the ``draws`` argument.
chains : int, default 4
The number of chains to sample.
target_accept : float in [0, 1].
The step size is tuned such that we approximate this acceptance rate. Higher
values like 0.9 or 0.95 often work better for problematic posteriors.
random_seed : int, RandomState or Generator, optional
Random seed used by the sampling steps.
initvals: StartDict or Sequence[Optional[StartDict]], optional
Initial values for random variables provided as a dictionary (or sequence of
dictionaries) mapping the random variable (by name or reference) to desired
starting values.
model : Model, optional
Model to sample from. The model needs to have free random variables. When inside
a ``with`` model context, it defaults to that model, otherwise the model must be
passed explicitly.
var_names : sequence of str, optional
Names of variables for which to compute the posterior samples. Defaults to all
variables in the posterior.
progressbar : bool, default True
Whether or not to display a progress bar in the command line. The bar shows the
percentage of completion, the sampling speed in samples per second (SPS), and
the estimated remaining time until completion ("expected time of arrival"; ETA).
keep_untransformed : bool, default False
Include untransformed variables in the posterior samples. Defaults to False.
chain_method : str, default "parallel"
Specify how samples should be drawn. The choices include "sequential",
"parallel", and "vectorized".
postprocessing_backend : Optional[str]
Specify how postprocessing should be computed. gpu or cpu
postprocessing_chunks: Optional[int], default None
Specify the number of chunks the postprocessing should be computed in. More
chunks reduces memory usage at the cost of losing some vectorization, None
uses jax.vmap
idata_kwargs : dict, optional
Keyword arguments for :func:`arviz.from_dict`. It also accepts a boolean as
value for the ``log_likelihood`` key to indicate that the pointwise log
likelihood should not be included in the returned object. Values for
``observed_data``, ``constant_data``, ``coords``, and ``dims`` are inferred from
the ``model`` argument if not provided in ``idata_kwargs``. If ``coords`` and
``dims`` are provided, they are used to update the inferred dictionaries.
nuts_kwargs: dict, optional
Keyword arguments for :func:`numpyro.infer.NUTS`.
Returns
-------
InferenceData
ArviZ ``InferenceData`` object that contains the posterior samples, together
with their respective sample stats and pointwise log likeihood values (unless
skipped with ``idata_kwargs``).
"""
import numpyro
from numpyro.infer import MCMC, NUTS
model = modelcontext(model)
if var_names is None:
var_names = model.unobserved_value_vars
vars_to_sample = list(get_default_varnames(var_names, include_transformed=keep_untransformed))
coords = {
cname: np.array(cvals) if isinstance(cvals, tuple) else cvals
for cname, cvals in model.coords.items()
if cvals is not None
}
dims = {
var_name: [dim for dim in dims if dim is not None]
for var_name, dims in model.named_vars_to_dims.items()
}
(random_seed,) = _get_seeds_per_chain(random_seed, 1)
tic1 = datetime.now()
print("Compiling...", file=sys.stdout)
init_params = _get_batched_jittered_initial_points(
model=model,
chains=chains,
initvals=initvals,
random_seed=random_seed,
)
logp_fn = get_jaxified_logp(model, negative_logp=False)
nuts_kwargs = _update_numpyro_nuts_kwargs(nuts_kwargs)
nuts_kernel = NUTS(
potential_fn=logp_fn,
target_accept_prob=target_accept,
**nuts_kwargs,
)
pmap_numpyro = MCMC(
nuts_kernel,
num_warmup=tune,
num_samples=draws,
num_chains=chains,
postprocess_fn=None,
chain_method=chain_method,
progress_bar=progressbar,
)
tic2 = datetime.now()
print("Compilation time = ", tic2 - tic1, file=sys.stdout)
print("Sampling...", file=sys.stdout)
map_seed = jax.random.PRNGKey(random_seed)
if chains > 1:
map_seed = jax.random.split(map_seed, chains)
pmap_numpyro.run(
map_seed,
init_params=init_params,
extra_fields=(
"num_steps",
"potential_energy",
"energy",
"adapt_state.step_size",
"accept_prob",
"diverging",
),
)
raw_mcmc_samples = pmap_numpyro.get_samples(group_by_chain=True)
tic3 = datetime.now()
print("Sampling time = ", tic3 - tic2, file=sys.stdout)
print("Transforming variables...", file=sys.stdout)
jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=vars_to_sample)
result = _postprocess_samples(
jax_fn, raw_mcmc_samples, postprocessing_backend, num_chunks=postprocessing_chunks
)
mcmc_samples = {v.name: r for v, r in zip(vars_to_sample, result)}
tic4 = datetime.now()
print("Transformation time = ", tic4 - tic3, file=sys.stdout)
if idata_kwargs is None:
idata_kwargs = {}
else:
idata_kwargs = idata_kwargs.copy()
if idata_kwargs.pop("log_likelihood", False):
tic5 = datetime.now()
print("Computing Log Likelihood...", file=sys.stdout)
log_likelihood = _get_log_likelihood(
model,
raw_mcmc_samples,
backend=postprocessing_backend,
num_chunks=postprocessing_chunks,
)
tic6 = datetime.now()
print("Log Likelihood time = ", tic6 - tic5, file=sys.stdout)
else:
log_likelihood = None
attrs = {
"sampling_time": (tic3 - tic2).total_seconds(),
}
posterior = mcmc_samples
# Update 'coords' and 'dims' extracted from the model with user 'idata_kwargs'
# and drop keys 'coords' and 'dims' from 'idata_kwargs' if present.
_update_coords_and_dims(coords=coords, dims=dims, idata_kwargs=idata_kwargs)
# Use 'partial' to set default arguments before passing 'idata_kwargs'
to_trace = partial(
az.from_dict,
log_likelihood=log_likelihood,
observed_data=find_observations(model),
constant_data=find_constants(model),
sample_stats=_sample_stats_to_xarray(pmap_numpyro),
coords=coords,
dims=dims,
attrs=make_attrs(attrs, library=numpyro),
)
az_trace = to_trace(posterior=posterior, **idata_kwargs)
return az_trace