Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure that cirq.decompose traverses the yielded OP-TREE in dfs ordering #6117

Merged
merged 3 commits into from
Jun 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 5 additions & 3 deletions cirq-core/cirq/ops/classically_controlled_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import sympy

from cirq import protocols, value
from cirq.ops import raw_types
from cirq.ops import op_tree, raw_types

if TYPE_CHECKING:
import cirq
Expand Down Expand Up @@ -105,11 +105,13 @@ def with_qubits(self, *new_qubits):
)

def _decompose_(self):
result = protocols.decompose_once(self._sub_operation, NotImplemented)
result = protocols.decompose_once(self._sub_operation, NotImplemented, flatten=False)
if result is NotImplemented:
return NotImplemented

return [ClassicallyControlledOperation(op, self._conditions) for op in result]
return op_tree.transform_op_tree(
result, lambda op: ClassicallyControlledOperation(op, self._conditions)
)

def _value_equality_values_(self):
return (frozenset(self._conditions), self._sub_operation)
Expand Down
22 changes: 14 additions & 8 deletions cirq-core/cirq/ops/controlled_gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,13 @@
import numpy as np

from cirq import protocols, value, _import
from cirq.ops import raw_types, controlled_operation as cop, matrix_gates, control_values as cv
from cirq.ops import (
raw_types,
controlled_operation as cop,
op_tree,
matrix_gates,
control_values as cv,
)
from cirq.type_workarounds import NotImplementedType

if TYPE_CHECKING:
Expand Down Expand Up @@ -194,17 +200,17 @@ def _decompose_(
return NotImplemented

result = protocols.decompose_once_with_qubits(
self.sub_gate, qubits[self.num_controls() :], NotImplemented
self.sub_gate, qubits[self.num_controls() :], NotImplemented, flatten=False
)
if result is NotImplemented:
return NotImplemented

decomposed: List['cirq.Operation'] = []
for op in result:
decomposed.append(
op.controlled_by(*qubits[: self.num_controls()], control_values=self.control_values)
)
return decomposed
return op_tree.transform_op_tree(
result,
lambda op: op.controlled_by(
*qubits[: self.num_controls()], control_values=self.control_values
),
)

def on(self, *qubits: 'cirq.Qid') -> cop.ControlledOperation:
if len(qubits) == 0:
Expand Down
13 changes: 8 additions & 5 deletions cirq-core/cirq/ops/controlled_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
eigen_gate,
gate_operation,
matrix_gates,
op_tree,
raw_types,
control_values as cv,
)
Expand Down Expand Up @@ -145,7 +146,9 @@ def with_qubits(self, *new_qubits):
)

def _decompose_(self):
result = protocols.decompose_once_with_qubits(self.gate, self.qubits, NotImplemented)
result = protocols.decompose_once_with_qubits(
self.gate, self.qubits, NotImplemented, flatten=False
)
if result is not NotImplemented:
return result

Expand All @@ -154,13 +157,13 @@ def _decompose_(self):
# local phase in the controlled variant and hence cannot be ignored.
return NotImplemented

result = protocols.decompose_once(self.sub_operation, NotImplemented)
result = protocols.decompose_once(self.sub_operation, NotImplemented, flatten=False)
if result is NotImplemented:
return NotImplemented

return [
op.controlled_by(*self.controls, control_values=self.control_values) for op in result
]
return op_tree.transform_op_tree(
result, lambda op: op.controlled_by(*self.controls, control_values=self.control_values)
)

def _value_equality_values_(self):
sorted_controls, expanded_cvals = tuple(
Expand Down
4 changes: 3 additions & 1 deletion cirq-core/cirq/ops/gate_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,9 @@ def _num_qubits_(self):
return len(self._qubits)

def _decompose_(self) -> 'cirq.OP_TREE':
return protocols.decompose_once_with_qubits(self.gate, self.qubits, NotImplemented)
return protocols.decompose_once_with_qubits(
self.gate, self.qubits, NotImplemented, flatten=False
)

def _pauli_expansion_(self) -> value.LinearDict[str]:
getter = getattr(self.gate, '_pauli_expansion_', None)
Expand Down
2 changes: 1 addition & 1 deletion cirq-core/cirq/ops/raw_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -830,7 +830,7 @@ def _json_dict_(self) -> Dict[str, Any]:
return protocols.obj_to_dict_helper(self, ['sub_operation', 'tags'])

def _decompose_(self) -> 'cirq.OP_TREE':
return protocols.decompose_once(self.sub_operation, default=None)
return protocols.decompose_once(self.sub_operation, default=None, flatten=False)

def _pauli_expansion_(self) -> value.LinearDict[str]:
return protocols.pauli_expansion(self.sub_operation)
Expand Down
22 changes: 14 additions & 8 deletions cirq-core/cirq/protocols/decompose_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def _decompose_dfs(item: Any, args: _DecomposeArgs) -> Iterator['cirq.Operation'
decomposed = _try_op_decomposer(item, args.intercepting_decomposer)

if decomposed is NotImplemented or decomposed is None:
decomposed = decompose_once(item, default=None)
decomposed = decompose_once(item, default=None, flatten=False)

if decomposed is NotImplemented or decomposed is None:
decomposed = _try_op_decomposer(item, args.fallback_decomposer)
Expand Down Expand Up @@ -275,12 +275,14 @@ def decompose_once(val: Any, **kwargs) -> List['cirq.Operation']:

@overload
def decompose_once(
val: Any, default: TDefault, *args, **kwargs
val: Any, default: TDefault, *args, flatten: bool = True, **kwargs
) -> Union[TDefault, List['cirq.Operation']]:
pass


def decompose_once(val: Any, default=RaiseTypeErrorIfNotProvided, *args, **kwargs):
def decompose_once(
val: Any, default=RaiseTypeErrorIfNotProvided, *args, flatten: bool = True, **kwargs
):
"""Decomposes a value into operations, if possible.

This method decomposes the value exactly once, instead of decomposing it
Expand All @@ -296,6 +298,7 @@ def decompose_once(val: Any, default=RaiseTypeErrorIfNotProvided, *args, **kwarg
*args: Positional arguments to forward into the `_decompose_` method of
`val`. For example, this is used to tell gates what qubits they are
being applied to.
flatten: If True, the returned OP-TREE will be flattened to a list of operations.
**kwargs: Keyword arguments to forward into the `_decompose_` method of
`val`.

Expand All @@ -311,9 +314,8 @@ def decompose_once(val: Any, default=RaiseTypeErrorIfNotProvided, *args, **kwarg
"""
method = getattr(val, '_decompose_', None)
decomposed = NotImplemented if method is None else method(*args, **kwargs)

if decomposed is not NotImplemented and decomposed is not None:
return list(ops.flatten_op_tree(decomposed))
return list(ops.flatten_to_ops(decomposed)) if flatten else decomposed

if default is not RaiseTypeErrorIfNotProvided:
return default
Expand All @@ -332,13 +334,16 @@ def decompose_once_with_qubits(val: Any, qubits: Iterable['cirq.Qid']) -> List['

@overload
def decompose_once_with_qubits(
val: Any, qubits: Iterable['cirq.Qid'], default: Optional[TDefault]
val: Any, qubits: Iterable['cirq.Qid'], default: Optional[TDefault], flatten: bool = True
) -> Union[TDefault, List['cirq.Operation']]:
pass


def decompose_once_with_qubits(
val: Any, qubits: Iterable['cirq.Qid'], default=RaiseTypeErrorIfNotProvided
val: Any,
qubits: Iterable['cirq.Qid'],
default=RaiseTypeErrorIfNotProvided,
flatten: bool = True,
):
"""Decomposes a value into operations on the given qubits.

Expand All @@ -355,6 +360,7 @@ def decompose_once_with_qubits(
`_decompose_` method or that method returns `NotImplemented` or
`None`. If not specified, non-decomposable values cause a
`TypeError`.
flatten: If True, the returned OP-TREE will be flattened to a list of operations.

Returns:
The result of `val._decompose_(qubits)`, if `val` has a
Expand All @@ -366,7 +372,7 @@ def decompose_once_with_qubits(
`val` didn't have a `_decompose_` method (or that method returned
`NotImplemented` or `None`) and `default` wasn't set.
"""
return decompose_once(val, default, tuple(qubits))
return decompose_once(val, default, tuple(qubits), flatten=flatten)


# pylint: enable=function-redefined
Expand Down
39 changes: 39 additions & 0 deletions cirq-core/cirq/protocols/decompose_protocol_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from unittest import mock
import pytest

import cirq
Expand Down Expand Up @@ -308,3 +310,40 @@ def test_decompose_tagged_operation():
'tag',
)
assert cirq.decompose_once(op) == cirq.decompose_once(op.untagged)


def test_decompose_recursive_dfs():
class RecursiveDecompose(cirq.Gate):
def __init__(self, recurse: bool = True, mock_qm: Optional[mock.Mock] = None):
self.recurse = recurse
self.mock_qm = mock.Mock() if mock_qm is None else mock_qm
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see few places in Cirq where stubbing is used, lets not normalize that. Stubs can go horribly wrong => tests that are smoke tests, change detectors or worse tests that pass ragardless of the change.

lets do the test in a better way.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this specific case, I think its okay to use stubs since we want to test the ordering in which cirq.decompose will traverse the yielded OP-TREE and this test asserts that specifically. I don't care about the behavior of a Qubit Manager in this test -- it will become relevant when we add a qubit manager and that would be the time to add an integration test.

I could modify the test that asserts the traversal order via some other mechanism (eg: keep a global counter and append items to a list whenever the qalloc / qfree statements are executed) but the mock is essentially doing the same thing so I don't see a reason to replace it with some other equivalent mechanism.

While I generally appreciate the concern for not overusing mocks to write tests, I think the use here is well justified. I'll be curious to hear any alternative suggestions you have in mind for rewriting the test though.


def _num_qubits_(self) -> int:
return 2

def _decompose_(self, qubits):
self.mock_qm.qalloc(self.recurse)
yield RecursiveDecompose(recurse=False, mock_qm=self.mock_qm).on(
*qubits
) if self.recurse else cirq.Z.on_each(*qubits)
self.mock_qm.qfree(self.recurse)

def _has_unitary_(self):
return True

expected_calls = [
mock.call.qalloc(True),
mock.call.qalloc(False),
mock.call.qfree(False),
mock.call.qfree(True),
]
mock_qm = mock.Mock(spec=["qalloc", "qfree"])
q = cirq.LineQubit.range(3)
gate = RecursiveDecompose(mock_qm=mock_qm)
gate_op = gate.on(*q[:2])
controlled_op = gate_op.controlled_by(q[2])
classically_controlled_op = gate_op.with_classical_controls('key')
for op in [gate_op, controlled_op, classically_controlled_op]:
mock_qm.reset_mock()
_ = cirq.decompose(op)
assert mock_qm.method_calls == expected_calls