Skip to content

Commit

Permalink
Update dev branch to use PyTorch 1.8 prerelease (#2753)
Browse files Browse the repository at this point in the history
  • Loading branch information
fritzo committed Feb 22, 2021
1 parent f82841a commit 1f02392
Show file tree
Hide file tree
Showing 53 changed files with 399 additions and 310 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ run_outputs*
data
.data
results
*.csv
examples/*/processed
examples/*/results
examples/*/raw
Expand Down
9 changes: 8 additions & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,14 @@ install:
- pip install -U pip
# Keep track of pyro-api master branch
- pip install https://github.com/pyro-ppl/pyro-api/archive/master.zip
- pip install torch==1.7.0+cpu torchvision==0.8.1+cpu -f https://download.pytorch.org/whl/torch_stable.html
# Download PyTorch
- pip install numpy
# TODO replace with torch_stable before release
# - pip install torch==1.8.0+cpu torchvision==0.9.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
# TODO replace with torch_test once torchvision binaries are released
# - pip install torch torchvision -f https://download.pytorch.org/whl/test/cpu/torch_test.html
# This is the last nightly release of 1.8.0 before splitting to 1.9.0.
- pip install --pre torch==1.8.0.dev20210210+cpu torchvision==0.9.0.dev20210210+cpu -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
- pip install .[test]
- pip install coveralls
- pip freeze
Expand Down
13 changes: 10 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,19 +60,26 @@ Make sure that the models come from the same release version of the [Pyro source
### Installing Pyro dev branch

For recent features you can install Pyro from source.
Pyro's dev branch requires PyTorch [nightly builds](https://pytorch.org/get-started/locally/).

**Install using pip:**
**Install PyTorch nightly:**

```sh
pip install git+https://github.com/pyro-ppl/pyro.git
pip install numpy
pip install --pre torch==1.8.0.dev20210210 torchvision==0.9.0.dev20210210 -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
```

**Install Pyro using pip:**

```sh
pip install git+https://github.com/pyro-ppl/pyro.git
```
or, with the `extras` dependency to run the probabilistic models included in the `examples`/`tutorials` directories:
```sh
pip install git+https://github.com/pyro-ppl/pyro.git#egg=project[extras]
```

**Install from source:**
**Install Pyro from source:**

```sh
git clone https://github.com/pyro-ppl/pyro
Expand Down
14 changes: 12 additions & 2 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,5 +213,15 @@ def setup(app):
# @jpchen's hack to get rtd builder to install latest pytorch
# See similar line in the install section of .travis.yml
if 'READTHEDOCS' in os.environ:
os.system('pip install torch==1.7.0+cpu torchvision==0.8.1+cpu '
'-f https://download.pytorch.org/whl/torch_stable.html')
os.system('pip install numpy')
# TODO replace with torch_stable before release
# os.system('pip install torch==1.8.0+cpu torchvision==0.9.0+cpu '
# '-f https://download.pytorch.org/whl/torch_stable.html')
# TODO replace with torch_test once torchvision binaries are released
# os.system('pip install torch torchvision '
# '-f https://download.pytorch.org/whl/test/cpu/torch_test.html')
# This is the last nightly release of 1.8.0 before splitting to 1.9.0.
os.system('pip install --pre '
'torch==1.8.0.dev20210210+cpu '
'torchvision==0.9.0.dev20210210+cpu '
'-f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html')
2 changes: 1 addition & 1 deletion pyro/contrib/gp/parameterized.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@


def _is_real_support(support):
if isinstance(support, pyro.distributions.constraints.IndependentConstraint):
if isinstance(support, pyro.distributions.constraints.independent):
return _is_real_support(support.base_constraint)
else:
return support in [constraints.real, constraints.real_vector]
Expand Down
2 changes: 1 addition & 1 deletion pyro/distributions/affine_beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from torch.distributions import constraints
from torch.distributions.transforms import AffineTransform

from pyro.distributions.torch import Beta, TransformedDistribution
from .torch import Beta, TransformedDistribution


class AffineBeta(TransformedDistribution):
Expand Down
25 changes: 15 additions & 10 deletions pyro/distributions/conjugate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@
import numbers

import torch
from torch.distributions import constraints
from torch.distributions.utils import broadcast_all

from pyro.distributions.torch import Beta, Binomial, Dirichlet, Gamma, Multinomial, Poisson
from pyro.distributions.torch_distribution import TorchDistribution
from pyro.ops.special import log_beta, log_binomial

from . import constraints
from .torch import Beta, Binomial, Dirichlet, Gamma, Multinomial, Poisson
from .torch_distribution import TorchDistribution
from .util import broadcast_shape


def _log_beta_1(alpha, value, is_sparse):
if is_sparse:
Expand Down Expand Up @@ -124,20 +126,23 @@ class DirichletMultinomial(TorchDistribution):
:param bool is_sparse: Whether to assume value is mostly zero when computing
:meth:`log_prob`, which can speed up computation when data is sparse.
"""
arg_constraints = {'concentration': constraints.positive, 'total_count': constraints.nonnegative_integer}
arg_constraints = {'concentration': constraints.independent(constraints.positive, 1),
'total_count': constraints.nonnegative_integer}
support = Multinomial.support

def __init__(self, concentration, total_count=1, is_sparse=False, validate_args=None):
batch_shape = concentration.shape[:-1]
event_shape = concentration.shape[-1:]
if isinstance(total_count, numbers.Number):
total_count = torch.tensor(total_count, dtype=concentration.dtype, device=concentration.device)
total_count_1 = total_count.unsqueeze(-1)
concentration, total_count = torch.broadcast_tensors(concentration, total_count_1)
total_count = total_count_1.squeeze(-1)
total_count = concentration.new_tensor(total_count)
else:
batch_shape = broadcast_shape(batch_shape, total_count.shape)
concentration = concentration.expand(batch_shape + (-1,))
total_count = total_count.expand(batch_shape)
self._dirichlet = Dirichlet(concentration)
self.total_count = total_count
self.is_sparse = is_sparse
super().__init__(
self._dirichlet._batch_shape, self._dirichlet.event_shape, validate_args=validate_args)
super().__init__(batch_shape, event_shape, validate_args=validate_args)

@property
def concentration(self):
Expand Down
47 changes: 7 additions & 40 deletions pyro/distributions/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,37 +5,16 @@
from torch.distributions.constraints import * # noqa F403
from torch.distributions.constraints import Constraint
from torch.distributions.constraints import __all__ as torch_constraints
from torch.distributions.constraints import lower_cholesky, positive, positive_definite


# TODO move this upstream to torch.distributions
class IndependentConstraint(Constraint):
"""
Wraps a constraint by aggregating over ``reinterpreted_batch_ndims``-many
dims in :meth:`check`, so that an event is valid only if all its
independent entries are valid.
:param torch.distributions.constraints.Constraint base_constraint: A base
constraint whose entries are incidentally independent.
:param int reinterpreted_batch_ndims: The number of extra event dimensions that will
be considered dependent.
"""
def __init__(self, base_constraint, reinterpreted_batch_ndims):
self.base_constraint = base_constraint
self.reinterpreted_batch_ndims = reinterpreted_batch_ndims

def check(self, value):
result = self.base_constraint.check(value)
result = result.reshape(result.shape[:result.dim() - self.reinterpreted_batch_ndims] + (-1,))
result = result.min(-1)[0]
return result
from torch.distributions.constraints import independent, positive, positive_definite


# TODO move this upstream to torch.distributions
class _Integer(Constraint):
"""
Constrain to integers.
"""
is_discrete = True

def check(self, value):
return value % 1 == 0

Expand All @@ -47,6 +26,7 @@ class _Sphere(Constraint):
"""
Constrain to the Euclidean sphere of any dimension.
"""
event_dim = 1
reltol = 10. # Relative to finfo.eps.

def check(self, value):
Expand All @@ -59,22 +39,11 @@ def __repr__(self):
return self.__class__.__name__[1:]


class _CorrCholesky(Constraint):
"""
Constrains to lower-triangular square matrices with positive diagonals and
Euclidean norm of each row is 1, such that `torch.mm(omega, omega.t())` will
have unit diagonal.
"""

def check(self, value):
unit_norm_row = (value.norm(dim=-1).sub(1) < 1e-4).min(-1)[0]
return lower_cholesky.check(value) & unit_norm_row


class _CorrMatrix(Constraint):
"""
Constrains to a correlation matrix.
"""
event_dim = 2

def check(self, value):
# check for diagonal equal to 1
Expand All @@ -88,6 +57,7 @@ class _OrderedVector(Constraint):
Constrains to a real-valued tensor where the elements are monotonically
increasing along the `event_shape` dimension.
"""
event_dim = 1

def check(self, value):
if value.ndim == 0:
Expand All @@ -108,19 +78,16 @@ def check(self, value):
return ordered_vector.check(value) & independent(positive, 1).check(value)


corr_cholesky_constraint = _CorrCholesky()
corr_matrix = _CorrMatrix()
independent = IndependentConstraint
integer = _Integer()
ordered_vector = _OrderedVector()
positive_ordered_vector = _PositiveOrderedVector()
sphere = _Sphere()
corr_cholesky_constraint = corr_cholesky # noqa: F405 DEPRECATED

__all__ = [
'IndependentConstraint',
'corr_cholesky_constraint',
'corr_matrix',
'independent',
'integer',
'ordered_vector',
'positive_ordered_vector',
Expand Down
31 changes: 19 additions & 12 deletions pyro/distributions/hmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,18 @@
# SPDX-License-Identifier: Apache-2.0

import torch
from torch.distributions import constraints

from pyro.distributions.torch import Categorical, Gamma, Independent, MultivariateNormal
from pyro.distributions.torch_distribution import TorchDistribution
from pyro.distributions.util import broadcast_shape, torch_jit_script_if_tracing
from pyro.ops.gamma_gaussian import (GammaGaussian, gamma_and_mvn_to_gamma_gaussian, gamma_gaussian_tensordot,
matrix_and_mvn_to_gamma_gaussian)
from pyro.ops.gaussian import Gaussian, gaussian_tensordot, matrix_and_mvn_to_gaussian, mvn_to_gaussian
from pyro.ops.special import safe_log
from pyro.ops.tensor_utils import cholesky, cholesky_solve

from . import constraints
from .torch import Categorical, Gamma, Independent, MultivariateNormal
from .torch_distribution import TorchDistribution
from .util import broadcast_shape, torch_jit_script_if_tracing


@torch_jit_script_if_tracing
def _linear_integrate(init, trans, shift):
Expand Down Expand Up @@ -309,9 +310,9 @@ def __init__(self, initial_logits, transition_logits, observation_dist,
self.observation_dist = observation_dist
super().__init__(duration, batch_shape, event_shape, validate_args=validate_args)

@property
@constraints.dependent_property(event_dim=2)
def support(self):
return self.observation_dist.support
return constraints.independent(self.observation_dist.support, 1)

def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(DiscreteHMM, _instance)
Expand Down Expand Up @@ -436,7 +437,7 @@ class GaussianHMM(HiddenMarkovModel):
"""
has_rsample = True
arg_constraints = {}
support = constraints.real
support = constraints.independent(constraints.real, 2)

def __init__(self, initial_dist, transition_matrix, transition_dist,
observation_matrix, observation_dist, validate_args=None, duration=None):
Expand Down Expand Up @@ -717,7 +718,7 @@ class GammaGaussianHMM(HiddenMarkovModel):
are not expanded along the time axis.
"""
arg_constraints = {}
support = constraints.real
support = constraints.independent(constraints.real, 2)

def __init__(self, scale_dist, initial_dist, transition_matrix, transition_dist,
observation_matrix, observation_dist, validate_args=None, duration=None):
Expand Down Expand Up @@ -886,7 +887,7 @@ class LinearHMM(HiddenMarkovModel):
are not expanded along the time axis.
"""
arg_constraints = {}
support = constraints.real
support = constraints.independent(constraints.real, 2)
has_rsample = True

def __init__(self, initial_dist, transition_matrix, transition_dist,
Expand Down Expand Up @@ -953,9 +954,9 @@ def __init__(self, initial_dist, transition_matrix, transition_dist,
self.observation_dist = observation_dist
self.transforms = transforms

@property
@constraints.dependent_property(event_dim=2)
def support(self):
return self.observation_dist.support
return constraints.independent(self.observation_dist.support, 1)

def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(LinearHMM, _instance)
Expand Down Expand Up @@ -1017,7 +1018,7 @@ def __init__(self, base_dist):
super().__init__(batch_shape, event_shape)
self.base_dist = base_dist

@constraints.dependent_property
@constraints.dependent_property(event_dim=2)
def support(self):
return self.base_dist.support

Expand Down Expand Up @@ -1107,6 +1108,11 @@ def __init__(self, initial_dist, transition_dist, observation_dist, validate_arg
self._init = mvn_to_gaussian(initial_dist)
self._trans = mvn_to_gaussian(transition_dist)
self._obs = mvn_to_gaussian(observation_dist)
self._support = constraints.independent(observation_dist.support, 1)

@constraints.dependent_property(event_dim=2)
def support(self):
return self._support

def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(GaussianMRF, _instance)
Expand All @@ -1119,6 +1125,7 @@ def expand(self, batch_shape, _instance=None):
new._init = self._init.expand(batch_shape)
new._trans = self._trans
new._obs = self._obs
new._support = self._support
super(GaussianMRF, new).__init__(batch_shape, self.event_shape, validate_args=False)
new._validate_args = self.__dict__.get('_validate_args')
return new
Expand Down
15 changes: 1 addition & 14 deletions pyro/distributions/kl.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,10 @@

import math

from torch.distributions import TransformedDistribution, kl_divergence, register_kl
from torch.distributions import Independent, MultivariateNormal, Normal, kl_divergence, register_kl

from pyro.distributions.delta import Delta
from pyro.distributions.distribution import Distribution
from pyro.distributions.torch import Independent, MultivariateNormal, Normal
from pyro.distributions.util import sum_rightmost


Expand Down Expand Up @@ -46,16 +45,4 @@ def _kl_independent_mvn(p, q):
raise NotImplementedError


# TODO: move upstream
@register_kl(TransformedDistribution, TransformedDistribution)
def _kl_transformed_transformed(p, q):
if p.transforms != q.transforms:
raise NotImplementedError
if p.event_shape != q.event_shape:
raise NotImplementedError
extra_event_dim = len(p.base_dist.batch_shape) - len(p.batch_shape)
base_kl_divergence = kl_divergence(p.base_dist, q.base_dist)
return sum_rightmost(base_kl_divergence, extra_event_dim)


__all__ = []

0 comments on commit 1f02392

Please sign in to comment.