diff --git a/README.md b/README.md index 841db9c..00cb5ad 100644 --- a/README.md +++ b/README.md @@ -325,6 +325,7 @@ leaky_relu(a, alpha) min(a, axis=None, squeeze=True) max(a, axis=None, squeeze=True) sum(a, axis=None, squeeze=True) +prod(a, axis=None, squeeze=True) mean(a, axis=None, squeeze=True) std(a, axis=None, squeeze=True) logsumexp(a, axis=None, squeeze=True) diff --git a/lab/autograd/generic.py b/lab/autograd/generic.py index 270bb48..b2143fb 100644 --- a/lab/autograd/generic.py +++ b/lab/autograd/generic.py @@ -192,6 +192,11 @@ def sum(a: Numeric, axis=None, squeeze=True): return anp.sum(a, axis=axis, keepdims=not squeeze) +@dispatch +def prod(a: Numeric, axis=None, squeeze=True): + return anp.prod(a, axis=axis, keepdims=not squeeze) + + @dispatch def mean(a: Numeric, axis=None, squeeze=True): return anp.mean(a, axis=axis, keepdims=not squeeze) diff --git a/lab/generic.py b/lab/generic.py index 2eeb81d..ef11e3f 100644 --- a/lab/generic.py +++ b/lab/generic.py @@ -74,6 +74,7 @@ "max", "argmax", "sum", + "prod", "nansum", "mean", "nanmean", @@ -971,6 +972,22 @@ def sum(a: Numeric, axis=None, squeeze=True): # pragma: no cover """ +@dispatch +@abstract() +def prod(a: Numeric, axis=None, squeeze=True): # pragma: no cover + """Product of all elements in a tensor, possibly along an axis. + + Args: + a (tensor): Tensor. + axis (int, optional): Optional axis. + squeeze (bool, optional): Squeeze the dimension after the reduction. Defaults + to `True`. + + Returns: + tensor: Reduced tensor. + """ + + @dispatch def nansum(x, **kw_args): """Like :func:`sum`, but ignores `NaN`s.""" diff --git a/lab/jax/generic.py b/lab/jax/generic.py index b08ffa6..499d73b 100644 --- a/lab/jax/generic.py +++ b/lab/jax/generic.py @@ -237,6 +237,11 @@ def sum(a: Numeric, axis=None, squeeze=True): return jnp.sum(a, axis=axis, keepdims=not squeeze) +@dispatch +def prod(a: Numeric, axis=None, squeeze=True): + return jnp.prod(a, axis=axis, keepdims=not squeeze) + + @dispatch def mean(a: Numeric, axis=None, squeeze=True): return jnp.mean(a, axis=axis, keepdims=not squeeze) diff --git a/lab/numpy/generic.py b/lab/numpy/generic.py index 9c5a3a2..7978709 100644 --- a/lab/numpy/generic.py +++ b/lab/numpy/generic.py @@ -216,6 +216,11 @@ def sum(a: Numeric, axis=None, squeeze=True): return np.sum(a, axis=axis, keepdims=not squeeze) +@dispatch +def prod(a: Numeric, axis=None, squeeze=True): + return np.prod(a, axis=axis, keepdims=not squeeze) + + @dispatch def mean(a: Numeric, axis=None, squeeze=True): return np.mean(a, axis=axis, keepdims=not squeeze) diff --git a/lab/tensorflow/generic.py b/lab/tensorflow/generic.py index 7a01540..078ae7f 100644 --- a/lab/tensorflow/generic.py +++ b/lab/tensorflow/generic.py @@ -231,6 +231,11 @@ def sum(a: Numeric, axis=None, squeeze=True): return tf.reduce_sum(a, axis=axis, keepdims=not squeeze) +@dispatch +def prod(a: Numeric, axis=None, squeeze=True): + return tf.reduce_prod(a, axis=axis, keepdims=not squeeze) + + @dispatch def mean(a: Numeric, axis=None, squeeze=True): return tf.reduce_mean(a, axis=axis, keepdims=not squeeze) diff --git a/lab/torch/generic.py b/lab/torch/generic.py index 3b373de..19c7950 100644 --- a/lab/torch/generic.py +++ b/lab/torch/generic.py @@ -238,6 +238,14 @@ def sum(a: Numeric, axis=None, squeeze=True): return torch.sum(a, dim=axis, keepdim=not squeeze) +@dispatch +def prod(a: Numeric, axis=None, squeeze=True): + if axis is None: + return torch.prod(a) + else: + return torch.prod(a, dim=axis, keepdim=not squeeze) + + @dispatch def mean(a: Numeric, axis=None, squeeze=True): if axis is None: diff --git a/tests/test_generic.py b/tests/test_generic.py index b457ad0..d745142 100644 --- a/tests/test_generic.py +++ b/tests/test_generic.py @@ -356,6 +356,7 @@ def test_binary_positive_first(f, check_lazy_shapes): (B.min, True), (B.max, True), (B.sum, True), + (B.prod, True), (B.nansum, True), (B.mean, True), (B.nanmean, True),