Skip to content

Commit

Permalink
Allow intercepting decomposer while preserving structure. (#4343)
Browse files Browse the repository at this point in the history
This allows users to specify an intercepting decomposer (i.e. one that runs before default decomposition behavior) and use structure-preserving behavior (i.e. keeping `CircuitOperation`s but decomposing their contents) in the same `decompose` call.

This is a prerequisite for allowing `CircuitOperation`s in `cirq.optimized_for_sycamore`.
  • Loading branch information
95-martin-orion committed Jul 21, 2021
1 parent 912110f commit 055db68
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 46 deletions.
35 changes: 8 additions & 27 deletions cirq-core/cirq/protocols/decompose_protocol.py
Expand Up @@ -119,30 +119,6 @@ def _decompose_(self, qubits: Tuple['cirq.Qid', ...]) -> DecomposeResult:
pass


# pylint: disable=function-redefined
@overload
def decompose(
val: Any,
*,
intercepting_decomposer: Optional[OpDecomposer] = None,
fallback_decomposer: Optional[OpDecomposer] = None,
keep: Optional[Callable[['cirq.Operation'], bool]] = None,
) -> List['cirq.Operation']:
pass


@overload
def decompose(
val: Any,
*,
intercepting_decomposer: Optional[OpDecomposer] = None,
fallback_decomposer: Optional[OpDecomposer] = None,
keep: Optional[Callable[['cirq.Operation'], bool]] = None,
on_stuck_raise: Union[None, TError, Callable[['cirq.Operation'], Optional[TError]]],
) -> List['cirq.Operation']:
pass


def decompose(
val: Any,
*,
Expand Down Expand Up @@ -212,10 +188,9 @@ def decompose(
)

if preserve_structure:
if intercepting_decomposer is not None:
raise ValueError('Cannot specify intercepting_decomposer while preserving structure.')
return _decompose_preserving_structure(
val,
intercepting_decomposer=intercepting_decomposer,
fallback_decomposer=fallback_decomposer,
keep=keep,
on_stuck_raise=on_stuck_raise,
Expand Down Expand Up @@ -263,6 +238,9 @@ def try_op_decomposer(val: Any, decomposer: Optional[OpDecomposer]) -> Decompose
return output


# pylint: disable=function-redefined


@overload
def decompose_once(val: Any, **kwargs) -> List['cirq.Operation']:
pass
Expand Down Expand Up @@ -401,6 +379,7 @@ def _try_decompose_into_operations_and_qubits(
def _decompose_preserving_structure(
val: Any,
*,
intercepting_decomposer: Optional[OpDecomposer] = None,
fallback_decomposer: Optional[OpDecomposer] = None,
keep: Optional[Callable[['cirq.Operation'], bool]] = None,
on_stuck_raise: Union[
Expand Down Expand Up @@ -431,7 +410,9 @@ def keep_structure(op: 'cirq.Operation'):

def dps_interceptor(op: 'cirq.Operation'):
if not isinstance(op.untagged, CircuitOperation):
return NotImplemented
if intercepting_decomposer is None:
return NotImplemented
return intercepting_decomposer(op)

new_fc = FrozenCircuit(
decompose(
Expand Down
24 changes: 5 additions & 19 deletions cirq-core/cirq/protocols/decompose_protocol_test.py
Expand Up @@ -250,7 +250,9 @@ def test_decompose_preserving_structure():
assert actual == expected


def test_decompose_preserving_structure_forwards_args():
# Test both intercepting and fallback decomposers.
@pytest.mark.parametrize('decompose_mode', ['intercept', 'fallback'])
def test_decompose_preserving_structure_forwards_args(decompose_mode):
a, b = cirq.LineQubit.range(2)
fc1 = cirq.FrozenCircuit(cirq.SWAP(a, b), cirq.FSimGate(0.1, 0.2).on(a, b))
cop1_1 = cirq.CircuitOperation(fc1).with_tags('test_tag')
Expand All @@ -276,7 +278,8 @@ def x_to_hzh(op: 'cirq.Operation'):
cirq.decompose(
circuit,
keep=keep_func,
fallback_decomposer=x_to_hzh,
intercepting_decomposer=x_to_hzh if decompose_mode == 'intercept' else None,
fallback_decomposer=x_to_hzh if decompose_mode == 'fallback' else None,
preserve_structure=True,
),
)
Expand All @@ -302,20 +305,3 @@ def x_to_hzh(op: 'cirq.Operation'):
cirq.measure(a, b, key='m'),
)
assert actual == expected


def test_decompose_preserving_structure_no_interceptor():
a, b = cirq.LineQubit.range(2)
fc1 = cirq.FrozenCircuit(cirq.SWAP(a, b), cirq.FSimGate(0.1, 0.2).on(a, b))
cop1_1 = cirq.CircuitOperation(fc1).with_tags('test_tag')
cop1_2 = cirq.CircuitOperation(fc1).with_qubit_mapping({a: b, b: a})
fc2 = cirq.FrozenCircuit(cirq.X(a), cop1_1, cop1_2)
cop2 = cirq.CircuitOperation(fc2)

circuit = cirq.Circuit(cop2, cirq.measure(a, b, key='m'))
with pytest.raises(ValueError, match='Cannot specify intercepting_decomposer'):
cirq.decompose(
circuit,
intercepting_decomposer=lambda x: [],
preserve_structure=True,
)

0 comments on commit 055db68

Please sign in to comment.