Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stray singleton dimension in mcabc.py? #12

Closed
atiyo opened this issue May 28, 2021 · 7 comments · Fixed by #13
Closed

Stray singleton dimension in mcabc.py? #12

atiyo opened this issue May 28, 2021 · 7 comments · Fixed by #13
Assignees

Comments

@atiyo
Copy link
Contributor

atiyo commented May 28, 2021

Thanks for building out and maintaining this package! There was definitely a need for something like this in the ABC/Likelihood Free community.

I'm hitting a seemingly stray dimension in mcabc.py:

from sbibm.algorithms import rej_abc 
task = sbibm.get_task("two_moons")
posterior_samples, _, _ = rej_abc(task=task, num_samples=10_000, num_observation=1, num_simulations=100_000)

which is returning a stacktrace like:

ValueError                                Traceback (most recent call last)
<ipython-input-128-10fe8b131cec> in <module>
      1 from sbibm.algorithms import rej_abc
      2 task = sbibm.get_task("two_moons")
----> 3 posterior_samples, _, _ = rej_abc(task=task, num_samples=10_000, num_observation=1, num_simul
ations=100_000)

~/.pyenv/versions/miniforge3-4.9.2/lib/python3.8/site-packages/sbibm/algorithms/sbi/mcabc.py in run(t
ask, num_samples, num_simulations, num_observation, observation, num_top_samples, quantile, eps, dist
ance, batch_size, save_distances, kde_bandwidth, sass, sass_fraction, sass_feature_expansion_degree,
lra)
    118     if num_observation is not None:
    119         true_parameters = task.get_true_parameters(num_observation=num_observation)
--> 120         log_prob_true_parameters = posterior.log_prob(true_parameters)
    121         return samples, simulator.num_simulations, log_prob_true_parameters
    122     else:

~/.pyenv/versions/miniforge3-4.9.2/lib/python3.8/site-packages/pyro/distributions/empirical.py in log
_prob(self, value)
     94         if self._validate_args:
     95             if value.shape != self.batch_shape + self.event_shape:
---> 96                 raise ValueError("``value.shape`` must be {}".format(self.batch_shape + self.
event_shape))
     97         if self.batch_shape:
     98             value = value.unsqueeze(self._aggregation_dim)

ValueError: ``value.shape`` must be torch.Size([2])

A bit of digging shows that the shape of true_parameters in this is coming out at [1,2]. Changing this line to log_prob_true_parameters = posterior.log_prob(true_parameters.squeeze()) does indeed make this run.

However, I'm not sure if the correct fix involves squeezing the tensor further upstream?

Thanks for any help!

@jan-matthis
Copy link
Contributor

Thanks for reporting this! Which version of PyTorch and Pyro are you on?

I suspect that this might have to do something with changes upstream, we will have a look

@atiyo
Copy link
Contributor Author

atiyo commented May 28, 2021

I'm on 1.8.0 for PyTorch and 1.6.0 for Pyro.

Thank you for looking into this!

@janfb
Copy link
Contributor

janfb commented May 31, 2021

Thanks again for reporting this @atiyo . It is indeed a bug, true_paramerters will always come as a batch and we need to squeeze them before the pass to log_prob.
We will fix it here and in smcabc. Unless you want to contribute with a PR?

@jan-matthis
Copy link
Contributor

jan-matthis commented May 31, 2021

It seems that the error is due to some change in PyTorch. For torch==1.7.1 the example runs fine, for torch==1.8.0 it fails. Unfortunately the list of changes associated with the 1.8 release is long https://github.com/pytorch/pytorch/releases/tag/v1.8.0 and I am not yet sure what the exact upstream change causing this is. Would be great to figure it out/pinpoint it exactly, I am not sure whether this might cause problems elsewhere. In the meantime, fixing the error with the proposed solution would be great.

Update: @janfb and me drilled down further, turns out this is due to a change that enables validation of distributions (including the empirical one used here) which will raise an error before automatic squeezing. We could disable it globally by setting pyro.enable_validation(False) which will make the code work with torch==1.8.0. However, this might silent useful error messages, so adding the squeeze in a PR seems the best way forward.

@atiyo
Copy link
Contributor Author

atiyo commented Jun 2, 2021

I'd be happy to contribute an attempted fix.

Is it only mcabc.py and smcabc.py that need changing? E.g. is this line in SNPE ok?

Or perhaps this discussion is better suited for a PR anyway.

@jan-matthis
Copy link
Contributor

Great, thanks a lot!

Is it only mcabc.py and smcabc.py that need changing? E.g. is this line in SNPE ok?

Yes, SNPE should be ok as it is, since its posterior distribution is part of sbi. The empirical distribution that (s)mcabc.py uses, on the other hand, is part of pyro and affected since it inherits from torch.distributions

@jan-matthis
Copy link
Contributor

I've merged the PR and released a new version including it as v1.0.6. Thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Development

Successfully merging a pull request may close this issue.

3 participants