Skip to content

Commit

Permalink
Refactor to depend on sbi 0.20.0 (#50)
Browse files Browse the repository at this point in the history
  • Loading branch information
janfb committed Feb 18, 2023
1 parent 5cfd0b3 commit 03db23e
Show file tree
Hide file tree
Showing 12 changed files with 99 additions and 229 deletions.
1 change: 1 addition & 0 deletions sbibm/algorithms/sbi/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from sbibm.algorithms.sbi.mcabc import run as mcabc
from sbibm.algorithms.sbi.sl import run as sl
from sbibm.algorithms.sbi.smcabc import run as smcabc
from sbibm.algorithms.sbi.snle import run as snle
from sbibm.algorithms.sbi.snpe import run as snpe
Expand Down
28 changes: 18 additions & 10 deletions sbibm/algorithms/sbi/sl.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import logging
import math
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Tuple

import torch
import torch.nn as nn
from sbi import inference as inference
from sbi.inference.posteriors.likelihood_based_posterior import LikelihoodBasedPosterior
from sbi.inference.posteriors.mcmc_posterior import MCMCPosterior
from sbi.inference.potentials import likelihood_estimator_based_potential
from torch import Tensor

from sbibm.algorithms.sbi.utils import (
Expand Down Expand Up @@ -95,7 +95,7 @@ def run(
mcmc_method: str = "slice_np",
mcmc_parameters: Dict[str, Any] = {},
diag_eps: float = 0.0,
) -> (torch.Tensor, int, Optional[torch.Tensor]):
) -> Tuple[torch.Tensor, int, Optional[torch.Tensor]]:
"""Runs (S)NLE from `sbi`
Args:
Expand Down Expand Up @@ -127,21 +127,29 @@ def run(
simulator = task.get_simulator()

transforms = task._get_transforms(automatic_transforms_enabled)["parameters"]
prior = wrap_prior_dist(prior, transforms)
simulator = wrap_simulator_fn(simulator, transforms)
if automatic_transforms_enabled:
prior = wrap_prior_dist(prior, transforms)
simulator = wrap_simulator_fn(simulator, transforms)

likelihood_estimator = SynthLikNet(
simulator=simulator,
num_simulations_per_step=num_simulations_per_step,
diag_eps=diag_eps,
)

posterior = LikelihoodBasedPosterior(
method_family="snle",
neural_net=likelihood_estimator,
potential_fn, theta_transform = likelihood_estimator_based_potential(
likelihood_estimator=likelihood_estimator,
prior=prior,
x_o=None,
enable_transform=not automatic_transforms_enabled,
)
posterior = MCMCPosterior(
potential_fn=potential_fn,
proposal=prior,
theta_transform=theta_transform,
method=mcmc_method,
x_shape=observation.shape,
mcmc_parameters=mcmc_parameters,
**mcmc_parameters,
)

posterior.set_default_x(observation)
Expand Down
20 changes: 12 additions & 8 deletions sbibm/algorithms/sbi/snle.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,16 @@ def run(
mcmc_parameters: Dict[str, Any] = {
"num_chains": 100,
"thin": 10,
"warmup_steps": 100,
"warmup_steps": 25,
"init_strategy": "sir",
"sir_batch_size": 1000,
"sir_num_batches": 100,
# NOTE: sir kwargs changed: num_candidate_samples = num_batches * batch_size
"init_strategy_parameters": {
"num_candidate_samples": 10000,
},
},
z_score_x: bool = True,
z_score_theta: bool = True,
max_num_epochs: Optional[int] = None,
max_num_epochs: Optional[int] = 2**31 - 1,
) -> Tuple[torch.Tensor, int, Optional[torch.Tensor]]:
"""Runs (S)NLE from `sbi`
Expand Down Expand Up @@ -106,8 +108,6 @@ def run(

posteriors = []
proposal = prior
mcmc_parameters["warmup_steps"] = 25
mcmc_parameters["enable_transform"] = False # NOTE: Disable `sbi` auto-transforms, since `sbibm` does its own

for r in range(num_rounds):
theta, x = inference.simulate_for_sbi(
Expand All @@ -121,15 +121,19 @@ def run(
theta, x, from_round=r
).train(
training_batch_size=training_batch_size,
retrain_from_scratch_each_round=False,
retrain_from_scratch=False,
discard_prior_samples=False,
show_train_summary=True,
max_num_epochs=max_num_epochs,
)
if r > 1:
mcmc_parameters["init_strategy"] = "latest_sample"

posterior = inference_method.build_posterior(
density_estimator, mcmc_method=mcmc_method, mcmc_parameters=mcmc_parameters
density_estimator=density_estimator,
sample_with="mcmc",
mcmc_method=mcmc_method,
mcmc_parameters=mcmc_parameters,
)
# Copy hyperparameters, e.g., mcmc_init_samples for "latest_sample" strategy.
if r > 0:
Expand Down
10 changes: 4 additions & 6 deletions sbibm/algorithms/sbi/snpe.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ def run(
automatic_transforms_enabled: bool = False,
z_score_x: bool = True,
z_score_theta: bool = True,
max_num_epochs: Optional[int] = None,
) -> Tuple[torch.Tensor, int, Optional[torch.Tensor]]:
max_num_epochs: Optional[int] = 2**31 - 1,
) -> Tuple[torch.Tensor, int, Optional[torch.Tensor]]:
"""Runs (S)NPE from `sbi`
Args:
Expand Down Expand Up @@ -109,15 +109,13 @@ def run(
).train(
num_atoms=num_atoms,
training_batch_size=training_batch_size,
retrain_from_scratch_each_round=False,
retrain_from_scratch=False,
discard_prior_samples=False,
use_combined_loss=False,
show_train_summary=True,
max_num_epochs=max_num_epochs,
)
posterior = inference_method.build_posterior(
density_estimator, sample_with_mcmc=False
)
posterior = inference_method.build_posterior(density_estimator)
proposal = posterior.set_default_x(observation)
posteriors.append(posterior)

Expand Down
18 changes: 10 additions & 8 deletions sbibm/algorithms/sbi/snre.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,17 @@ def run(
mcmc_parameters: Dict[str, Any] = {
"num_chains": 100,
"thin": 10,
"warmup_steps": 100,
"warmup_steps": 25,
"init_strategy": "sir",
"sir_batch_size": 1000,
"sir_num_batches": 100,
# NOTE: sir kwargs changed: num_candidate_samples = num_batches * batch_size
"init_strategy_parameters": {
"num_candidate_samples": 10000,
},
},
z_score_x: bool = True,
z_score_theta: bool = True,
variant: str = "B",
max_num_epochs: Optional[int] = None,
max_num_epochs: Optional[int] = 2**31 - 1,
) -> Tuple[torch.Tensor, int, Optional[torch.Tensor]]:
"""Runs (S)NRE from `sbi`
Expand Down Expand Up @@ -116,8 +118,6 @@ def run(

posteriors = []
proposal = prior
mcmc_parameters["warmup_steps"] = 25
mcmc_parameters["enable_transform"] = False # NOTE: Disable `sbi` auto-transforms, since `sbibm` does its own

for r in range(num_rounds):
theta, x = inference.simulate_for_sbi(
Expand All @@ -131,7 +131,7 @@ def run(
theta, x, from_round=r
).train(
training_batch_size=training_batch_size,
retrain_from_scratch_each_round=False,
retrain_from_scratch=False,
discard_prior_samples=False,
show_train_summary=True,
max_num_epochs=max_num_epochs,
Expand All @@ -140,7 +140,9 @@ def run(
if r > 1:
mcmc_parameters["init_strategy"] = "latest_sample"
posterior = inference_method.build_posterior(
density_estimator, mcmc_method=mcmc_method, mcmc_parameters=mcmc_parameters
density_estimator,
mcmc_method=mcmc_method,
mcmc_parameters=mcmc_parameters,
)
# Copy hyperparameters, e.g., mcmc_init_samples for "latest_sample" strategy.
if r > 0:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
"pandas>=1.0.0",
"pyabc>=0.10.8",
"pyabcranger>=0.0.48",
"sbi==0.17.2",
"sbi>=0.20.0,<0.22.0",
"pyro-ppl",
"scikit-learn",
"torch>=1.8.0",
Expand Down
28 changes: 0 additions & 28 deletions tests/algorithms/sbi/test_abc_run_methods.py

This file was deleted.

53 changes: 53 additions & 0 deletions tests/algorithms/sbi/test_sbi_run_methods.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import pytest

import sbibm
from sbibm.algorithms.sbi import mcabc, smcabc, snle, snpe, snre, sl


# a fast test
@pytest.mark.parametrize(
"run_method",
(
mcabc,
smcabc,
snle,
snpe,
snre,
sl,
),
)
@pytest.mark.parametrize("task_name", ("gaussian_linear",))
@pytest.mark.parametrize("num_observation", (1, 3))
def test_sbi_api(
run_method, task_name, num_observation, num_simulations=2_000, num_samples=100
):
task = sbibm.get_task(task_name)

if run_method in (mcabc, smcabc, sl): # abc algorithms
kwargs = dict()
else: # neural algorithms
kwargs = dict(
num_rounds=1,
max_num_epochs=2,
neural_net="mlp" if run_method == snre else "mdn",
)
if run_method in (snle, snre):
kwargs["mcmc_parameters"] = dict(
num_chains=100, warmup_steps=10, thin=1, init_strategy="proposal"
)

predicted, _, _ = run_method(
task=task,
num_observation=num_observation,
num_simulations=num_simulations,
num_samples=num_samples,
**kwargs,
)

reference_samples = task.get_reference_posterior_samples(
num_observation=num_observation
)

expected = reference_samples[:num_samples, :]

assert expected.shape == predicted.shape
26 changes: 0 additions & 26 deletions tests/algorithms/sbi/test_sl_run_methods.py

This file was deleted.

47 changes: 0 additions & 47 deletions tests/algorithms/sbi/test_snle_posterior.py

This file was deleted.

0 comments on commit 03db23e

Please sign in to comment.