Skip to content

Commit

Permalink
Add try..except around torch imports
Browse files Browse the repository at this point in the history
  • Loading branch information
fritzo committed Feb 27, 2024
1 parent 4708b75 commit 9432a0a
Show file tree
Hide file tree
Showing 3 changed files with 169 additions and 57 deletions.
92 changes: 49 additions & 43 deletions pyro/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,50 +3,55 @@

import pyro.distributions.torch_patch # noqa F403

# Import both * to get new distributions and explicitly to help mypy.
# Import * to get the latest upstream distributions.
from pyro.distributions.torch import * # noqa F403
from pyro.distributions.torch import (
Bernoulli,
Beta,
Binomial,
Categorical,
Cauchy,
Chi2,
ContinuousBernoulli,
Dirichlet,
ExponentialFamily,
Exponential,
FisherSnedecor,
Gamma,
Geometric,
Gumbel,
HalfCauchy,
HalfNormal,
Independent,
Kumaraswamy,
Laplace,
LKJCholesky,
LogNormal,
LogisticNormal,
LowRankMultivariateNormal,
MixtureSameFamily,
Multinomial,
MultivariateNormal,
NegativeBinomial,
Normal,
OneHotCategorical,
OneHotCategoricalStraightThrough,
Pareto,
Poisson,
RelaxedBernoulli,
RelaxedOneHotCategorical,
StudentT,
TransformedDistribution,
Uniform,
VonMises,
Weibull,
Wishart,
)

# Additionally try to import explicitly to help mypy static analysis.
try:
from pyro.distributions.torch import (
Bernoulli,
Beta,
Binomial,
Categorical,
Cauchy,
Chi2,
ContinuousBernoulli,
Dirichlet,
Exponential,
ExponentialFamily,
FisherSnedecor,
Gamma,
Geometric,
Gumbel,
HalfCauchy,
HalfNormal,
Independent,
Kumaraswamy,
Laplace,
LKJCholesky,
LogisticNormal,
LogNormal,
LowRankMultivariateNormal,
MixtureSameFamily,
Multinomial,
MultivariateNormal,
NegativeBinomial,
Normal,
OneHotCategorical,
OneHotCategoricalStraightThrough,
Pareto,
Poisson,
RelaxedBernoulli,
RelaxedOneHotCategorical,
StudentT,
TransformedDistribution,
Uniform,
VonMises,
Weibull,
Wishart,
)
except ImportError:
pass

# isort: split

Expand Down Expand Up @@ -255,4 +260,5 @@

# Import all torch distributions from `pyro.distributions.torch_distribution`
__all__.extend(torch_dists)
__all__[:] = sorted(set(__all__))
del torch_dists
79 changes: 71 additions & 8 deletions pyro/distributions/constraints.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,50 @@
# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

# Import * to get the latest upstream constraints.
from torch.distributions.constraints import * # noqa F403

# Additionally try to import explicitly to help mypy static analysis.
try:
from torch.distributions.constraints import (
Constraint,
boolean,
cat,
corr_cholesky,
dependent,
dependent_property,
greater_than,
greater_than_eq,
half_open_interval,
independent,
integer_interval,
interval,
is_dependent,
less_than,
lower_cholesky,
lower_triangular,
multinomial,
nonnegative,
nonnegative_integer,
one_hot,
positive,
positive_definite,
positive_integer,
positive_semidefinite,
real,
real_vector,
simplex,
square,
stack,
symmetric,
unit_interval,
)
except ImportError:
pass

# isort: split

import torch
from torch.distributions.constraints import (
Constraint,
independent,
lower_cholesky,
positive,
positive_definite,
)
from torch.distributions.constraints import __all__ as torch_constraints


Expand Down Expand Up @@ -129,19 +161,50 @@ def check(self, value):
corr_cholesky_constraint = corr_cholesky # noqa: F405 DEPRECATED

__all__ = [
"Constraint",
"boolean",
"cat",
"corr_cholesky",
"corr_cholesky_constraint",
"corr_matrix",
"dependent",
"dependent_property",
"greater_than",
"greater_than_eq",
"half_open_interval",
"independent",
"integer",
"integer_interval",
"interval",
"is_dependent",
"less_than",
"lower_cholesky",
"lower_triangular",
"multinomial",
"nonnegative",
"nonnegative_integer",
"one_hot",
"ordered_vector",
"positive",
"positive_definite",
"positive_integer",
"positive_ordered_vector",
"positive_semidefinite",
"real",
"real_vector",
"simplex",
"softplus_lower_cholesky",
"softplus_positive",
"sphere",
"square",
"stack",
"symmetric",
"unit_interval",
"unit_lower_cholesky",
]

__all__.extend(torch_constraints)
__all__ = sorted(set(__all__))
__all__[:] = sorted(set(__all__))
del torch_constraints


Expand Down
55 changes: 49 additions & 6 deletions pyro/distributions/transforms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,39 @@
# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

# Import * to get the latest upstream transforms.
from torch.distributions.transforms import * # noqa F403

# Additionally try to import explicitly to help mypy static analysis.
try:
from torch.distributions.transforms import (
AbsTransform,
AffineTransform,
CatTransform,
ComposeTransform,
# CorrCholeskyTransform, # Use Pyro's version below.
CumulativeDistributionTransform,
ExpTransform,
IndependentTransform,
LowerCholeskyTransform,
PositiveDefiniteTransform,
PowerTransform,
ReshapeTransform,
SigmoidTransform,
SoftmaxTransform,
# SoftplusTransform, # Use Pyro's version below.
StackTransform,
StickBreakingTransform,
TanhTransform,
Transform,
identity_transform,
)
except ImportError:
pass

# isort: split

from torch.distributions import biject_to, transform_to
from torch.distributions.transforms import (
ComposeTransform,
ExpTransform,
LowerCholeskyTransform,
)
from torch.distributions.transforms import __all__ as torch_transforms

from .. import constraints
Expand Down Expand Up @@ -150,12 +173,15 @@ def iterated(repeats, base_fn, *args, **kwargs):


__all__ = [
"iterated",
"AbsTransform",
"AffineAutoregressive",
"AffineCoupling",
"AffineTransform",
"BatchNorm",
"BlockAutoregressive",
"CatTransform",
"CholeskyTransform",
"ComposeTransform",
"ComposeTransformModule",
"ConditionalAffineAutoregressive",
"ConditionalAffineCoupling",
Expand All @@ -167,31 +193,45 @@ def iterated(repeats, base_fn, *args, **kwargs):
"ConditionalRadial",
"ConditionalSpline",
"ConditionalSplineAutoregressive",
"CorrCholeskyTransform",
"CorrLCholeskyTransform",
"CorrMatrixCholeskyTransform",
"CumulativeDistributionTransform",
"DiscreteCosineTransform",
"ELUTransform",
"ExpTransform",
"GeneralizedChannelPermute",
"HaarTransform",
"Householder",
"IndependentTransform",
"LeakyReLUTransform",
"LowerCholeskyAffine",
"LowerCholeskyTransform",
"MatrixExponential",
"NeuralAutoregressive",
"Normalize",
"OrderedTransform",
"Permute",
"Planar",
"Polynomial",
"PositiveDefiniteTransform",
"PositivePowerTransform",
"PowerTransform",
"Radial",
"ReshapeTransform",
"SigmoidTransform",
"SimplexToOrderedTransform",
"SoftmaxTransform",
"SoftplusLowerCholeskyTransform",
"SoftplusTransform",
"Spline",
"SplineAutoregressive",
"SplineCoupling",
"StackTransform",
"StickBreakingTransform",
"Sylvester",
"TanhTransform",
"Transform",
"affine_autoregressive",
"affine_coupling",
"batchnorm",
Expand All @@ -209,6 +249,8 @@ def iterated(repeats, base_fn, *args, **kwargs):
"elu",
"generalized_channel_permute",
"householder",
"identity_transform",
"iterated",
"leaky_relu",
"matrix_exponential",
"neural_autoregressive",
Expand All @@ -223,4 +265,5 @@ def iterated(repeats, base_fn, *args, **kwargs):
]

__all__.extend(torch_transforms)
__all__[:] = sorted(set(__all__))
del torch_transforms

0 comments on commit 9432a0a

Please sign in to comment.