Skip to content

Commit

Permalink
Update cirq.decompose protocol to perform a DFS instead of a BFS on t…
Browse files Browse the repository at this point in the history
…he decomposed OP-TREE (#6116)

* update cirq.decompose protocol to perform a DFS instead of a BFS on the decomposed OP-TREE

* Fix mypy error

* Use a dataclass instead of kwargs
  • Loading branch information
tanujkhattar committed Jun 1, 2023
1 parent ebec38b commit b1e09a9
Showing 1 changed file with 64 additions and 111 deletions.
175 changes: 64 additions & 111 deletions cirq-core/cirq/protocols/decompose_protocol.py
Expand Up @@ -11,13 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import dataclasses
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Iterable,
Iterator,
List,
Optional,
overload,
Expand Down Expand Up @@ -128,6 +129,60 @@ def _decompose_(self, qubits: Tuple['cirq.Qid', ...]) -> DecomposeResult:
pass


def _try_op_decomposer(val: Any, decomposer: Optional[OpDecomposer]) -> DecomposeResult:
if decomposer is None or not isinstance(val, ops.Operation):
return None
return decomposer(val)


@dataclasses.dataclass(frozen=True)
class _DecomposeArgs:
intercepting_decomposer: Optional[OpDecomposer]
fallback_decomposer: Optional[OpDecomposer]
keep: Optional[Callable[['cirq.Operation'], bool]]
on_stuck_raise: Union[None, Exception, Callable[['cirq.Operation'], Optional[Exception]]]
preserve_structure: bool


def _decompose_dfs(item: Any, args: _DecomposeArgs) -> Iterator['cirq.Operation']:
from cirq.circuits import CircuitOperation, FrozenCircuit

if isinstance(item, ops.Operation):
item_untagged = item.untagged
if args.preserve_structure and isinstance(item_untagged, CircuitOperation):
new_fc = FrozenCircuit(_decompose_dfs(item_untagged.circuit, args))
yield item_untagged.replace(circuit=new_fc).with_tags(*item.tags)
return
if args.keep is not None and args.keep(item):
yield item
return

decomposed = _try_op_decomposer(item, args.intercepting_decomposer)

if decomposed is NotImplemented or decomposed is None:
decomposed = decompose_once(item, default=None)

if decomposed is NotImplemented or decomposed is None:
decomposed = _try_op_decomposer(item, args.fallback_decomposer)

if decomposed is NotImplemented or decomposed is None:
if not isinstance(item, ops.Operation) and isinstance(item, Iterable):
decomposed = item

if decomposed is NotImplemented or decomposed is None:
if args.keep is not None and args.on_stuck_raise is not None:
if isinstance(args.on_stuck_raise, Exception):
raise args.on_stuck_raise
elif callable(args.on_stuck_raise):
error = args.on_stuck_raise(item)
if error is not None:
raise error
yield item
else:
for val in ops.flatten_to_ops(decomposed):
yield from _decompose_dfs(val, args)


def decompose(
val: Any,
*,
Expand Down Expand Up @@ -200,55 +255,14 @@ def decompose(
"acceptable to keep."
)

if preserve_structure:
return _decompose_preserving_structure(
val,
intercepting_decomposer=intercepting_decomposer,
fallback_decomposer=fallback_decomposer,
keep=keep,
on_stuck_raise=on_stuck_raise,
)

def try_op_decomposer(val: Any, decomposer: Optional[OpDecomposer]) -> DecomposeResult:
if decomposer is None or not isinstance(val, ops.Operation):
return None
return decomposer(val)

output = []
queue: List[Any] = [val]
while queue:
item = queue.pop(0)
if isinstance(item, ops.Operation) and keep is not None and keep(item):
output.append(item)
continue

decomposed = try_op_decomposer(item, intercepting_decomposer)

if decomposed is NotImplemented or decomposed is None:
decomposed = decompose_once(item, default=None)

if decomposed is NotImplemented or decomposed is None:
decomposed = try_op_decomposer(item, fallback_decomposer)

if decomposed is not NotImplemented and decomposed is not None:
queue[:0] = ops.flatten_to_ops(decomposed)
continue

if not isinstance(item, ops.Operation) and isinstance(item, Iterable):
queue[:0] = ops.flatten_to_ops(item)
continue

if keep is not None and on_stuck_raise is not None:
if isinstance(on_stuck_raise, Exception):
raise on_stuck_raise
elif callable(on_stuck_raise):
error = on_stuck_raise(item)
if error is not None:
raise error

output.append(item)

return output
args = _DecomposeArgs(
intercepting_decomposer=intercepting_decomposer,
fallback_decomposer=fallback_decomposer,
keep=keep,
on_stuck_raise=on_stuck_raise,
preserve_structure=preserve_structure,
)
return [*_decompose_dfs(val, args)]


# pylint: disable=function-redefined
Expand Down Expand Up @@ -383,65 +397,4 @@ def _try_decompose_into_operations_and_qubits(
qid_shape_dict[q] = max(qid_shape_dict[q], level)
qubits = sorted(qubit_set)
return result, qubits, tuple(qid_shape_dict[q] for q in qubits)

return None, (), ()


def _decompose_preserving_structure(
val: Any,
*,
intercepting_decomposer: Optional[OpDecomposer] = None,
fallback_decomposer: Optional[OpDecomposer] = None,
keep: Optional[Callable[['cirq.Operation'], bool]] = None,
on_stuck_raise: Union[
None, Exception, Callable[['cirq.Operation'], Optional[Exception]]
] = _value_error_describing_bad_operation,
) -> List['cirq.Operation']:
"""Preserves structure (e.g. subcircuits) while decomposing ops.
This can be used to reduce a circuit to a particular gateset without
increasing its serialization size. See tests for examples.
"""

# This method provides a generated 'keep' to its decompose() calls.
# If the user-provided keep is not set, on_stuck_raise must be unset to
# ensure that failure to decompose does not generate errors.
on_stuck_raise = on_stuck_raise if keep is not None else None

from cirq.circuits import CircuitOperation, FrozenCircuit

visited_fcs = set()

def keep_structure(op: 'cirq.Operation'):
circuit = getattr(op.untagged, 'circuit', None)
if circuit is not None:
return circuit in visited_fcs
if keep is not None and keep(op):
return True

def dps_interceptor(op: 'cirq.Operation'):
if not isinstance(op.untagged, CircuitOperation):
if intercepting_decomposer is None:
return NotImplemented
return intercepting_decomposer(op)

new_fc = FrozenCircuit(
decompose(
op.untagged.circuit,
intercepting_decomposer=dps_interceptor,
fallback_decomposer=fallback_decomposer,
keep=keep_structure,
on_stuck_raise=on_stuck_raise,
)
)
visited_fcs.add(new_fc)
new_co = op.untagged.replace(circuit=new_fc)
return new_co if not op.tags else new_co.with_tags(*op.tags)

return decompose(
val,
intercepting_decomposer=dps_interceptor,
fallback_decomposer=fallback_decomposer,
keep=keep_structure,
on_stuck_raise=on_stuck_raise,
)

0 comments on commit b1e09a9

Please sign in to comment.