From 18252aab81cc72af0c0a623b4a48f60619244393 Mon Sep 17 00:00:00 2001 From: patel-zeel Date: Fri, 25 Jun 2021 12:41:07 +0530 Subject: [PATCH 1/2] Add prod --- README.md | 1 + lab/autograd/generic.py | 5 +++++ lab/generic.py | 16 ++++++++++++++++ lab/jax/generic.py | 5 +++++ lab/numpy/generic.py | 5 +++++ lab/tensorflow/generic.py | 5 +++++ lab/torch/generic.py | 8 ++++++++ tests/test_generic.py | 1 + 8 files changed, 46 insertions(+) diff --git a/README.md b/README.md index 841db9c..00cb5ad 100644 --- a/README.md +++ b/README.md @@ -325,6 +325,7 @@ leaky_relu(a, alpha) min(a, axis=None, squeeze=True) max(a, axis=None, squeeze=True) sum(a, axis=None, squeeze=True) +prod(a, axis=None, squeeze=True) mean(a, axis=None, squeeze=True) std(a, axis=None, squeeze=True) logsumexp(a, axis=None, squeeze=True) diff --git a/lab/autograd/generic.py b/lab/autograd/generic.py index 270bb48..b2143fb 100644 --- a/lab/autograd/generic.py +++ b/lab/autograd/generic.py @@ -192,6 +192,11 @@ def sum(a: Numeric, axis=None, squeeze=True): return anp.sum(a, axis=axis, keepdims=not squeeze) +@dispatch +def prod(a: Numeric, axis=None, squeeze=True): + return anp.prod(a, axis=axis, keepdims=not squeeze) + + @dispatch def mean(a: Numeric, axis=None, squeeze=True): return anp.mean(a, axis=axis, keepdims=not squeeze) diff --git a/lab/generic.py b/lab/generic.py index 2eeb81d..0763c28 100644 --- a/lab/generic.py +++ b/lab/generic.py @@ -74,6 +74,7 @@ "max", "argmax", "sum", + "prod", "nansum", "mean", "nanmean", @@ -971,6 +972,21 @@ def sum(a: Numeric, axis=None, squeeze=True): # pragma: no cover """ +@abstract() +def prod(a: Numeric, axis=None, squeeze=True): # pragma: no cover + """Product of all elements in a tensor, possibly along an axis. + + Args: + a (tensor): Tensor. + axis (int, optional): Optional axis. + squeeze (bool, optional): Squeeze the dimension after the reduction. Defaults + to `True`. + + Returns: + tensor: Reduced tensor. + """ + + @dispatch def nansum(x, **kw_args): """Like :func:`sum`, but ignores `NaN`s.""" diff --git a/lab/jax/generic.py b/lab/jax/generic.py index b08ffa6..499d73b 100644 --- a/lab/jax/generic.py +++ b/lab/jax/generic.py @@ -237,6 +237,11 @@ def sum(a: Numeric, axis=None, squeeze=True): return jnp.sum(a, axis=axis, keepdims=not squeeze) +@dispatch +def prod(a: Numeric, axis=None, squeeze=True): + return jnp.prod(a, axis=axis, keepdims=not squeeze) + + @dispatch def mean(a: Numeric, axis=None, squeeze=True): return jnp.mean(a, axis=axis, keepdims=not squeeze) diff --git a/lab/numpy/generic.py b/lab/numpy/generic.py index 9c5a3a2..7978709 100644 --- a/lab/numpy/generic.py +++ b/lab/numpy/generic.py @@ -216,6 +216,11 @@ def sum(a: Numeric, axis=None, squeeze=True): return np.sum(a, axis=axis, keepdims=not squeeze) +@dispatch +def prod(a: Numeric, axis=None, squeeze=True): + return np.prod(a, axis=axis, keepdims=not squeeze) + + @dispatch def mean(a: Numeric, axis=None, squeeze=True): return np.mean(a, axis=axis, keepdims=not squeeze) diff --git a/lab/tensorflow/generic.py b/lab/tensorflow/generic.py index 7a01540..078ae7f 100644 --- a/lab/tensorflow/generic.py +++ b/lab/tensorflow/generic.py @@ -231,6 +231,11 @@ def sum(a: Numeric, axis=None, squeeze=True): return tf.reduce_sum(a, axis=axis, keepdims=not squeeze) +@dispatch +def prod(a: Numeric, axis=None, squeeze=True): + return tf.reduce_prod(a, axis=axis, keepdims=not squeeze) + + @dispatch def mean(a: Numeric, axis=None, squeeze=True): return tf.reduce_mean(a, axis=axis, keepdims=not squeeze) diff --git a/lab/torch/generic.py b/lab/torch/generic.py index 3b373de..19c7950 100644 --- a/lab/torch/generic.py +++ b/lab/torch/generic.py @@ -238,6 +238,14 @@ def sum(a: Numeric, axis=None, squeeze=True): return torch.sum(a, dim=axis, keepdim=not squeeze) +@dispatch +def prod(a: Numeric, axis=None, squeeze=True): + if axis is None: + return torch.prod(a) + else: + return torch.prod(a, dim=axis, keepdim=not squeeze) + + @dispatch def mean(a: Numeric, axis=None, squeeze=True): if axis is None: diff --git a/tests/test_generic.py b/tests/test_generic.py index b457ad0..d745142 100644 --- a/tests/test_generic.py +++ b/tests/test_generic.py @@ -356,6 +356,7 @@ def test_binary_positive_first(f, check_lazy_shapes): (B.min, True), (B.max, True), (B.sum, True), + (B.prod, True), (B.nansum, True), (B.mean, True), (B.nanmean, True), From 1e8d01edbecf18338430d19311df8edacd71c98a Mon Sep 17 00:00:00 2001 From: patel-zeel Date: Fri, 25 Jun 2021 12:44:46 +0530 Subject: [PATCH 2/2] a minor fix --- lab/generic.py | 1 + 1 file changed, 1 insertion(+) diff --git a/lab/generic.py b/lab/generic.py index 0763c28..ef11e3f 100644 --- a/lab/generic.py +++ b/lab/generic.py @@ -972,6 +972,7 @@ def sum(a: Numeric, axis=None, squeeze=True): # pragma: no cover """ +@dispatch @abstract() def prod(a: Numeric, axis=None, squeeze=True): # pragma: no cover """Product of all elements in a tensor, possibly along an axis.