Skip to content

Commit

Permalink
Support adding numpy arrays to GridQubits (#3131)
Browse files Browse the repository at this point in the history
  • Loading branch information
dabacon committed Jul 10, 2020
1 parent 190e842 commit c869e0e
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 28 deletions.
19 changes: 14 additions & 5 deletions cirq/devices/grid_qubit.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

import abc

import numpy as np

from cirq import ops, protocols

if TYPE_CHECKING:
Expand Down Expand Up @@ -78,8 +80,8 @@ def __add__(self: TSelf, other: Tuple[int, int]) -> 'TSelf':
f"Got {self.dimension} and {other.dimension}")
return self._with_row_col(row=self.row + other.row,
col=self.col + other.col)
if not (isinstance(other, tuple) and len(other) == 2 and
all(isinstance(x, int) for x in other)):
if not (isinstance(other, (tuple, np.ndarray)) and len(other) == 2 and
all(isinstance(x, (int, np.integer)) for x in other)):
raise TypeError('Can only add integer tuples of length 2 to '
f'{type(self).__name__}. Instead was {other}')
return self._with_row_col(row=self.row + other[0],
Expand All @@ -93,8 +95,8 @@ def __sub__(self: TSelf, other: Tuple[int, int]) -> 'TSelf':
f"Got {self.dimension} and {other.dimension}")
return self._with_row_col(row=self.row - other.row,
col=self.col - other.col)
if not (isinstance(other, tuple) and len(other) == 2 and
all(isinstance(x, int) for x in other)):
if not (isinstance(other, (tuple, np.ndarray)) and len(other) == 2 and
all(isinstance(x, (int, np.integer)) for x in other)):
raise TypeError("Can only subtract integer tuples of length 2 to "
f"{type(self).__name__}. Instead was {other}")
return self._with_row_col(row=self.row - other[0],
Expand All @@ -118,13 +120,17 @@ class GridQid(_BaseGridQid):
GridQid(0, 0, dimension=2) < GridQid(0, 1, dimension=2)
< GridQid(1, 0, dimension=2) < GridQid(1, 1, dimension=2)
New GridQid can be constructed by adding or subtracting tuples
New GridQid can be constructed by adding or subtracting tuples or numpy
arrays
>>> cirq.GridQid(2, 3, dimension=2) + (3, 1)
cirq.GridQid(5, 4, dimension=2)
>>> cirq.GridQid(2, 3, dimension=2) - (1, 2)
cirq.GridQid(1, 1, dimension=2)
>>> cirq.GridQid(2, 3, dimension=2) + np.array([3, 1], dtype=int)
cirq.GridQid(5, 4, dimension=2)
"""

def __init__(self, row: int, col: int, *, dimension: int) -> None:
Expand Down Expand Up @@ -264,6 +270,9 @@ class GridQubit(_BaseGridQid):
>>> cirq.GridQubit(2, 3) - (1, 2)
cirq.GridQubit(1, 1)
>>> cirq.GridQubit(2, 3,) + np.array([3, 1], dtype=int)
cirq.GridQubit(5, 4)
"""

@property
Expand Down
76 changes: 53 additions & 23 deletions cirq/devices/grid_qubit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

import pytest

import numpy as np

import cirq


Expand Down Expand Up @@ -217,29 +219,23 @@ def test_addition_subtraction():
assert cirq.GridQubit(1, -2) + cirq.GridQubit(3, 5) == cirq.GridQubit(4, 3)

# GridQids
assert cirq.GridQid(1, 2, dimension=3) + (2, 5) == cirq.GridQid(3,
7,
dimension=3)
assert cirq.GridQid(1, 2, dimension=3) + (0, 0) == cirq.GridQid(1,
2,
dimension=3)
assert cirq.GridQid(1, 2, dimension=3) + (-1, 0) == cirq.GridQid(
0, 2, dimension=3)
assert cirq.GridQid(1, 2, dimension=3) - (2, 5) == cirq.GridQid(-1,
-3,
dimension=3)
assert cirq.GridQid(1, 2, dimension=3) - (0, 0) == cirq.GridQid(1,
2,
dimension=3)
assert cirq.GridQid(1, 2, dimension=3) - (-1, 0) == cirq.GridQid(
2, 2, dimension=3)

assert (2, 5) + cirq.GridQid(1, 2, dimension=3) == cirq.GridQid(3,
7,
dimension=3)
assert (2, 5) - cirq.GridQid(1, 2, dimension=3) == cirq.GridQid(1,
3,
dimension=3)
assert (cirq.GridQid(1, 2, dimension=3) + (2, 5) == cirq.GridQid(
3, 7, dimension=3))
assert (cirq.GridQid(1, 2, dimension=3) + (0, 0) == cirq.GridQid(
1, 2, dimension=3))
assert (cirq.GridQid(1, 2, dimension=3) + (-1, 0) == cirq.GridQid(
0, 2, dimension=3))
assert (cirq.GridQid(1, 2, dimension=3) - (2, 5) == cirq.GridQid(
-1, -3, dimension=3))
assert (cirq.GridQid(1, 2, dimension=3) - (0, 0) == cirq.GridQid(
1, 2, dimension=3))
assert (cirq.GridQid(1, 2, dimension=3) - (-1, 0) == cirq.GridQid(
2, 2, dimension=3))

assert ((2, 5) + cirq.GridQid(1, 2, dimension=3) == cirq.GridQid(
3, 7, dimension=3))
assert ((2, 5) - cirq.GridQid(1, 2, dimension=3) == cirq.GridQid(
1, 3, dimension=3))

assert cirq.GridQid(1, 2, dimension=3) + cirq.GridQid(
3, 5, dimension=3) == cirq.GridQid(4, 7, dimension=3)
Expand All @@ -249,6 +245,35 @@ def test_addition_subtraction():
3, 5, dimension=3) == cirq.GridQid(4, 3, dimension=3)


@pytest.mark.parametrize('dtype', (np.int8, np.int16, np.int32, np.int64, int))
def test_addition_subtraction_numpy_array(dtype):
assert cirq.GridQubit(1, 2) + np.array([1, 2],
dtype=dtype) == cirq.GridQubit(2, 4)
assert cirq.GridQubit(1, 2) + np.array([0, 0],
dtype=dtype) == cirq.GridQubit(1, 2)
assert (cirq.GridQubit(1, 2) +
np.array([-1, 0], dtype=dtype) == cirq.GridQubit(0, 2))
assert cirq.GridQubit(1, 2) - np.array([1, 2],
dtype=dtype) == cirq.GridQubit(0, 0)
assert cirq.GridQubit(1, 2) - np.array([0, 0],
dtype=dtype) == cirq.GridQubit(1, 2)
assert (cirq.GridQid(1, 2, dimension=3) -
np.array([-1, 0], dtype=dtype) == cirq.GridQid(2, 2, dimension=3))

assert cirq.GridQid(1, 2, dimension=3) + np.array(
[1, 2], dtype=dtype) == cirq.GridQid(2, 4, dimension=3)
assert cirq.GridQid(1, 2, dimension=3) + np.array(
[0, 0], dtype=dtype) == cirq.GridQid(1, 2, dimension=3)
assert (cirq.GridQid(1, 2, dimension=3) +
np.array([-1, 0], dtype=dtype) == cirq.GridQid(0, 2, dimension=3))
assert cirq.GridQid(1, 2, dimension=3) - np.array(
[1, 2], dtype=dtype) == cirq.GridQid(0, 0, dimension=3)
assert cirq.GridQid(1, 2, dimension=3) - np.array(
[0, 0], dtype=dtype) == cirq.GridQid(1, 2, dimension=3)
assert (cirq.GridQid(1, 2, dimension=3) -
np.array([-1, 0], dtype=dtype) == cirq.GridQid(2, 2, dimension=3))


def test_unsupported_add():
with pytest.raises(TypeError, match='1'):
_ = cirq.GridQubit(1, 1) + 1
Expand All @@ -262,6 +287,11 @@ def test_unsupported_add():
with pytest.raises(TypeError, match='1'):
_ = cirq.GridQubit(1, 1) - 1

with pytest.raises(TypeError, match='[1., 2.]'):
_ = cirq.GridQubit(1, 1) + np.array([1.0, 2.0])
with pytest.raises(TypeError, match='[1, 2, 3]'):
_ = cirq.GridQubit(1, 1) + np.array([1, 2, 3], dtype=int)


def test_addition_subtraction_type_error():
with pytest.raises(TypeError, match="bort"):
Expand Down

0 comments on commit c869e0e

Please sign in to comment.