Skip to content

Commit

Permalink
Add _decompose_with_context_ protocol to enable passing qubit manag…
Browse files Browse the repository at this point in the history
…er within decompose (#6118)

* Add _decompose_with_context_ protocol to enable passing qubit manager within decompose

* Add more test cases and use typing_extensions for runtime_checkable

* Fix lint and coverage tests

* another attempt to fix coverage

* Fix mypy type check
  • Loading branch information
tanujkhattar committed Jun 5, 2023
1 parent 9177708 commit 20b3d93
Show file tree
Hide file tree
Showing 14 changed files with 375 additions and 45 deletions.
2 changes: 2 additions & 0 deletions cirq-core/cirq/__init__.py
Expand Up @@ -284,6 +284,7 @@
qft,
Qid,
QuantumFourierTransformGate,
QubitManager,
QubitOrder,
QubitOrderOrList,
QubitPermutationGate,
Expand Down Expand Up @@ -566,6 +567,7 @@
decompose,
decompose_once,
decompose_once_with_qubits,
DecompositionContext,
DEFAULT_RESOLVERS,
definitely_commutes,
equal_up_to_global_phase,
Expand Down
2 changes: 2 additions & 0 deletions cirq-core/cirq/ops/__init__.py
Expand Up @@ -120,6 +120,8 @@

from cirq.ops.controlled_operation import ControlledOperation

from cirq.ops.qubit_manager import BorrowableQubit, CleanQubit, QubitManager, SimpleQubitManager

from cirq.ops.qubit_order import QubitOrder

from cirq.ops.qubit_order_or_list import QubitOrderOrList
Expand Down
7 changes: 6 additions & 1 deletion cirq-core/cirq/ops/classically_controlled_operation.py
Expand Up @@ -105,7 +105,12 @@ def with_qubits(self, *new_qubits):
)

def _decompose_(self):
result = protocols.decompose_once(self._sub_operation, NotImplemented, flatten=False)
return self._decompose_with_context_()

def _decompose_with_context_(self, context: Optional['cirq.DecompositionContext'] = None):
result = protocols.decompose_once(
self._sub_operation, NotImplemented, flatten=False, context=context
)
if result is NotImplemented:
return NotImplemented

Expand Down
15 changes: 13 additions & 2 deletions cirq-core/cirq/ops/controlled_gate.py
Expand Up @@ -151,6 +151,11 @@ def _qid_shape_(self) -> Tuple[int, ...]:

def _decompose_(
self, qubits: Tuple['cirq.Qid', ...]
) -> Union[None, NotImplementedType, 'cirq.OP_TREE']:
return self._decompose_with_context_(qubits)

def _decompose_with_context_(
self, qubits: Tuple['cirq.Qid', ...], context: Optional['cirq.DecompositionContext'] = None
) -> Union[None, NotImplementedType, 'cirq.OP_TREE']:
if (
protocols.has_unitary(self.sub_gate)
Expand Down Expand Up @@ -192,15 +197,21 @@ def _decompose_(
)
)
if self != controlled_z:
return protocols.decompose_once_with_qubits(controlled_z, qubits, NotImplemented)
return protocols.decompose_once_with_qubits(
controlled_z, qubits, NotImplemented, context=context
)

if isinstance(self.sub_gate, matrix_gates.MatrixGate):
# Default decompositions of 2/3 qubit `cirq.MatrixGate` ignores global phase, which is
# local phase in the controlled variant and hence cannot be ignored.
return NotImplemented

result = protocols.decompose_once_with_qubits(
self.sub_gate, qubits[self.num_controls() :], NotImplemented, flatten=False
self.sub_gate,
qubits[self.num_controls() :],
NotImplemented,
flatten=False,
context=context,
)
if result is NotImplemented:
return NotImplemented
Expand Down
9 changes: 7 additions & 2 deletions cirq-core/cirq/ops/controlled_operation.py
Expand Up @@ -146,8 +146,11 @@ def with_qubits(self, *new_qubits):
)

def _decompose_(self):
return self._decompose_with_context_()

def _decompose_with_context_(self, context: Optional['cirq.DecompositionContext'] = None):
result = protocols.decompose_once_with_qubits(
self.gate, self.qubits, NotImplemented, flatten=False
self.gate, self.qubits, NotImplemented, flatten=False, context=context
)
if result is not NotImplemented:
return result
Expand All @@ -157,7 +160,9 @@ def _decompose_(self):
# local phase in the controlled variant and hence cannot be ignored.
return NotImplemented

result = protocols.decompose_once(self.sub_operation, NotImplemented, flatten=False)
result = protocols.decompose_once(
self.sub_operation, NotImplemented, flatten=False, context=context
)
if result is NotImplemented:
return NotImplemented

Expand Down
7 changes: 6 additions & 1 deletion cirq-core/cirq/ops/gate_operation.py
Expand Up @@ -160,8 +160,13 @@ def _num_qubits_(self):
return len(self._qubits)

def _decompose_(self) -> 'cirq.OP_TREE':
return self._decompose_with_context_()

def _decompose_with_context_(
self, context: Optional['cirq.DecompositionContext'] = None
) -> 'cirq.OP_TREE':
return protocols.decompose_once_with_qubits(
self.gate, self.qubits, NotImplemented, flatten=False
self.gate, self.qubits, NotImplemented, flatten=False, context=context
)

def _pauli_expansion_(self) -> value.LinearDict[str]:
Expand Down
91 changes: 91 additions & 0 deletions cirq-core/cirq/ops/qubit_manager.py
@@ -0,0 +1,91 @@
# Copyright 2023 The Cirq Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import abc
import dataclasses
from typing import Iterable, List, TYPE_CHECKING
from cirq.ops import raw_types

if TYPE_CHECKING:
import cirq


class QubitManager(metaclass=abc.ABCMeta):
@abc.abstractmethod
def qalloc(self, n: int, dim: int = 2) -> List['cirq.Qid']:
"""Allocate `n` clean qubits, i.e. qubits guaranteed to be in state |0>."""

@abc.abstractmethod
def qborrow(self, n: int, dim: int = 2) -> List['cirq.Qid']:
"""Allocate `n` dirty qubits, i.e. the returned qubits can be in any state."""

@abc.abstractmethod
def qfree(self, qubits: Iterable['cirq.Qid']) -> None:
"""Free pre-allocated clean or dirty qubits managed by this qubit manager."""


@dataclasses.dataclass(frozen=True)
class _BaseAncillaQid(raw_types.Qid):
id: int
dim: int = 2

def _comparison_key(self) -> int:
return self.id

@property
def dimension(self) -> int:
return self.dim

def __repr__(self) -> str:
dim_str = f', dim={self.dim}' if self.dim != 2 else ''
return f"cirq.ops.{type(self).__name__}({self.id}{dim_str})"


class CleanQubit(_BaseAncillaQid):
"""An internal qid type that represents a clean ancilla allocation."""

def __str__(self) -> str:
dim_str = f' (d={self.dimension})' if self.dim != 2 else ''
return f"_c({self.id}){dim_str}"


class BorrowableQubit(_BaseAncillaQid):
"""An internal qid type that represents a dirty ancilla allocation."""

def __str__(self) -> str:
dim_str = f' (d={self.dimension})' if self.dim != 2 else ''
return f"_b({self.id}){dim_str}"


class SimpleQubitManager(QubitManager):
"""Allocates a new `CleanQubit`/`BorrowableQubit` for every `qalloc`/`qborrow` request."""

def __init__(self):
self._clean_id = 0
self._borrow_id = 0

def qalloc(self, n: int, dim: int = 2) -> List['cirq.Qid']:
self._clean_id += n
return [CleanQubit(i, dim) for i in range(self._clean_id - n, self._clean_id)]

def qborrow(self, n: int, dim: int = 2) -> List['cirq.Qid']:
self._borrow_id = self._borrow_id + n
return [BorrowableQubit(i, dim) for i in range(self._borrow_id - n, self._borrow_id)]

def qfree(self, qubits: Iterable['cirq.Qid']) -> None:
for q in qubits:
good = isinstance(q, CleanQubit) and q.id < self._clean_id
good |= isinstance(q, BorrowableQubit) and q.id < self._borrow_id
if not good:
raise ValueError(f"{q} was not allocated by {self}")
66 changes: 66 additions & 0 deletions cirq-core/cirq/ops/qubit_manager_test.py
@@ -0,0 +1,66 @@
# Copyright 2023 The Cirq Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import cirq
from cirq.ops import qubit_manager as cqi
import pytest


def test_clean_qubits():
q = cqi.CleanQubit(1)
assert q.id == 1
assert q.dimension == 2
assert str(q) == '_c(1)'
assert repr(q) == 'cirq.ops.CleanQubit(1)'

q = cqi.CleanQubit(2, dim=3)
assert q.id == 2
assert q.dimension == 3
assert str(q) == '_c(2) (d=3)'
assert repr(q) == 'cirq.ops.CleanQubit(2, dim=3)'

assert cqi.CleanQubit(1) < cqi.CleanQubit(2)


def test_borrow_qubits():
q = cqi.BorrowableQubit(10)
assert q.id == 10
assert q.dimension == 2
assert str(q) == '_b(10)'
assert repr(q) == 'cirq.ops.BorrowableQubit(10)'

q = cqi.BorrowableQubit(20, dim=4)
assert q.id == 20
assert q.dimension == 4
assert str(q) == '_b(20) (d=4)'
assert repr(q) == 'cirq.ops.BorrowableQubit(20, dim=4)'

assert cqi.BorrowableQubit(1) < cqi.BorrowableQubit(2)


@pytest.mark.parametrize('_', range(2))
def test_simple_qubit_manager(_):
qm = cirq.ops.SimpleQubitManager()
assert qm.qalloc(1) == [cqi.CleanQubit(0)]
assert qm.qalloc(2) == [cqi.CleanQubit(1), cqi.CleanQubit(2)]
assert qm.qalloc(1, dim=3) == [cqi.CleanQubit(3, dim=3)]
assert qm.qborrow(1) == [cqi.BorrowableQubit(0)]
assert qm.qborrow(2) == [cqi.BorrowableQubit(1), cqi.BorrowableQubit(2)]
assert qm.qborrow(1, dim=3) == [cqi.BorrowableQubit(3, dim=3)]
qm.qfree([cqi.CleanQubit(i) for i in range(3)] + [cqi.CleanQubit(3, dim=3)])
qm.qfree([cqi.BorrowableQubit(i) for i in range(3)] + [cqi.BorrowableQubit(3, dim=3)])
with pytest.raises(ValueError, match="not allocated"):
qm.qfree([cqi.CleanQubit(10)])
with pytest.raises(ValueError, match="not allocated"):
qm.qfree([cqi.BorrowableQubit(10)])
18 changes: 16 additions & 2 deletions cirq-core/cirq/ops/raw_types.py
Expand Up @@ -830,7 +830,14 @@ def _json_dict_(self) -> Dict[str, Any]:
return protocols.obj_to_dict_helper(self, ['sub_operation', 'tags'])

def _decompose_(self) -> 'cirq.OP_TREE':
return protocols.decompose_once(self.sub_operation, default=None, flatten=False)
return self._decompose_with_context_()

def _decompose_with_context_(
self, context: Optional['cirq.DecompositionContext'] = None
) -> 'cirq.OP_TREE':
return protocols.decompose_once(
self.sub_operation, default=None, flatten=False, context=context
)

def _pauli_expansion_(self) -> value.LinearDict[str]:
return protocols.pauli_expansion(self.sub_operation)
Expand Down Expand Up @@ -979,7 +986,14 @@ def __pow__(self, power):
return NotImplemented

def _decompose_(self, qubits):
return protocols.inverse(protocols.decompose_once_with_qubits(self._original, qubits))
return self._decompose_with_context_(qubits)

def _decompose_with_context_(
self, qubits: Sequence['cirq.Qid'], context: Optional['cirq.DecompositionContext'] = None
) -> 'cirq.OP_TREE':
return protocols.inverse(
protocols.decompose_once_with_qubits(self._original, qubits, context=context)
)

def _has_unitary_(self):
from cirq import protocols, devices
Expand Down
3 changes: 3 additions & 0 deletions cirq-core/cirq/ops/raw_types_test.py
Expand Up @@ -201,6 +201,8 @@ def __repr__(self):
assert i**-1 == t
assert t**-1 == i
assert cirq.decompose(i) == [cirq.X(a), cirq.S(b) ** -1, cirq.Z(a)]
assert [*i._decompose_()] == [cirq.X(a), cirq.S(b) ** -1, cirq.Z(a)]
assert [*i.gate._decompose_([a, b])] == [cirq.X(a), cirq.S(b) ** -1, cirq.Z(a)]
cirq.testing.assert_allclose_up_to_global_phase(
cirq.unitary(i), cirq.unitary(t).conj().T, atol=1e-8
)
Expand Down Expand Up @@ -618,6 +620,7 @@ def test_tagged_operation_forwards_protocols():
np.testing.assert_equal(cirq.unitary(tagged_h), cirq.unitary(h))
assert cirq.has_unitary(tagged_h)
assert cirq.decompose(tagged_h) == cirq.decompose(h)
assert [*tagged_h._decompose_()] == cirq.decompose(h)
assert cirq.pauli_expansion(tagged_h) == cirq.pauli_expansion(h)
assert cirq.equal_up_to_global_phase(h, tagged_h)
assert np.isclose(cirq.kraus(h), cirq.kraus(tagged_h)).all()
Expand Down
1 change: 1 addition & 0 deletions cirq-core/cirq/protocols/__init__.py
Expand Up @@ -50,6 +50,7 @@
decompose,
decompose_once,
decompose_once_with_qubits,
DecompositionContext,
SupportsDecompose,
SupportsDecomposeWithQubits,
)
Expand Down

0 comments on commit 20b3d93

Please sign in to comment.