From 74a9bae377ffd170975404bcfb27eb55f255ab2a Mon Sep 17 00:00:00 2001 From: Weiguo Ma Date: Wed, 27 Aug 2025 09:53:21 +0800 Subject: [PATCH 1/7] Add clip(), floor() for abstract_backend. Add reshaped() (corresponding to reshape2) for qudit systems. --- tensorcircuit/backends/abstract_backend.py | 74 ++++++++++++++++++++++ 1 file changed, 74 insertions(+) diff --git a/tensorcircuit/backends/abstract_backend.py b/tensorcircuit/backends/abstract_backend.py index 83720805..e7de5654 100644 --- a/tensorcircuit/backends/abstract_backend.py +++ b/tensorcircuit/backends/abstract_backend.py @@ -9,6 +9,7 @@ from operator import mul from typing import Any, Callable, List, Optional, Sequence, Tuple, Union +import math import numpy as np from ..utils import return_partial @@ -405,6 +406,31 @@ def reshape2(self: Any, a: Tensor) -> Tensor: a = self.reshape(a, [2 for _ in range(nleg)]) return a + def reshaped(self: Any, a: Tensor, d: int) -> Tensor: + """ + Reshape a tensor to the [d, d, ...] shape. + + :param a: Input tensor + :type a: Tensor + :param d: edge length for each dimension + :type d: int + :return: the reshaped tensor + :rtype: Tensor + """ + if not isinstance(d, int) or d <= 0: + raise ValueError("d must be a positive integer.") + + size = self.sizen(a) + if size == 0: + return self.reshape(a, (0,)) + + nleg_float = math.log(size, d) + nleg = int(round(nleg_float)) + if d**nleg != size: + raise ValueError(f"cannot reshape: size {size} is not a power of d={d}") + + return self.reshape(a, (d,) * nleg) + def reshapem(self: Any, a: Tensor) -> Tensor: """ Reshape a tensor to the [l, l] shape. @@ -839,6 +865,54 @@ def mod(self: Any, x: Tensor, y: Tensor) -> Tensor: "Backend '{}' has not implemented `mod`.".format(self.name) ) + def floor(self: Any, x: Tensor) -> Tensor: + """ + Compute the element-wise floor of the input tensor. + + This operation returns a new tensor with the largest integers + less than or equal to each element of the input tensor, + i.e. it rounds each value down towards negative infinity. + + :param x: Input tensor containing numeric values. + :type x: Tensor + :return: A tensor with the same shape as `x`, where each element + is the floored value of the corresponding element in `x`. + :rtype: Tensor + + :raises NotImplementedError: If the backend does not provide an + implementation for `floor`. + """ + raise NotImplementedError( + "Backend '{}' has not implemented `floor`.".format(self.name) + ) + + def clip(self: Any, a: Tensor, a_min: Tensor, a_max: Tensor) -> Tensor: + """ + Clip (limit) the values of a tensor element-wise to the range [a_min, a_max]. + + Each element in the input tensor `a` is compared against the corresponding + bounds `a_min` and `a_max`. If a value in `a` is less than `a_min`, it is set + to `a_min`; if greater than `a_max`, it is set to `a_max`. Otherwise, the + value is left unchanged. The result preserves the dtype and device of the input. + + :param a: Input tensor containing values to be clipped. + :type a: Tensor + :param a_min: Lower bound (minimum value) for clipping. Can be a scalar tensor + or broadcastable to the shape of `a`. + :type a_min: Tensor + :param a_max: Upper bound (maximum value) for clipping. Can be a scalar tensor + or broadcastable to the shape of `a`. + :type a_max: Tensor + :return: A tensor with the same shape as `a`, where all values are clipped + to lie within the interval [a_min, a_max]. + :rtype: Tensor + + :raises NotImplementedError: If the backend does not implement `clip`. + """ + raise NotImplementedError( + "Backend '{}' has not implemented `clip`.".format(self.name) + ) + def reverse(self: Any, a: Tensor) -> Tensor: """ return ``a[::-1]``, only 1D tensor is guaranteed for consistent behavior From 70b37717ff42f493c22925773df1f2a45acf0fcd Mon Sep 17 00:00:00 2001 From: Weiguo Ma Date: Wed, 27 Aug 2025 09:54:16 +0800 Subject: [PATCH 2/7] Add clip(), floor() for backends. --- tensorcircuit/backends/jax_backend.py | 6 ++++++ tensorcircuit/backends/numpy_backend.py | 6 ++++++ tensorcircuit/backends/pytorch_backend.py | 6 ++++++ 3 files changed, 18 insertions(+) diff --git a/tensorcircuit/backends/jax_backend.py b/tensorcircuit/backends/jax_backend.py index d678b42d..44c85866 100644 --- a/tensorcircuit/backends/jax_backend.py +++ b/tensorcircuit/backends/jax_backend.py @@ -349,6 +349,12 @@ def arange(self, start: int, stop: Optional[int] = None, step: int = 1) -> Tenso def mod(self, x: Tensor, y: Tensor) -> Tensor: return jnp.mod(x, y) + def floor(self, x: Tensor) -> Tensor: + return jnp.floor(x) + + def clip(self, x: Tensor, lower: Tensor, upper: Tensor) -> Tensor: + return jnp.clip(x, lower, upper) + def right_shift(self, x: Tensor, y: Tensor) -> Tensor: return jnp.right_shift(x, y) diff --git a/tensorcircuit/backends/numpy_backend.py b/tensorcircuit/backends/numpy_backend.py index 8678f3dc..ffb81c92 100644 --- a/tensorcircuit/backends/numpy_backend.py +++ b/tensorcircuit/backends/numpy_backend.py @@ -250,6 +250,12 @@ def arange(self, start: int, stop: Optional[int] = None, step: int = 1) -> Tenso def mod(self, x: Tensor, y: Tensor) -> Tensor: return np.mod(x, y) + def floor(self, x: Tensor) -> Tensor: + return np.floor(x) + + def clip(self, x: Tensor, a_min: Tensor, a_max: Tensor) -> Tensor: + return np.clip(x, a_min, a_max) + def right_shift(self, x: Tensor, y: Tensor) -> Tensor: return np.right_shift(x, y) diff --git a/tensorcircuit/backends/pytorch_backend.py b/tensorcircuit/backends/pytorch_backend.py index cd037b6d..3274235d 100644 --- a/tensorcircuit/backends/pytorch_backend.py +++ b/tensorcircuit/backends/pytorch_backend.py @@ -429,6 +429,12 @@ def arange(self, start: int, stop: Optional[int] = None, step: int = 1) -> Tenso def mod(self, x: Tensor, y: Tensor) -> Tensor: return torchlib.fmod(x, y) + def floor(self, x: Tensor) -> Tensor: + return torchlib.floor(x) + + def clip(self, x: Tensor, a_min: Tensor, a_max: Tensor) -> Tensor: + return torchlib.clamp(x, a_min, a_max) + def right_shift(self, x: Tensor, y: Tensor) -> Tensor: return torchlib.bitwise_right_shift(x, y) From ab2b64143730accebcc0b0dba09cae87451aa7ab Mon Sep 17 00:00:00 2001 From: Weiguo Ma Date: Wed, 27 Aug 2025 09:55:02 +0800 Subject: [PATCH 3/7] Add clip(), floor() for backends. TensorFlow does not support to apply math.floor() to int directly. --- tensorcircuit/backends/tensorflow_backend.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tensorcircuit/backends/tensorflow_backend.py b/tensorcircuit/backends/tensorflow_backend.py index f7c7ae5c..43ea25aa 100644 --- a/tensorcircuit/backends/tensorflow_backend.py +++ b/tensorcircuit/backends/tensorflow_backend.py @@ -573,6 +573,22 @@ def unique_with_counts(self, a: Tensor, **kws: Any) -> Tuple[Tensor, Tensor]: def stack(self, a: Sequence[Tensor], axis: int = 0) -> Tensor: return tf.stack(a, axis=axis) + def clip(self, a: Tensor, a_min: Tensor, a_max: Tensor) -> Tensor: + return tf.clip_by_value(a, a_min, a_max) + + def floor(self, x: Tensor) -> Tensor: + dtype_str = x.dtype.name if hasattr(x.dtype, "name") else str(x.dtype) + if x.dtype.is_floating: + return tf.math.floor(x) + elif x.dtype.is_integer: + return x + elif x.dtype.is_complex: + raise TypeError( + f"tf.math.floor does not support complex dtype ({dtype_str})" + ) + else: + raise TypeError(f"Unsupported dtype for floor: {dtype_str}") + def concat(self, a: Sequence[Tensor], axis: int = 0) -> Tensor: return tf.concat(a, axis=axis) From ba4e1af768e3a76c8a4f42619790bd899f018b91 Mon Sep 17 00:00:00 2001 From: Weiguo Ma Date: Wed, 27 Aug 2025 09:55:44 +0800 Subject: [PATCH 4/7] Add tests for clip(), floor() functions in all backends. --- tests/test_backends.py | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/tests/test_backends.py b/tests/test_backends.py index b589163e..3361b4c1 100644 --- a/tests/test_backends.py +++ b/tests/test_backends.py @@ -435,6 +435,39 @@ def test_backend_methods_2(backend): # assert tc.dtype(a) == "float32" +@pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")]) +def test_backend_floor(backend): + """Test floor method (element-wise, dtype/device preservation, integers unchanged).""" + a = tc.backend.convert_to_tensor([-1.7, -0.0, 0.0, 0.2, 3.9]) + r = tc.backend.floor(a) + expected = tc.backend.convert_to_tensor([-2.0, -0.0, 0.0, 0.0, 3.0]) + np.testing.assert_allclose(r, expected, atol=1e-6) + assert tc.backend.dtype(r) == tc.backend.dtype(a) + assert tc.backend.device(r) == tc.backend.device(a) + ai = tc.backend.convert_to_tensor([0, 1, -2, 3]) + ri = tc.backend.floor(ai) + np.testing.assert_allclose(ri, ai) + + +@pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")]) +def test_backend_clip(backend): + """Test clip method (scalar/tensor bounds, broadcasting, dtype/device).""" + a = tc.backend.convert_to_tensor([-2.0, -0.5, 0.0, 0.5, 10.0]) + a_min = tc.backend.convert_to_tensor(-1.0) + a_max = tc.backend.convert_to_tensor(1.0) + r = tc.backend.clip(a, a_min, a_max) + expected = tc.backend.convert_to_tensor([-1.0, -0.5, 0.0, 0.5, 1.0]) + np.testing.assert_allclose(r, expected, atol=1e-6) + assert tc.backend.dtype(r) == tc.backend.dtype(a) + assert tc.backend.device(r) == tc.backend.device(a) + a2 = tc.backend.convert_to_tensor([[-5.0, 0.0, 5.0], [1.0, 2.0, 3.0]]) + a2_min = tc.backend.convert_to_tensor([[-1.0, 0.0, 0.0], [0.0, 1.0, 2.0]]) + a2_max = tc.backend.convert_to_tensor([[0.0, 0.0, 4.0], [1.0, 2.0, 2.0]]) + r2 = tc.backend.clip(a2, a2_min, a2_max) + expected2 = tc.backend.convert_to_tensor([[-1.0, 0.0, 4.0], [1.0, 2.0, 2.0]]) + np.testing.assert_allclose(r2, expected2, atol=1e-6) + + @pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")]) def test_device_cpu_only(backend): a = tc.backend.ones([]) From d1a360fc11bd4d2a0fb8d41dcce2060c0f769cb3 Mon Sep 17 00:00:00 2001 From: Weiguo Ma Date: Wed, 27 Aug 2025 11:25:03 +0800 Subject: [PATCH 5/7] According to the maintainer`s relevant suggestions, the bugs have been fixed. --- tensorcircuit/backends/jax_backend.py | 4 ++-- tensorcircuit/backends/numpy_backend.py | 4 ++-- tensorcircuit/backends/pytorch_backend.py | 4 ++-- tensorcircuit/backends/tensorflow_backend.py | 12 ++---------- 4 files changed, 8 insertions(+), 16 deletions(-) diff --git a/tensorcircuit/backends/jax_backend.py b/tensorcircuit/backends/jax_backend.py index 44c85866..f1589776 100644 --- a/tensorcircuit/backends/jax_backend.py +++ b/tensorcircuit/backends/jax_backend.py @@ -352,8 +352,8 @@ def mod(self, x: Tensor, y: Tensor) -> Tensor: def floor(self, x: Tensor) -> Tensor: return jnp.floor(x) - def clip(self, x: Tensor, lower: Tensor, upper: Tensor) -> Tensor: - return jnp.clip(x, lower, upper) + def clip(self, x: Tensor, a_min: Tensor, a_max: Tensor) -> Tensor: + return jnp.clip(x, a_min, a_max) def right_shift(self, x: Tensor, y: Tensor) -> Tensor: return jnp.right_shift(x, y) diff --git a/tensorcircuit/backends/numpy_backend.py b/tensorcircuit/backends/numpy_backend.py index ffb81c92..7cca2680 100644 --- a/tensorcircuit/backends/numpy_backend.py +++ b/tensorcircuit/backends/numpy_backend.py @@ -253,8 +253,8 @@ def mod(self, x: Tensor, y: Tensor) -> Tensor: def floor(self, x: Tensor) -> Tensor: return np.floor(x) - def clip(self, x: Tensor, a_min: Tensor, a_max: Tensor) -> Tensor: - return np.clip(x, a_min, a_max) + def clip(self, a: Tensor, a_min: Tensor, a_max: Tensor) -> Tensor: + return np.clip(a, a_min, a_max) def right_shift(self, x: Tensor, y: Tensor) -> Tensor: return np.right_shift(x, y) diff --git a/tensorcircuit/backends/pytorch_backend.py b/tensorcircuit/backends/pytorch_backend.py index 3274235d..20757416 100644 --- a/tensorcircuit/backends/pytorch_backend.py +++ b/tensorcircuit/backends/pytorch_backend.py @@ -432,8 +432,8 @@ def mod(self, x: Tensor, y: Tensor) -> Tensor: def floor(self, x: Tensor) -> Tensor: return torchlib.floor(x) - def clip(self, x: Tensor, a_min: Tensor, a_max: Tensor) -> Tensor: - return torchlib.clamp(x, a_min, a_max) + def clip(self, a: Tensor, a_min: Tensor, a_max: Tensor) -> Tensor: + return torchlib.clamp(a, a_min, a_max) def right_shift(self, x: Tensor, y: Tensor) -> Tensor: return torchlib.bitwise_right_shift(x, y) diff --git a/tensorcircuit/backends/tensorflow_backend.py b/tensorcircuit/backends/tensorflow_backend.py index 43ea25aa..7feb707f 100644 --- a/tensorcircuit/backends/tensorflow_backend.py +++ b/tensorcircuit/backends/tensorflow_backend.py @@ -577,17 +577,9 @@ def clip(self, a: Tensor, a_min: Tensor, a_max: Tensor) -> Tensor: return tf.clip_by_value(a, a_min, a_max) def floor(self, x: Tensor) -> Tensor: - dtype_str = x.dtype.name if hasattr(x.dtype, "name") else str(x.dtype) - if x.dtype.is_floating: - return tf.math.floor(x) - elif x.dtype.is_integer: + if x.dtype.is_integer: return x - elif x.dtype.is_complex: - raise TypeError( - f"tf.math.floor does not support complex dtype ({dtype_str})" - ) - else: - raise TypeError(f"Unsupported dtype for floor: {dtype_str}") + return tf.math.floor(x) def concat(self, a: Sequence[Tensor], axis: int = 0) -> Tensor: return tf.concat(a, axis=axis) From 738b2144e3bc375daf5b8ac8a09f203ea54d9b12 Mon Sep 17 00:00:00 2001 From: Weiguo Ma Date: Wed, 27 Aug 2025 11:47:30 +0800 Subject: [PATCH 6/7] x -> a --- tensorcircuit/backends/jax_backend.py | 8 ++++---- tensorcircuit/backends/numpy_backend.py | 4 ++-- tensorcircuit/backends/pytorch_backend.py | 4 ++-- tensorcircuit/backends/tensorflow_backend.py | 8 ++++---- 4 files changed, 12 insertions(+), 12 deletions(-) diff --git a/tensorcircuit/backends/jax_backend.py b/tensorcircuit/backends/jax_backend.py index f1589776..880ad4e8 100644 --- a/tensorcircuit/backends/jax_backend.py +++ b/tensorcircuit/backends/jax_backend.py @@ -349,11 +349,11 @@ def arange(self, start: int, stop: Optional[int] = None, step: int = 1) -> Tenso def mod(self, x: Tensor, y: Tensor) -> Tensor: return jnp.mod(x, y) - def floor(self, x: Tensor) -> Tensor: - return jnp.floor(x) + def floor(self, a: Tensor) -> Tensor: + return jnp.floor(a) - def clip(self, x: Tensor, a_min: Tensor, a_max: Tensor) -> Tensor: - return jnp.clip(x, a_min, a_max) + def clip(self, a: Tensor, a_min: Tensor, a_max: Tensor) -> Tensor: + return jnp.clip(a, a_min, a_max) def right_shift(self, x: Tensor, y: Tensor) -> Tensor: return jnp.right_shift(x, y) diff --git a/tensorcircuit/backends/numpy_backend.py b/tensorcircuit/backends/numpy_backend.py index 7cca2680..02669ca7 100644 --- a/tensorcircuit/backends/numpy_backend.py +++ b/tensorcircuit/backends/numpy_backend.py @@ -250,8 +250,8 @@ def arange(self, start: int, stop: Optional[int] = None, step: int = 1) -> Tenso def mod(self, x: Tensor, y: Tensor) -> Tensor: return np.mod(x, y) - def floor(self, x: Tensor) -> Tensor: - return np.floor(x) + def floor(self, a: Tensor) -> Tensor: + return np.floor(a) def clip(self, a: Tensor, a_min: Tensor, a_max: Tensor) -> Tensor: return np.clip(a, a_min, a_max) diff --git a/tensorcircuit/backends/pytorch_backend.py b/tensorcircuit/backends/pytorch_backend.py index 20757416..dba8e1a0 100644 --- a/tensorcircuit/backends/pytorch_backend.py +++ b/tensorcircuit/backends/pytorch_backend.py @@ -429,8 +429,8 @@ def arange(self, start: int, stop: Optional[int] = None, step: int = 1) -> Tenso def mod(self, x: Tensor, y: Tensor) -> Tensor: return torchlib.fmod(x, y) - def floor(self, x: Tensor) -> Tensor: - return torchlib.floor(x) + def floor(self, a: Tensor) -> Tensor: + return torchlib.floor(a) def clip(self, a: Tensor, a_min: Tensor, a_max: Tensor) -> Tensor: return torchlib.clamp(a, a_min, a_max) diff --git a/tensorcircuit/backends/tensorflow_backend.py b/tensorcircuit/backends/tensorflow_backend.py index 7feb707f..82a3e3d6 100644 --- a/tensorcircuit/backends/tensorflow_backend.py +++ b/tensorcircuit/backends/tensorflow_backend.py @@ -576,10 +576,10 @@ def stack(self, a: Sequence[Tensor], axis: int = 0) -> Tensor: def clip(self, a: Tensor, a_min: Tensor, a_max: Tensor) -> Tensor: return tf.clip_by_value(a, a_min, a_max) - def floor(self, x: Tensor) -> Tensor: - if x.dtype.is_integer: - return x - return tf.math.floor(x) + def floor(self, a: Tensor) -> Tensor: + if a.dtype.is_integer: + return a + return tf.math.floor(a) def concat(self, a: Sequence[Tensor], axis: int = 0) -> Tensor: return tf.concat(a, axis=axis) From 60eff6079543a37aac3b6c71c400b654969400f9 Mon Sep 17 00:00:00 2001 From: Weiguo Ma Date: Thu, 28 Aug 2025 08:51:34 +0800 Subject: [PATCH 7/7] Add tests for reshaped() in abstractbackend.py. --- tests/test_backends.py | 69 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 69 insertions(+) diff --git a/tests/test_backends.py b/tests/test_backends.py index 3361b4c1..01422429 100644 --- a/tests/test_backends.py +++ b/tests/test_backends.py @@ -497,6 +497,75 @@ def test_dlpack(backend): np.testing.assert_allclose(a, a1, atol=1e-5) +@pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")]) +def test_backend_reshaped_basic(backend): + a1 = tc.backend.convert_to_tensor(np.arange(27)) + r1 = tc.backend.reshaped(a1, 3) + assert r1.shape == (3, 3, 3) + np.testing.assert_allclose(tc.backend.numpy(r1), np.arange(27).reshape(3, 3, 3)) + d, n = 4, 3 + dim = d**n + mat = np.arange(dim * dim, dtype=np.float32).reshape(dim, dim) + a2 = tc.backend.convert_to_tensor(mat) + r2 = tc.backend.reshaped(a2, d) + assert r2.shape == (d,) * (2 * n) + np.testing.assert_allclose(tc.backend.numpy(r2), mat.reshape((d,) * (2 * n))) + + +@pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")]) +def test_backend_reshaped_zero_size(backend): + """size == 0 returns a canonical empty vector shape (0,) regardless of input shape.""" + a0 = tc.backend.convert_to_tensor(np.array([], dtype=np.float32)) + r0 = tc.backend.reshaped(a0, 3) + assert r0.shape == (0,) + assert tc.backend.sizen(r0) == 0 + + a1 = tc.backend.convert_to_tensor(np.zeros((2, 0), dtype=np.float32)) + r1 = tc.backend.reshaped(a1, 5) + assert r1.shape == (0,) + assert tc.backend.sizen(r1) == 0 + + +@pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")]) +def test_backend_reshaped_dtype_device_preserved(backend): + """Reshape should not change dtype or device.""" + a = tc.backend.ones([16], dtype="float32") + dev = tc.backend.device(a) + r = tc.backend.reshaped(a, 2) + assert r.shape == (2, 2, 2, 2) + assert tc.backend.dtype(r) == tc.backend.dtype(a) + assert tc.backend.device(r) == dev + + +@pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")]) +def test_backend_reshaped_scalar_size_one(backend): + """size == 1 stays scalar: nleg = 0 so shape () is kept.""" + a = tc.backend.ones([]) # scalar tensor, total size = 1 + r = tc.backend.reshaped(a, 2) + assert r.shape == () + np.testing.assert_allclose(tc.backend.numpy(r), tc.backend.numpy(a)) + + +@pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")]) +def test_backend_reshaped_invalid_d_raises(backend): + """d must be a positive integer: non-int or <=0 should raise.""" + a = tc.backend.ones([4], dtype="float32") + with pytest.raises(ValueError): + tc.backend.reshaped(a, 0) + with pytest.raises(ValueError): + tc.backend.reshaped(a, -2) + with pytest.raises(ValueError): + tc.backend.reshaped(a, 2.5) # not an int + + +@pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")]) +def test_backend_reshaped_non_power_raises(backend): + """When size is not a power of d, raise ValueError.""" + a = tc.backend.convert_to_tensor(np.arange(10)) + with pytest.raises(ValueError): + tc.backend.reshaped(a, 3) + + @pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")]) def test_arg_cmp(backend): np.testing.assert_allclose(tc.backend.argmax(tc.backend.ones([3], "float64")), 0)