Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use unified type for distributions.constraint API #50616

Open
neerajprad opened this issue Jan 15, 2021 · 1 comment
Open

Use unified type for distributions.constraint API #50616

neerajprad opened this issue Jan 15, 2021 · 1 comment
Labels
module: distributions Related to torch.distributions triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@neerajprad
Copy link
Contributor

neerajprad commented Jan 15, 2021

馃殌 Feature

Currently, the API of torch.constraints exposes a mix of singleton Constraint instances and Constraint classes. e.g.

>>> type(constraints.real)
torch.distributions.constraints._Real

>>> type(constraints.interval)
type

The ask is to uniformly return the more flexible class in the public API, e.g. instead of exposing real = _Real(), we should instead expose real = _Real.

Motivation

The motivation for this is that the current API makes checking for constraint type awkward, e.g. to check if a constraint is an interval or a real constraint, we need to do:

constraint is constraints.real or isinstance(constraint, constraints.interval)

, which requires figuring out which constraint is a type vs an actual instance, instead of more simply:

isinstance(constraint, (constraints.real, constraints.interval))

Note that this will be a breaking change and we may need to go through a round of deprecation to support this.

Alternative

Alternatively, we can instead do isinstance checks on the non-public classes, e.g. isinstance(constraint, (constraints._Real, constraints._Interval), but I am not sure if we have strong reasons to not expose these classes publicly.

Another option is to provide a backward compatible utility function:

def _unwrap(constraint):
    if isinstance(constraint, constraints.independent):
        return _unwrap(constraint.base_constraint)
    return constraint if isinstance(constraint, type) else constraint.__class__


def constraint_type_eq(constraint1, constraint2):
    return _unwrap(constraint1) == _unwrap(constraint2)

cc @fritzo @neerajprad @alicanb @vishwakftw @nikitaved @fehiepsi

@neerajprad neerajprad added the module: distributions Related to torch.distributions label Jan 15, 2021
@fritzo
Copy link
Collaborator

fritzo commented Jan 16, 2021

How about adding entirely new names Real etc. while maintaining the old interface for backwards compatibility (at least in a Pyro shim module). It seems like CamelCase classes would also be more standard:

class Real: ...
class IntegerInterval: ...

# add these for backwards compatibility
real = Real()  # an instance
integer_interval = IntegerInterval  # a class

# maybe add these for one release
_Real = Real
_IntegerInterval = IntegerInterval

This has the advantage of allowing the old and new interfaces to coexist for at least one release.

@ngimel ngimel added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jan 18, 2021
facebook-github-bot pushed a commit to facebookresearch/beanmachine that referenced this issue Jan 21, 2021
Summary:
Pull Request resolved: #542

This adds an equality testing utility for `torch.constraint` objects/classes, so that these comparisons can be consolidated within a single utility function `is_constraint_eq`.

Usage:
```
is_constraint_eq(dist.support, (constraints.real, constraints.greater_than))
```

, instead of:

```
isinstance(dist.support, (constraints._Real, constraints._GreaterThan))
```
, or the more obfuscatory

```
dist.support is constraints.real or isinstance(dist.support, constraints.greater_than)
```

See pytorch/pytorch#50616 for more details (note that the changes suggested in the issue are complementary).

 - This avoids usage of the non-public constraint classes (like `constraints._Real`, `constraints._Interval`).
 - Makes it possible to consolidate future changes (e.g. those arising out of the introduction of an `Independent` constraint - pytorch/pytorch#50547) within a single function.

This is a pre-requisite to some other fixes that are currently blocking D25918330. I will add these small fixes when I merge D25918330.

Reviewed By: feynmanliang, jpchen

Differential Revision: D25935106

fbshipit-source-id: 76296464f51e01a9cac4506fb927c531145212ce
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: distributions Related to torch.distributions triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

3 participants