Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Alignment with SBI ABC API #20

Closed
psteinb opened this issue Nov 2, 2021 · 2 comments
Closed

Alignment with SBI ABC API #20

psteinb opened this issue Nov 2, 2021 · 2 comments
Assignees

Comments

@psteinb
Copy link
Contributor

psteinb commented Nov 2, 2021

when running the sbibm demo code based on commit 074e06a, I get

import sbibm

task = sbibm.get_task("two_moons")  # See sbibm.get_available_tasks() for all tasks
prior = task.get_prior()
simulator = task.get_simulator()
observation = task.get_observation(num_observation=1)  # 10 per task

# These objects can then be used for custom inference algorithms, e.g.
# we might want to generate simulations by sampling from prior:
thetas = prior(num_samples=10_000)
xs = simulator(thetas)

# Alternatively, we can import existing algorithms, e.g:
from sbibm.algorithms import rej_abc  # See help(rej_abc) for keywords
posterior_samples, _, _ = rej_abc(task=task, num_samples=10_000, num_observation=1, num_simulations=100_000)

I get

task = <sbibm.tasks.two_moons.task.TwoMoons object at 0x7ff456f40f10>, num_samples = 50, num_simulations = 500, num_observation = 1
observation = tensor([[-0.6397,  0.1623]]), num_top_samples = 100, quantile = 0.2, eps = None, distance = 'l2', batch_size = 1000, save_distances = False
kde_bandwidth = 'cv', sass = False, sass_fraction = 0.5, sass_feature_expansion_degree = 3, lra = False

    def run(
        task: Task,
        num_samples: int,
        num_simulations: int,
        num_observation: Optional[int] = None,
        observation: Optional[torch.Tensor] = None,
        num_top_samples: Optional[int] = 100,
        quantile: Optional[float] = None,
        eps: Optional[float] = None,
        distance: str = "l2",
        batch_size: int = 1000,
        save_distances: bool = False,
        kde_bandwidth: Optional[str] = "cv",
        sass: bool = False,
        sass_fraction: float = 0.5,
        sass_feature_expansion_degree: int = 3,
        lra: bool = False,
    ) -> Tuple[torch.Tensor, int, Optional[torch.Tensor]]:
        """Runs REJ-ABC from `sbi`
    
        Choose one of `num_top_samples`, `quantile`, `eps`.
    
        Args:
            task: Task instance
            num_samples: Number of samples to generate from posterior
            num_simulations: Simulation budget
            num_observation: Observation number to load, alternative to `observation`
            observation: Observation, alternative to `num_observation`
            num_top_samples: If given, will use `top=True` with num_top_samples
            quantile: Quantile to use
            eps: Epsilon threshold to use
            distance: Distance to use
            batch_size: Batch size for simulator
            save_distances: If True, stores distances of samples to disk
            kde_bandwidth: If not None, will resample using KDE when necessary, set
                e.g. to "cv" for cross-validated bandwidth selection
            sass: If True, summary statistics are learned as in
                Fearnhead & Prangle 2012.
            sass_fraction: Fraction of simulation budget to use for sass.
            sass_feature_expansion_degree: Degree of polynomial expansion of the summary
                statistics.
            lra: If True, posterior samples are adjusted with
                linear regression as in Beaumont et al. 2002.
        Returns:
            Samples from posterior, number of simulator calls, log probability of true params if computable
        """
        assert not (num_observation is None and observation is None)
        assert not (num_observation is not None and observation is not None)
    
        assert not (num_top_samples is None and quantile is None and eps is None)
    
        log = sbibm.get_logger(__name__)
        log.info(f"Running REJ-ABC")
    
        prior = task.get_prior_dist()
        simulator = task.get_simulator(max_calls=num_simulations)
        if observation is None:
            observation = task.get_observation(num_observation)
    
        if num_top_samples is not None and quantile is None:
            if sass:
                quantile = num_top_samples / (
                    num_simulations - int(sass_fraction * num_simulations)
                )
            else:
                quantile = num_top_samples / num_simulations
    
        inference_method = MCABC(
            simulator=simulator,
            prior=prior,
            simulation_batch_size=batch_size,
            distance=distance,
            show_progress_bars=True,
        )
>       posterior, distances = inference_method(
            x_o=observation,
            num_simulations=num_simulations,
            eps=eps,
            quantile=quantile,
            return_distances=True,
            lra=lra,
            sass=sass,
            sass_expansion_degree=sass_feature_expansion_degree,
            sass_fraction=sass_fraction,
        )
E       TypeError: __call__() got an unexpected keyword argument 'return_distances'
@jan-matthis
Copy link
Contributor

Thanks for reporting this!

@janfb: It seems this the APIs diverged with this commit in sbi: sbi-dev/sbi@9ed18ca?

@jan-matthis jan-matthis changed the title alignment with SBI API Alignment with SBI ABC API Nov 3, 2021
@psteinb
Copy link
Contributor Author

psteinb commented Nov 3, 2021

I am hence wondering, if setup.py should honor a specific sbi version then.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Development

Successfully merging a pull request may close this issue.

3 participants