Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inconsistent Behavior in torch.distributions log_prob for float input in uniform distribution. #22970

Closed
heidekrueger opened this issue Jul 17, 2019 · 7 comments
Labels
good first issue module: distributions Related to torch.distributions small We think this is a small issue to fix. Consider knocking off high priority small issues triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@heidekrueger
Copy link

馃悰 Bug

(Most?) torch.distributions' cdf and log_prob work for inputs that are either torch.Tensors or native python floats. However, torch.distributions.uniform.Uniform.log_prob() fails for python-float inputs.

To Reproduce

Steps to reproduce the behavior:

f = 1. # native python input
t = torch.tensor(f) #torch input

prior = torch.distributions.Exponential(0.5)
prior.cdf(t), prior.cdf(f), prior.log_prob(t), prior.log_prob(f)
>>> (tensor(0.3935), tensor(0.3935), tensor(-1.1931), tensor(-1.1931))

prior = torch.distributions.Normal(5,2)
prior.cdf(t), prior.cdf(f), prior.log_prob(t), prior.log_prob(f)
>>> (tensor(0.0228), tensor(0.0228), tensor(-3.6121), tensor(-3.6121))

prior = torch.distributions.Uniform(0,2)
prior.cdf(t)
prior.cdf(f)
prior.log_prob(t)
prior.log_prob(f)

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-56-e6a2d5d8fcb4> in <module>
      3 prior.cdf(f)
      4 prior.log_prob(t)
----> 5 prior.log_prob(f)

/opt/anaconda/anaconda3/envs/bnelearn/lib/python3.7/site-packages/torch/distributions/uniform.py in log_prob(self, value)
     71         if self._validate_args:
     72             self._validate_sample(value)
---> 73         lb = value.ge(self.low).type_as(self.low)
     74         ub = value.lt(self.high).type_as(self.low)
     75         return torch.log(lb.mul(ub)) - torch.log(self.high - self.low)

AttributeError: 'float' object has no attribute 'ge'

Expected behavior

Either one of the following:

  • cdf, icdf and log_prob fail consistently for native python inputs across all distributions
  • torch.distributions.Uniform.log_prob returns correct log_probability when given native python inputs.

Environment

Collecting environment information...
PyTorch version: 1.1.0
Is debug build: No
CUDA used to build PyTorch: 10.0.130

OS: Ubuntu 18.04.2 LTS
GCC version: (Ubuntu 7.4.0-1ubuntu1~18.04.1) 7.4.0
CMake version: version 3.10.2

Python version: 3.7
Is CUDA available: Yes
CUDA runtime version: Could not collect
GPU models and configuration:
GPU 0: GeForce RTX 2080 Ti
GPU 1: GeForce RTX 2080 Ti
GPU 2: GeForce RTX 2080 Ti
GPU 3: GeForce RTX 2080 Ti
GPU 4: GeForce RTX 2080 Ti
GPU 5: GeForce RTX 2080 Ti
GPU 6: GeForce RTX 2080 Ti
GPU 7: GeForce RTX 2080 Ti

Nvidia driver version: 430.26
cuDNN version: Could not collect

Versions of relevant libraries:
[pip] numpy==1.16.4
[pip] torch==1.1.0
[pip] torchvision==0.3.0
[conda] blas 1.0 mkl
[conda] mkl 2019.4 243
[conda] mkl_fft 1.0.12 py37ha843d7b_0
[conda] mkl_random 1.0.2 py37hd81dba3_0
[conda] pytorch 1.1.0 py3.7_cuda10.0.130_cudnn7.5.1_0 pytorch
[conda] torchvision 0.3.0 py37_cu10.0.130_1 pytorch

@vishwakftw vishwakftw added the module: distributions Related to torch.distributions label Jul 17, 2019
@fmassa fmassa added good first issue small We think this is a small issue to fix. Consider knocking off high priority small issues triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Jul 17, 2019
@fmassa
Copy link
Member

fmassa commented Jul 17, 2019

This should be an easy fix. Those lines should be made such that they work with Python floats as well, so replacing ge etc with >= would be a first step, or wrapping value in a tensor if it's a python float.

lb = value.ge(self.low).type_as(self.low)
ub = value.lt(self.high).type_as(self.low)
return torch.log(lb.mul(ub)) - torch.log(self.high - self.low)

@rrkarim
Copy link

rrkarim commented Jul 17, 2019

@fmassa, there is some ambiguity with self._validate_args. What is the point of having it included in instance methods? (head of log_prob as an example) Maybe it would be better to remove it and check for right arguments explicitly?

@vishwakftw
Copy link
Contributor

vishwakftw commented Jul 18, 2019

@rrkarim self._validate_args causes the samples passed to the instance methods to be checked. When you are sure that the samples are of the right size and within the support of the distribution, you might as well turn off this checking, which could consequently give you some performance gains.

You can take a look at the pull request in which this was added: #5358.

@vbsinha
Copy link
Contributor

vbsinha commented Jul 18, 2019

@fmassa I would like to take up this issue. I'm working on it.

@rrkarim
Copy link

rrkarim commented Jul 18, 2019

are of the right size and within the support of the distribution

@vishwakftw then we can just explicitly write all the checks in the instance methods. Ok, I see, performance gain can be the case. Still, pretty weird design. Sizes can be checked using generic methods, about the cases for each distribution - need a better design. (maybe quick fix is enough)

@vishwakftw
Copy link
Contributor

vishwakftw commented Jul 18, 2019

I can't find any case where a user needs to pass wrong arguments.

I think it wouldn't be appropriate to assume this.

then we can just explicitly write all the checks in the instance methods

Well, it would be too verbose and against the DRY principle.

This is the actual discussion thread regarding this addition: #5248.

cc: @fritzo @neerajprad

@rrkarim
Copy link

rrkarim commented Jul 18, 2019

DRY is good with checks at instance methods. Performance is another case. Everything is pretty much discussed in the thread (including all the tf.distribution stuff), I have nothing to add to that. (still pretty odd for my taste)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
good first issue module: distributions Related to torch.distributions small We think this is a small issue to fix. Consider knocking off high priority small issues triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants