Skip to content

Commit

Permalink
Ensure that cirq.decompose traverses the yielded OP-TREE in dfs order…
Browse files Browse the repository at this point in the history
…ing (#6117)
  • Loading branch information
tanujkhattar committed Jun 2, 2023
1 parent b1e09a9 commit 99e8a13
Show file tree
Hide file tree
Showing 7 changed files with 84 additions and 26 deletions.
8 changes: 5 additions & 3 deletions cirq-core/cirq/ops/classically_controlled_operation.py
Expand Up @@ -28,7 +28,7 @@
import sympy

from cirq import protocols, value
from cirq.ops import raw_types
from cirq.ops import op_tree, raw_types

if TYPE_CHECKING:
import cirq
Expand Down Expand Up @@ -105,11 +105,13 @@ def with_qubits(self, *new_qubits):
)

def _decompose_(self):
result = protocols.decompose_once(self._sub_operation, NotImplemented)
result = protocols.decompose_once(self._sub_operation, NotImplemented, flatten=False)
if result is NotImplemented:
return NotImplemented

return [ClassicallyControlledOperation(op, self._conditions) for op in result]
return op_tree.transform_op_tree(
result, lambda op: ClassicallyControlledOperation(op, self._conditions)
)

def _value_equality_values_(self):
return (frozenset(self._conditions), self._sub_operation)
Expand Down
22 changes: 14 additions & 8 deletions cirq-core/cirq/ops/controlled_gate.py
Expand Up @@ -28,7 +28,13 @@
import numpy as np

from cirq import protocols, value, _import
from cirq.ops import raw_types, controlled_operation as cop, matrix_gates, control_values as cv
from cirq.ops import (
raw_types,
controlled_operation as cop,
op_tree,
matrix_gates,
control_values as cv,
)
from cirq.type_workarounds import NotImplementedType

if TYPE_CHECKING:
Expand Down Expand Up @@ -194,17 +200,17 @@ def _decompose_(
return NotImplemented

result = protocols.decompose_once_with_qubits(
self.sub_gate, qubits[self.num_controls() :], NotImplemented
self.sub_gate, qubits[self.num_controls() :], NotImplemented, flatten=False
)
if result is NotImplemented:
return NotImplemented

decomposed: List['cirq.Operation'] = []
for op in result:
decomposed.append(
op.controlled_by(*qubits[: self.num_controls()], control_values=self.control_values)
)
return decomposed
return op_tree.transform_op_tree(
result,
lambda op: op.controlled_by(
*qubits[: self.num_controls()], control_values=self.control_values
),
)

def on(self, *qubits: 'cirq.Qid') -> cop.ControlledOperation:
if len(qubits) == 0:
Expand Down
13 changes: 8 additions & 5 deletions cirq-core/cirq/ops/controlled_operation.py
Expand Up @@ -34,6 +34,7 @@
eigen_gate,
gate_operation,
matrix_gates,
op_tree,
raw_types,
control_values as cv,
)
Expand Down Expand Up @@ -145,7 +146,9 @@ def with_qubits(self, *new_qubits):
)

def _decompose_(self):
result = protocols.decompose_once_with_qubits(self.gate, self.qubits, NotImplemented)
result = protocols.decompose_once_with_qubits(
self.gate, self.qubits, NotImplemented, flatten=False
)
if result is not NotImplemented:
return result

Expand All @@ -154,13 +157,13 @@ def _decompose_(self):
# local phase in the controlled variant and hence cannot be ignored.
return NotImplemented

result = protocols.decompose_once(self.sub_operation, NotImplemented)
result = protocols.decompose_once(self.sub_operation, NotImplemented, flatten=False)
if result is NotImplemented:
return NotImplemented

return [
op.controlled_by(*self.controls, control_values=self.control_values) for op in result
]
return op_tree.transform_op_tree(
result, lambda op: op.controlled_by(*self.controls, control_values=self.control_values)
)

def _value_equality_values_(self):
sorted_controls, expanded_cvals = tuple(
Expand Down
4 changes: 3 additions & 1 deletion cirq-core/cirq/ops/gate_operation.py
Expand Up @@ -160,7 +160,9 @@ def _num_qubits_(self):
return len(self._qubits)

def _decompose_(self) -> 'cirq.OP_TREE':
return protocols.decompose_once_with_qubits(self.gate, self.qubits, NotImplemented)
return protocols.decompose_once_with_qubits(
self.gate, self.qubits, NotImplemented, flatten=False
)

def _pauli_expansion_(self) -> value.LinearDict[str]:
getter = getattr(self.gate, '_pauli_expansion_', None)
Expand Down
2 changes: 1 addition & 1 deletion cirq-core/cirq/ops/raw_types.py
Expand Up @@ -830,7 +830,7 @@ def _json_dict_(self) -> Dict[str, Any]:
return protocols.obj_to_dict_helper(self, ['sub_operation', 'tags'])

def _decompose_(self) -> 'cirq.OP_TREE':
return protocols.decompose_once(self.sub_operation, default=None)
return protocols.decompose_once(self.sub_operation, default=None, flatten=False)

def _pauli_expansion_(self) -> value.LinearDict[str]:
return protocols.pauli_expansion(self.sub_operation)
Expand Down
22 changes: 14 additions & 8 deletions cirq-core/cirq/protocols/decompose_protocol.py
Expand Up @@ -160,7 +160,7 @@ def _decompose_dfs(item: Any, args: _DecomposeArgs) -> Iterator['cirq.Operation'
decomposed = _try_op_decomposer(item, args.intercepting_decomposer)

if decomposed is NotImplemented or decomposed is None:
decomposed = decompose_once(item, default=None)
decomposed = decompose_once(item, default=None, flatten=False)

if decomposed is NotImplemented or decomposed is None:
decomposed = _try_op_decomposer(item, args.fallback_decomposer)
Expand Down Expand Up @@ -275,12 +275,14 @@ def decompose_once(val: Any, **kwargs) -> List['cirq.Operation']:

@overload
def decompose_once(
val: Any, default: TDefault, *args, **kwargs
val: Any, default: TDefault, *args, flatten: bool = True, **kwargs
) -> Union[TDefault, List['cirq.Operation']]:
pass


def decompose_once(val: Any, default=RaiseTypeErrorIfNotProvided, *args, **kwargs):
def decompose_once(
val: Any, default=RaiseTypeErrorIfNotProvided, *args, flatten: bool = True, **kwargs
):
"""Decomposes a value into operations, if possible.
This method decomposes the value exactly once, instead of decomposing it
Expand All @@ -296,6 +298,7 @@ def decompose_once(val: Any, default=RaiseTypeErrorIfNotProvided, *args, **kwarg
*args: Positional arguments to forward into the `_decompose_` method of
`val`. For example, this is used to tell gates what qubits they are
being applied to.
flatten: If True, the returned OP-TREE will be flattened to a list of operations.
**kwargs: Keyword arguments to forward into the `_decompose_` method of
`val`.
Expand All @@ -311,9 +314,8 @@ def decompose_once(val: Any, default=RaiseTypeErrorIfNotProvided, *args, **kwarg
"""
method = getattr(val, '_decompose_', None)
decomposed = NotImplemented if method is None else method(*args, **kwargs)

if decomposed is not NotImplemented and decomposed is not None:
return list(ops.flatten_op_tree(decomposed))
return list(ops.flatten_to_ops(decomposed)) if flatten else decomposed

if default is not RaiseTypeErrorIfNotProvided:
return default
Expand All @@ -332,13 +334,16 @@ def decompose_once_with_qubits(val: Any, qubits: Iterable['cirq.Qid']) -> List['

@overload
def decompose_once_with_qubits(
val: Any, qubits: Iterable['cirq.Qid'], default: Optional[TDefault]
val: Any, qubits: Iterable['cirq.Qid'], default: Optional[TDefault], flatten: bool = True
) -> Union[TDefault, List['cirq.Operation']]:
pass


def decompose_once_with_qubits(
val: Any, qubits: Iterable['cirq.Qid'], default=RaiseTypeErrorIfNotProvided
val: Any,
qubits: Iterable['cirq.Qid'],
default=RaiseTypeErrorIfNotProvided,
flatten: bool = True,
):
"""Decomposes a value into operations on the given qubits.
Expand All @@ -355,6 +360,7 @@ def decompose_once_with_qubits(
`_decompose_` method or that method returns `NotImplemented` or
`None`. If not specified, non-decomposable values cause a
`TypeError`.
flatten: If True, the returned OP-TREE will be flattened to a list of operations.
Returns:
The result of `val._decompose_(qubits)`, if `val` has a
Expand All @@ -366,7 +372,7 @@ def decompose_once_with_qubits(
`val` didn't have a `_decompose_` method (or that method returned
`NotImplemented` or `None`) and `default` wasn't set.
"""
return decompose_once(val, default, tuple(qubits))
return decompose_once(val, default, tuple(qubits), flatten=flatten)


# pylint: enable=function-redefined
Expand Down
39 changes: 39 additions & 0 deletions cirq-core/cirq/protocols/decompose_protocol_test.py
Expand Up @@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from unittest import mock
import pytest

import cirq
Expand Down Expand Up @@ -308,3 +310,40 @@ def test_decompose_tagged_operation():
'tag',
)
assert cirq.decompose_once(op) == cirq.decompose_once(op.untagged)


def test_decompose_recursive_dfs():
class RecursiveDecompose(cirq.Gate):
def __init__(self, recurse: bool = True, mock_qm: Optional[mock.Mock] = None):
self.recurse = recurse
self.mock_qm = mock.Mock() if mock_qm is None else mock_qm

def _num_qubits_(self) -> int:
return 2

def _decompose_(self, qubits):
self.mock_qm.qalloc(self.recurse)
yield RecursiveDecompose(recurse=False, mock_qm=self.mock_qm).on(
*qubits
) if self.recurse else cirq.Z.on_each(*qubits)
self.mock_qm.qfree(self.recurse)

def _has_unitary_(self):
return True

expected_calls = [
mock.call.qalloc(True),
mock.call.qalloc(False),
mock.call.qfree(False),
mock.call.qfree(True),
]
mock_qm = mock.Mock(spec=["qalloc", "qfree"])
q = cirq.LineQubit.range(3)
gate = RecursiveDecompose(mock_qm=mock_qm)
gate_op = gate.on(*q[:2])
controlled_op = gate_op.controlled_by(q[2])
classically_controlled_op = gate_op.with_classical_controls('key')
for op in [gate_op, controlled_op, classically_controlled_op]:
mock_qm.reset_mock()
_ = cirq.decompose(op)
assert mock_qm.method_calls == expected_calls

0 comments on commit 99e8a13

Please sign in to comment.