Skip to content

Commit

Permalink
[4173] Add atol/rtol for unitary checks in MatrixGate (#4220)
Browse files Browse the repository at this point in the history
Add rtol/atol for unitary checks in MatrixGate.

Fixes #4173.
  • Loading branch information
shivanth committed Jun 21, 2021
1 parent 4278277 commit 7d9e603
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 2 deletions.
11 changes: 9 additions & 2 deletions cirq-core/cirq/ops/matrix_gates.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,14 @@
class MatrixGate(raw_types.Gate):
"""A unitary qubit or qudit gate defined entirely by its matrix."""

def __init__(self, matrix: np.ndarray, *, qid_shape: Optional[Iterable[int]] = None) -> None:
def __init__(
self,
matrix: np.ndarray,
*,
qid_shape: Optional[Iterable[int]] = None,
unitary_check_rtol: float = 1e-5,
unitary_check_atol: float = 1e-8,
) -> None:
"""Initializes a matrix gate.
Args:
matrix: The matrix that defines the gate.
Expand Down Expand Up @@ -59,7 +66,7 @@ def __init__(self, matrix: np.ndarray, *, qid_shape: Optional[Iterable[int]] = N
f'qid_shape: {self._qid_shape}\n'
)

if not linalg.is_unitary(matrix):
if not linalg.is_unitary(matrix, rtol=unitary_check_rtol, atol=unitary_check_atol):
raise ValueError(f'Not a unitary matrix: {self._matrix}')

def _json_dict_(self) -> Dict[str, Any]:
Expand Down
23 changes: 23 additions & 0 deletions cirq-core/cirq/ops/matrix_gates_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,3 +299,26 @@ def test_protocols_and_repr():
cirq.testing.assert_implements_consistent_protocols(
cirq.MatrixGate(np.diag([1, 1j, -1]), qid_shape=(3,))
)


def test_matrixgate_unitary_tolerance():
## non-unitary matrix
with pytest.raises(ValueError):
_ = cirq.MatrixGate(np.array([[1, 0], [0, -0.6]]), unitary_check_atol=0.5)

# very high atol -> check converges quickly
_ = cirq.MatrixGate(np.array([[1, 0], [0, 1]]), unitary_check_atol=1)

# very high rtol -> check converges quickly
_ = cirq.MatrixGate(np.array([[1, 0], [0, -0.6]]), unitary_check_rtol=1)

## unitary matrix
_ = cirq.MatrixGate(np.array([[0.707, 0.707], [-0.707, 0.707]]), unitary_check_atol=0.5)

# very low atol -> the check never converges
with pytest.raises(ValueError):
_ = cirq.MatrixGate(np.array([[0.707, 0.707], [-0.707, 0.707]]), unitary_check_atol=1e-10)

# very low atol -> the check never converges
with pytest.raises(ValueError):
_ = cirq.MatrixGate(np.array([[0.707, 0.707], [-0.707, 0.707]]), unitary_check_rtol=1e-10)

0 comments on commit 7d9e603

Please sign in to comment.