Skip to content

Commit

Permalink
Add assert_decompose_ends_at_default_gateset to consistent protocol…
Browse files Browse the repository at this point in the history
…s test to ensure all cirq gates decompose to default gateset (#5107)
  • Loading branch information
tanujkhattar committed Mar 21, 2022
1 parent 6af5387 commit d3c4853
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 11 deletions.
6 changes: 4 additions & 2 deletions cirq-core/cirq/ops/dense_pauli_string_test.py
Expand Up @@ -376,9 +376,11 @@ def test_protocols():
cirq.testing.assert_implements_consistent_protocols(-cirq.DensePauliString('Z'))
cirq.testing.assert_implements_consistent_protocols(1j * cirq.DensePauliString('X'))
cirq.testing.assert_implements_consistent_protocols(2 * cirq.DensePauliString('X'))
cirq.testing.assert_implements_consistent_protocols(t * cirq.DensePauliString('XYIZ'))
cirq.testing.assert_implements_consistent_protocols(
cirq.DensePauliString('XYIZ', coefficient=t + 2)
t * cirq.DensePauliString('XYIZ'), ignore_decompose_to_default_gateset=True
)
cirq.testing.assert_implements_consistent_protocols(
cirq.DensePauliString('XYIZ', coefficient=t + 2), ignore_decompose_to_default_gateset=True
)
cirq.testing.assert_implements_consistent_protocols(-cirq.DensePauliString('XYIZ'))
cirq.testing.assert_implements_consistent_protocols(
Expand Down
12 changes: 8 additions & 4 deletions cirq-core/cirq/ops/random_gate_channel_test.py
Expand Up @@ -80,16 +80,20 @@ def test_eq():

def test_consistent_protocols():
cirq.testing.assert_implements_consistent_protocols(
cirq.RandomGateChannel(sub_gate=cirq.X, probability=1)
cirq.RandomGateChannel(sub_gate=cirq.X, probability=1),
ignore_decompose_to_default_gateset=True,
)
cirq.testing.assert_implements_consistent_protocols(
cirq.RandomGateChannel(sub_gate=cirq.X, probability=0)
cirq.RandomGateChannel(sub_gate=cirq.X, probability=0),
ignore_decompose_to_default_gateset=True,
)
cirq.testing.assert_implements_consistent_protocols(
cirq.RandomGateChannel(sub_gate=cirq.X, probability=sympy.Symbol('x') / 2)
cirq.RandomGateChannel(sub_gate=cirq.X, probability=sympy.Symbol('x') / 2),
ignore_decompose_to_default_gateset=True,
)
cirq.testing.assert_implements_consistent_protocols(
cirq.RandomGateChannel(sub_gate=cirq.X, probability=0.5)
cirq.RandomGateChannel(sub_gate=cirq.X, probability=0.5),
ignore_decompose_to_default_gateset=True,
)


Expand Down
2 changes: 2 additions & 0 deletions cirq-core/cirq/testing/consistent_decomposition.py
Expand Up @@ -53,6 +53,8 @@ def _known_gate_with_no_decomposition(val: Any):
"""Checks whether `val` is a known gate with no default decomposition to default gateset."""
if isinstance(val, ops.MatrixGate):
return protocols.qid_shape(val) not in [(2,), (2,) * 2, (2,) * 3]
if isinstance(val, ops.BaseDensePauliString) and not protocols.has_unitary(val):
return True
if isinstance(val, ops.ControlledGate):
if protocols.is_parameterized(val):
return True
Expand Down
9 changes: 9 additions & 0 deletions cirq-core/cirq/testing/consistent_protocols.py
Expand Up @@ -26,6 +26,7 @@
)
from cirq.testing.consistent_decomposition import (
assert_decompose_is_consistent_with_unitary,
assert_decompose_ends_at_default_gateset,
)
from cirq.testing.consistent_phase_by import (
assert_phase_by_is_consistent_with_unitary,
Expand Down Expand Up @@ -55,6 +56,7 @@ def assert_implements_consistent_protocols(
setup_code: str = 'import cirq\nimport numpy as np\nimport sympy',
global_vals: Optional[Dict[str, Any]] = None,
local_vals: Optional[Dict[str, Any]] = None,
ignore_decompose_to_default_gateset: bool = False,
) -> None:
"""Checks that a value is internally consistent and has a good __repr__."""
global_vals = global_vals or {}
Expand All @@ -66,6 +68,7 @@ def assert_implements_consistent_protocols(
setup_code=setup_code,
global_vals=global_vals,
local_vals=local_vals,
ignore_decompose_to_default_gateset=ignore_decompose_to_default_gateset,
)

for exponent in exponents:
Expand All @@ -77,6 +80,7 @@ def assert_implements_consistent_protocols(
setup_code=setup_code,
global_vals=global_vals,
local_vals=local_vals,
ignore_decompose_to_default_gateset=ignore_decompose_to_default_gateset,
)


Expand All @@ -90,6 +94,7 @@ def assert_eigengate_implements_consistent_protocols(
setup_code: str = 'import cirq\nimport numpy as np\nimport sympy',
global_vals: Optional[Dict[str, Any]] = None,
local_vals: Optional[Dict[str, Any]] = None,
ignore_decompose_to_default_gateset: bool = False,
) -> None:
"""Checks that an EigenGate subclass is internally consistent and has a
good __repr__."""
Expand All @@ -105,6 +110,7 @@ def assert_eigengate_implements_consistent_protocols(
setup_code=setup_code,
global_vals=global_vals,
local_vals=local_vals,
ignore_decompose_to_default_gateset=ignore_decompose_to_default_gateset,
)


Expand Down Expand Up @@ -143,6 +149,7 @@ def _assert_meets_standards_helper(
setup_code: str,
global_vals: Optional[Dict[str, Any]],
local_vals: Optional[Dict[str, Any]],
ignore_decompose_to_default_gateset: bool,
) -> None:
__tracebackhide__ = True # pylint: disable=unused-variable

Expand All @@ -154,6 +161,8 @@ def _assert_meets_standards_helper(
assert_qasm_is_consistent_with_unitary(val)
assert_has_consistent_trace_distance_bound(val)
assert_decompose_is_consistent_with_unitary(val, ignoring_global_phase=ignoring_global_phase)
if not ignore_decompose_to_default_gateset:
assert_decompose_ends_at_default_gateset(val)
assert_phase_by_is_consistent_with_unitary(val)
assert_pauli_expansion_is_consistent_with_unitary(val)
assert_equivalent_repr(
Expand Down
11 changes: 6 additions & 5 deletions cirq-core/cirq/testing/consistent_protocols_test.py
Expand Up @@ -65,9 +65,6 @@ def _decompose_(self, qubits: Sequence[cirq.Qid]) -> cirq.OP_TREE:
q = qubits[0]
z = cirq.Z(q) ** self.phase_exponent
x = cirq.X(q) ** self.exponent
if cirq.is_parameterized(z):
# coverage: ignore
return NotImplemented
return z ** -1, x, z

def _pauli_expansion_(self) -> cirq.LinearDict[str]:
Expand Down Expand Up @@ -260,12 +257,16 @@ def test_assert_implements_consistent_protocols():

def test_assert_eigengate_implements_consistent_protocols():
cirq.testing.assert_eigengate_implements_consistent_protocols(
GoodEigenGate, global_vals={'GoodEigenGate': GoodEigenGate}
GoodEigenGate,
global_vals={'GoodEigenGate': GoodEigenGate},
ignore_decompose_to_default_gateset=True,
)

with pytest.raises(AssertionError):
cirq.testing.assert_eigengate_implements_consistent_protocols(
BadEigenGate, global_vals={'BadEigenGate': BadEigenGate}
BadEigenGate,
global_vals={'BadEigenGate': BadEigenGate},
ignore_decompose_to_default_gateset=True,
)


Expand Down
Expand Up @@ -25,6 +25,7 @@ def test_consistent_protocols():
gate,
setup_code='import cirq\nimport numpy as np\nimport sympy\nimport cirq_google',
qubit_count=2,
ignore_decompose_to_default_gateset=True,
)
assert gate.num_qubits() == 2

Expand Down

0 comments on commit d3c4853

Please sign in to comment.