From 20b3d93edf54ffa4935c7594828a8c2b10ce09f4 Mon Sep 17 00:00:00 2001 From: Tanuj Khattar Date: Mon, 5 Jun 2023 13:22:05 -0700 Subject: [PATCH] Add `_decompose_with_context_` protocol to enable passing qubit manager within decompose (#6118) * Add _decompose_with_context_ protocol to enable passing qubit manager within decompose * Add more test cases and use typing_extensions for runtime_checkable * Fix lint and coverage tests * another attempt to fix coverage * Fix mypy type check --- cirq-core/cirq/__init__.py | 2 + cirq-core/cirq/ops/__init__.py | 2 + .../ops/classically_controlled_operation.py | 7 +- cirq-core/cirq/ops/controlled_gate.py | 15 ++- cirq-core/cirq/ops/controlled_operation.py | 9 +- cirq-core/cirq/ops/gate_operation.py | 7 +- cirq-core/cirq/ops/qubit_manager.py | 91 +++++++++++++++ cirq-core/cirq/ops/qubit_manager_test.py | 66 +++++++++++ cirq-core/cirq/ops/raw_types.py | 18 ++- cirq-core/cirq/ops/raw_types_test.py | 3 + cirq-core/cirq/protocols/__init__.py | 1 + .../cirq/protocols/decompose_protocol.py | 105 +++++++++++++++--- .../cirq/protocols/decompose_protocol_test.py | 93 ++++++++++++---- .../cirq/protocols/json_test_data/spec.py | 1 + 14 files changed, 375 insertions(+), 45 deletions(-) create mode 100644 cirq-core/cirq/ops/qubit_manager.py create mode 100644 cirq-core/cirq/ops/qubit_manager_test.py diff --git a/cirq-core/cirq/__init__.py b/cirq-core/cirq/__init__.py index 73b21c1cd9d..685f4209e8a 100644 --- a/cirq-core/cirq/__init__.py +++ b/cirq-core/cirq/__init__.py @@ -284,6 +284,7 @@ qft, Qid, QuantumFourierTransformGate, + QubitManager, QubitOrder, QubitOrderOrList, QubitPermutationGate, @@ -566,6 +567,7 @@ decompose, decompose_once, decompose_once_with_qubits, + DecompositionContext, DEFAULT_RESOLVERS, definitely_commutes, equal_up_to_global_phase, diff --git a/cirq-core/cirq/ops/__init__.py b/cirq-core/cirq/ops/__init__.py index 516ebbbb505..ae352645ce2 100644 --- a/cirq-core/cirq/ops/__init__.py +++ b/cirq-core/cirq/ops/__init__.py @@ -120,6 +120,8 @@ from cirq.ops.controlled_operation import ControlledOperation +from cirq.ops.qubit_manager import BorrowableQubit, CleanQubit, QubitManager, SimpleQubitManager + from cirq.ops.qubit_order import QubitOrder from cirq.ops.qubit_order_or_list import QubitOrderOrList diff --git a/cirq-core/cirq/ops/classically_controlled_operation.py b/cirq-core/cirq/ops/classically_controlled_operation.py index 84be33d6bfe..63725fe6b23 100644 --- a/cirq-core/cirq/ops/classically_controlled_operation.py +++ b/cirq-core/cirq/ops/classically_controlled_operation.py @@ -105,7 +105,12 @@ def with_qubits(self, *new_qubits): ) def _decompose_(self): - result = protocols.decompose_once(self._sub_operation, NotImplemented, flatten=False) + return self._decompose_with_context_() + + def _decompose_with_context_(self, context: Optional['cirq.DecompositionContext'] = None): + result = protocols.decompose_once( + self._sub_operation, NotImplemented, flatten=False, context=context + ) if result is NotImplemented: return NotImplemented diff --git a/cirq-core/cirq/ops/controlled_gate.py b/cirq-core/cirq/ops/controlled_gate.py index 8edd8a88e69..641b52fb9ff 100644 --- a/cirq-core/cirq/ops/controlled_gate.py +++ b/cirq-core/cirq/ops/controlled_gate.py @@ -151,6 +151,11 @@ def _qid_shape_(self) -> Tuple[int, ...]: def _decompose_( self, qubits: Tuple['cirq.Qid', ...] + ) -> Union[None, NotImplementedType, 'cirq.OP_TREE']: + return self._decompose_with_context_(qubits) + + def _decompose_with_context_( + self, qubits: Tuple['cirq.Qid', ...], context: Optional['cirq.DecompositionContext'] = None ) -> Union[None, NotImplementedType, 'cirq.OP_TREE']: if ( protocols.has_unitary(self.sub_gate) @@ -192,7 +197,9 @@ def _decompose_( ) ) if self != controlled_z: - return protocols.decompose_once_with_qubits(controlled_z, qubits, NotImplemented) + return protocols.decompose_once_with_qubits( + controlled_z, qubits, NotImplemented, context=context + ) if isinstance(self.sub_gate, matrix_gates.MatrixGate): # Default decompositions of 2/3 qubit `cirq.MatrixGate` ignores global phase, which is @@ -200,7 +207,11 @@ def _decompose_( return NotImplemented result = protocols.decompose_once_with_qubits( - self.sub_gate, qubits[self.num_controls() :], NotImplemented, flatten=False + self.sub_gate, + qubits[self.num_controls() :], + NotImplemented, + flatten=False, + context=context, ) if result is NotImplemented: return NotImplemented diff --git a/cirq-core/cirq/ops/controlled_operation.py b/cirq-core/cirq/ops/controlled_operation.py index 4a7e35cc878..85c61cb298d 100644 --- a/cirq-core/cirq/ops/controlled_operation.py +++ b/cirq-core/cirq/ops/controlled_operation.py @@ -146,8 +146,11 @@ def with_qubits(self, *new_qubits): ) def _decompose_(self): + return self._decompose_with_context_() + + def _decompose_with_context_(self, context: Optional['cirq.DecompositionContext'] = None): result = protocols.decompose_once_with_qubits( - self.gate, self.qubits, NotImplemented, flatten=False + self.gate, self.qubits, NotImplemented, flatten=False, context=context ) if result is not NotImplemented: return result @@ -157,7 +160,9 @@ def _decompose_(self): # local phase in the controlled variant and hence cannot be ignored. return NotImplemented - result = protocols.decompose_once(self.sub_operation, NotImplemented, flatten=False) + result = protocols.decompose_once( + self.sub_operation, NotImplemented, flatten=False, context=context + ) if result is NotImplemented: return NotImplemented diff --git a/cirq-core/cirq/ops/gate_operation.py b/cirq-core/cirq/ops/gate_operation.py index 0052e9d35a3..abf238eb04e 100644 --- a/cirq-core/cirq/ops/gate_operation.py +++ b/cirq-core/cirq/ops/gate_operation.py @@ -160,8 +160,13 @@ def _num_qubits_(self): return len(self._qubits) def _decompose_(self) -> 'cirq.OP_TREE': + return self._decompose_with_context_() + + def _decompose_with_context_( + self, context: Optional['cirq.DecompositionContext'] = None + ) -> 'cirq.OP_TREE': return protocols.decompose_once_with_qubits( - self.gate, self.qubits, NotImplemented, flatten=False + self.gate, self.qubits, NotImplemented, flatten=False, context=context ) def _pauli_expansion_(self) -> value.LinearDict[str]: diff --git a/cirq-core/cirq/ops/qubit_manager.py b/cirq-core/cirq/ops/qubit_manager.py new file mode 100644 index 00000000000..b7624cf789c --- /dev/null +++ b/cirq-core/cirq/ops/qubit_manager.py @@ -0,0 +1,91 @@ +# Copyright 2023 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc +import dataclasses +from typing import Iterable, List, TYPE_CHECKING +from cirq.ops import raw_types + +if TYPE_CHECKING: + import cirq + + +class QubitManager(metaclass=abc.ABCMeta): + @abc.abstractmethod + def qalloc(self, n: int, dim: int = 2) -> List['cirq.Qid']: + """Allocate `n` clean qubits, i.e. qubits guaranteed to be in state |0>.""" + + @abc.abstractmethod + def qborrow(self, n: int, dim: int = 2) -> List['cirq.Qid']: + """Allocate `n` dirty qubits, i.e. the returned qubits can be in any state.""" + + @abc.abstractmethod + def qfree(self, qubits: Iterable['cirq.Qid']) -> None: + """Free pre-allocated clean or dirty qubits managed by this qubit manager.""" + + +@dataclasses.dataclass(frozen=True) +class _BaseAncillaQid(raw_types.Qid): + id: int + dim: int = 2 + + def _comparison_key(self) -> int: + return self.id + + @property + def dimension(self) -> int: + return self.dim + + def __repr__(self) -> str: + dim_str = f', dim={self.dim}' if self.dim != 2 else '' + return f"cirq.ops.{type(self).__name__}({self.id}{dim_str})" + + +class CleanQubit(_BaseAncillaQid): + """An internal qid type that represents a clean ancilla allocation.""" + + def __str__(self) -> str: + dim_str = f' (d={self.dimension})' if self.dim != 2 else '' + return f"_c({self.id}){dim_str}" + + +class BorrowableQubit(_BaseAncillaQid): + """An internal qid type that represents a dirty ancilla allocation.""" + + def __str__(self) -> str: + dim_str = f' (d={self.dimension})' if self.dim != 2 else '' + return f"_b({self.id}){dim_str}" + + +class SimpleQubitManager(QubitManager): + """Allocates a new `CleanQubit`/`BorrowableQubit` for every `qalloc`/`qborrow` request.""" + + def __init__(self): + self._clean_id = 0 + self._borrow_id = 0 + + def qalloc(self, n: int, dim: int = 2) -> List['cirq.Qid']: + self._clean_id += n + return [CleanQubit(i, dim) for i in range(self._clean_id - n, self._clean_id)] + + def qborrow(self, n: int, dim: int = 2) -> List['cirq.Qid']: + self._borrow_id = self._borrow_id + n + return [BorrowableQubit(i, dim) for i in range(self._borrow_id - n, self._borrow_id)] + + def qfree(self, qubits: Iterable['cirq.Qid']) -> None: + for q in qubits: + good = isinstance(q, CleanQubit) and q.id < self._clean_id + good |= isinstance(q, BorrowableQubit) and q.id < self._borrow_id + if not good: + raise ValueError(f"{q} was not allocated by {self}") diff --git a/cirq-core/cirq/ops/qubit_manager_test.py b/cirq-core/cirq/ops/qubit_manager_test.py new file mode 100644 index 00000000000..56c32292b35 --- /dev/null +++ b/cirq-core/cirq/ops/qubit_manager_test.py @@ -0,0 +1,66 @@ +# Copyright 2023 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import cirq +from cirq.ops import qubit_manager as cqi +import pytest + + +def test_clean_qubits(): + q = cqi.CleanQubit(1) + assert q.id == 1 + assert q.dimension == 2 + assert str(q) == '_c(1)' + assert repr(q) == 'cirq.ops.CleanQubit(1)' + + q = cqi.CleanQubit(2, dim=3) + assert q.id == 2 + assert q.dimension == 3 + assert str(q) == '_c(2) (d=3)' + assert repr(q) == 'cirq.ops.CleanQubit(2, dim=3)' + + assert cqi.CleanQubit(1) < cqi.CleanQubit(2) + + +def test_borrow_qubits(): + q = cqi.BorrowableQubit(10) + assert q.id == 10 + assert q.dimension == 2 + assert str(q) == '_b(10)' + assert repr(q) == 'cirq.ops.BorrowableQubit(10)' + + q = cqi.BorrowableQubit(20, dim=4) + assert q.id == 20 + assert q.dimension == 4 + assert str(q) == '_b(20) (d=4)' + assert repr(q) == 'cirq.ops.BorrowableQubit(20, dim=4)' + + assert cqi.BorrowableQubit(1) < cqi.BorrowableQubit(2) + + +@pytest.mark.parametrize('_', range(2)) +def test_simple_qubit_manager(_): + qm = cirq.ops.SimpleQubitManager() + assert qm.qalloc(1) == [cqi.CleanQubit(0)] + assert qm.qalloc(2) == [cqi.CleanQubit(1), cqi.CleanQubit(2)] + assert qm.qalloc(1, dim=3) == [cqi.CleanQubit(3, dim=3)] + assert qm.qborrow(1) == [cqi.BorrowableQubit(0)] + assert qm.qborrow(2) == [cqi.BorrowableQubit(1), cqi.BorrowableQubit(2)] + assert qm.qborrow(1, dim=3) == [cqi.BorrowableQubit(3, dim=3)] + qm.qfree([cqi.CleanQubit(i) for i in range(3)] + [cqi.CleanQubit(3, dim=3)]) + qm.qfree([cqi.BorrowableQubit(i) for i in range(3)] + [cqi.BorrowableQubit(3, dim=3)]) + with pytest.raises(ValueError, match="not allocated"): + qm.qfree([cqi.CleanQubit(10)]) + with pytest.raises(ValueError, match="not allocated"): + qm.qfree([cqi.BorrowableQubit(10)]) diff --git a/cirq-core/cirq/ops/raw_types.py b/cirq-core/cirq/ops/raw_types.py index d78756fe0f2..b38b0c6bd66 100644 --- a/cirq-core/cirq/ops/raw_types.py +++ b/cirq-core/cirq/ops/raw_types.py @@ -830,7 +830,14 @@ def _json_dict_(self) -> Dict[str, Any]: return protocols.obj_to_dict_helper(self, ['sub_operation', 'tags']) def _decompose_(self) -> 'cirq.OP_TREE': - return protocols.decompose_once(self.sub_operation, default=None, flatten=False) + return self._decompose_with_context_() + + def _decompose_with_context_( + self, context: Optional['cirq.DecompositionContext'] = None + ) -> 'cirq.OP_TREE': + return protocols.decompose_once( + self.sub_operation, default=None, flatten=False, context=context + ) def _pauli_expansion_(self) -> value.LinearDict[str]: return protocols.pauli_expansion(self.sub_operation) @@ -979,7 +986,14 @@ def __pow__(self, power): return NotImplemented def _decompose_(self, qubits): - return protocols.inverse(protocols.decompose_once_with_qubits(self._original, qubits)) + return self._decompose_with_context_(qubits) + + def _decompose_with_context_( + self, qubits: Sequence['cirq.Qid'], context: Optional['cirq.DecompositionContext'] = None + ) -> 'cirq.OP_TREE': + return protocols.inverse( + protocols.decompose_once_with_qubits(self._original, qubits, context=context) + ) def _has_unitary_(self): from cirq import protocols, devices diff --git a/cirq-core/cirq/ops/raw_types_test.py b/cirq-core/cirq/ops/raw_types_test.py index 81ab8f61935..844d44610c5 100644 --- a/cirq-core/cirq/ops/raw_types_test.py +++ b/cirq-core/cirq/ops/raw_types_test.py @@ -201,6 +201,8 @@ def __repr__(self): assert i**-1 == t assert t**-1 == i assert cirq.decompose(i) == [cirq.X(a), cirq.S(b) ** -1, cirq.Z(a)] + assert [*i._decompose_()] == [cirq.X(a), cirq.S(b) ** -1, cirq.Z(a)] + assert [*i.gate._decompose_([a, b])] == [cirq.X(a), cirq.S(b) ** -1, cirq.Z(a)] cirq.testing.assert_allclose_up_to_global_phase( cirq.unitary(i), cirq.unitary(t).conj().T, atol=1e-8 ) @@ -618,6 +620,7 @@ def test_tagged_operation_forwards_protocols(): np.testing.assert_equal(cirq.unitary(tagged_h), cirq.unitary(h)) assert cirq.has_unitary(tagged_h) assert cirq.decompose(tagged_h) == cirq.decompose(h) + assert [*tagged_h._decompose_()] == cirq.decompose(h) assert cirq.pauli_expansion(tagged_h) == cirq.pauli_expansion(h) assert cirq.equal_up_to_global_phase(h, tagged_h) assert np.isclose(cirq.kraus(h), cirq.kraus(tagged_h)).all() diff --git a/cirq-core/cirq/protocols/__init__.py b/cirq-core/cirq/protocols/__init__.py index 8e977a2432e..1ea0b16a126 100644 --- a/cirq-core/cirq/protocols/__init__.py +++ b/cirq-core/cirq/protocols/__init__.py @@ -50,6 +50,7 @@ decompose, decompose_once, decompose_once_with_qubits, + DecompositionContext, SupportsDecompose, SupportsDecomposeWithQubits, ) diff --git a/cirq-core/cirq/protocols/decompose_protocol.py b/cirq-core/cirq/protocols/decompose_protocol.py index 05cd8818540..2dd2c3cb935 100644 --- a/cirq-core/cirq/protocols/decompose_protocol.py +++ b/cirq-core/cirq/protocols/decompose_protocol.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. import dataclasses +import inspect +from collections import defaultdict from typing import ( TYPE_CHECKING, Any, @@ -27,7 +29,7 @@ TypeVar, Union, ) -from collections import defaultdict +from typing_extensions import runtime_checkable from typing_extensions import Protocol @@ -46,7 +48,17 @@ RaiseTypeErrorIfNotProvided: Any = ([],) DecomposeResult = Union[None, NotImplementedType, 'cirq.OP_TREE'] -OpDecomposer = Callable[['cirq.Operation'], DecomposeResult] + + +@runtime_checkable +class OpDecomposerWithContext(Protocol): + def __call__( + self, __op: 'cirq.Operation', *, context: Optional['cirq.DecompositionContext'] = None + ) -> DecomposeResult: + ... + + +OpDecomposer = Union[Callable[['cirq.Operation'], DecomposeResult], OpDecomposerWithContext] DECOMPOSE_TARGET_GATESET = ops.Gateset( ops.XPowGate, @@ -62,6 +74,18 @@ def _value_error_describing_bad_operation(op: 'cirq.Operation') -> ValueError: return ValueError(f"Operation doesn't satisfy the given `keep` but can't be decomposed: {op!r}") +@dataclasses.dataclass(frozen=True) +class DecompositionContext: + """Stores common configurable options for decomposing composite gates into simpler operations. + + Args: + qubit_manager: A `cirq.QubitManager` instance to allocate clean / dirty ancilla qubits as + part of the decompose protocol. + """ + + qubit_manager: 'cirq.QubitManager' + + class SupportsDecompose(Protocol): """An object that can be decomposed into simpler operations. @@ -105,6 +129,11 @@ class SupportsDecompose(Protocol): def _decompose_(self) -> DecomposeResult: pass + def _decompose_with_context_( + self, *, context: Optional[DecompositionContext] = None + ) -> DecomposeResult: + pass + class SupportsDecomposeWithQubits(Protocol): """An object that can be decomposed into operations on given qubits. @@ -128,15 +157,27 @@ class SupportsDecomposeWithQubits(Protocol): def _decompose_(self, qubits: Tuple['cirq.Qid', ...]) -> DecomposeResult: pass + def _decompose_with_context_( + self, qubits: Tuple['cirq.Qid', ...], *, context: Optional[DecompositionContext] = None + ) -> DecomposeResult: + pass + -def _try_op_decomposer(val: Any, decomposer: Optional[OpDecomposer]) -> DecomposeResult: +def _try_op_decomposer( + val: Any, decomposer: Optional[OpDecomposer], *, context: Optional[DecompositionContext] = None +) -> DecomposeResult: if decomposer is None or not isinstance(val, ops.Operation): return None - return decomposer(val) + if 'context' in inspect.signature(decomposer).parameters: + assert isinstance(decomposer, OpDecomposerWithContext) + return decomposer(val, context=context) + else: + return decomposer(val) @dataclasses.dataclass(frozen=True) class _DecomposeArgs: + context: Optional[DecompositionContext] intercepting_decomposer: Optional[OpDecomposer] fallback_decomposer: Optional[OpDecomposer] keep: Optional[Callable[['cirq.Operation'], bool]] @@ -157,13 +198,13 @@ def _decompose_dfs(item: Any, args: _DecomposeArgs) -> Iterator['cirq.Operation' yield item return - decomposed = _try_op_decomposer(item, args.intercepting_decomposer) + decomposed = _try_op_decomposer(item, args.intercepting_decomposer, context=args.context) if decomposed is NotImplemented or decomposed is None: - decomposed = decompose_once(item, default=None, flatten=False) + decomposed = decompose_once(item, default=None, flatten=False, context=args.context) if decomposed is NotImplemented or decomposed is None: - decomposed = _try_op_decomposer(item, args.fallback_decomposer) + decomposed = _try_op_decomposer(item, args.fallback_decomposer, context=args.context) if decomposed is NotImplemented or decomposed is None: if not isinstance(item, ops.Operation) and isinstance(item, Iterable): @@ -193,6 +234,7 @@ def decompose( None, Exception, Callable[['cirq.Operation'], Optional[Exception]] ] = _value_error_describing_bad_operation, preserve_structure: bool = False, + context: Optional[DecompositionContext] = None, ) -> List['cirq.Operation']: """Recursively decomposes a value into `cirq.Operation`s meeting a criteria. @@ -224,6 +266,8 @@ def decompose( preserve_structure: Prevents subcircuits (i.e. `CircuitOperation`s) from being decomposed, but decomposes their contents. If this is True, `intercepting_decomposer` cannot be specified. + context: Decomposition context specifying common configurable options for + controlling the behavior of decompose. Returns: A list of operations that the given value was decomposed into. If @@ -256,6 +300,7 @@ def decompose( ) args = _DecomposeArgs( + context=context, intercepting_decomposer=intercepting_decomposer, fallback_decomposer=fallback_decomposer, keep=keep, @@ -281,7 +326,12 @@ def decompose_once( def decompose_once( - val: Any, default=RaiseTypeErrorIfNotProvided, *args, flatten: bool = True, **kwargs + val: Any, + default=RaiseTypeErrorIfNotProvided, + *args, + flatten: bool = True, + context: Optional[DecompositionContext] = None, + **kwargs, ): """Decomposes a value into operations, if possible. @@ -299,6 +349,8 @@ def decompose_once( `val`. For example, this is used to tell gates what qubits they are being applied to. flatten: If True, the returned OP-TREE will be flattened to a list of operations. + context: Decomposition context specifying common configurable options for + controlling the behavior of decompose. **kwargs: Keyword arguments to forward into the `_decompose_` method of `val`. @@ -312,29 +364,47 @@ def decompose_once( TypeError: `val` didn't have a `_decompose_` method (or that method returned `NotImplemented` or `None`) and `default` wasn't set. """ - method = getattr(val, '_decompose_', None) - decomposed = NotImplemented if method is None else method(*args, **kwargs) + method = getattr(val, '_decompose_with_context_', None) + decomposed = NotImplemented if method is None else method(*args, **kwargs, context=context) + if decomposed is NotImplemented or None: + method = getattr(val, '_decompose_', None) + decomposed = NotImplemented if method is None else method(*args, **kwargs) + if decomposed is not NotImplemented and decomposed is not None: return list(ops.flatten_to_ops(decomposed)) if flatten else decomposed if default is not RaiseTypeErrorIfNotProvided: return default if method is None: - raise TypeError(f"object of type '{type(val)}' has no _decompose_ method.") + raise TypeError( + f"object of type '{type(val)}' has no _decompose_with_context_ or " + f"_decompose_ method." + ) raise TypeError( - "object of type '{}' does have a _decompose_ method, " - "but it returned NotImplemented or None.".format(type(val)) + f"object of type {type(val)} does have a _decompose_ method, " + "but it returned NotImplemented or None." ) @overload -def decompose_once_with_qubits(val: Any, qubits: Iterable['cirq.Qid']) -> List['cirq.Operation']: +def decompose_once_with_qubits( + val: Any, + qubits: Iterable['cirq.Qid'], + *, + flatten: bool = True, + context: Optional['DecompositionContext'] = None, +) -> List['cirq.Operation']: pass @overload def decompose_once_with_qubits( - val: Any, qubits: Iterable['cirq.Qid'], default: Optional[TDefault], flatten: bool = True + val: Any, + qubits: Iterable['cirq.Qid'], + default: Optional[TDefault], + *, + flatten: bool = True, + context: Optional['DecompositionContext'] = None, ) -> Union[TDefault, List['cirq.Operation']]: pass @@ -344,6 +414,7 @@ def decompose_once_with_qubits( qubits: Iterable['cirq.Qid'], default=RaiseTypeErrorIfNotProvided, flatten: bool = True, + context: Optional['DecompositionContext'] = None, ): """Decomposes a value into operations on the given qubits. @@ -361,6 +432,8 @@ def decompose_once_with_qubits( `None`. If not specified, non-decomposable values cause a `TypeError`. flatten: If True, the returned OP-TREE will be flattened to a list of operations. + context: Decomposition context specifying common configurable options for + controlling the behavior of decompose. Returns: The result of `val._decompose_(qubits)`, if `val` has a @@ -372,7 +445,7 @@ def decompose_once_with_qubits( `val` didn't have a `_decompose_` method (or that method returned `NotImplemented` or `None`) and `default` wasn't set. """ - return decompose_once(val, default, tuple(qubits), flatten=flatten) + return decompose_once(val, default, tuple(qubits), flatten=flatten, context=context) # pylint: enable=function-redefined diff --git a/cirq-core/cirq/protocols/decompose_protocol_test.py b/cirq-core/cirq/protocols/decompose_protocol_test.py index 071a96568c9..1049a652b65 100644 --- a/cirq-core/cirq/protocols/decompose_protocol_test.py +++ b/cirq-core/cirq/protocols/decompose_protocol_test.py @@ -61,7 +61,7 @@ def _decompose_(self, qids): def test_decompose_once(): # No default value results in descriptive error. - with pytest.raises(TypeError, match='no _decompose_ method'): + with pytest.raises(TypeError, match='no _decompose_with_context_ or _decompose_ method'): _ = cirq.decompose_once(NoMethod()) with pytest.raises(TypeError, match='returned NotImplemented or None'): _ = cirq.decompose_once(DecomposeNotImplemented()) @@ -90,7 +90,7 @@ def test_decompose_once_with_qubits(): qs = cirq.LineQubit.range(3) # No default value results in descriptive error. - with pytest.raises(TypeError, match='no _decompose_ method'): + with pytest.raises(TypeError, match='no _decompose_with_context_ or _decompose_ method'): _ = cirq.decompose_once_with_qubits(NoMethod(), qs) with pytest.raises(TypeError, match='returned NotImplemented or None'): _ = cirq.decompose_once_with_qubits(DecomposeNotImplemented(), qs) @@ -224,6 +224,23 @@ def test_decompose_intercept(): ) assert actual == [cirq.CNOT(a, b), cirq.CNOT(b, a), cirq.CNOT(a, b)] + # Accepts a context, when provided. + def _intercept_with_context( + op: cirq.Operation, context: Optional[cirq.DecompositionContext] = None + ): + assert context is not None + if op.gate == cirq.SWAP: + q = context.qubit_manager.qalloc(1) + a, b = op.qubits + return [cirq.X(a), cirq.X(*q), cirq.X(b)] + return NotImplemented + + context = cirq.DecompositionContext(cirq.ops.SimpleQubitManager()) + actual = cirq.decompose( + cirq.SWAP(a, b), intercepting_decomposer=_intercept_with_context, context=context + ) + assert actual == [cirq.X(a), cirq.X(cirq.ops.CleanQubit(0)), cirq.X(b)] + def test_decompose_preserving_structure(): a, b = cirq.LineQubit.range(2) @@ -312,38 +329,72 @@ def test_decompose_tagged_operation(): assert cirq.decompose_once(op) == cirq.decompose_once(op.untagged) -def test_decompose_recursive_dfs(): - class RecursiveDecompose(cirq.Gate): - def __init__(self, recurse: bool = True, mock_qm: Optional[mock.Mock] = None): - self.recurse = recurse - self.mock_qm = mock.Mock() if mock_qm is None else mock_qm +class RecursiveDecompose(cirq.Gate): + def __init__( + self, + recurse: bool = True, + mock_qm=mock.Mock(spec=cirq.QubitManager), + with_context: bool = False, + ): + self.recurse = recurse + self.mock_qm = mock_qm + self.with_context = with_context + + def _num_qubits_(self) -> int: + return 2 - def _num_qubits_(self) -> int: - return 2 + def _decompose_impl(self, qubits, mock_qm: mock.Mock): + mock_qm.qalloc(self.recurse) + yield RecursiveDecompose( + recurse=False, mock_qm=self.mock_qm, with_context=self.with_context + ).on(*qubits) if self.recurse else cirq.Z.on_each(*qubits) + mock_qm.qfree(self.recurse) - def _decompose_(self, qubits): - self.mock_qm.qalloc(self.recurse) - yield RecursiveDecompose(recurse=False, mock_qm=self.mock_qm).on( - *qubits - ) if self.recurse else cirq.Z.on_each(*qubits) - self.mock_qm.qfree(self.recurse) + def _decompose_(self, qubits): + if self.with_context: + assert False + else: + return self._decompose_impl(qubits, self.mock_qm) + + def _decompose_with_context_(self, qubits, context): + if self.with_context: + qm = self.mock_qm if context is None else context.qubit_manager + return self._decompose_impl(qubits, qm) + else: + return NotImplemented + + def _has_unitary_(self): + return True - def _has_unitary_(self): - return True +@pytest.mark.parametrize('with_context', [True, False]) +def test_decompose_recursive_dfs(with_context: bool): expected_calls = [ mock.call.qalloc(True), mock.call.qalloc(False), mock.call.qfree(False), mock.call.qfree(True), ] - mock_qm = mock.Mock(spec=["qalloc", "qfree"]) + mock_qm = mock.Mock(spec=cirq.QubitManager) + context_qm = mock.Mock(spec=cirq.QubitManager) + gate = RecursiveDecompose(mock_qm=mock_qm, with_context=with_context) q = cirq.LineQubit.range(3) - gate = RecursiveDecompose(mock_qm=mock_qm) gate_op = gate.on(*q[:2]) + tagged_op = gate_op.with_tags("custom tag") controlled_op = gate_op.controlled_by(q[2]) classically_controlled_op = gate_op.with_classical_controls('key') - for op in [gate_op, controlled_op, classically_controlled_op]: + moment = cirq.Moment(gate_op) + circuit = cirq.Circuit(moment) + for val in [gate_op, tagged_op, controlled_op, classically_controlled_op, moment, circuit]: mock_qm.reset_mock() - _ = cirq.decompose(op) + _ = cirq.decompose(val) assert mock_qm.method_calls == expected_calls + + mock_qm.reset_mock() + context_qm.reset_mock() + _ = cirq.decompose(val, context=cirq.DecompositionContext(context_qm)) + assert ( + context_qm.method_calls == expected_calls + if with_context + else mock_qm.method_calls == expected_calls + ) diff --git a/cirq-core/cirq/protocols/json_test_data/spec.py b/cirq-core/cirq/protocols/json_test_data/spec.py index 589d03e57bd..7a92491267a 100644 --- a/cirq-core/cirq/protocols/json_test_data/spec.py +++ b/cirq-core/cirq/protocols/json_test_data/spec.py @@ -81,6 +81,7 @@ 'MEASUREMENT_KEY_SEPARATOR', 'PointOptimizer', # Transformers + 'DecompositionContext', 'TransformerLogger', 'TransformerContext', # Routing utilities