Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -170,10 +170,24 @@ def decompose_to_target_gateset(self, op: 'cirq.Operation', moment_idx: int) ->
old_2q_gate_count = sum(1 for o in ops.flatten_to_ops(old_optree) if len(o.qubits) == 2)
new_2q_gate_count = sum(1 for o in ops.flatten_to_ops(new_optree) if len(o.qubits) == 2)
switch_to_new = (
any(op not in self for op in ops.flatten_to_ops(old_optree))
any(
protocols.num_qubits(op) == 2 and op not in self
for op in ops.flatten_to_ops(old_optree)
)
Comment on lines +173 to +176
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we do this part of the check before constructing new_optree so we don't waste time on it if it's not needed?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We will have to construct the new_optree irrespective of whether this check returns a True or False.

We will not switch to the new_optree only if both the checks are false (i.e. all 2q operations in old_optree are in self and old_2q_gate_count <= new_2q_gate_count.

or new_2q_gate_count < old_2q_gate_count
)
return new_optree if switch_to_new else old_optree
if switch_to_new:
return new_optree
mapped_old_optree: List['cirq.OP_TREE'] = []
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is different from the old behavior: if all old ops (1q or 2q) are in the target gateset and new_2q_gate_count >= old_2q_gate_count, this previously would have returned the old optree. Is this intended?

(If gateset behavior would prevent this case from occurring, a comment explaining this would be helpful)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this previously would have returned the old optree.

The behavior is unchanged, i.e., if all old ops (1q or 2q) are in the target gateset and new_2q_gate_count >= old_2q_gate_count; it would have previously returned old_optree and it would still return old_optree.

Notice that the mapped_old_optree will be same as old_optree if all operations are in self (the first if condition in the loop below).

for old_op in ops.flatten_to_ops(old_optree):
if old_op in self:
mapped_old_optree.append(old_op)
else:
decomposed_op = self._decompose_single_qubit_operation(old_op, moment_idx)
if decomposed_op is None or decomposed_op is NotImplemented:
return NotImplemented
mapped_old_optree.append(decomposed_op)
return mapped_old_optree

def _decompose_single_qubit_operation(
self, op: 'cirq.Operation', moment_idx: int
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,3 +199,23 @@ def test_two_qubit_compilation_merge_and_replace_inefficient_component():
m: ═════════════════════════════════════════════════════@═══^═══
''',
)


def test_two_qubit_compilation_replaces_only_if_2q_gate_count_is_less():
class DummyTargetGateset(cirq.TwoQubitCompilationTargetGateset):
def __init__(self):
super().__init__(cirq.X, cirq.CNOT)

def _decompose_two_qubit_operation(self, op: 'cirq.Operation', _) -> DecomposeResult:
q0, q1 = op.qubits
return [cirq.X.on_each(q0, q1), cirq.CNOT(q0, q1)] * 10

def _decompose_single_qubit_operation(self, op: 'cirq.Operation', _) -> DecomposeResult:
return cirq.X(*op.qubits) if op.gate == cirq.Y else NotImplemented

q = cirq.LineQubit.range(2)
ops = [cirq.Y.on_each(*q), cirq.CNOT(*q), cirq.Z.on_each(*q)]
c_orig = cirq.Circuit(ops)
c_expected = cirq.Circuit(cirq.X.on_each(*q), ops[-2:])
c_new = cirq.optimize_for_target_gateset(c_orig, gateset=DummyTargetGateset())
cirq.testing.assert_same_circuits(c_new, c_expected)