Skip to content

Commit

Permalink
Replace isinstance(op, GateOperation) checks in cirq_google optimiz…
Browse files Browse the repository at this point in the history
…ers to support other operation types. (#4459)

Proliferation of `isinstance(op, GateOperation)` checks results in many inconsistencies due to different available operation types like `ControlledOperations` and `TaggedOperations`. This PR fixes #4152 and is a first step towards fixing #3556 

Note that `TaggedOperations` which were earlier ignored by the optimizers would now be considered, and hence this is potentially a breaking change if people were implicitly relying on TaggedOperations not getting compiled by the optimizers. Since the optimizer doesn't document / test this behavior, I consider it to be a bug rather than a feature and an explicit `NoCompile` tag should be implemented as part of #4253 


This PR is blocked on submitting #4167 (tests will stop failing once the PR is submitted and this rebased). 

Update: This is now ready for review.
  • Loading branch information
tanujkhattar committed Oct 12, 2021
1 parent bd2e63c commit 9f3034c
Show file tree
Hide file tree
Showing 6 changed files with 81 additions and 20 deletions.
2 changes: 1 addition & 1 deletion cirq-google/cirq_google/devices/xmon_device.py
Expand Up @@ -98,7 +98,7 @@ def validate_gate(self, gate: cirq.Gate):
raise ValueError(f'Unsupported gate type: {gate!r}')

def validate_operation(self, operation: cirq.Operation):
if not isinstance(operation, cirq.GateOperation):
if operation.gate is None:
raise ValueError(f'Unsupported operation: {operation!r}')

self.validate_gate(operation.gate)
Expand Down
45 changes: 32 additions & 13 deletions cirq-google/cirq_google/devices/xmon_device_test.py
Expand Up @@ -153,20 +153,39 @@ def test_validate_operation_existing_qubits():
d.validate_operation(cirq.CZ(cirq.GridQubit(1, 0), cirq.GridQubit(1, 1)))


def test_validate_operation_supported_gate():
class MyGate(cirq.Gate):
def num_qubits(self):
return 1


q = cirq.GridQubit.rect(1, 3)
matrix_gate = cirq.MatrixGate(cirq.testing.random_unitary(2))


@pytest.mark.parametrize(
'op,is_valid',
[
(cirq.Z(cirq.GridQubit(0, 0)), True),
(cirq.Z(cirq.GridQubit(0, 0)).with_tags('test_tag'), True),
(
cirq.Z(cirq.GridQubit(0, 0)).with_tags('test_tag').controlled_by(cirq.GridQubit(0, 1)),
True,
),
(
cirq.Z(cirq.GridQubit(0, 0)).controlled_by(cirq.GridQubit(0, 1)).with_tags('test_tag'),
True,
),
(NotImplementedOperation(), False),
(MyGate()(cirq.GridQubit(0, 0)), False),
],
)
def test_validate_operation_supported_gate(op, is_valid):
d = square_device(3, 3)

class MyGate(cirq.Gate):
def num_qubits(self):
return 1

d.validate_operation(cirq.GateOperation(cirq.Z, [cirq.GridQubit(0, 0)]))

assert MyGate().num_qubits() == 1
with pytest.raises(ValueError):
d.validate_operation(cirq.GateOperation(MyGate(), [cirq.GridQubit(0, 0)]))
with pytest.raises(ValueError):
d.validate_operation(NotImplementedOperation())
if is_valid:
d.validate_operation(op)
else:
with pytest.raises(ValueError):
d.validate_operation(op)


def test_validate_circuit_repeat_measurement_keys():
Expand Down
Expand Up @@ -113,7 +113,7 @@ def _convert_one(self, op: cirq.Operation) -> cirq.OP_TREE:
"""
if len(op.qubits) == 1:
return _phased_x_z_ops(cirq.unitary(op, None), op.qubits[0])
elif len(op.qubits) == 2 and isinstance(op, cirq.GateOperation):
elif len(op.qubits) == 2:
return known_two_q_operations_to_sycamore_operations(
op.qubits[0], op.qubits[1], op, self.tabulation
)
Expand All @@ -139,7 +139,7 @@ def on_stuck_raise(bad):
def optimization_at(
self, circuit: cirq.Circuit, index: int, op: cirq.Operation
) -> Optional[cirq.PointOptimizationSummary]:
if not isinstance(op, cirq.GateOperation):
if op.gate is None:
return None

gate = op.gate
Expand All @@ -151,7 +151,7 @@ def optimization_at(
next_index = circuit.next_moment_operating_on(op.qubits, index + 1)
if next_index is not None:
ops_in_front = list({circuit.operation_at(q, next_index) for q in op.qubits})
if len(ops_in_front) == 1 and isinstance(ops_in_front[0], cirq.GateOperation):
if len(ops_in_front) == 1 and ops_in_front[0] is not None:
gate2 = ops_in_front[0].gate
else:
next_index = 0
Expand Down
Expand Up @@ -278,3 +278,23 @@ def test_sycamore_invalid_tabulation():
sycamore_tabulation = {}
with pytest.raises(ValueError):
cgoc.ConvertToSycamoreGates(sycamore_tabulation)


q = cirq.GridQubit.rect(1, 3)
matrix_gate = cirq.MatrixGate(cirq.testing.random_unitary(2))


@pytest.mark.parametrize(
'op, is_valid',
[
(cirq.CircuitOperation(cirq.FrozenCircuit(matrix_gate(q[0]))), False),
(matrix_gate(q[0]), True),
(matrix_gate(q[0]).with_tags('test_tags'), True),
(matrix_gate(q[0]).controlled_by(q[1]), True),
(matrix_gate(q[0]).controlled_by(q[1]).with_tags('test_tags'), True),
(matrix_gate(q[0]).with_tags('test_tags').controlled_by(q[1]), True),
],
)
def test_supported_operation(op, is_valid):
c = cirq.Circuit(op)
assert (cirq_google.ConvertToSycamoreGates().optimization_at(c, 0, op) is not None) == is_valid
5 changes: 4 additions & 1 deletion cirq-google/cirq_google/optimizers/convert_to_xmon_gates.py
Expand Up @@ -65,7 +65,7 @@ def _is_native_xmon_op(self, op: cirq.Operation) -> bool:
"""
from cirq_google.devices import XmonDevice

return isinstance(op, cirq.GateOperation) and XmonDevice.is_supported_gate(op.gate)
return op.gate is not None and XmonDevice.is_supported_gate(op.gate)

def convert(self, op: cirq.Operation) -> List[cirq.Operation]:
def on_stuck_raise(bad):
Expand All @@ -86,6 +86,9 @@ def on_stuck_raise(bad):
def optimization_at(
self, circuit: cirq.Circuit, index: int, op: cirq.Operation
) -> Optional[cirq.PointOptimizationSummary]:
if op.gate is None:
return None

converted = self.convert(op)
if len(converted) == 1 and converted[0] is op:
return None
Expand Down
23 changes: 21 additions & 2 deletions cirq-google/cirq_google/optimizers/convert_to_xmon_gates_test.py
Expand Up @@ -45,8 +45,27 @@ def test_avoids_infinite_cycle_when_matrix_available():
cirq.protocols.decompose(c)


q = cirq.GridQubit.rect(1, 3)
matrix_gate = cirq.MatrixGate(cirq.testing.random_unitary(2))


def test_bad_operation():
qubits = cirq.GridQubit.rect(1, 3)
c = cirq.Circuit(NonNativeGate().on(qubits[0]))
c = cirq.Circuit(NonNativeGate().on(q[0]))
with pytest.raises(TypeError):
cirq_google.ConvertToXmonGates().optimize_circuit(c)


@pytest.mark.parametrize(
'op, is_valid',
[
(cirq.CircuitOperation(cirq.FrozenCircuit(matrix_gate(q[0]))), False),
(matrix_gate(q[0]), True),
(matrix_gate(q[0]).with_tags('test_tags'), True),
(matrix_gate(q[0]).controlled_by(q[1]), True),
(matrix_gate(q[0]).controlled_by(q[1]).with_tags('test_tags'), True),
(matrix_gate(q[0]).with_tags('test_tags').controlled_by(q[1]), True),
],
)
def test_supported_operation(op, is_valid):
c = cirq.Circuit(op)
assert (cirq_google.ConvertToXmonGates().optimization_at(c, 0, op) is not None) == is_valid

0 comments on commit 9f3034c

Please sign in to comment.