From d48181c5c909190c8dd2e1feee6a3da04149e1d9 Mon Sep 17 00:00:00 2001 From: Wessel Bruinsma Date: Thu, 22 Sep 2022 10:35:03 +0200 Subject: [PATCH] Fix broadcasting of `randgamma` --- lab/jax/random.py | 4 +++- lab/numpy/random.py | 5 +++-- lab/shaping.py | 21 +++++++++++++++++++++ lab/tensorflow/random.py | 11 +++++------ lab/torch/random.py | 7 ++++--- lab/util.py | 15 +++++++++++++++ tests/test_random.py | 11 +++++++++++ 7 files changed, 62 insertions(+), 12 deletions(-) diff --git a/lab/jax/random.py b/lab/jax/random.py index 4c47b5a..5df5cf5 100644 --- a/lab/jax/random.py +++ b/lab/jax/random.py @@ -4,6 +4,7 @@ from . import dispatch, B, Numeric from ..types import Int, JAXDType, JAXNumeric, JAXRandomState +from ..util import broadcast_shapes __all__ = [] @@ -138,8 +139,9 @@ def randgamma( scale: Numeric, ): state, key = jax.random.split(state) + shape = shape + broadcast_shapes(B.shape(alpha), B.shape(scale)) sample = B.to_active_device(jax.random.gamma(key, alpha, shape, dtype=dtype)) - sample = B.multiply(sample, B.to_active_device(B.cast(dtype, scale))) + sample = sample * B.to_active_device(B.cast(dtype, scale)) return state, sample diff --git a/lab/numpy/random.py b/lab/numpy/random.py index d1a7868..052bd4d 100644 --- a/lab/numpy/random.py +++ b/lab/numpy/random.py @@ -4,8 +4,8 @@ from plum import Union from . import dispatch, B, Numeric -from ..shape import unwrap_dimension from ..types import NPDType, NPRandomState, Int +from ..util import broadcast_shapes __all__ = [] @@ -107,7 +107,8 @@ def randgamma( scale: Numeric, ): _warn_dtype(dtype) - return state, B.cast(dtype, state.gamma(alpha, scale=scale, size=shape)) + shape = shape + broadcast_shapes(B.shape(alpha), B.shape(scale)) + return state, B.cast(dtype, state.gamma(alpha, size=shape) * scale) @dispatch diff --git a/lab/shaping.py b/lab/shaping.py index cb80175..69f5ace 100644 --- a/lab/shaping.py +++ b/lab/shaping.py @@ -34,6 +34,7 @@ "concat", "concat2d", "tile", + "repeat", "take", "submatrix", ] @@ -490,6 +491,26 @@ def tile(a: Numeric, *repeats: Int): # pragma: no cover """ +@dispatch +def repeat(x, *repeats: Int): + """Repeat a tensor a number of times, adding new dimensions to the beginning. + + Args: + x (tensor): Tensor to repeat. + *repeats (int): Repetitions per dimension. + + Returns: + x: Repeated tensor. + """ + if repeats == (): + return x + return tile( + expand_dims(x, axis=0, times=len(repeats)), + *repeats, + *((1,) * rank(x)), + ) + + @dispatch def take(a: Numeric, indices_or_mask, axis: Int = 0): """Take particular elements along an axis. diff --git a/lab/tensorflow/random.py b/lab/tensorflow/random.py index bd695f5..bfa7d85 100644 --- a/lab/tensorflow/random.py +++ b/lab/tensorflow/random.py @@ -1,12 +1,11 @@ import logging import tensorflow as tf -from plum import Union from . import dispatch, B, Numeric -from ..types import TFDType, TFNumeric, Int, TFRandomState -from ..util import compress_batch from ..random import _randcat_last_first +from ..types import TFDType, TFNumeric, Int, TFRandomState +from ..util import compress_batch, broadcast_shapes __all__ = [] @@ -123,13 +122,13 @@ def randgamma( alpha: Numeric, scale: Numeric, ): - return state, tf.random.stateless_gamma( - shape, + sample = tf.random.stateless_gamma( + shape + broadcast_shapes(B.shape(alpha), B.shape(scale)), alpha=alpha, - beta=B.divide(1, scale), seed=state.make_seeds()[:, 0], dtype=dtype, ) + return state, sample * B.to_active_device(B.cast(dtype, scale)) @dispatch diff --git a/lab/torch/random.py b/lab/torch/random.py index 798267a..4a64bd1 100644 --- a/lab/torch/random.py +++ b/lab/torch/random.py @@ -1,10 +1,9 @@ import torch -from plum import Union from . import dispatch, B, Numeric +from ..random import _randcat_last_first from ..types import TorchNumeric, TorchDType, Int, TorchRandomState from ..util import compress_batch -from ..random import _randcat_last_first __all__ = [] @@ -137,7 +136,9 @@ def randgamma( ): alpha = B.to_active_device(B.cast(dtype, alpha)) scale = B.to_active_device(B.cast(dtype, scale)) - alpha = B.broadcast_to(alpha, *shape) + alpha, scale = torch.broadcast_tensors(alpha, scale) + alpha = B.repeat(alpha, *shape) + scale = B.repeat(scale, *shape) return state, torch._standard_gamma(alpha, generator=state) * scale diff --git a/lab/util.py b/lab/util.py index 5c0e779..fdb2b63 100644 --- a/lab/util.py +++ b/lab/util.py @@ -13,6 +13,7 @@ "batch_computation", "abstract", "compress_batch", + "broadcast_shapes", ] _dispatch = plum.Dispatcher() @@ -241,3 +242,17 @@ def uncompress(y): return B.reshape(y, *shape[:-n], *B.shape(y)[1:]) return B.reshape(x, -1, *shape[-n:]), uncompress + + +@_dispatch +def broadcast_shapes(*shapes): + """Broadcast shapes. + + Args: + *shapes (shape): Shapes to broadcast. + + Return: + tuple[int]: Broadcasted shape. + """ + shapes = [tuple(int(d) for d in shape) for shape in shapes] + return np.broadcast_shapes(*shapes) diff --git a/tests/test_random.py b/tests/test_random.py index 0b8a01b..3a60204 100644 --- a/tests/test_random.py +++ b/tests/test_random.py @@ -144,6 +144,17 @@ def test_randgamma_parameters(t, check_lazy_shapes): approx(B.randgamma(t, alpha=1, scale=0), 0, atol=1e-6) +@pytest.mark.parametrize("t", [np.float32, tf.float32, torch.float32, jnp.float32]) +def test_randgamma_broadcasting(t, check_lazy_shapes): + assert B.shape(B.randgamma(t, alpha=1, scale=0)) == () + assert B.shape(B.randgamma(t, alpha=B.rand(5), scale=0)) == (5,) + assert B.shape(B.randgamma(t, alpha=B.rand(5), scale=B.rand(5))) == (5,) + assert B.shape(B.randgamma(t, alpha=1, scale=B.rand(5))) == (5,) + assert B.shape(B.randgamma(t, 3, alpha=B.rand(5), scale=0)) == (3, 5) + assert B.shape(B.randgamma(t, 3, alpha=B.rand(5), scale=B.rand(5))) == (3, 5) + assert B.shape(B.randgamma(t, 3, alpha=1, scale=B.rand(5))) == (3, 5) + + @pytest.mark.parametrize("t", [np.float32, tf.float32, torch.float32, jnp.float32]) def test_randbeta_parameters(t, check_lazy_shapes): approx(B.randbeta(t, alpha=1e-6, beta=1), 0, atol=1e-6)