Skip to content

Commit

Permalink
Apply variable-spaced optimization to QROM circuits (#6257)
Browse files Browse the repository at this point in the history
* Apply variable-spaced optimization to QROM circuits

* Fix flaky test due to a flakiness bug in GreedyQubitManager

* Fix mypy issues

* More tests and update hash for QROM since T-complexity now depends upon the data

* Fix typo and failing test
  • Loading branch information
tanujkhattar committed Aug 24, 2023
1 parent 6abc740 commit 5bbdc22
Show file tree
Hide file tree
Showing 5 changed files with 188 additions and 21 deletions.
33 changes: 29 additions & 4 deletions cirq-ft/cirq_ft/algos/qrom.py
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Callable, Sequence, Tuple
from typing import Callable, Sequence, Tuple, Set

import attr
import cirq
Expand All @@ -29,7 +29,11 @@ class QROM(unary_iteration_gate.UnaryIterationGate):
"""Gate to load data[l] in the target register when the selection stores an index l.
In the case of multi-dimensional data[p,q,r,...] we use multiple named
selection registers [p, q, r, ...] to index and load the data.
selection registers [p, q, r, ...] to index and load the data. Here `p, q, r, ...`
correspond to registers named `selection0`, `selection1`, `selection2`, ... etc.
When the input data elements contain consecutive entries of identical data elements to
load, the QROM also implements the "variable-spaced" QROM optimization described in Ref[2].
Args:
data: List of numpy ndarrays specifying the data to load. If the length
Expand All @@ -44,6 +48,15 @@ class QROM(unary_iteration_gate.UnaryIterationGate):
registers. This can be deduced from the maximum element of each of the
datasets. Should be of length len(data), i.e. the number of datasets.
num_controls: The number of control registers.
References:
[Encoding Electronic Spectra in Quantum Circuits with Linear T Complexity]
(https://arxiv.org/abs/1805.03662).
Babbush et. al. (2018). Figure 1.
[Compilation of Fault-Tolerant Quantum Heuristics for Combinatorial Optimization]
(https://arxiv.org/abs/2007.07391).
Babbush et. al. (2020). Figure 3.
"""

data: Sequence[NDArray]
Expand Down Expand Up @@ -152,11 +165,22 @@ def decompose_zero_selection(
yield cirq.inverse(multi_controlled_and)
context.qubit_manager.qfree(and_ancilla + [and_target])

def _break_early(self, selection_index_prefix: Tuple[int, ...], l: int, r: int):
global_unique_element: Set[int] = set()
for data in self.data:
unique_element = np.unique(data[selection_index_prefix][l:r])
if len(unique_element) > 1:
return False
global_unique_element.add(unique_element[0])
if len(global_unique_element) > 1:
return False
return True

def nth_operation(
self, context: cirq.DecompositionContext, control: cirq.Qid, **kwargs
) -> cirq.OP_TREE:
selection_idx = tuple(kwargs[reg.name] for reg in self.selection_registers)
target_regs = {k: v for k, v in kwargs.items() if k in self.target_registers}
target_regs = {reg.name: kwargs[reg.name] for reg in self.target_registers}
yield self._load_nth_data(selection_idx, lambda q: cirq.CNOT(control, q), **target_regs)

def _circuit_diagram_info_(self, _) -> cirq.CircuitDiagramInfo:
Expand All @@ -172,4 +196,5 @@ def __pow__(self, power: int):
return NotImplemented # pragma: no cover

def _value_equality_values_(self):
return (self.selection_registers, self.target_registers, self.control_registers)
data_tuple = tuple(tuple(d.flatten()) for d in self.data)
return (self.selection_registers, self.target_registers, self.control_registers, data_tuple)
74 changes: 74 additions & 0 deletions cirq-ft/cirq_ft/algos/qrom_test.py
Expand Up @@ -116,6 +116,80 @@ def test_t_complexity(data):
assert cirq_ft.t_complexity(g.gate).t == max(0, 4 * n - 8), n


def _assert_qrom_has_diagram(qrom: cirq_ft.QROM, expected_diagram: str):
gh = cirq_ft.testing.GateHelper(qrom)
op = gh.operation
context = cirq.DecompositionContext(qubit_manager=cirq_ft.GreedyQubitManager(prefix="anc"))
circuit = cirq.Circuit(cirq.decompose_once(op, context=context))
selection = [
*itertools.chain.from_iterable(gh.quregs[reg.name] for reg in qrom.selection_registers)
]
selection = [q for q in selection if q in circuit.all_qubits()]
anc = sorted(set(circuit.all_qubits()) - set(op.qubits))
selection_and_anc = (selection[0],) + sum(zip(selection[1:], anc), ())
qubit_order = cirq.QubitOrder.explicit(selection_and_anc, fallback=cirq.QubitOrder.DEFAULT)
cirq.testing.assert_has_diagram(circuit, expected_diagram, qubit_order=qubit_order)


def test_qrom_variable_spacing():
# Tests for variable spacing optimization applied from https://arxiv.org/abs/2007.07391
data = [1, 2, 3, 4, 5, 5, 6, 6, 7, 7, 7, 7, 8, 8, 8, 8] # Figure 3a.
assert cirq_ft.t_complexity(cirq_ft.QROM.build(data)).t == (8 - 2) * 4
data = [1, 2, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5] # Figure 3b.
assert cirq_ft.t_complexity(cirq_ft.QROM.build(data)).t == (5 - 2) * 4
data = [1, 2, 3, 4, 4, 4, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7] # Negative test: t count is not (g-2)*4
assert cirq_ft.t_complexity(cirq_ft.QROM.build(data)).t == (8 - 2) * 4
# Works as expected when multiple data arrays are to be loaded.
data = [1, 2, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5]
assert cirq_ft.t_complexity(cirq_ft.QROM.build(data, data)).t == (5 - 2) * 4
assert cirq_ft.t_complexity(cirq_ft.QROM.build(data, 2 * np.array(data))).t == (16 - 2) * 4
# Works as expected when multidimensional input data is to be loaded
qrom = cirq_ft.QROM.build(
np.array(
[
[1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1],
[2, 2, 2, 2, 2, 2, 2, 2],
[2, 2, 2, 2, 2, 2, 2, 2],
]
)
)
# Value to be loaded depends only the on the first bit of outer loop.
_assert_qrom_has_diagram(
qrom,
r'''
selection00: ───X───@───X───@───
│ │
target00: ──────────┼───────X───
target01: ──────────X───────────
''',
)
# When inner loop range is not a power of 2, the inner segment tree cannot be skipped.
qrom = cirq_ft.QROM.build(
np.array(
[[1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1], [2, 2, 2, 2, 2, 2], [2, 2, 2, 2, 2, 2]],
dtype=int,
)
)
_assert_qrom_has_diagram(
qrom,
r'''
selection00: ───X───@─────────@───────@──────X───@─────────@───────@──────
│ │ │ │ │ │
selection10: ───────(0)───────┼───────@──────────(0)───────┼───────@──────
│ │ │ │ │ │
anc_1: ─────────────And───@───X───@───And†───────And───@───X───@───And†───
│ │ │ │
target00: ────────────────┼───────┼────────────────────X───────X──────────
│ │
target01: ────────────────X───────X───────────────────────────────────────
''',
)
# No T-gates needed if all elements to load are identical.
assert cirq_ft.t_complexity(cirq_ft.QROM.build([3, 3, 3, 3])).t == 0


@pytest.mark.parametrize(
"data",
[[np.arange(6).reshape(2, 3), 4 * np.arange(6).reshape(2, 3)], [np.arange(8).reshape(2, 2, 2)]],
Expand Down
90 changes: 75 additions & 15 deletions cirq-ft/cirq_ft/algos/unary_iteration_gate.py
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import abc
from typing import Dict, Iterator, List, Sequence, Tuple
from typing import Callable, Dict, Iterator, List, Sequence, Tuple
from numpy.typing import NDArray

import cirq
Expand All @@ -34,6 +34,7 @@ def _unary_iteration_segtree(
r: int,
l_iter: int,
r_iter: int,
break_early: Callable[[int, int], bool],
) -> Iterator[Tuple[cirq.OP_TREE, cirq.Qid, int]]:
"""Constructs a unary iteration circuit by iterating over nodes of an implicit Segment Tree.
Expand All @@ -53,6 +54,11 @@ def _unary_iteration_segtree(
r: Right index of the range represented by current node of the segment tree.
l_iter: Left index of iteration range over which the segment tree should be constructed.
r_iter: Right index of iteration range over which the segment tree should be constructed.
break_early: For each internal node of the segment tree, `break_early(l, r)` is called to
evaluate whether the unary iteration should terminate early and not recurse in the
subtree of the node representing range `[l, r)`. If True, the internal node is
considered equivalent to a leaf node and the method yields only one tuple
`(OP_TREE, control_qubit, l)` for all integers in the range `[l, r)`.
Yields:
One `Tuple[cirq.OP_TREE, cirq.Qid, int]` for each leaf node in the segment tree. The i'th
Expand All @@ -68,8 +74,8 @@ def _unary_iteration_segtree(
if l >= r_iter or l_iter >= r:
# Range corresponding to this node is completely outside of iteration range.
return
if l == (r - 1):
# Reached a leaf node; yield the operations.
if l_iter <= l < r <= r_iter and (l == (r - 1) or break_early(l, r)):
# Reached a leaf node or a "special" internal node; yield the operations.
yield tuple(ops), control, l
ops.clear()
return
Expand All @@ -78,20 +84,24 @@ def _unary_iteration_segtree(
if r_iter <= m:
# Yield only left sub-tree.
yield from _unary_iteration_segtree(
ops, control, selection, ancilla, sl + 1, l, m, l_iter, r_iter
ops, control, selection, ancilla, sl + 1, l, m, l_iter, r_iter, break_early
)
return
if l_iter >= m:
# Yield only right sub-tree
yield from _unary_iteration_segtree(
ops, control, selection, ancilla, sl + 1, m, r, l_iter, r_iter
ops, control, selection, ancilla, sl + 1, m, r, l_iter, r_iter, break_early
)
return
anc, sq = ancilla[sl], selection[sl]
ops.append(and_gate.And((1, 0)).on(control, sq, anc))
yield from _unary_iteration_segtree(ops, anc, selection, ancilla, sl + 1, l, m, l_iter, r_iter)
yield from _unary_iteration_segtree(
ops, anc, selection, ancilla, sl + 1, l, m, l_iter, r_iter, break_early
)
ops.append(cirq.CNOT(control, anc))
yield from _unary_iteration_segtree(ops, anc, selection, ancilla, sl + 1, m, r, l_iter, r_iter)
yield from _unary_iteration_segtree(
ops, anc, selection, ancilla, sl + 1, m, r, l_iter, r_iter, break_early
)
ops.append(and_gate.And(adjoint=True).on(control, sq, anc))


Expand All @@ -101,16 +111,17 @@ def _unary_iteration_zero_control(
ancilla: Sequence[cirq.Qid],
l_iter: int,
r_iter: int,
break_early: Callable[[int, int], bool],
) -> Iterator[Tuple[cirq.OP_TREE, cirq.Qid, int]]:
sl, l, r = 0, 0, 2 ** len(selection)
m = (l + r) >> 1
ops.append(cirq.X(selection[0]))
yield from _unary_iteration_segtree(
ops, selection[0], selection[1:], ancilla, sl, l, m, l_iter, r_iter
ops, selection[0], selection[1:], ancilla, sl, l, m, l_iter, r_iter, break_early
)
ops.append(cirq.X(selection[0]))
yield from _unary_iteration_segtree(
ops, selection[0], selection[1:], ancilla, sl, m, r, l_iter, r_iter
ops, selection[0], selection[1:], ancilla, sl, m, r, l_iter, r_iter, break_early
)


Expand All @@ -121,9 +132,12 @@ def _unary_iteration_single_control(
ancilla: Sequence[cirq.Qid],
l_iter: int,
r_iter: int,
break_early: Callable[[int, int], bool],
) -> Iterator[Tuple[cirq.OP_TREE, cirq.Qid, int]]:
sl, l, r = 0, 0, 2 ** len(selection)
yield from _unary_iteration_segtree(ops, control, selection, ancilla, sl, l, r, l_iter, r_iter)
yield from _unary_iteration_segtree(
ops, control, selection, ancilla, sl, l, r, l_iter, r_iter, break_early
)


def _unary_iteration_multi_controls(
Expand All @@ -133,6 +147,7 @@ def _unary_iteration_multi_controls(
ancilla: Sequence[cirq.Qid],
l_iter: int,
r_iter: int,
break_early: Callable[[int, int], bool],
) -> Iterator[Tuple[cirq.OP_TREE, cirq.Qid, int]]:
num_controls = len(controls)
and_ancilla = ancilla[: num_controls - 2]
Expand All @@ -142,7 +157,7 @@ def _unary_iteration_multi_controls(
)
ops.append(multi_controlled_and)
yield from _unary_iteration_single_control(
ops, and_target, selection, ancilla[num_controls - 1 :], l_iter, r_iter
ops, and_target, selection, ancilla[num_controls - 1 :], l_iter, r_iter, break_early
)
ops.append(cirq.inverse(multi_controlled_and))

Expand All @@ -154,6 +169,7 @@ def unary_iteration(
controls: Sequence[cirq.Qid],
selection: Sequence[cirq.Qid],
qubit_manager: cirq.QubitManager,
break_early: Callable[[int, int], bool] = lambda l, r: False,
) -> Iterator[Tuple[cirq.OP_TREE, cirq.Qid, int]]:
"""The method performs unary iteration on `selection` integer in `range(l_iter, r_iter)`.
Expand Down Expand Up @@ -181,6 +197,9 @@ def unary_iteration(
... circuit.append(j_ops)
>>> circuit.append(i_ops)
Note: Unary iteration circuits assume that the selection register stores integers only in the
range `[l, r)` for which the corresponding unary iteration circuit should be built.
Args:
l_iter: Starting index of the iteration range.
r_iter: Ending index of the iteration range.
Expand All @@ -192,6 +211,11 @@ def unary_iteration(
controls: Control register of qubits.
selection: Selection register of qubits.
qubit_manager: A `cirq.QubitManager` to allocate new qubits.
break_early: For each internal node of the segment tree, `break_early(l, r)` is called to
evaluate whether the unary iteration should terminate early and not recurse in the
subtree of the node representing range `[l, r)`. If True, the internal node is
considered equivalent to a leaf node and the method yields only one tuple
`(OP_TREE, control_qubit, l)` for all integers in the range `[l, r)`.
Yields:
(r_iter - l_iter) different tuples, each corresponding to an integer in range
Expand All @@ -207,14 +231,16 @@ def unary_iteration(
assert len(selection) > 0
ancilla = qubit_manager.qalloc(max(0, len(controls) + len(selection) - 1))
if len(controls) == 0:
yield from _unary_iteration_zero_control(flanking_ops, selection, ancilla, l_iter, r_iter)
yield from _unary_iteration_zero_control(
flanking_ops, selection, ancilla, l_iter, r_iter, break_early
)
elif len(controls) == 1:
yield from _unary_iteration_single_control(
flanking_ops, controls[0], selection, ancilla, l_iter, r_iter
flanking_ops, controls[0], selection, ancilla, l_iter, r_iter, break_early
)
else:
yield from _unary_iteration_multi_controls(
flanking_ops, controls, selection, ancilla, l_iter, r_iter
flanking_ops, controls, selection, ancilla, l_iter, r_iter, break_early
)
qubit_manager.qfree(ancilla)

Expand All @@ -231,6 +257,9 @@ class UnaryIterationGate(infra.GateWithRegisters):
indexed operations on a target register depending on the index value stored in a selection
register.
Note: Unary iteration circuits assume that the selection register stores integers only in the
range `[l, r)` for which the corresponding unary iteration circuit should be built.
References:
[Encoding Electronic Spectra in Quantum Circuits with Linear T Complexity]
(https://arxiv.org/abs/1805.03662).
Expand Down Expand Up @@ -308,10 +337,38 @@ def decompose_zero_selection(
"""
raise NotImplementedError("Selection register must not be empty.")

def _break_early(self, selection_index_prefix: Tuple[int, ...], l: int, r: int) -> bool:
"""Derived classes should override this method to specify an early termination condition.
For each internal node of the unary iteration segment tree, `break_early(l, r)` is called
to evaluate whether the unary iteration should not recurse in the subtree of the node
representing range `[l, r)`. If True, the internal node is considered equivalent to a leaf
node and thus, `self.nth_operation` will be called for only integer `l` in the range [l, r).
When the `UnaryIteration` class is constructed using multiple selection registers, i.e. we
wish to perform nested coherent for-loops, a unary iteration segment tree is constructed
corresponding to each nested coherent for-loop. For every such unary iteration segment tree,
the `_break_early` condition is checked by passing the `selection_index_prefix` tuple.
Args:
selection_index_prefix: To evaluate the early breaking condition for the i'th nested
for-loop, the `selection_index_prefix` contains `i-1` integers corresponding to
the loop variable values for the first `i-1` nested loops.
l: Beginning of range `[l, r)` for internal node of unary iteration segment tree.
r: End (exclusive) of range `[l, r)` for internal node of unary iteration segment tree.
Returns:
True of the `len(selection_index_prefix)`'th unary iteration should terminate early for
the given parameters.
"""
return False

def decompose_from_registers(
self, *, context: cirq.DecompositionContext, **quregs: NDArray[cirq.Qid]
) -> cirq.OP_TREE:
if self.selection_registers.total_bits() == 0:
if self.selection_registers.total_bits() == 0 or self._break_early(
(), 0, self.selection_registers[0].iteration_length
):
return self.decompose_zero_selection(context=context, **quregs)

num_loops = len(self.selection_registers)
Expand Down Expand Up @@ -354,20 +411,23 @@ def unary_iteration_loops(
return
# Use recursion to write `num_loops` nested loops using unary_iteration().
ops: List[cirq.Operation] = []
selection_index_prefix = tuple(selection_reg_name_to_val.values())
ith_for_loop = unary_iteration(
l_iter=0,
r_iter=self.selection_registers[nested_depth].iteration_length,
flanking_ops=ops,
controls=controls,
selection=[*quregs[self.selection_registers[nested_depth].name]],
qubit_manager=context.qubit_manager,
break_early=lambda l, r: self._break_early(selection_index_prefix, l, r),
)
for op_tree, control_qid, n in ith_for_loop:
yield op_tree
selection_reg_name_to_val[self.selection_registers[nested_depth].name] = n
yield from unary_iteration_loops(
nested_depth + 1, selection_reg_name_to_val, (control_qid,)
)
selection_reg_name_to_val.pop(self.selection_registers[nested_depth].name)
yield ops

return unary_iteration_loops(0, {}, self.control_registers.merge_qubits(**quregs))
Expand Down

0 comments on commit 5bbdc22

Please sign in to comment.