Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix typing errors in the torch.distributions module #45689

Closed
wants to merge 1 commit into from

Conversation

xuzhao9
Copy link
Contributor

@xuzhao9 xuzhao9 commented Oct 1, 2020

Fixes #42979.

@xuzhao9 xuzhao9 requested a review from malfet October 1, 2020 19:23
@xuzhao9 xuzhao9 self-assigned this Oct 1, 2020
@xuzhao9 xuzhao9 added module: typing Related to mypy type annotations module: distributions Related to torch.distributions labels Oct 1, 2020
mypy.ini Outdated
Comment on lines 93 to 96
[mypy-torch.jit.quantized]
ignore_errors = True

[mypy-torch.nn.functional]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are these actually intended to be added? or is this a merge master issue?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They are added because after annotating the distributions.* dir, a lot of typing errors appear under these two modules. Maybe it's because master update. I am looking into this.

@@ -71,7 +72,7 @@ def event_shape(self):
"""
return self._event_shape

@property
@property # type: ignore[no-redef]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we have a mypy git issue link? (is it: python/mypy#6185?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we have a mypy git issue link? (is it: python/mypy#6185?)

No it is not - this is simply because the name "arg_constraints" is reused as a member variable and a method name at the same time. I don't find a related mypy git issue. This could be resolved by renaming one of them, but that would introduce too much code change IMO.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it would be nice to put a comment to explain this

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, if arg_constraints is an internal variable, perhaps an underscore prefix should be added to it? ( Which could be done as a separate PR)

@@ -2,7 +2,7 @@
from torch._six import inf
from torch.distributions.distribution import Distribution
from torch.distributions import Categorical
from numbers import Number
import numbers
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems like only Integral was used? should we just keep original format?

Suggested change
import numbers
from numbers import Integral

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems like only Integral was used? should we just keep original format?

Nice catch. Will update.

@@ -88,7 +89,7 @@ def param_shape(self):

def sample(self, sample_shape=torch.Size()):
sample_shape = torch.Size(sample_shape)
samples = self._categorical.sample(torch.Size((self.total_count,)) + sample_shape)
samples = self._categorical.sample(torch.Size((int(self.total_count),).__iter__()) + sample_shape)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was wondering if we can remove __iter__() seems like (int(val),) should return a Tuple[int] type. yes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right! I have removed it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's correct! Will update

@@ -94,7 +96,7 @@ class lazy_property(object):
"""
def __init__(self, wrapped):
self.wrapped = wrapped
update_wrapper(self, wrapped)
update_wrapper(self, wrapped) # type: ignore[arg-type]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a link to the reason?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixing this error requires annotating a built-in python package(functools), therefore I just ignore this error.
More details: update_wrapper(wrapper: [Callable, ...], wrapped) is a functools function. Here the Argument 1 is "self", which is "class lazy_property", therefore the type mismatches.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To fix this problem we need to annotate the update_wrapper(wrapper, wrapped), which is in a third-party python package(functools). The function's first parameter is "[Callable, ...]" whereas the Argument 1 here("self") is type "class lazy_property(object)".

@xuzhao9 xuzhao9 requested a review from walterddr October 5, 2020 21:18
Copy link
Contributor

@walterddr walterddr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the update.

also i think you might've rebased incorrectly with all those third-party changes. could you use the instruction @janeyx99 just added here: #45903 to update your PR

@@ -71,7 +72,7 @@ def event_shape(self):
"""
return self._event_shape

@property
@property # type: ignore[no-redef]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it would be nice to put a comment to explain this

@@ -1,4 +1,5 @@
import math
import numbers
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

Suggested change
import numbers
from numbers import Real

@@ -1,3 +1,4 @@
import numbers
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

Suggested change
import numbers
from numbers import Real

Copy link
Contributor

@walterddr walterddr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@codecov
Copy link

codecov bot commented Oct 7, 2020

Codecov Report

Merging #45689 into master will increase coverage by 0.01%.
The diff coverage is 100.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master   #45689      +/-   ##
==========================================
+ Coverage   68.25%   68.26%   +0.01%     
==========================================
  Files         410      410              
  Lines       53246    53266      +20     
==========================================
+ Hits        36344    36363      +19     
- Misses      16902    16903       +1     
Impacted Files Coverage Δ
torch/distributions/__init__.py 100.00% <100.00%> (ø)
torch/distributions/beta.py 96.49% <100.00%> (ø)
torch/distributions/distribution.py 90.72% <100.00%> (ø)
torch/distributions/independent.py 91.37% <100.00%> (+0.15%) ⬆️
torch/distributions/kl.py 97.52% <100.00%> (+<0.01%) ⬆️
torch/distributions/mixture_same_family.py 95.74% <100.00%> (+0.04%) ⬆️
torch/distributions/multinomial.py 95.31% <100.00%> (ø)
torch/distributions/normal.py 96.87% <100.00%> (+0.04%) ⬆️
torch/distributions/transformed_distribution.py 97.53% <100.00%> (+0.03%) ⬆️
torch/distributions/transforms.py 94.34% <100.00%> (+0.15%) ⬆️
... and 3 more

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 9679e1a...3928b46. Read the comment docs.

@@ -155,4 +156,4 @@
'register_kl',
'transform_to',
]
__all__.extend(transforms.__all__)
__all__.extend(transform_all)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please add a comment, why this is needed_

torch/distributions/beta.py Show resolved Hide resolved
@@ -71,7 +72,7 @@ def event_shape(self):
"""
return self._event_shape

@property
@property # type: ignore[no-redef]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, if arg_constraints is an internal variable, perhaps an underscore prefix should be added to it? ( Which could be done as a separate PR)

@@ -81,7 +83,8 @@ def arg_constraints(self):
"""
raise NotImplementedError

@property
# Ignore the mypy type error caused by redefining `support` as a method
@property # type: ignore[no-redef]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above, perhaps the variable should be renamed to something else.

@@ -500,8 +511,8 @@ def __eq__(self, other):

@property
def sign(self):
if isinstance(self.scale, numbers.Number):
return 1 if self.scale > 0 else -1 if self.scale < 0 else 0
if isinstance(self.scale, numbers.Integral):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, changing Number to Integral here would lead to a crash if scale is float, right? (because there are no sign method for the float type, is there?

@xuzhao9 xuzhao9 force-pushed the xzhao9/distributions-typing branch from f1e9efc to 2e718f1 Compare October 8, 2020 20:53
@xuzhao9 xuzhao9 requested a review from malfet October 8, 2020 20:56
@@ -50,7 +51,7 @@ def variance(self):
return self.total_count * self.probs * (1 - self.probs)

def __init__(self, total_count=1, probs=None, logits=None, validate_args=None):
if not isinstance(total_count, Number):
if not isinstance(total_count, Integral):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure that total_count must be an integral value?

@@ -40,6 +40,7 @@ class Multinomial(Distribution):
"""
arg_constraints = {'probs': constraints.simplex,
'logits': constraints.real}
total_count: Integral
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this one can not be just int?

Suggested change
total_count: Integral
total_count: int

@xuzhao9 xuzhao9 force-pushed the xzhao9/distributions-typing branch 3 times, most recently from f32e6de to 47cb358 Compare October 9, 2020 03:30
@dr-ci
Copy link

dr-ci bot commented Oct 9, 2020

💊 CI failures summary and remediations

As of commit 3928b46 (more details on the Dr. CI page):


💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group.

See how this bot performed.

This comment has been revised 7 times.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@xuzhao9 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@xuzhao9 xuzhao9 force-pushed the xzhao9/distributions-typing branch 2 times, most recently from 6ff4772 to 3928b46 Compare October 9, 2020 22:25
Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@xuzhao9 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@xuzhao9 xuzhao9 deleted the xzhao9/distributions-typing branch October 12, 2020 17:46
@facebook-github-bot
Copy link
Contributor

@xuzhao9 merged this pull request in 146721f.

facebook-github-bot pushed a commit that referenced this pull request Mar 8, 2022
Summary:
Fixes hiding of the `CatTransform` docstring due to a misplaced type annotation in #45689.
Also fixes that annotation and adds to to `StackTransform` to keep `CatTransform` similar.

Pull Request resolved: #73747

Reviewed By: mruberry

Differential Revision: D34649927

Pulled By: neerajprad

fbshipit-source-id: e4594fd76020a37ac4cada88e5fb08191984e911
pytorchmergebot pushed a commit that referenced this pull request Mar 8, 2022
Summary:
Fixes hiding of the `CatTransform` docstring due to a misplaced type annotation in #45689.
Also fixes that annotation and adds to to `StackTransform` to keep `CatTransform` similar.

Pull Request resolved: #73747

Reviewed By: mruberry

Differential Revision: D34649927

Pulled By: neerajprad

fbshipit-source-id: e4594fd76020a37ac4cada88e5fb08191984e911
(cherry picked from commit cec256c)
cyyever pushed a commit to cyyever/pytorch_private that referenced this pull request Mar 9, 2022
Summary:
Fixes hiding of the `CatTransform` docstring due to a misplaced type annotation in pytorch/pytorch#45689.
Also fixes that annotation and adds to to `StackTransform` to keep `CatTransform` similar.

Pull Request resolved: pytorch/pytorch#73747

Reviewed By: mruberry

Differential Revision: D34649927

Pulled By: neerajprad

fbshipit-source-id: e4594fd76020a37ac4cada88e5fb08191984e911
(cherry picked from commit cec256c3242d1cf55073a980060af87c1fd59ac9)
cyyever pushed a commit to cyyever/pytorch_private that referenced this pull request Mar 9, 2022
Summary:
Fixes hiding of the `CatTransform` docstring due to a misplaced type annotation in pytorch/pytorch#45689.
Also fixes that annotation and adds to to `StackTransform` to keep `CatTransform` similar.

Pull Request resolved: pytorch/pytorch#73747

Reviewed By: mruberry

Differential Revision: D34649927

Pulled By: neerajprad

fbshipit-source-id: e4594fd76020a37ac4cada88e5fb08191984e911
(cherry picked from commit cec256c3242d1cf55073a980060af87c1fd59ac9)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Merged module: distributions Related to torch.distributions module: typing Related to mypy type annotations
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Enable torch.distributions typechecks during CI
4 participants