Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add broadcast_shapes() function and use it in MultivariateNormal #43935

Closed
wants to merge 13 commits into from
1 change: 1 addition & 0 deletions docs/source/torch.rst
Expand Up @@ -447,6 +447,7 @@ Other Operations
bincount
block_diag
broadcast_tensors
broadcast_shapes
bucketize
cartesian_prod
cdist
Expand Down
17 changes: 17 additions & 0 deletions test/test_view_ops.py
Expand Up @@ -1021,6 +1021,23 @@ def test_broadcast_tensors(self, device, dtype):
self.assertTrue(y1.size() == expected_size)
self.assertTrue(y2.size() == expected_size)


@onlyCPU
def test_broadcast_shapes(self, device):
examples = [(), (1,), (2,), (1, 1), (3, 1), (3, 2), (4, 1, 1), (4, 3, 2)]
for s0 in examples:
x0 = torch.randn(s0)
expected = torch.broadcast_tensors(x0)[0].shape
actual = torch.broadcast_shapes(s0)
self.assertEqual(expected, actual)

for s1 in examples:
x1 = torch.randn(s1)
expected = torch.broadcast_tensors(x0, x1)[0].shape
actual = torch.broadcast_shapes(s0, s1)
self.assertEqual(expected, actual)


def test_view(self, device):
tensor = torch.rand(15, device=device)
template = torch.rand(3, 5, device=device)
Expand Down
14 changes: 8 additions & 6 deletions torch/distributions/multivariate_normal.py
Expand Up @@ -122,25 +122,27 @@ def __init__(self, loc, covariance_matrix=None, precision_matrix=None, scale_tri
if (covariance_matrix is not None) + (scale_tril is not None) + (precision_matrix is not None) != 1:
raise ValueError("Exactly one of covariance_matrix or precision_matrix or scale_tril may be specified.")

loc_ = loc.unsqueeze(-1) # temporarily add dim on right
if scale_tril is not None:
if scale_tril.dim() < 2:
raise ValueError("scale_tril matrix must be at least two-dimensional, "
"with optional leading batch dimensions")
self.scale_tril, loc_ = torch.broadcast_tensors(scale_tril, loc_)
batch_shape = torch.broadcast_shapes(scale_tril.shape[:-2], loc.shape[:-1])
self.scale_tril = scale_tril.expand(batch_shape + (-1, -1))
elif covariance_matrix is not None:
if covariance_matrix.dim() < 2:
raise ValueError("covariance_matrix must be at least two-dimensional, "
"with optional leading batch dimensions")
self.covariance_matrix, loc_ = torch.broadcast_tensors(covariance_matrix, loc_)
batch_shape = torch.broadcast_shapes(covariance_matrix.shape[:-2], loc.shape[:-1])
self.covariance_matrix = covariance_matrix.expand(batch_shape + (-1, -1))
else:
if precision_matrix.dim() < 2:
raise ValueError("precision_matrix must be at least two-dimensional, "
"with optional leading batch dimensions")
self.precision_matrix, loc_ = torch.broadcast_tensors(precision_matrix, loc_)
self.loc = loc_[..., 0] # drop rightmost dim
batch_shape = torch.broadcast_shapes(precision_matrix.shape[:-2], loc.shape[:-1])
self.precision_matrix = precision_matrix.expand(batch_shape + (-1, -1))
self.loc = loc.expand(batch_shape + (-1,))

batch_shape, event_shape = self.loc.shape[:-1], self.loc.shape[-1:]
event_shape = self.loc.shape[-1:]
super(MultivariateNormal, self).__init__(batch_shape, event_shape, validate_args=validate_args)

if scale_tril is not None:
Expand Down
34 changes: 34 additions & 0 deletions torch/functional.py
Expand Up @@ -19,6 +19,7 @@
'atleast_2d',
'atleast_3d',
'align_tensors',
'broadcast_shapes',
'broadcast_tensors',
'cartesian_prod',
'block_diag',
Expand Down Expand Up @@ -72,6 +73,39 @@ def broadcast_tensors(*tensors):
return _VF.broadcast_tensors(tensors) # type: ignore


def broadcast_shapes(*shapes):
r"""broadcast_shapes(*shapes) -> Size

Similar to :func:`broadcast_tensors` but for shapes.

This is equivalent to
``torch.broadcast_tensors(*map(torch.empty, shapes))[0].shape``
but avoids the need create to intermediate tensors. This is useful for
broadcasting tensors of common batch shape but different rightmost shape,
e.g. to broadcast mean vectors with covariance matrices.

Example::

>>> torch.broadcast_shapes((2,), (3, 1), (1, 1, 1))
torch.Size([1, 3, 2])

Args:
\*shapes (torch.Size): Shapes of tensors.

Returns:
shape (torch.Size): A shape compatible with all input shapes.

Raises:
RuntimeError: If shapes are incompatible.
"""
# TODO Movie this to C++ once the jit has better support for torch.Size.
with torch.no_grad():
scalar = torch.zeros((), device="cpu")
tensors = [scalar.expand(shape) for shape in shapes]
tensors = broadcast_tensors(*tensors)
return tensors[0].shape


def split(tensor, split_size_or_sections, dim=0):
r"""Splits the tensor into chunks. Each chunk is a view of the original tensor.

Expand Down
1 change: 1 addition & 0 deletions torch/overrides.py
Expand Up @@ -119,6 +119,7 @@ def get_ignored_functions() -> Set[Callable]:
torch.as_strided,
torch.bartlett_window,
torch.blackman_window,
torch.broadcast_shapes,
torch.can_cast,
torch.cudnn_affine_grid_generator,
torch.cudnn_batch_norm,
Expand Down