Skip to content

Commit

Permalink
Boolean Hamiltonian gate yields fewer gates (#4386)
Browse files Browse the repository at this point in the history
* Boolean Hamiltonian gate yields fewer gates

* Address some of the comments

* Address some of the comments

* Add unit test

* Expand unit test

* Fix unit test

* Add test that is failing but should pass

* More comprehensive tests

* Fix code and make unit tests pass

* Address comments
  • Loading branch information
tonybruguier committed Oct 12, 2021
1 parent 3d50921 commit bd2e63c
Show file tree
Hide file tree
Showing 2 changed files with 318 additions and 5 deletions.
197 changes: 192 additions & 5 deletions cirq-core/cirq/ops/boolean_hamiltonian.py
Expand Up @@ -19,7 +19,11 @@
by Stuart Hadfield, https://arxiv.org/pdf/1804.09130.pdf
[2] https://www.youtube.com/watch?v=AOKM9BkweVU is a useful intro
[3] https://github.com/rsln-s/IEEE_QW_2020/blob/master/Slides.pdf
[4] Efficient Quantum Circuits for Diagonal Unitaries Without Ancillas by Jonathan Welch, Daniel
Greenbaum, Sarah Mostame, and Alán Aspuru-Guzik, https://arxiv.org/abs/1306.3991
"""
import itertools
import functools

from typing import Any, Dict, Generator, List, Sequence, Tuple

Expand Down Expand Up @@ -112,6 +116,187 @@ def _decompose_(self):
)


def _gray_code_comparator(k1: Tuple[int, ...], k2: Tuple[int, ...], flip: bool = False) -> int:
"""Compares two Gray-encoded binary numbers.
Args:
k1: A tuple of ints, representing the bits that are one. For example, 6 would be (1, 2).
k2: The second number, represented similarly as k1.
flip: Whether to flip the comparison.
Returns:
-1 if k1 < k2 (or +1 if flip is true)
0 if k1 == k2
+1 if k1 > k2 (or -1 if flip is true)
"""
max_1 = k1[-1] if k1 else -1
max_2 = k2[-1] if k2 else -1
if max_1 != max_2:
return -1 if (max_1 < max_2) ^ flip else 1
if max_1 == -1:
return 0
return _gray_code_comparator(k1[0:-1], k2[0:-1], not flip)


def _simplify_commuting_cnots(
cnots: List[Tuple[int, int]], flip_control_and_target: bool
) -> Tuple[bool, List[Tuple[int, int]]]:
"""Attempts to commute CNOTs and remove cancelling pairs.
Commutation relations are based on 9 (flip_control_and_target=False) or 10
(flip_control_target=True) of [4]:
When flip_control_target=True:
CNOT(j, i) @ CNOT(j, k) = CNOT(j, k) @ CNOT(j, i)
───X─────── ───────X───
│ │
───@───@─── = ───@───@───
│ │
───────X─── ───X───────
When flip_control_target=False:
CNOT(i, j) @ CNOT(k, j) = CNOT(k, j) @ CNOT(i, j)
───@─────── ───────@───
│ │
───X───X─── = ───X───X───
│ │
───────@─── ───@───────
Args:
cnots: A list of CNOTS, encoded as integer tuples (control, target). The code does not make
any assumption as to the order of the CNOTs, but it is likely to work better if its
inputs are from Gray-sorted Hamiltonians. Regardless of the order of the CNOTs, the
code is conservative and should be robust to mis-ordered inputs with the only side
effect being a lack of simplification.
flip_control_and_target: Whether to flip control and target.
Returns:
A tuple containing a Boolean that tells whether a simplification has been performed and the
CNOT list, potentially simplified, encoded as integer tuples (control, target).
"""

target, control = (0, 1) if flip_control_and_target else (1, 0)

i = 0
qubit_to_index: Dict[int, int] = {cnots[i][control]: i} if cnots else {}
for j in range(1, len(cnots)):
if cnots[i][target] != cnots[j][target]:
# The targets (resp. control) don't match, so we reset the search.
i = j
qubit_to_index = {cnots[j][control]: j}
continue

if cnots[j][control] in qubit_to_index:
k = qubit_to_index[cnots[j][control]]
# The controls (resp. targets) are the same, so we can simplify away.
cnots = [cnots[n] for n in range(len(cnots)) if n != j and n != k]
# TODO(#4532): Speed up code by not returning early.
return True, cnots

qubit_to_index[cnots[j][control]] = j

return False, cnots


def _simplify_cnots_triplets(
cnots: List[Tuple[int, int]], flip_control_and_target: bool
) -> Tuple[bool, List[Tuple[int, int]]]:
"""Simplifies CNOT pairs according to equation 11 of [4].
CNOT(i, j) @ CNOT(j, k) == CNOT(j, k) @ CNOT(i, k) @ CNOT(i, j)
───@─────── ───────@───@───
│ │ │
───X───@─── = ───@───┼───X───
│ │ │
───────X─── ───X───X───────
Args:
cnots: A list of CNOTS, encoded as integer tuples (control, target).
flip_control_and_target: Whether to flip control and target.
Returns:
A tuple containing a Boolean that tells whether a simplification has been performed and the
CNOT list, potentially simplified, encoded as integer tuples (control, target).
"""
target, control = (0, 1) if flip_control_and_target else (1, 0)

# We investigate potential pivots sequentially.
for j in range(1, len(cnots) - 1):
# First, we look back for as long as the controls (resp. targets) are the same.
# They all commute, so all are potential candidates for being simplified.
# prev_match_index is qubit to index in `cnots` array.
prev_match_index: Dict[int, int] = {}
for i in range(j - 1, -1, -1):
# These CNOTs have the same target (resp. control) and though they are not candidates
# for simplification, since they commute, we can keep looking for candidates.
if cnots[i][target] == cnots[j][target]:
continue
if cnots[i][control] != cnots[j][control]:
break
# We take a note of the control (resp. target).
prev_match_index[cnots[i][target]] = i

# Next, we look forward for as long as the targets (resp. controls) are the
# same. They all commute, so all are potential candidates for being simplified.
# post_match_index is qubit to index in `cnots` array.
post_match_index: Dict[int, int] = {}
for k in range(j + 1, len(cnots)):
# These CNOTs have the same control (resp. target) and though they are not candidates
# for simplification, since they commute, we can keep looking for candidates.
if cnots[j][control] == cnots[k][control]:
continue
if cnots[j][target] != cnots[k][target]:
break
# We take a note of the target (resp. control).
post_match_index[cnots[k][control]] = k

# Among all the candidates, find if they have a match.
keys = prev_match_index.keys() & post_match_index.keys()
for key in keys:
# We perform the swap which removes the pivot.
new_idx: List[int] = (
# Anything strictly before the pivot that is not the CNOT to swap.
[idx for idx in range(0, j) if idx != prev_match_index[key]]
# The two swapped CNOTs.
+ [post_match_index[key], prev_match_index[key]]
# Anything after the pivot that is not the CNOT to swap.
+ [idx for idx in range(j + 1, len(cnots)) if idx != post_match_index[key]]
)
# Since we removed the pivot, the length should be one fewer.
cnots = [cnots[idx] for idx in new_idx]
# TODO(#4532): Speed up code by not returning early.
return True, cnots

return False, cnots


def _simplify_cnots(cnots: List[Tuple[int, int]]) -> List[Tuple[int, int]]:
"""Takes a series of CNOTs and tries to applies rule to cancel out gates.
Algorithm based on "Efficient quantum circuits for diagonal unitaries without ancillas" by
Jonathan Welch, Daniel Greenbaum, Sarah Mostame, Alán Aspuru-Guzik
https://arxiv.org/abs/1306.3991
Args:
cnots: A list of CNOTs represented as tuples of integer (control, target).
Returns:
The simplified list of CNOTs, encoded as integer tuples (control, target).
"""

found_simplification = True
while found_simplification:
for simplify_fn, flip_control_and_target in itertools.product(
[_simplify_commuting_cnots, _simplify_cnots_triplets], [False, True]
):
found_simplification, cnots = simplify_fn(cnots, flip_control_and_target)
if found_simplification:
break

return cnots


def _get_gates_from_hamiltonians(
hamiltonian_polynomial_list: List['cirq.PauliSum'],
qubit_map: Dict[str, 'cirq.Qid'],
Expand Down Expand Up @@ -145,16 +330,18 @@ def _apply_cnots(prevh: Tuple[int, ...], currh: Tuple[int, ...]):
cnots.extend((prevh[i], prevh[-1]) for i in range(len(prevh) - 1))
cnots.extend((currh[i], currh[-1]) for i in range(len(currh) - 1))

# TODO(tonybruguier): At this point, some CNOT gates can be cancelled out according to:
# "Efficient quantum circuits for diagonal unitaries without ancillas" by Jonathan Welch,
# Daniel Greenbaum, Sarah Mostame, Alán Aspuru-Guzik
# https://arxiv.org/abs/1306.3991
cnots = _simplify_cnots(cnots)

for gate in (cirq.CNOT(qubits[c], qubits[t]) for c, t in cnots):
yield gate

sorted_hamiltonian_keys = sorted(
hamiltonians.keys(), key=functools.cmp_to_key(_gray_code_comparator)
)

previous_h: Tuple[int, ...] = ()
for h, w in hamiltonians.items():
for h in sorted_hamiltonian_keys:
w = hamiltonians[h]
yield _apply_cnots(previous_h, h)

if len(h) >= 1:
Expand Down
126 changes: 126 additions & 0 deletions cirq-core/cirq/ops/boolean_hamiltonian_test.py
Expand Up @@ -11,14 +11,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import itertools
import math
import random

import numpy as np
import pytest
import sympy.parsing.sympy_parser as sympy_parser

import cirq
import cirq.ops.boolean_hamiltonian as bh


@pytest.mark.parametrize(
Expand Down Expand Up @@ -98,3 +101,126 @@ def test_with_custom_names():

with pytest.raises(ValueError, match='Length of replacement qubits must be the same'):
original_op.with_qubits(q2)


@pytest.mark.parametrize(
'n_bits,expected_hs',
[
(1, [(), (0,)]),
(2, [(), (0,), (0, 1), (1,)]),
(3, [(), (0,), (0, 1), (1,), (1, 2), (0, 1, 2), (0, 2), (2,)]),
],
)
def test_gray_code_sorting(n_bits, expected_hs):
hs_template = []
for x in range(2 ** n_bits):
h = []
for i in range(n_bits):
if x % 2 == 1:
h.append(i)
x -= 1
x //= 2
hs_template.append(tuple(sorted(h)))

for seed in range(10):
random.seed(seed)

hs = hs_template.copy()
random.shuffle(hs)

sorted_hs = sorted(list(hs), key=functools.cmp_to_key(bh._gray_code_comparator))

np.testing.assert_array_equal(sorted_hs, expected_hs)


@pytest.mark.parametrize(
'seq_a,seq_b,expected',
[
((), (), 0),
((), (0,), -1),
((0,), (), 1),
((0,), (0,), 0),
],
)
def test_gray_code_comparison(seq_a, seq_b, expected):
assert bh._gray_code_comparator(seq_a, seq_b) == expected


@pytest.mark.parametrize(
'input_cnots,input_flip_control_and_target,expected_simplified,expected_output_cnots',
[
# Empty inputs don't get simplified.
([], False, False, []),
([], True, False, []),
# Single CNOTs don't get simplified.
([(0, 1)], False, False, [(0, 1)]),
([(0, 1)], True, False, [(0, 1)]),
# Simplify away two CNOTs that are identical:
([(0, 1), (0, 1)], False, True, []),
([(0, 1), (0, 1)], True, True, []),
# Also simplify away if there's another CNOT in between.
([(0, 1), (2, 1), (0, 1)], False, True, [(2, 1)]),
([(0, 1), (0, 2), (0, 1)], True, True, [(0, 2)]),
# However, the in-between has to share the same target/control.
([(0, 1), (0, 2), (0, 1)], False, False, [(0, 1), (0, 2), (0, 1)]),
([(0, 1), (2, 1), (0, 1)], True, False, [(0, 1), (2, 1), (0, 1)]),
# Can simplify, but violates CNOT ordering assumption
([(0, 1), (2, 3), (0, 1)], False, False, [(0, 1), (2, 3), (0, 1)]),
],
)
def test_simplify_commuting_cnots(
input_cnots, input_flip_control_and_target, expected_simplified, expected_output_cnots
):
actual_simplified, actual_output_cnots = bh._simplify_commuting_cnots(
input_cnots, input_flip_control_and_target
)
assert actual_simplified == expected_simplified
assert actual_output_cnots == expected_output_cnots


@pytest.mark.parametrize(
'input_cnots,input_flip_control_and_target,expected_simplified,expected_output_cnots',
[
# Empty inputs don't get simplified.
([], False, False, []),
([], True, False, []),
# Single CNOTs don't get simplified.
([(0, 1)], False, False, [(0, 1)]),
([(0, 1)], True, False, [(0, 1)]),
# Simplify according to equation 11 of [4].
([(2, 1), (2, 0), (1, 0)], False, True, [(1, 0), (2, 1)]),
([(1, 2), (0, 2), (0, 1)], True, True, [(0, 1), (1, 2)]),
# Same as above, but with a intervening CNOTs that prevent simplifications.
([(2, 1), (2, 0), (100, 101), (1, 0)], False, False, [(2, 1), (2, 0), (100, 101), (1, 0)]),
([(2, 1), (100, 101), (2, 0), (1, 0)], False, False, [(2, 1), (100, 101), (2, 0), (1, 0)]),
# swap (2, 1) and (1, 0) around (2, 0)
([(2, 1), (2, 3), (2, 0), (3, 0), (1, 0)], False, True, [(2, 3), (1, 0), (2, 1), (3, 0)]),
([(2, 1), (2, 0), (2, 3), (3, 0), (1, 0)], False, True, [(1, 0), (2, 1), (2, 3), (3, 0)]),
([(2, 3), (2, 1), (2, 0), (3, 0), (1, 0)], False, True, [(2, 3), (1, 0), (2, 1), (3, 0)]),
([(2, 1), (2, 3), (3, 0), (2, 0), (1, 0)], False, True, [(2, 3), (3, 0), (1, 0), (2, 1)]),
([(2, 1), (2, 3), (2, 0), (1, 0), (3, 0)], False, True, [(2, 3), (1, 0), (2, 1), (3, 0)]),
],
)
def test_simplify_cnots_triplets(
input_cnots, input_flip_control_and_target, expected_simplified, expected_output_cnots
):
actual_simplified, actual_output_cnots = bh._simplify_cnots_triplets(
input_cnots, input_flip_control_and_target
)
assert actual_simplified == expected_simplified
assert actual_output_cnots == expected_output_cnots

# Check that the unitaries are the same.
qubit_ids = set(sum(input_cnots, ()))
qubits = {qubit_id: cirq.NamedQubit(f"{qubit_id}") for qubit_id in qubit_ids}

target, control = (0, 1) if input_flip_control_and_target else (1, 0)

circuit_input = cirq.Circuit()
for input_cnot in input_cnots:
circuit_input.append(cirq.CNOT(qubits[input_cnot[target]], qubits[input_cnot[control]]))
circuit_actual = cirq.Circuit()
for actual_cnot in actual_output_cnots:
circuit_actual.append(cirq.CNOT(qubits[actual_cnot[target]], qubits[actual_cnot[control]]))

np.testing.assert_allclose(cirq.unitary(circuit_input), cirq.unitary(circuit_actual), atol=1e-6)

0 comments on commit bd2e63c

Please sign in to comment.