From 2c7c746e8145c41091aa3a7fe8c76f7b7449f22a Mon Sep 17 00:00:00 2001 From: Xi Wang Date: Fri, 13 Mar 2020 02:14:28 +0800 Subject: [PATCH] fix np.clip scalar input case (#17788) --- python/mxnet/numpy/multiarray.py | 5 +++++ tests/python/unittest/test_numpy_op.py | 10 ++++++++++ 2 files changed, 15 insertions(+) diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index fceaaf3a282f..9a803d48b5b3 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -6174,6 +6174,11 @@ def clip(a, a_min, a_max, out=None): >>> np.clip(a, 3, 6, out=a) array([3., 3., 3., 3., 4., 5., 6., 6., 6., 6.], dtype=float32) """ + from numbers import Number + if isinstance(a, Number): + # In case input is a scalar, the computation would fall back to native numpy. + # The value returned would be a python scalar. + return _np.clip(a, a_min, a_max, out=None) return _mx_nd_np.clip(a, a_min, a_max, out=out) diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 8a7d8f030ac3..0f35dece5354 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -3644,6 +3644,16 @@ def __init__(self, a_min=None, a_max=None): def hybrid_forward(self, F, x): return x.clip(self._a_min, self._a_max) + + # Test scalar case + for _, a_min, a_max, throw_exception in workloads: + a = _np.random.uniform() # A scalar + if throw_exception: + # No need to test the exception case here. + continue + mx_ret = np.clip(a, a_min, a_max) + np_ret = _np.clip(a, a_min, a_max) + assert_almost_equal(mx_ret, np_ret, atol=1e-4, rtol=1e-3, use_broadcast=False) for shape, a_min, a_max, throw_exception in workloads: for dtype in dtypes: