From 848bfde611182be49496a3e56772d18d7ed3299e Mon Sep 17 00:00:00 2001 From: Tanuj Khattar Date: Fri, 18 Mar 2022 06:04:13 +0530 Subject: [PATCH] Add default decomposition for cirq.QubitPermutationGate in terms of adjacent swaps (#5093) - Adds decomposition to `cirq.QubitPermutationGate` in terms of minimum number of adjacent swap operations on qubits. - Part of https://github.com/quantumlib/Cirq/issues/4858 Closes https://github.com/quantumlib/Cirq/issues/5090 --- cirq-core/cirq/ops/permutation_gate.py | 23 +++++++++++++++++++-- cirq-core/cirq/ops/permutation_gate_test.py | 11 ++++++++-- 2 files changed, 30 insertions(+), 4 deletions(-) diff --git a/cirq-core/cirq/ops/permutation_gate.py b/cirq-core/cirq/ops/permutation_gate.py index 188e82f4da2..919e84a3a62 100644 --- a/cirq-core/cirq/ops/permutation_gate.py +++ b/cirq-core/cirq/ops/permutation_gate.py @@ -12,11 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Sequence, Tuple, TYPE_CHECKING +from typing import Any, Dict, Iterable, Sequence, Tuple, TYPE_CHECKING from cirq import protocols, value from cirq._compat import deprecated -from cirq.ops import raw_types +from cirq.ops import raw_types, swap_gates if TYPE_CHECKING: import cirq @@ -74,6 +74,25 @@ def num_qubits(self): def _has_unitary_(self): return True + def _decompose_(self, qubits: Sequence['cirq.Qid']) -> 'cirq.OP_TREE': + n = len(qubits) + qubit_ids = [*range(n)] + is_sorted = False + + def _swap_if_out_of_order(idx: int) -> Iterable['cirq.Operation']: + nonlocal is_sorted + if self._permutation[qubit_ids[idx]] > self._permutation[qubit_ids[idx + 1]]: + yield swap_gates.SWAP(qubits[idx], qubits[idx + 1]) + qubit_ids[idx + 1], qubit_ids[idx] = qubit_ids[idx], qubit_ids[idx + 1] + is_sorted = False + + while not is_sorted: + is_sorted = True + for i in range(0, n - 1, 2): + yield from _swap_if_out_of_order(i) + for i in range(1, n - 1, 2): + yield from _swap_if_out_of_order(i) + def _apply_unitary_(self, args: 'cirq.ApplyUnitaryArgs'): # Compute the permutation index list. permuted_axes = list(range(len(args.target_tensor.shape))) diff --git a/cirq-core/cirq/ops/permutation_gate_test.py b/cirq-core/cirq/ops/permutation_gate_test.py index ed0b27c1d3b..7b74992f2d7 100644 --- a/cirq-core/cirq/ops/permutation_gate_test.py +++ b/cirq-core/cirq/ops/permutation_gate_test.py @@ -15,6 +15,7 @@ import pytest import cirq +import numpy as np from cirq.ops import QubitPermutationGate @@ -30,8 +31,12 @@ def test_permutation_gate_repr(): cirq.testing.assert_equivalent_repr(QubitPermutationGate([0, 1])) -def test_permutation_gate_consistent_protocols(): - gate = QubitPermutationGate([1, 0, 2, 3]) +rs = np.random.RandomState(seed=1234) + + +@pytest.mark.parametrize('permutation', [rs.permutation(i) for i in range(3, 7)]) +def test_permutation_gate_consistent_protocols(permutation): + gate = QubitPermutationGate(list(permutation)) cirq.testing.assert_implements_consistent_protocols(gate) @@ -98,6 +103,8 @@ def test_permutation_gate_maps(maps, permutation): permutationOp = cirq.QubitPermutationGate(permutation).on(*qs) circuit = cirq.Circuit(permutationOp) cirq.testing.assert_equivalent_computational_basis_map(maps, circuit) + circuit = cirq.Circuit(cirq.I.on_each(*qs), cirq.decompose(permutationOp)) + cirq.testing.assert_equivalent_computational_basis_map(maps, circuit) def test_setters_deprecated():