New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use unified type for distributions.constraint API #50616
Labels
module: distributions
Related to torch.distributions
triaged
This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Comments
How about adding entirely new names class Real: ...
class IntegerInterval: ...
# add these for backwards compatibility
real = Real() # an instance
integer_interval = IntegerInterval # a class
# maybe add these for one release
_Real = Real
_IntegerInterval = IntegerInterval This has the advantage of allowing the old and new interfaces to coexist for at least one release. |
ngimel
added
the
triaged
This issue has been looked at a team member, and triaged and prioritized into an appropriate module
label
Jan 18, 2021
facebook-github-bot
pushed a commit
to facebookresearch/beanmachine
that referenced
this issue
Jan 21, 2021
Summary: Pull Request resolved: #542 This adds an equality testing utility for `torch.constraint` objects/classes, so that these comparisons can be consolidated within a single utility function `is_constraint_eq`. Usage: ``` is_constraint_eq(dist.support, (constraints.real, constraints.greater_than)) ``` , instead of: ``` isinstance(dist.support, (constraints._Real, constraints._GreaterThan)) ``` , or the more obfuscatory ``` dist.support is constraints.real or isinstance(dist.support, constraints.greater_than) ``` See pytorch/pytorch#50616 for more details (note that the changes suggested in the issue are complementary). - This avoids usage of the non-public constraint classes (like `constraints._Real`, `constraints._Interval`). - Makes it possible to consolidate future changes (e.g. those arising out of the introduction of an `Independent` constraint - pytorch/pytorch#50547) within a single function. This is a pre-requisite to some other fixes that are currently blocking D25918330. I will add these small fixes when I merge D25918330. Reviewed By: feynmanliang, jpchen Differential Revision: D25935106 fbshipit-source-id: 76296464f51e01a9cac4506fb927c531145212ce
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Labels
module: distributions
Related to torch.distributions
triaged
This issue has been looked at a team member, and triaged and prioritized into an appropriate module
馃殌 Feature
Currently, the API of
torch.constraints
exposes a mix of singletonConstraint
instances andConstraint
classes. e.g.The ask is to uniformly return the more flexible class in the public API, e.g. instead of exposing
real = _Real()
, we should instead exposereal = _Real
.Motivation
The motivation for this is that the current API makes checking for constraint type awkward, e.g. to check if a constraint is an interval or a real constraint, we need to do:
, which requires figuring out which constraint is a type vs an actual instance, instead of more simply:
Note that this will be a breaking change and we may need to go through a round of deprecation to support this.
Alternative
Alternatively, we can instead do
isinstance
checks on the non-public classes, e.g.isinstance(constraint, (constraints._Real, constraints._Interval)
, but I am not sure if we have strong reasons to not expose these classes publicly.Another option is to provide a backward compatible utility function:
cc @fritzo @neerajprad @alicanb @vishwakftw @nikitaved @fehiepsi
The text was updated successfully, but these errors were encountered: