Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions tensorcircuit/backends/abstract_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from operator import mul
from typing import Any, Callable, List, Optional, Sequence, Tuple, Union

import math
import numpy as np
from ..utils import return_partial

Expand Down Expand Up @@ -405,6 +406,31 @@ def reshape2(self: Any, a: Tensor) -> Tensor:
a = self.reshape(a, [2 for _ in range(nleg)])
return a

def reshaped(self: Any, a: Tensor, d: int) -> Tensor:
"""
Reshape a tensor to the [d, d, ...] shape.

:param a: Input tensor
:type a: Tensor
:param d: edge length for each dimension
:type d: int
:return: the reshaped tensor
:rtype: Tensor
"""
if not isinstance(d, int) or d <= 0:
raise ValueError("d must be a positive integer.")

size = self.sizen(a)
if size == 0:
return self.reshape(a, (0,))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is reshape (0) mean? just return a?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typically, a tensor with size=0 can still have multiple shape possibilities, such as shape=(2, 0, 5), (0, ), (0, 3). These three tensors with different shapes all belong to the size=0 case. To ensure consistent output under this API and avoid potential shape errors, we use reshape(a, (0, )).


nleg_float = math.log(size, d)
nleg = int(round(nleg_float))
if d**nleg != size:
raise ValueError(f"cannot reshape: size {size} is not a power of d={d}")

return self.reshape(a, (d,) * nleg)

def reshapem(self: Any, a: Tensor) -> Tensor:
"""
Reshape a tensor to the [l, l] shape.
Expand Down Expand Up @@ -839,6 +865,54 @@ def mod(self: Any, x: Tensor, y: Tensor) -> Tensor:
"Backend '{}' has not implemented `mod`.".format(self.name)
)

def floor(self: Any, x: Tensor) -> Tensor:
"""
Compute the element-wise floor of the input tensor.

This operation returns a new tensor with the largest integers
less than or equal to each element of the input tensor,
i.e. it rounds each value down towards negative infinity.

:param x: Input tensor containing numeric values.
:type x: Tensor
:return: A tensor with the same shape as `x`, where each element
is the floored value of the corresponding element in `x`.
:rtype: Tensor

:raises NotImplementedError: If the backend does not provide an
implementation for `floor`.
"""
raise NotImplementedError(
"Backend '{}' has not implemented `floor`.".format(self.name)
)

def clip(self: Any, a: Tensor, a_min: Tensor, a_max: Tensor) -> Tensor:
"""
Clip (limit) the values of a tensor element-wise to the range [a_min, a_max].

Each element in the input tensor `a` is compared against the corresponding
bounds `a_min` and `a_max`. If a value in `a` is less than `a_min`, it is set
to `a_min`; if greater than `a_max`, it is set to `a_max`. Otherwise, the
value is left unchanged. The result preserves the dtype and device of the input.

:param a: Input tensor containing values to be clipped.
:type a: Tensor
:param a_min: Lower bound (minimum value) for clipping. Can be a scalar tensor
or broadcastable to the shape of `a`.
:type a_min: Tensor
:param a_max: Upper bound (maximum value) for clipping. Can be a scalar tensor
or broadcastable to the shape of `a`.
:type a_max: Tensor
:return: A tensor with the same shape as `a`, where all values are clipped
to lie within the interval [a_min, a_max].
:rtype: Tensor

:raises NotImplementedError: If the backend does not implement `clip`.
"""
raise NotImplementedError(
"Backend '{}' has not implemented `clip`.".format(self.name)
)

def reverse(self: Any, a: Tensor) -> Tensor:
"""
return ``a[::-1]``, only 1D tensor is guaranteed for consistent behavior
Expand Down
6 changes: 6 additions & 0 deletions tensorcircuit/backends/jax_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,12 @@ def arange(self, start: int, stop: Optional[int] = None, step: int = 1) -> Tenso
def mod(self, x: Tensor, y: Tensor) -> Tensor:
return jnp.mod(x, y)

def floor(self, a: Tensor) -> Tensor:
return jnp.floor(a)

def clip(self, a: Tensor, a_min: Tensor, a_max: Tensor) -> Tensor:
return jnp.clip(a, a_min, a_max)

def right_shift(self, x: Tensor, y: Tensor) -> Tensor:
return jnp.right_shift(x, y)

Expand Down
6 changes: 6 additions & 0 deletions tensorcircuit/backends/numpy_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,12 @@ def arange(self, start: int, stop: Optional[int] = None, step: int = 1) -> Tenso
def mod(self, x: Tensor, y: Tensor) -> Tensor:
return np.mod(x, y)

def floor(self, a: Tensor) -> Tensor:
return np.floor(a)

def clip(self, a: Tensor, a_min: Tensor, a_max: Tensor) -> Tensor:
return np.clip(a, a_min, a_max)

def right_shift(self, x: Tensor, y: Tensor) -> Tensor:
return np.right_shift(x, y)

Expand Down
6 changes: 6 additions & 0 deletions tensorcircuit/backends/pytorch_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,12 @@ def arange(self, start: int, stop: Optional[int] = None, step: int = 1) -> Tenso
def mod(self, x: Tensor, y: Tensor) -> Tensor:
return torchlib.fmod(x, y)

def floor(self, a: Tensor) -> Tensor:
return torchlib.floor(a)

def clip(self, a: Tensor, a_min: Tensor, a_max: Tensor) -> Tensor:
return torchlib.clamp(a, a_min, a_max)

def right_shift(self, x: Tensor, y: Tensor) -> Tensor:
return torchlib.bitwise_right_shift(x, y)

Expand Down
8 changes: 8 additions & 0 deletions tensorcircuit/backends/tensorflow_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,6 +573,14 @@ def unique_with_counts(self, a: Tensor, **kws: Any) -> Tuple[Tensor, Tensor]:
def stack(self, a: Sequence[Tensor], axis: int = 0) -> Tensor:
return tf.stack(a, axis=axis)

def clip(self, a: Tensor, a_min: Tensor, a_max: Tensor) -> Tensor:
return tf.clip_by_value(a, a_min, a_max)

def floor(self, a: Tensor) -> Tensor:
if a.dtype.is_integer:
return a
return tf.math.floor(a)

def concat(self, a: Sequence[Tensor], axis: int = 0) -> Tensor:
return tf.concat(a, axis=axis)

Expand Down
102 changes: 102 additions & 0 deletions tests/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,39 @@ def test_backend_methods_2(backend):
# assert tc.dtype(a) == "float32"


@pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")])
def test_backend_floor(backend):
"""Test floor method (element-wise, dtype/device preservation, integers unchanged)."""
a = tc.backend.convert_to_tensor([-1.7, -0.0, 0.0, 0.2, 3.9])
r = tc.backend.floor(a)
expected = tc.backend.convert_to_tensor([-2.0, -0.0, 0.0, 0.0, 3.0])
np.testing.assert_allclose(r, expected, atol=1e-6)
assert tc.backend.dtype(r) == tc.backend.dtype(a)
assert tc.backend.device(r) == tc.backend.device(a)
ai = tc.backend.convert_to_tensor([0, 1, -2, 3])
ri = tc.backend.floor(ai)
np.testing.assert_allclose(ri, ai)


@pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")])
def test_backend_clip(backend):
"""Test clip method (scalar/tensor bounds, broadcasting, dtype/device)."""
a = tc.backend.convert_to_tensor([-2.0, -0.5, 0.0, 0.5, 10.0])
a_min = tc.backend.convert_to_tensor(-1.0)
a_max = tc.backend.convert_to_tensor(1.0)
r = tc.backend.clip(a, a_min, a_max)
expected = tc.backend.convert_to_tensor([-1.0, -0.5, 0.0, 0.5, 1.0])
np.testing.assert_allclose(r, expected, atol=1e-6)
assert tc.backend.dtype(r) == tc.backend.dtype(a)
assert tc.backend.device(r) == tc.backend.device(a)
a2 = tc.backend.convert_to_tensor([[-5.0, 0.0, 5.0], [1.0, 2.0, 3.0]])
a2_min = tc.backend.convert_to_tensor([[-1.0, 0.0, 0.0], [0.0, 1.0, 2.0]])
a2_max = tc.backend.convert_to_tensor([[0.0, 0.0, 4.0], [1.0, 2.0, 2.0]])
r2 = tc.backend.clip(a2, a2_min, a2_max)
expected2 = tc.backend.convert_to_tensor([[-1.0, 0.0, 4.0], [1.0, 2.0, 2.0]])
np.testing.assert_allclose(r2, expected2, atol=1e-6)


@pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")])
def test_device_cpu_only(backend):
a = tc.backend.ones([])
Expand Down Expand Up @@ -464,6 +497,75 @@ def test_dlpack(backend):
np.testing.assert_allclose(a, a1, atol=1e-5)


@pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")])
def test_backend_reshaped_basic(backend):
a1 = tc.backend.convert_to_tensor(np.arange(27))
r1 = tc.backend.reshaped(a1, 3)
assert r1.shape == (3, 3, 3)
np.testing.assert_allclose(tc.backend.numpy(r1), np.arange(27).reshape(3, 3, 3))
d, n = 4, 3
dim = d**n
mat = np.arange(dim * dim, dtype=np.float32).reshape(dim, dim)
a2 = tc.backend.convert_to_tensor(mat)
r2 = tc.backend.reshaped(a2, d)
assert r2.shape == (d,) * (2 * n)
np.testing.assert_allclose(tc.backend.numpy(r2), mat.reshape((d,) * (2 * n)))


@pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")])
def test_backend_reshaped_zero_size(backend):
"""size == 0 returns a canonical empty vector shape (0,) regardless of input shape."""
a0 = tc.backend.convert_to_tensor(np.array([], dtype=np.float32))
r0 = tc.backend.reshaped(a0, 3)
assert r0.shape == (0,)
assert tc.backend.sizen(r0) == 0

a1 = tc.backend.convert_to_tensor(np.zeros((2, 0), dtype=np.float32))
r1 = tc.backend.reshaped(a1, 5)
assert r1.shape == (0,)
assert tc.backend.sizen(r1) == 0


@pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")])
def test_backend_reshaped_dtype_device_preserved(backend):
"""Reshape should not change dtype or device."""
a = tc.backend.ones([16], dtype="float32")
dev = tc.backend.device(a)
r = tc.backend.reshaped(a, 2)
assert r.shape == (2, 2, 2, 2)
assert tc.backend.dtype(r) == tc.backend.dtype(a)
assert tc.backend.device(r) == dev


@pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")])
def test_backend_reshaped_scalar_size_one(backend):
"""size == 1 stays scalar: nleg = 0 so shape () is kept."""
a = tc.backend.ones([]) # scalar tensor, total size = 1
r = tc.backend.reshaped(a, 2)
assert r.shape == ()
np.testing.assert_allclose(tc.backend.numpy(r), tc.backend.numpy(a))


@pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")])
def test_backend_reshaped_invalid_d_raises(backend):
"""d must be a positive integer: non-int or <=0 should raise."""
a = tc.backend.ones([4], dtype="float32")
with pytest.raises(ValueError):
tc.backend.reshaped(a, 0)
with pytest.raises(ValueError):
tc.backend.reshaped(a, -2)
with pytest.raises(ValueError):
tc.backend.reshaped(a, 2.5) # not an int


@pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")])
def test_backend_reshaped_non_power_raises(backend):
"""When size is not a power of d, raise ValueError."""
a = tc.backend.convert_to_tensor(np.arange(10))
with pytest.raises(ValueError):
tc.backend.reshaped(a, 3)


@pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")])
def test_arg_cmp(backend):
np.testing.assert_allclose(tc.backend.argmax(tc.backend.ones([3], "float64")), 0)
Expand Down