Skip to content

Commit

Permalink
refactor sbi tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
janfb committed Nov 8, 2022
1 parent d5981df commit 327844a
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 173 deletions.
28 changes: 0 additions & 28 deletions tests/algorithms/sbi/test_abc_run_methods.py

This file was deleted.

45 changes: 45 additions & 0 deletions tests/algorithms/sbi/test_sbi_run_methods.py
@@ -0,0 +1,45 @@
import pytest

import sbibm
from sbibm.algorithms.sbi import mcabc, smcabc, snle, snpe, snre


# a fast test
@pytest.mark.parametrize(
"run_method", (mcabc, smcabc, snle, snpe, snre)
)
@pytest.mark.parametrize("task_name", ("gaussian_linear",))
@pytest.mark.parametrize("num_observation", (1, 3))
def test_sbi_api(
run_method, task_name, num_observation, num_simulations=2_000, num_samples=100
):
task = sbibm.get_task(task_name)

if run_method in (mcabc, smcabc): # abc algorithms
kwargs = dict()
else: # neural algorithms
kwargs = dict(
num_rounds=1,
max_num_epochs=2,
neural_net="mlp" if run_method == snre else "mdn",
)
if run_method in (snle, snre):
kwargs["mcmc_parameters"] = dict(
num_chains=100, warmup_steps=10, thin=1, init_strategy="proposal"
)

predicted, _, _ = run_method(
task=task,
num_observation=num_observation,
num_simulations=num_simulations,
num_samples=num_samples,
**kwargs,
)

reference_samples = task.get_reference_posterior_samples(
num_observation=num_observation
)

expected = reference_samples[:num_samples, :]

assert expected.shape == predicted.shape
50 changes: 0 additions & 50 deletions tests/algorithms/sbi/test_snle_posterior.py

This file was deleted.

48 changes: 0 additions & 48 deletions tests/algorithms/sbi/test_snpe_posterior.py

This file was deleted.

47 changes: 0 additions & 47 deletions tests/algorithms/sbi/test_snre_posterior.py

This file was deleted.

0 comments on commit 327844a

Please sign in to comment.