Skip to content

Commit

Permalink
Fix broadcasting of randgamma
Browse files Browse the repository at this point in the history
  • Loading branch information
wesselb committed Sep 22, 2022
1 parent f5e092c commit d48181c
Show file tree
Hide file tree
Showing 7 changed files with 62 additions and 12 deletions.
4 changes: 3 additions & 1 deletion lab/jax/random.py
Expand Up @@ -4,6 +4,7 @@

from . import dispatch, B, Numeric
from ..types import Int, JAXDType, JAXNumeric, JAXRandomState
from ..util import broadcast_shapes

__all__ = []

Expand Down Expand Up @@ -138,8 +139,9 @@ def randgamma(
scale: Numeric,
):
state, key = jax.random.split(state)
shape = shape + broadcast_shapes(B.shape(alpha), B.shape(scale))
sample = B.to_active_device(jax.random.gamma(key, alpha, shape, dtype=dtype))
sample = B.multiply(sample, B.to_active_device(B.cast(dtype, scale)))
sample = sample * B.to_active_device(B.cast(dtype, scale))
return state, sample


Expand Down
5 changes: 3 additions & 2 deletions lab/numpy/random.py
Expand Up @@ -4,8 +4,8 @@
from plum import Union

from . import dispatch, B, Numeric
from ..shape import unwrap_dimension
from ..types import NPDType, NPRandomState, Int
from ..util import broadcast_shapes

__all__ = []

Expand Down Expand Up @@ -107,7 +107,8 @@ def randgamma(
scale: Numeric,
):
_warn_dtype(dtype)
return state, B.cast(dtype, state.gamma(alpha, scale=scale, size=shape))
shape = shape + broadcast_shapes(B.shape(alpha), B.shape(scale))
return state, B.cast(dtype, state.gamma(alpha, size=shape) * scale)


@dispatch
Expand Down
21 changes: 21 additions & 0 deletions lab/shaping.py
Expand Up @@ -34,6 +34,7 @@
"concat",
"concat2d",
"tile",
"repeat",
"take",
"submatrix",
]
Expand Down Expand Up @@ -490,6 +491,26 @@ def tile(a: Numeric, *repeats: Int): # pragma: no cover
"""


@dispatch
def repeat(x, *repeats: Int):
"""Repeat a tensor a number of times, adding new dimensions to the beginning.
Args:
x (tensor): Tensor to repeat.
*repeats (int): Repetitions per dimension.
Returns:
x: Repeated tensor.
"""
if repeats == ():
return x
return tile(
expand_dims(x, axis=0, times=len(repeats)),
*repeats,
*((1,) * rank(x)),
)


@dispatch
def take(a: Numeric, indices_or_mask, axis: Int = 0):
"""Take particular elements along an axis.
Expand Down
11 changes: 5 additions & 6 deletions lab/tensorflow/random.py
@@ -1,12 +1,11 @@
import logging

import tensorflow as tf
from plum import Union

from . import dispatch, B, Numeric
from ..types import TFDType, TFNumeric, Int, TFRandomState
from ..util import compress_batch
from ..random import _randcat_last_first
from ..types import TFDType, TFNumeric, Int, TFRandomState
from ..util import compress_batch, broadcast_shapes

__all__ = []

Expand Down Expand Up @@ -123,13 +122,13 @@ def randgamma(
alpha: Numeric,
scale: Numeric,
):
return state, tf.random.stateless_gamma(
shape,
sample = tf.random.stateless_gamma(
shape + broadcast_shapes(B.shape(alpha), B.shape(scale)),
alpha=alpha,
beta=B.divide(1, scale),
seed=state.make_seeds()[:, 0],
dtype=dtype,
)
return state, sample * B.to_active_device(B.cast(dtype, scale))


@dispatch
Expand Down
7 changes: 4 additions & 3 deletions lab/torch/random.py
@@ -1,10 +1,9 @@
import torch
from plum import Union

from . import dispatch, B, Numeric
from ..random import _randcat_last_first
from ..types import TorchNumeric, TorchDType, Int, TorchRandomState
from ..util import compress_batch
from ..random import _randcat_last_first

__all__ = []

Expand Down Expand Up @@ -137,7 +136,9 @@ def randgamma(
):
alpha = B.to_active_device(B.cast(dtype, alpha))
scale = B.to_active_device(B.cast(dtype, scale))
alpha = B.broadcast_to(alpha, *shape)
alpha, scale = torch.broadcast_tensors(alpha, scale)
alpha = B.repeat(alpha, *shape)
scale = B.repeat(scale, *shape)
return state, torch._standard_gamma(alpha, generator=state) * scale


Expand Down
15 changes: 15 additions & 0 deletions lab/util.py
Expand Up @@ -13,6 +13,7 @@
"batch_computation",
"abstract",
"compress_batch",
"broadcast_shapes",
]

_dispatch = plum.Dispatcher()
Expand Down Expand Up @@ -241,3 +242,17 @@ def uncompress(y):
return B.reshape(y, *shape[:-n], *B.shape(y)[1:])

return B.reshape(x, -1, *shape[-n:]), uncompress


@_dispatch
def broadcast_shapes(*shapes):
"""Broadcast shapes.
Args:
*shapes (shape): Shapes to broadcast.
Return:
tuple[int]: Broadcasted shape.
"""
shapes = [tuple(int(d) for d in shape) for shape in shapes]
return np.broadcast_shapes(*shapes)
11 changes: 11 additions & 0 deletions tests/test_random.py
Expand Up @@ -144,6 +144,17 @@ def test_randgamma_parameters(t, check_lazy_shapes):
approx(B.randgamma(t, alpha=1, scale=0), 0, atol=1e-6)


@pytest.mark.parametrize("t", [np.float32, tf.float32, torch.float32, jnp.float32])
def test_randgamma_broadcasting(t, check_lazy_shapes):
assert B.shape(B.randgamma(t, alpha=1, scale=0)) == ()
assert B.shape(B.randgamma(t, alpha=B.rand(5), scale=0)) == (5,)
assert B.shape(B.randgamma(t, alpha=B.rand(5), scale=B.rand(5))) == (5,)
assert B.shape(B.randgamma(t, alpha=1, scale=B.rand(5))) == (5,)
assert B.shape(B.randgamma(t, 3, alpha=B.rand(5), scale=0)) == (3, 5)
assert B.shape(B.randgamma(t, 3, alpha=B.rand(5), scale=B.rand(5))) == (3, 5)
assert B.shape(B.randgamma(t, 3, alpha=1, scale=B.rand(5))) == (3, 5)


@pytest.mark.parametrize("t", [np.float32, tf.float32, torch.float32, jnp.float32])
def test_randbeta_parameters(t, check_lazy_shapes):
approx(B.randbeta(t, alpha=1e-6, beta=1), 0, atol=1e-6)
Expand Down

0 comments on commit d48181c

Please sign in to comment.