Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace isinstance(op, GateOperation) checks in cirq_google optimizers to support other operation types. #4459

Merged
merged 9 commits into from
Oct 12, 2021
2 changes: 1 addition & 1 deletion cirq-google/cirq_google/devices/xmon_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def validate_gate(self, gate: cirq.Gate):
raise ValueError(f'Unsupported gate type: {gate!r}')

def validate_operation(self, operation: cirq.Operation):
if not isinstance(operation, cirq.GateOperation):
if operation.gate is None:
raise ValueError(f'Unsupported operation: {operation!r}')

self.validate_gate(operation.gate)
Expand Down
45 changes: 32 additions & 13 deletions cirq-google/cirq_google/devices/xmon_device_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,20 +153,39 @@ def test_validate_operation_existing_qubits():
d.validate_operation(cirq.CZ(cirq.GridQubit(1, 0), cirq.GridQubit(1, 1)))


def test_validate_operation_supported_gate():
class MyGate(cirq.Gate):
def num_qubits(self):
return 1


q = cirq.GridQubit.rect(1, 3)
matrix_gate = cirq.MatrixGate(cirq.testing.random_unitary(2))


@pytest.mark.parametrize(
'op,is_valid',
[
(cirq.Z(cirq.GridQubit(0, 0)), True),
(cirq.Z(cirq.GridQubit(0, 0)).with_tags('test_tag'), True),
(
cirq.Z(cirq.GridQubit(0, 0)).with_tags('test_tag').controlled_by(cirq.GridQubit(0, 1)),
True,
),
(
cirq.Z(cirq.GridQubit(0, 0)).controlled_by(cirq.GridQubit(0, 1)).with_tags('test_tag'),
True,
),
(NotImplementedOperation(), False),
(MyGate()(cirq.GridQubit(0, 0)), False),
],
)
def test_validate_operation_supported_gate(op, is_valid):
d = square_device(3, 3)

class MyGate(cirq.Gate):
def num_qubits(self):
return 1

d.validate_operation(cirq.GateOperation(cirq.Z, [cirq.GridQubit(0, 0)]))

assert MyGate().num_qubits() == 1
with pytest.raises(ValueError):
d.validate_operation(cirq.GateOperation(MyGate(), [cirq.GridQubit(0, 0)]))
with pytest.raises(ValueError):
d.validate_operation(NotImplementedOperation())
if is_valid:
d.validate_operation(op)
else:
with pytest.raises(ValueError):
d.validate_operation(op)


def test_validate_circuit_repeat_measurement_keys():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def _convert_one(self, op: cirq.Operation) -> cirq.OP_TREE:
"""
if len(op.qubits) == 1:
return _phased_x_z_ops(cirq.unitary(op, None), op.qubits[0])
elif len(op.qubits) == 2 and isinstance(op, cirq.GateOperation):
elif len(op.qubits) == 2:
return known_two_q_operations_to_sycamore_operations(
op.qubits[0], op.qubits[1], op, self.tabulation
)
Expand All @@ -139,7 +139,7 @@ def on_stuck_raise(bad):
def optimization_at(
self, circuit: cirq.Circuit, index: int, op: cirq.Operation
) -> Optional[cirq.PointOptimizationSummary]:
if not isinstance(op, cirq.GateOperation):
if op.gate is None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[No change needed] #4511 will modify this to support CircuitOperations, whose gate is None. Are we still okay with using isinstance to identify CircuitOperations after this PR?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, we should not use isinstance to identify CircuitOperations, or else we will get into the same problem. For example, a controlled / tagged CircuitOperation would not get identified if use isinstance checks.

return None

gate = op.gate
Expand All @@ -151,7 +151,7 @@ def optimization_at(
next_index = circuit.next_moment_operating_on(op.qubits, index + 1)
if next_index is not None:
ops_in_front = list({circuit.operation_at(q, next_index) for q in op.qubits})
if len(ops_in_front) == 1 and isinstance(ops_in_front[0], cirq.GateOperation):
if len(ops_in_front) == 1 and ops_in_front[0] is not None:
gate2 = ops_in_front[0].gate
else:
next_index = 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -278,3 +278,23 @@ def test_sycamore_invalid_tabulation():
sycamore_tabulation = {}
with pytest.raises(ValueError):
cgoc.ConvertToSycamoreGates(sycamore_tabulation)


q = cirq.GridQubit.rect(1, 3)
matrix_gate = cirq.MatrixGate(cirq.testing.random_unitary(2))


@pytest.mark.parametrize(
'op, is_valid',
[
(cirq.CircuitOperation(cirq.FrozenCircuit(matrix_gate(q[0]))), False),
(matrix_gate(q[0]), True),
(matrix_gate(q[0]).with_tags('test_tags'), True),
(matrix_gate(q[0]).controlled_by(q[1]), True),
(matrix_gate(q[0]).controlled_by(q[1]).with_tags('test_tags'), True),
(matrix_gate(q[0]).with_tags('test_tags').controlled_by(q[1]), True),
],
)
def test_supported_operation(op, is_valid):
c = cirq.Circuit(op)
assert (cirq_google.ConvertToSycamoreGates().optimization_at(c, 0, op) is not None) == is_valid
5 changes: 4 additions & 1 deletion cirq-google/cirq_google/optimizers/convert_to_xmon_gates.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def _is_native_xmon_op(self, op: cirq.Operation) -> bool:
"""
from cirq_google.devices import XmonDevice

return isinstance(op, cirq.GateOperation) and XmonDevice.is_supported_gate(op.gate)
return op.gate is not None and XmonDevice.is_supported_gate(op.gate)

def convert(self, op: cirq.Operation) -> List[cirq.Operation]:
def on_stuck_raise(bad):
Expand All @@ -86,6 +86,9 @@ def on_stuck_raise(bad):
def optimization_at(
self, circuit: cirq.Circuit, index: int, op: cirq.Operation
) -> Optional[cirq.PointOptimizationSummary]:
if op.gate is None:
return None

converted = self.convert(op)
if len(converted) == 1 and converted[0] is op:
return None
Expand Down
23 changes: 21 additions & 2 deletions cirq-google/cirq_google/optimizers/convert_to_xmon_gates_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,27 @@ def test_avoids_infinite_cycle_when_matrix_available():
cirq.protocols.decompose(c)


q = cirq.GridQubit.rect(1, 3)
matrix_gate = cirq.MatrixGate(cirq.testing.random_unitary(2))


def test_bad_operation():
qubits = cirq.GridQubit.rect(1, 3)
c = cirq.Circuit(NonNativeGate().on(qubits[0]))
c = cirq.Circuit(NonNativeGate().on(q[0]))
with pytest.raises(TypeError):
cirq_google.ConvertToXmonGates().optimize_circuit(c)


@pytest.mark.parametrize(
'op, is_valid',
[
(cirq.CircuitOperation(cirq.FrozenCircuit(matrix_gate(q[0]))), False),
(matrix_gate(q[0]), True),
(matrix_gate(q[0]).with_tags('test_tags'), True),
(matrix_gate(q[0]).controlled_by(q[1]), True),
(matrix_gate(q[0]).controlled_by(q[1]).with_tags('test_tags'), True),
(matrix_gate(q[0]).with_tags('test_tags').controlled_by(q[1]), True),
],
)
def test_supported_operation(op, is_valid):
c = cirq.Circuit(op)
assert (cirq_google.ConvertToXmonGates().optimization_at(c, 0, op) is not None) == is_valid