Skip to content

Commit

Permalink
fixing nested_sampling.py (#1738)
Browse files Browse the repository at this point in the history
  • Loading branch information
renecotyfanboy committed Feb 21, 2024
1 parent aec6bd5 commit a92bd0d
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 11 deletions.
2 changes: 1 addition & 1 deletion docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ funsor
ipython
jax
jaxlib
jaxns==2.2.6
jaxns==2.4.8
Jinja2
matplotlib
multipledispatch
Expand Down
20 changes: 11 additions & 9 deletions numpyro/contrib/nested_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,17 @@

try:
from jaxns import (
ExactNestedSampler as OrigNestedSampler,
DefaultNestedSampler,
Model,
NestedSamplerResults,
Prior,
TerminationCondition,
plot_cornerplot,
plot_diagnostics,
resample,
summary,
)
from jaxns.utils import NestedSamplerResults

except ImportError as e:
raise ImportError(
"To use this module, please install `jaxns` package. It can be"
Expand Down Expand Up @@ -257,10 +258,10 @@ def prior_model():

default_constructor_kwargs = dict(
num_live_points=model.U_ndims * 25,
num_parallel_samplers=1,
num_parallel_workers=1,
max_samples=1e4,
)
default_termination_kwargs = dict(live_evidence_frac=1e-4)
default_termination_kwargs = dict(dlogZ=1e-4)
# Fill-in missing values with defaults. This allows user to inspect what was actually used by inspecting
# these dictionaries
list(
Expand All @@ -276,16 +277,17 @@ def prior_model():
)
)

exact_ns = OrigNestedSampler(
default_ns = DefaultNestedSampler(
model=model,
**self.constructor_kwargs,
)

termination_reason, state = exact_ns(
rng_sampling,
term_cond=TerminationCondition(**self.termination_kwargs),
termination_reason, state = default_ns(
rng_sampling, term_cond=TerminationCondition(**self.termination_kwargs)
)
results = default_ns.to_results(
termination_reason=termination_reason, state=state
)
results = exact_ns.to_results(state, termination_reason)

# transform base samples back to original domains
# Here we only transform the first valid num_samples samples
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
"flax",
"funsor>=0.4.1",
"graphviz",
"jaxns==2.2.6",
"jaxns==2.4.8",
"matplotlib",
"optax>=0.0.6",
"pylab-sdk", # jaxns dependency
Expand Down

0 comments on commit a92bd0d

Please sign in to comment.