You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
when running the sbibm demo code based on commit 074e06a, I get
import sbibm
task = sbibm.get_task("two_moons") # See sbibm.get_available_tasks() for all tasks
prior = task.get_prior()
simulator = task.get_simulator()
observation = task.get_observation(num_observation=1) # 10 per task
# These objects can then be used for custom inference algorithms, e.g.
# we might want to generate simulations by sampling from prior:
thetas = prior(num_samples=10_000)
xs = simulator(thetas)
# Alternatively, we can import existing algorithms, e.g:
from sbibm.algorithms import rej_abc # See help(rej_abc) for keywords
posterior_samples, _, _ = rej_abc(task=task, num_samples=10_000, num_observation=1, num_simulations=100_000)
I get
task = <sbibm.tasks.two_moons.task.TwoMoons object at 0x7ff456f40f10>, num_samples = 50, num_simulations = 500, num_observation = 1
observation = tensor([[-0.6397, 0.1623]]), num_top_samples = 100, quantile = 0.2, eps = None, distance = 'l2', batch_size = 1000, save_distances = False
kde_bandwidth = 'cv', sass = False, sass_fraction = 0.5, sass_feature_expansion_degree = 3, lra = False
def run(
task: Task,
num_samples: int,
num_simulations: int,
num_observation: Optional[int] = None,
observation: Optional[torch.Tensor] = None,
num_top_samples: Optional[int] = 100,
quantile: Optional[float] = None,
eps: Optional[float] = None,
distance: str = "l2",
batch_size: int = 1000,
save_distances: bool = False,
kde_bandwidth: Optional[str] = "cv",
sass: bool = False,
sass_fraction: float = 0.5,
sass_feature_expansion_degree: int = 3,
lra: bool = False,
) -> Tuple[torch.Tensor, int, Optional[torch.Tensor]]:
"""Runs REJ-ABC from `sbi`
Choose one of `num_top_samples`, `quantile`, `eps`.
Args:
task: Task instance
num_samples: Number of samples to generate from posterior
num_simulations: Simulation budget
num_observation: Observation number to load, alternative to `observation`
observation: Observation, alternative to `num_observation`
num_top_samples: If given, will use `top=True` with num_top_samples
quantile: Quantile to use
eps: Epsilon threshold to use
distance: Distance to use
batch_size: Batch size for simulator
save_distances: If True, stores distances of samples to disk
kde_bandwidth: If not None, will resample using KDE when necessary, set
e.g. to "cv" for cross-validated bandwidth selection
sass: If True, summary statistics are learned as in
Fearnhead & Prangle 2012.
sass_fraction: Fraction of simulation budget to use for sass.
sass_feature_expansion_degree: Degree of polynomial expansion of the summary
statistics.
lra: If True, posterior samples are adjusted with
linear regression as in Beaumont et al. 2002.
Returns:
Samples from posterior, number of simulator calls, log probability of true params if computable
"""
assert not (num_observation is None and observation is None)
assert not (num_observation is not None and observation is not None)
assert not (num_top_samples is None and quantile is None and eps is None)
log = sbibm.get_logger(__name__)
log.info(f"Running REJ-ABC")
prior = task.get_prior_dist()
simulator = task.get_simulator(max_calls=num_simulations)
if observation is None:
observation = task.get_observation(num_observation)
if num_top_samples is not None and quantile is None:
if sass:
quantile = num_top_samples / (
num_simulations - int(sass_fraction * num_simulations)
)
else:
quantile = num_top_samples / num_simulations
inference_method = MCABC(
simulator=simulator,
prior=prior,
simulation_batch_size=batch_size,
distance=distance,
show_progress_bars=True,
)
> posterior, distances = inference_method(
x_o=observation,
num_simulations=num_simulations,
eps=eps,
quantile=quantile,
return_distances=True,
lra=lra,
sass=sass,
sass_expansion_degree=sass_feature_expansion_degree,
sass_fraction=sass_fraction,
)
E TypeError: __call__() got an unexpected keyword argument 'return_distances'
The text was updated successfully, but these errors were encountered:
when running the
sbibm
demo code based on commit 074e06a, I getI get
The text was updated successfully, but these errors were encountered: