Skip to content

Commit

Permalink
Add default decomposition for cirq.QubitPermutationGate in terms of a…
Browse files Browse the repository at this point in the history
…djacent swaps (#5093)

- Adds decomposition to `cirq.QubitPermutationGate` in terms of minimum number of adjacent swap operations on qubits. 
- Part of #4858

Closes #5090
  • Loading branch information
tanujkhattar committed Mar 18, 2022
1 parent e0f7432 commit 848bfde
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 4 deletions.
23 changes: 21 additions & 2 deletions cirq-core/cirq/ops/permutation_gate.py
Expand Up @@ -12,11 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Dict, Sequence, Tuple, TYPE_CHECKING
from typing import Any, Dict, Iterable, Sequence, Tuple, TYPE_CHECKING

from cirq import protocols, value
from cirq._compat import deprecated
from cirq.ops import raw_types
from cirq.ops import raw_types, swap_gates

if TYPE_CHECKING:
import cirq
Expand Down Expand Up @@ -74,6 +74,25 @@ def num_qubits(self):
def _has_unitary_(self):
return True

def _decompose_(self, qubits: Sequence['cirq.Qid']) -> 'cirq.OP_TREE':
n = len(qubits)
qubit_ids = [*range(n)]
is_sorted = False

def _swap_if_out_of_order(idx: int) -> Iterable['cirq.Operation']:
nonlocal is_sorted
if self._permutation[qubit_ids[idx]] > self._permutation[qubit_ids[idx + 1]]:
yield swap_gates.SWAP(qubits[idx], qubits[idx + 1])
qubit_ids[idx + 1], qubit_ids[idx] = qubit_ids[idx], qubit_ids[idx + 1]
is_sorted = False

while not is_sorted:
is_sorted = True
for i in range(0, n - 1, 2):
yield from _swap_if_out_of_order(i)
for i in range(1, n - 1, 2):
yield from _swap_if_out_of_order(i)

def _apply_unitary_(self, args: 'cirq.ApplyUnitaryArgs'):
# Compute the permutation index list.
permuted_axes = list(range(len(args.target_tensor.shape)))
Expand Down
11 changes: 9 additions & 2 deletions cirq-core/cirq/ops/permutation_gate_test.py
Expand Up @@ -15,6 +15,7 @@
import pytest

import cirq
import numpy as np
from cirq.ops import QubitPermutationGate


Expand All @@ -30,8 +31,12 @@ def test_permutation_gate_repr():
cirq.testing.assert_equivalent_repr(QubitPermutationGate([0, 1]))


def test_permutation_gate_consistent_protocols():
gate = QubitPermutationGate([1, 0, 2, 3])
rs = np.random.RandomState(seed=1234)


@pytest.mark.parametrize('permutation', [rs.permutation(i) for i in range(3, 7)])
def test_permutation_gate_consistent_protocols(permutation):
gate = QubitPermutationGate(list(permutation))
cirq.testing.assert_implements_consistent_protocols(gate)


Expand Down Expand Up @@ -98,6 +103,8 @@ def test_permutation_gate_maps(maps, permutation):
permutationOp = cirq.QubitPermutationGate(permutation).on(*qs)
circuit = cirq.Circuit(permutationOp)
cirq.testing.assert_equivalent_computational_basis_map(maps, circuit)
circuit = cirq.Circuit(cirq.I.on_each(*qs), cirq.decompose(permutationOp))
cirq.testing.assert_equivalent_computational_basis_map(maps, circuit)


def test_setters_deprecated():
Expand Down

0 comments on commit 848bfde

Please sign in to comment.