Skip to content

Commit

Permalink
Implemented 8n T complexity decomposition of LessThanEqual gate (#6156)
Browse files Browse the repository at this point in the history
  • Loading branch information
NoureldinYosri committed Jul 17, 2023
1 parent cb05a69 commit c93224e
Show file tree
Hide file tree
Showing 4 changed files with 374 additions and 12 deletions.
2 changes: 2 additions & 0 deletions cirq-ft/cirq_ft/algos/__init__.py
Expand Up @@ -20,6 +20,8 @@
ContiguousRegisterGate,
LessThanEqualGate,
LessThanGate,
SingleQubitCompare,
BiQubitsMixer,
)
from cirq_ft.algos.generic_select import GenericSelect
from cirq_ft.algos.hubbard_model import PrepareHubbard, SelectHubbard
Expand Down
271 changes: 267 additions & 4 deletions cirq-ft/cirq_ft/algos/arithmetic_gates.py
Expand Up @@ -12,8 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Iterable, Optional, Sequence, Tuple, Union
from typing import Iterable, Optional, Sequence, Tuple, Union, List, Iterator

from cirq._compat import cached_property
import attr
import cirq
from cirq_ft import infra
Expand Down Expand Up @@ -78,7 +79,7 @@ def _decompose_with_context_(
return
adjoint = []

[are_equal] = context.qubit_manager.qalloc(1)
(are_equal,) = context.qubit_manager.qalloc(1)

# Initially our belief is that the numbers are equal.
yield cirq.X(are_equal)
Expand Down Expand Up @@ -130,6 +131,147 @@ def _t_complexity_(self) -> infra.TComplexity:
)


@attr.frozen
class BiQubitsMixer(infra.GateWithRegisters):
"""Implements the COMPARE2 (Fig. 1) https://static-content.springer.com/esm/art%3A10.1038%2Fs41534-018-0071-5/MediaObjects/41534_2018_71_MOESM1_ESM.pdf
This gates mixes the values in a way that preserves the result of comparison.
The registers being compared are 2-qubit registers where
x = 2*x_msb + x_lsb
y = 2*y_msb + y_lsb
The Gate mixes the 4 qubits so that sign(x - y) = sign(x_lsb' - y_lsb') where x_lsb' and y_lsb'
are the final values of x_lsb' and y_lsb'.
""" # pylint: disable=line-too-long

adjoint: bool = False

@cached_property
def registers(self) -> infra.Registers:
return infra.Registers.build(x=2, y=2, ancilla=3)

def __repr__(self) -> str:
return f'cirq_ft.algos.BiQubitsMixer({self.adjoint})'

def decompose_from_registers(
self, *, context: cirq.DecompositionContext, **quregs: Sequence[cirq.Qid]
) -> cirq.OP_TREE:
x, y, ancilla = quregs['x'], quregs['y'], quregs['ancilla']
x_msb, x_lsb = x
y_msb, y_lsb = y

def _cswap(control: cirq.Qid, a: cirq.Qid, b: cirq.Qid, aux: cirq.Qid) -> cirq.OP_TREE:
"""A CSWAP with 4T complexity and whose adjoint has 0T complexity.
A controlled SWAP that swaps `a` and `b` based on `control`.
It uses an extra qubit `aux` so that its adjoint would have
a T complexity of zero.
"""
yield cirq.CNOT(a, b)
yield and_gate.And(adjoint=self.adjoint).on(control, b, aux)
yield cirq.CNOT(aux, a)
yield cirq.CNOT(a, b)

def _decomposition():
# computes the difference of x - y where
# x = 2*x_msb + x_lsb
# y = 2*y_msb + y_lsb
# And stores the result in x_lsb and y_lsb such that
# sign(x - y) = sign(x_lsb - y_lsb)
# This decomposition uses 3 ancilla qubits in order to have a
# T complexity of 8.
yield cirq.X(ancilla[0])
yield cirq.CNOT(y_msb, x_msb)
yield cirq.CNOT(y_lsb, x_lsb)
yield from _cswap(x_msb, x_lsb, ancilla[0], ancilla[1])
yield from _cswap(x_msb, y_msb, y_lsb, ancilla[2])
yield cirq.CNOT(y_lsb, x_lsb)

if self.adjoint:
yield from reversed(tuple(cirq.flatten_to_ops(_decomposition())))
else:
yield from _decomposition()

def __pow__(self, power: int) -> cirq.Gate:
if power == 1:
return self
if power == -1:
return BiQubitsMixer(adjoint=not self.adjoint)
return NotImplemented # coverage: ignore

def _t_complexity_(self) -> infra.TComplexity:
if self.adjoint:
return infra.TComplexity(clifford=18)
return infra.TComplexity(t=8, clifford=28)

def _has_unitary_(self):
return not self.adjoint


@attr.frozen
class SingleQubitCompare(infra.GateWithRegisters):
"""Applies U|a>|b>|0>|0> = |a> |a=b> |(a<b)> |(a>b)>
Source: (FIG. 3) in https://static-content.springer.com/esm/art%3A10.1038%2Fs41534-018-0071-5/MediaObjects/41534_2018_71_MOESM1_ESM.pdf
""" # pylint: disable=line-too-long

adjoint: bool = False

@cached_property
def registers(self) -> infra.Registers:
return infra.Registers.build(a=1, b=1, less_than=1, greater_than=1)

def __repr__(self) -> str:
return f'cirq_ft.algos.SingleQubitCompare({self.adjoint})'

def decompose_from_registers(
self, *, context: cirq.DecompositionContext, **quregs: Sequence[cirq.Qid]
) -> cirq.OP_TREE:
a = quregs['a']
b = quregs['b']
less_than = quregs['less_than']
greater_than = quregs['greater_than']

def _decomposition() -> Iterator[cirq.Operation]:
yield and_gate.And((0, 1), adjoint=self.adjoint).on(*a, *b, *less_than)
yield cirq.CNOT(*less_than, *greater_than)
yield cirq.CNOT(*b, *greater_than)
yield cirq.CNOT(*a, *b)
yield cirq.CNOT(*a, *greater_than)
yield cirq.X(*b)

if self.adjoint:
yield from reversed(tuple(_decomposition()))
else:
yield from _decomposition()

def __pow__(self, power: int) -> cirq.Gate:
if not isinstance(power, int):
raise ValueError('SingleQubitCompare is only defined for integer powers.')
if power % 2 == 0:
return cirq.IdentityGate(4)
if power < 0:
return SingleQubitCompare(adjoint=not self.adjoint)
return self

def _t_complexity_(self) -> infra.TComplexity:
if self.adjoint:
return infra.TComplexity(clifford=11)
return infra.TComplexity(t=4, clifford=16)


def _equality_with_zero(
context: cirq.DecompositionContext, qubits: Sequence[cirq.Qid], z: cirq.Qid
) -> cirq.OP_TREE:
if len(qubits) == 1:
(q,) = qubits
yield cirq.X(q)
yield cirq.CNOT(q, z)
return

ancilla = context.qubit_manager.qalloc(len(qubits) - 2)
yield and_gate.And(cv=[0] * len(qubits)).on(*qubits, *ancilla, z)


@attr.frozen
class LessThanEqualGate(cirq.ArithmeticGate):
"""Applies U|x>|y>|z> = |x>|y> |z ^ (x <= y)>"""
Expand Down Expand Up @@ -161,9 +303,130 @@ def __pow__(self, power: int):
def __repr__(self) -> str:
return f'cirq_ft.LessThanEqualGate({self.x_bitsize}, {self.y_bitsize})'

def _decompose_via_tree(
self, context: cirq.DecompositionContext, X: Sequence[cirq.Qid], Y: Sequence[cirq.Qid]
) -> cirq.OP_TREE:
"""Returns comparison oracle from https://static-content.springer.com/esm/art%3A10.1038%2Fs41534-018-0071-5/MediaObjects/41534_2018_71_MOESM1_ESM.pdf
This decomposition follows the tree structure of (FIG. 2)
""" # pylint: disable=line-too-long
if len(X) == 1:
return
if len(X) == 2:
yield BiQubitsMixer().on_registers(x=X, y=Y, ancilla=context.qubit_manager.qalloc(3))
return

m = len(X) // 2
yield self._decompose_via_tree(context, X[:m], Y[:m])
yield self._decompose_via_tree(context, X[m:], Y[m:])
yield BiQubitsMixer().on_registers(
x=(X[m - 1], X[-1]), y=(Y[m - 1], Y[-1]), ancilla=context.qubit_manager.qalloc(3)
)

def _decompose_with_context_(
self, qubits: Sequence[cirq.Qid], context: Optional[cirq.DecompositionContext] = None
) -> cirq.OP_TREE:
"""Decomposes the gate in a T-complexity optimal way.
The construction can be broken in 4 parts:
1. In case of differing bitsizes then a multicontrol And Gate
- Section III.A. https://arxiv.org/abs/1805.03662) is used to check whether
the extra prefix is equal to zero:
- result stored in: `prefix_equality` qubit.
2. The tree structure (FIG. 2) https://static-content.springer.com/esm/art%3A10.1038%2Fs41534-018-0071-5/MediaObjects/41534_2018_71_MOESM1_ESM.pdf
followed by a SingleQubitCompare to compute the result of comparison of
the suffixes of equal length:
- result stored in: `less_than` and `greater_than` with equality in qubits[-2]
3. The results from the previous two steps are combined to update the target qubit.
4. The adjoint of the previous operations is added to restore the input qubits
to their original state and clean the ancilla qubits.
""" # pylint: disable=line-too-long

if context is None:
context = cirq.DecompositionContext(cirq.ops.SimpleQubitManager())

lhs, rhs, target = qubits[: self.x_bitsize], qubits[self.x_bitsize : -1], qubits[-1]

n = min(len(lhs), len(rhs))

prefix_equality = None
adjoint: List[cirq.Operation] = []

# if one of the registers is longer than the other store equality with |0--0>
# into `prefix_equality` using d = |len(P) - len(Q)| And operations => 4d T.
if len(lhs) != len(rhs):
(prefix_equality,) = context.qubit_manager.qalloc(1)
if len(lhs) > len(rhs):
for op in cirq.flatten_to_ops(
_equality_with_zero(context, lhs[:-n], prefix_equality)
):
yield op
adjoint.append(cirq.inverse(op))
else:
for op in cirq.flatten_to_ops(
_equality_with_zero(context, rhs[:-n], prefix_equality)
):
yield op
adjoint.append(cirq.inverse(op))

yield cirq.X(target), cirq.CNOT(prefix_equality, target)

# compare the remaing suffix of P and Q
lhs = lhs[-n:]
rhs = rhs[-n:]
for op in cirq.flatten_to_ops(self._decompose_via_tree(context, lhs, rhs)):
yield op
adjoint.append(cirq.inverse(op))

less_than, greater_than = context.qubit_manager.qalloc(2)
yield SingleQubitCompare().on_registers(
a=lhs[-1], b=rhs[-1], less_than=less_than, greater_than=greater_than
)
adjoint.append(
SingleQubitCompare(adjoint=True).on_registers(
a=lhs[-1], b=rhs[-1], less_than=less_than, greater_than=greater_than
)
)

if prefix_equality is None:
yield cirq.X(target)
yield cirq.CNOT(greater_than, target)
else:
(less_than_or_equal,) = context.qubit_manager.qalloc(1)
yield and_gate.And([1, 0]).on(prefix_equality, greater_than, less_than_or_equal)
adjoint.append(
and_gate.And([1, 0], adjoint=True).on(
prefix_equality, greater_than, less_than_or_equal
)
)

yield cirq.CNOT(less_than_or_equal, target)

yield from reversed(adjoint)

def _t_complexity_(self) -> infra.TComplexity:
# TODO(#112): This is rough cost that ignores cliffords.
return infra.TComplexity(t=4 * (self.x_bitsize + self.y_bitsize))
n = min(self.x_bitsize, self.y_bitsize)
d = max(self.x_bitsize, self.y_bitsize) - n
is_second_longer = self.y_bitsize > self.x_bitsize
if d == 0:
# When both registers are of the same size the T complexity is
# 8n - 4 same as in https://static-content.springer.com/esm/art%3A10.1038%2Fs41534-018-0071-5/MediaObjects/41534_2018_71_MOESM1_ESM.pdf. pylint: disable=line-too-long
return infra.TComplexity(t=8 * n - 4, clifford=46 * n - 17)
else:
# When the registers differ in size and `n` is the size of the smaller one and
# `d` is the difference in size. The T complexity is the sum of the tree
# decomposition as before giving 8n + O(1) and the T complexity of an `And` gate
# over `d` registers giving 4d + O(1) totaling 8n + 4d + O(1).
# From the decomposition we get that the constant is -4 as well as the clifford counts.
if d == 1:
return infra.TComplexity(t=8 * n, clifford=46 * n + 3 + 2 * is_second_longer)
else:
return infra.TComplexity(
t=8 * n + 4 * d - 4, clifford=46 * n + 17 * d - 14 + 2 * is_second_longer
)

def _has_unitary_(self):
return True


@attr.frozen
Expand Down

0 comments on commit c93224e

Please sign in to comment.