Skip to content

Commit

Permalink
Support multiple inheritance in torch.distributions (#16772)
Browse files Browse the repository at this point in the history
Summary:
This adds calls to `super().__init__()` in three classes in torch.distributions.

This is needed when `Distribution` and `Transform` objects are used with multiple inheritance, as e.g. combined with `torch.nn.Module`s. For example
```py
class MyModule(torch.distributions.Transform, torch.nn.Module):
    ...
```
cc  martinjankowiak esling who have wanted to use this pattern, e.g. in #16756
Pull Request resolved: #16772

Differential Revision: D13978633

Pulled By: soumith

fbshipit-source-id: 8bc6cca1747cd74d32135ee2fe588bba2ea796f1
  • Loading branch information
fritzo authored and facebook-github-bot committed Feb 7, 2019
1 parent 2681af1 commit 0d366e1
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 0 deletions.
1 change: 1 addition & 0 deletions torch/distributions/constraint_registry.py
Expand Up @@ -82,6 +82,7 @@ class ConstraintRegistry(object):
"""
def __init__(self):
self._registry = {}
super(ConstraintRegistry, self).__init__()

def register(self, constraint, factory=None):
"""
Expand Down
1 change: 1 addition & 0 deletions torch/distributions/distribution.py
Expand Up @@ -34,6 +34,7 @@ def __init__(self, batch_shape=torch.Size(), event_shape=torch.Size(), validate_
continue # skip checking lazily-constructed args
if not constraint.check(getattr(self, param)).all():
raise ValueError("The parameter {} has invalid values".format(param))
super(Distribution, self).__init__()

def expand(self, batch_shape, _instance=None):
"""
Expand Down
1 change: 1 addition & 0 deletions torch/distributions/transforms.py
Expand Up @@ -83,6 +83,7 @@ def __init__(self, cache_size=0):
self._cached_x_y = None, None
else:
raise ValueError('cache_size must be 0 or 1')
super(Transform, self).__init__()

@property
def inv(self):
Expand Down

0 comments on commit 0d366e1

Please sign in to comment.