Skip to content

Commit

Permalink
Fixes current density estimator bug (#1155)
Browse files Browse the repository at this point in the history
* fixes current bug!

* Added tests
  • Loading branch information
manuelgloeckler committed May 7, 2024
1 parent 9a8c7c0 commit fe55b1c
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 6 deletions.
8 changes: 2 additions & 6 deletions sbi/neural_nets/density_estimators/nflows_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,12 +135,8 @@ def sample(self, sample_shape: Shape, condition: Tensor) -> Tensor:
num_samples = torch.Size(sample_shape).numel()

samples = self.net.sample(num_samples, context=condition)

return samples.reshape((
*sample_shape,
condition_batch_dim,
-1,
))
samples = samples.transpose(0, 1)
return samples.reshape((*sample_shape, condition_batch_dim, *self.input_shape))

def sample_and_log_prob(
self, sample_shape: torch.Size, condition: Tensor, **kwargs
Expand Down
69 changes: 69 additions & 0 deletions tests/density_estimator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,75 @@ def test_correctness_of_density_estimator_log_prob(
assert torch.allclose(log_probs[0, :], log_probs[1, :])


@pytest.mark.parametrize(
"density_estimator_build_fn",
(
build_mdn,
build_maf,
build_maf_rqs,
build_nsf,
build_zuko_bpf,
build_zuko_gf,
build_zuko_maf,
build_zuko_naf,
build_zuko_ncsf,
build_zuko_nice,
build_zuko_nsf,
build_zuko_sospf,
build_zuko_unaf,
build_categoricalmassestimator,
build_mnle,
),
)
@pytest.mark.parametrize("input_event_shape", ((1,), (4,)))
@pytest.mark.parametrize("condition_event_shape", ((1,), (7,)))
def test_correctness_of_batched_vs_seperate_sample_and_log_prob(
density_estimator_build_fn, input_event_shape, condition_event_shape
):
input_sample_dim = 2
batch_dim = 2
density_estimator, inputs, condition = _build_density_estimator_and_tensors(
density_estimator_build_fn,
input_event_shape,
condition_event_shape,
batch_dim,
input_sample_dim,
)
# Batched vs separate sampling
samples = density_estimator.sample((1000,), condition=condition)
samples_separate1 = density_estimator.sample(
(1000,), condition=condition[0][None, ...]
)
samples_separate2 = density_estimator.sample(
(1000,), condition=condition[1][None, ...]
)

# Check if means are approx. same
samples_m = torch.mean(samples, dim=0, dtype=torch.float32)
samples_separate1_m = torch.mean(samples_separate1, dim=0, dtype=torch.float32)
samples_separate2_m = torch.mean(samples_separate2, dim=0, dtype=torch.float32)
samples_sep_m = torch.cat([samples_separate1_m, samples_separate2_m], dim=0)

assert torch.allclose(
samples_m, samples_sep_m, atol=0.5, rtol=0.5
), "Batched sampling is not consistent with separate sampling."

# Batched vs separate log_prob
log_probs = density_estimator.log_prob(inputs, condition=condition)

log_probs_separate1 = density_estimator.log_prob(
inputs[:, :1], condition=condition[0][None, ...]
)
log_probs_separate2 = density_estimator.log_prob(
inputs[:, 1:], condition=condition[1][None, ...]
)
log_probs_sep = torch.hstack([log_probs_separate1, log_probs_separate2])

assert torch.allclose(
log_probs, log_probs_sep, atol=1e-2, rtol=1e-2
), "Batched log_prob is not consistent with separate log_prob."


def _build_density_estimator_and_tensors(
density_estimator_build_fn: str,
input_event_shape: Tuple[int],
Expand Down

0 comments on commit fe55b1c

Please sign in to comment.