Skip to content

Commit

Permalink
Refactor and speed up cirq.transformers.stratify (#6013)
Browse files Browse the repository at this point in the history
* refactor cirq.transformers.stratify

* fix one failing test

* nit renaming

* fix coverage

* formatting fix

* pylint fix

* fix bug with measurements in stratification

* add missing import

* fix bug with finding time index for op

* fix test, and nit change to keeping track of time indices

* hopefully fix measurement bug

* minor fix with ignored ops

* fix test

* only store shortest circuit found

* minor bugfix

* nit typing fix

* one more silly bugfig

* store shortest stratified circuit properly

* fix bug with overlapping measurements

* clean up handling of ignored ops

* further clean up logic deciding where to put an op

* factor out logic for finding earliest accomodating moment

* fix typo

* fix typo

* remove unnecesaary use of defaultdict

* further simplify logic in get_earliest_accommodating_moment_index

* fix lint check

* fix minor bug

* nit docstring update

* separately update qubit/mkey/ckey moments in cirq.stratify

* fix bug with max only getting one argument

* fix coverage check

* fix typo

* fix lint check

---------

Co-authored-by: Tanuj Khattar <tanujkhattar@google.com>
  • Loading branch information
perlinm and tanujkhattar committed Apr 3, 2023
1 parent 663d404 commit 6a97cca
Show file tree
Hide file tree
Showing 4 changed files with 241 additions and 121 deletions.
129 changes: 86 additions & 43 deletions cirq-core/cirq/circuits/circuit.py
Expand Up @@ -1776,12 +1776,10 @@ def _load_contents_with_earliest_strategy(self, contents: 'cirq.OP_TREE'):
Non-moment entries will be inserted according to the EARLIEST
insertion strategy.
"""
# These are dicts from the qubit/key to the greatest moment index that has it. It is safe
# to default to `-1`, as that is interpreted as meaning the zeroth index onward does not
# have this value.
qubit_indexes: Dict['cirq.Qid', int] = defaultdict(lambda: -1)
mkey_indexes: Dict['cirq.MeasurementKey', int] = defaultdict(lambda: -1)
ckey_indexes: Dict['cirq.MeasurementKey', int] = defaultdict(lambda: -1)
# These are dicts from the qubit/key to the greatest moment index that has it.
qubit_indices: Dict['cirq.Qid', int] = {}
mkey_indices: Dict['cirq.MeasurementKey', int] = {}
ckey_indices: Dict['cirq.MeasurementKey', int] = {}

# We also maintain the dict from moment index to moments/ops that go into it, for use when
# building the actual moments at the end.
Expand All @@ -1793,46 +1791,17 @@ def _load_contents_with_earliest_strategy(self, contents: 'cirq.OP_TREE'):

# "mop" means current moment-or-operation
for mop in ops.flatten_to_ops_or_moments(contents):
mop_qubits = mop.qubits
mop_mkeys = protocols.measurement_key_objs(mop)
mop_ckeys = protocols.control_keys(mop)

# Both branches define `i`, the moment index at which to place the mop.
# Identify the index of the moment to place this `mop` into.
placement_index = get_earliest_accommodating_moment_index(
mop, qubit_indices, mkey_indices, ckey_indices, length
)
length = max(length, placement_index + 1) # update the length of the circuit thus far

if isinstance(mop, Moment):
# We always append moment to the end, to be consistent with `self.append`
i = length
moments_by_index[i] = mop
moments_by_index[placement_index] = mop
else:
# Initially we define `i` as the greatest moment index that has a conflict. `-1` is
# the initial conflict, and we search for larger ones. Once we get the largest one,
# we increment i by 1 to set the placement index.
i = -1

# Look for the maximum conflict; i.e. a moment that has a qubit the same as one of
# this op's qubits, that has a measurement or control key the same as one of this
# op's measurement keys, or that has a measurement key the same as one of this op's
# control keys. (Control keys alone can commute past each other). The `ifs` are
# logically unnecessary but seem to make this slightly faster.
if mop_qubits:
i = max(i, *[qubit_indexes[q] for q in mop_qubits])
if mop_mkeys:
i = max(i, *[mkey_indexes[k] for k in mop_mkeys])
i = max(i, *[ckey_indexes[k] for k in mop_mkeys])
if mop_ckeys:
i = max(i, *[mkey_indexes[k] for k in mop_ckeys])
i += 1
op_lists_by_index[i].append(mop)

# Update our dicts with data from the latest mop placement. Note `i` will always be
# greater than the existing value for all of these, by construction, so there is no
# need to do a `max(i, existing)`.
for q in mop_qubits:
qubit_indexes[q] = i
for k in mop_mkeys:
mkey_indexes[k] = i
for k in mop_ckeys:
ckey_indexes[k] = i
length = max(length, i + 1)
op_lists_by_index[placement_index].append(mop)

# Finally, once everything is placed, we can construct and append the actual moments for
# each index.
Expand Down Expand Up @@ -2753,3 +2722,77 @@ def _group_until_different(items: Iterable[_TIn], key: Callable[[_TIn], _TKey],
Tuples containing the group key and item values.
"""
return ((k, [val(i) for i in v]) for (k, v) in itertools.groupby(items, key))


def get_earliest_accommodating_moment_index(
moment_or_operation: Union['cirq.Moment', 'cirq.Operation'],
qubit_indices: Dict['cirq.Qid', int],
mkey_indices: Dict['cirq.MeasurementKey', int],
ckey_indices: Dict['cirq.MeasurementKey', int],
length: Optional[int] = None,
) -> int:
"""Get the index of the earliest moment that can accomodate the given moment or operation.
Updates the dictionaries keeping track of the last moment index addressing a given qubit,
measurement key, and control key.
Args:
moment_or_operation: The moment operation in question.
qubit_indices: A dictionary mapping qubits to the latest moments that address them.
mkey_indices: A dictionary mapping measureent keys to the latest moments that address them.
ckey_indices: A dictionary mapping control keys to the latest moments that address them.
length: The length of the circuit that we are trying to insert a moment or operation into.
Should probably be equal to the maximum of the values in `qubit_indices`,
`mkey_indices`, and `ckey_indices`.
Returns:
The integer index of the earliest moment that can accomodate the given moment or operation.
"""
mop_qubits = moment_or_operation.qubits
mop_mkeys = protocols.measurement_key_objs(moment_or_operation)
mop_ckeys = protocols.control_keys(moment_or_operation)

if isinstance(moment_or_operation, Moment):
# For consistency with `Circuit.append`, moments always get placed at the end of a circuit.
if length is not None:
last_conflict = length - 1
else:
last_conflict = max(
[*qubit_indices.values(), *mkey_indices.values(), *ckey_indices.values(), -1]
)

else:
# We start by searching for the `latest_conflict` moment index, which we will increment by
# `1` to identify the earliest moment that *does not* conflict with the given operation.
# The `latest_conflict` is initialized to `-1` before searching for later conflicting
# moments.
last_conflict = -1

# Look for the maximum conflict; i.e. a moment that has a qubit the same as one of this op's
# qubits, that has a measurement or control key the same as one of this op's measurement
# keys, or that has a measurement key the same as one of this op's control keys. (Control
# keys alone can commute past each other). The `ifs` are logically unnecessary but seem to
# make this slightly faster.
if mop_qubits:
last_conflict = max(
last_conflict, *[qubit_indices.get(qubit, -1) for qubit in mop_qubits]
)
if mop_mkeys:
last_conflict = max(last_conflict, *[mkey_indices.get(key, -1) for key in mop_mkeys])
last_conflict = max(last_conflict, *[ckey_indices.get(key, -1) for key in mop_mkeys])
if mop_ckeys:
last_conflict = max(last_conflict, *[mkey_indices.get(key, -1) for key in mop_ckeys])

# The index of the moment to place this moment or operaton ("mop") into.
mop_index = last_conflict + 1

# Update our dicts with data from this `mop` placement. Note `mop_index` will always be greater
# than the existing value for all of these, by construction.
for qubit in mop_qubits:
qubit_indices[qubit] = mop_index
for key in mop_mkeys:
mkey_indices[key] = mop_index
for key in mop_ckeys:
ckey_indices[key] = mop_index

return mop_index
11 changes: 11 additions & 0 deletions cirq-core/cirq/circuits/circuit_test.py
Expand Up @@ -834,6 +834,17 @@ def test_insert_moment():
assert c.operation_at(qubit, actual_index) == operation[0]


def test_circuit_length_inference():
# tests that `get_earliest_accommodating_moment_index` properly computes circuit length
circuit = cirq.Circuit(cirq.X(cirq.q(0)))
qubit_indices = {cirq.q(0): 0}
mkey_indices = {}
ckey_indices = {}
assert circuits.circuit.get_earliest_accommodating_moment_index(
cirq.Moment(), qubit_indices, mkey_indices, ckey_indices
) == len(circuit)


def test_insert_inline_near_start():
a = cirq.NamedQubit('a')
b = cirq.NamedQubit('b')
Expand Down
186 changes: 124 additions & 62 deletions cirq-core/cirq/transformers/stratify.py
Expand Up @@ -15,10 +15,10 @@
"""Transformer pass to repack circuits avoiding simultaneous operations with different classes."""

import itertools
from typing import TYPE_CHECKING, Type, Callable, Optional, Union, Iterable, Sequence, List, Tuple
from typing import TYPE_CHECKING, Type, Callable, Dict, Optional, Union, Iterable, Sequence, List

from cirq import ops, circuits, _import
from cirq.transformers import transformer_api, transformer_primitives
from cirq import ops, circuits, protocols, _import
from cirq.transformers import transformer_api

drop_empty_moments = _import.LazyLoader('drop_empty_moments', globals(), 'cirq.transformers')

Expand Down Expand Up @@ -61,38 +61,36 @@ def stratified_circuit(
Returns:
A copy of the original circuit, but with re-arranged operations.
"""

# Normalize categories into classifier functions.
classifiers = [_category_to_classifier(category) for category in categories]
# Make the classifiers exhaustive by adding an "everything else" bucket.
and_the_rest = lambda op: all(not classifier(op) for classifier in classifiers)
classifiers_and_the_rest = [*classifiers, and_the_rest]
classifiers = _get_classifiers(circuit, categories)

# Try the algorithm with each permutation of the classifiers.
classifiers_permutations = list(itertools.permutations(classifiers_and_the_rest))
smallest_depth = protocols.num_qubits(circuit) * len(circuit) + 1
shortest_stratified_circuit = circuits.Circuit()
reversed_circuit = circuit[::-1]
solutions = []
for c in classifiers_permutations:
solutions.append(
_stratify_circuit(
circuit,
classifiers=list(c),
context=context or transformer_api.TransformerContext(),
)
for ordered_classifiers in itertools.permutations(classifiers):
solution = _stratify_circuit(
circuit,
classifiers=ordered_classifiers,
context=context or transformer_api.TransformerContext(),
)
if len(solution) < smallest_depth:
shortest_stratified_circuit = solution
smallest_depth = len(solution)

# Do the same thing, except this time in reverse. This helps for some
# circuits because it inserts operations at the end instead of at the
# beginning.
solutions.append(
_stratify_circuit(
reversed_circuit,
classifiers=list(c),
context=context or transformer_api.TransformerContext(),
)[::-1]
)
solution = _stratify_circuit(
reversed_circuit,
classifiers=ordered_classifiers,
context=context or transformer_api.TransformerContext(),
)[::-1]
if len(solution) < smallest_depth:
shortest_stratified_circuit = solution
smallest_depth = len(solution)

# Return the shortest circuit.
return min(solutions, key=lambda c: len(c))
return shortest_stratified_circuit


def _stratify_circuit(
Expand All @@ -116,43 +114,88 @@ def _stratify_circuit(
Returns:
The stratified circuit.
"""
num_categories = len(classifiers) + 1

def map_func(m: 'cirq.Moment', _) -> Sequence['cirq.Moment']:
stratified_ops: List[List['cirq.Operation']] = [[] for _ in range(num_categories)]
for op in m:
if set(op.tags) & set(context.tags_to_ignore):
stratified_ops[0].append(op)
continue
for i, classifier in enumerate(classifiers):
if classifier(op):
stratified_ops[i + 1].append(op)
break
return [circuits.Moment(op_list) for op_list in stratified_ops]

stratified_circuit = transformer_primitives.map_moments(circuit, map_func).unfreeze(copy=False)
assert len(stratified_circuit) == len(circuit) * num_categories

# Try to move operations to the left to reduce circuit depth, preserving stratification.
for curr_idx, moment in enumerate(stratified_circuit):
curr_category = curr_idx % num_categories
if curr_category == 0:
# Moment containing tagged operations to be ignored.
continue
batch_removals: List[Tuple[int, 'cirq.Operation']] = []
batch_inserts: List[Tuple[int, 'cirq.Operation']] = []
num_classes = len(classifiers) + 1 # include one "extra" category for ignored operations
new_moments: List[List['cirq.Operation']] = []

# Keep track of the the latest time index for each qubit, measurement key, and control key.
qubit_time_index: Dict['cirq.Qid', int] = {}
measurement_time_index: Dict['cirq.MeasurementKey', int] = {}
control_time_index: Dict['cirq.MeasurementKey', int] = {}

# The minimum time index for operations with a tag in context.tags_to_ignore.
last_ignored_ops_time_index = 0

for moment in circuit:
# Identify the new time indices that operations should be moved into.
ignored_ops = []
op_time_indices = {}
for op in moment:
prv_idx = stratified_circuit.earliest_available_moment(op, end_moment_index=curr_idx)
prv_category = prv_idx % num_categories
should_move_to_next_batch = curr_category < prv_category
prv_idx += curr_category - prv_category + num_categories * should_move_to_next_batch
assert prv_idx <= curr_idx and prv_idx % num_categories == curr_idx % num_categories
if prv_idx < curr_idx:
batch_inserts.append((prv_idx, op))
batch_removals.append((curr_idx, op))
stratified_circuit.batch_remove(batch_removals)
stratified_circuit.batch_insert_into(batch_inserts)
return drop_empty_moments.drop_empty_moments(stratified_circuit)

# Identify the earliest moment that can accommodate this op.
min_time_index_for_op = circuits.circuit.get_earliest_accommodating_moment_index(
op, qubit_time_index, measurement_time_index, control_time_index
)

# Identify the "class" of this operation (by index).
ignored_op = any(tag in op.tags for tag in context.tags_to_ignore)
if not ignored_op:
op_class = _get_op_class(op, classifiers)
else:
op_class = len(classifiers)
ignored_ops.append(op)
min_time_index_for_op = max(min_time_index_for_op, last_ignored_ops_time_index + 1)

# Identify the time index to place this operation into.
time_index = (min_time_index_for_op // num_classes) * num_classes + op_class
if time_index < min_time_index_for_op:
time_index += num_classes
op_time_indices[op] = time_index

# Assign ignored operations to the same moment.
if ignored_ops:
last_ignored_ops_time_index = max(op_time_indices[op] for op in ignored_ops)
for op in ignored_ops:
op_time_indices[op] = last_ignored_ops_time_index

# Move the operations into their assigned moments.
for op, time_index in op_time_indices.items():
if time_index >= len(new_moments):
new_moments += [[] for _ in range(num_classes)]
new_moments[time_index].append(op)

# Update qubit, measurment key, and control key moments.
for qubit in op.qubits:
qubit_time_index[qubit] = time_index
for key in protocols.measurement_key_objs(op):
measurement_time_index[key] = time_index
for key in protocols.control_keys(op):
control_time_index[key] = time_index

return circuits.Circuit(circuits.Moment(moment) for moment in new_moments if moment)


def _get_classifiers(
circuit: circuits.AbstractCircuit, categories: Iterable[Category]
) -> List[Classifier]:
"""Convert a collection of categories into a list of classifiers.
The returned list of classifiers is:
- Exhaustive, meaning every operation in the circuit is classified by at least one classifier.
- Minimal, meaning unused classifiers are forgotten.
"""
# Convert all categories into classifiers, and make the list exhaustive by adding a dummy
# classifier for otherwise unclassified ops.
classifiers = [_category_to_classifier(cat) for cat in categories] + [_dummy_classifier]

# Figure out which classes are actually used in the circuit.
class_is_used = [False for _ in classifiers]
for op in circuit.all_operations():
class_is_used[_get_op_class(op, classifiers)] = True
if all(class_is_used):
break

# Return only the classifiers that are used.
return [classifier for classifier, is_used in zip(classifiers, class_is_used) if is_used]


# No type for `category` because mypy does not keep the return type when
Expand All @@ -177,3 +220,22 @@ def _category_to_classifier(category) -> Classifier:
f'Type[cirq.Gate], Type[cirq.Operation], '
f'or Callable[[cirq.Operation], bool].'
)


def _dummy_classifier(op: 'cirq.Operation') -> bool:
"""Dummy classifier, used to "complete" a collection of classifiers and make it exhaustive."""


def _get_op_class(op: 'cirq.Operation', classifiers: Sequence[Classifier]) -> int:
"""Get the "class" of an operator, by index."""
for class_index, classifier in enumerate(classifiers):
if classifier is _dummy_classifier:
dummy_classifier_index = class_index
elif classifier(op):
return class_index
# If we got this far, the operation did not match any "actual" classifier,
# so return the index of the dummy classifer.
try:
return dummy_classifier_index
except NameError:
raise ValueError(f"Operation {op} not identified by any classifier")

0 comments on commit 6a97cca

Please sign in to comment.