Skip to content

Commit

Permalink
Sane way of doing diagonal basis commuting (#148)
Browse files Browse the repository at this point in the history
Cleaned up the method and made the key something sane that keeps track
of fewest physical labels.  Should be safe to use now with the new
PauliTerm objects that use an OrderedDict as a storage device for tensor
product terms.
  • Loading branch information
ncrubin committed Apr 4, 2018
1 parent 2df9832 commit c1603ec
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 66 deletions.
2 changes: 1 addition & 1 deletion grove/measurements/estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ def get_parity(pauli_terms, bitstring_results):
for term in pauli_terms:
qubit_set.extend(list(term.get_qubits()))
active_qubit_indices = sorted(list(set(qubit_set)))

index_mapper = dict(zip(active_qubit_indices,
range(len(active_qubit_indices))))

Expand Down Expand Up @@ -140,6 +139,7 @@ def estimate_pauli_sum(pauli_terms, basis_transform_dict, program,

pauli_for_rotations = PauliTerm.from_list(
[(value, key) for key, value in basis_transform_dict.items()])

post_rotations = get_rotation_program(pauli_for_rotations)

coeff_vec = np.array(
Expand Down
138 changes: 80 additions & 58 deletions grove/measurements/term_grouping.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
This augments the existing infrastructure in pyquil that finds commuting sets
of PauliTerms.
"""
from pyquil.paulis import check_commutation, is_identity, PauliTerm, PauliSum
from functools import reduce
from pyquil.paulis import is_identity, PauliTerm, PauliSum


def _commutes(p1, p2):
Expand All @@ -20,7 +21,48 @@ def _commutes(p1, p2):
return p1.id() == p2.id()


def _max_key_overlap(pauli_term, diagonal_sets, active_qubits):
def diagonal_basis_commutes(pauli_a, pauli_b):
"""
Test if `pauli_a` and `pauli_b` share a diagonal basis
Example:
Check if [A, B] with the constraint that A & B must share a one-qubit
diagonalizing basis. If the inputs were [sZ(0), sZ(0) * sZ(1)] then this
function would return True. If the inputs were [sX(5), sZ(4)] this
function would return True. If the inputs were [sX(0), sY(0) * sZ(2)]
this function would return False.
:param pauli_a: Pauli term to check commutation against `pauli_b`
:param pauli_b: Pauli term to check commutation against `pauli_a`
:return: Boolean of commutation result
:rtype: Bool
"""
overlapping_active_qubits = set(pauli_a.get_qubits()) & set(pauli_b.get_qubits())
for qubit_index in overlapping_active_qubits:
if (pauli_a[qubit_index] != 'I' and pauli_b[qubit_index] != 'I' and
pauli_a[qubit_index] != pauli_b[qubit_index]):
return False

return True


def get_diagonalizing_basis(list_of_pauli_terms):
"""
Find the Pauli Term with the most non-identity terms
:param list_of_pauli_terms: List of Pauli terms to check
:return: The highest weight Pauli Term
:rtype: PauliTerm
"""
qubit_ops = set(reduce(lambda x, y: x + y,
[list(term._ops.items()) for term in list_of_pauli_terms]))
qubit_ops = sorted(list(qubit_ops), key=lambda x: x[0])

return PauliTerm.from_list(list(map(lambda x: tuple(reversed(x)), qubit_ops)))


def _max_key_overlap(pauli_term, diagonal_sets):
"""
Calculate the max overlap of a pauli term ID with keys of diagonal_sets
Expand All @@ -34,41 +76,45 @@ def _max_key_overlap(pauli_term, diagonal_sets, active_qubits):
and list of PauliTerms that share that basis
:rtype: dict
"""
hash_ptp = tuple([pauli_term[n] for n in active_qubits])

keys = list(diagonal_sets.keys())
# if there are keys check for collisions if not return updated
# diagonal_set dictionary with the key and term added
for key in keys: # for each key check any collisions
for idx, pauli_tensor_element in enumerate(key):
if ((pauli_tensor_element != 'I' and hash_ptp[idx] != 'I')
and hash_ptp[idx] != pauli_tensor_element):
# item has collision with this key
# so must be a different key or new key
break
else:
# we've gotten to the end without finding a difference!
# that means this key works with this pauli term!
# Now we must select the longer of the two keys
# longer is the key or item with fewer identities
new_key = []
for ii in range(len(hash_ptp)):
if hash_ptp[ii] != 'I':
new_key.append(hash_ptp[ii])
elif key[ii] != 'I':
new_key.append(key[ii])
else:
new_key.append('I')

if tuple(new_key) in diagonal_sets.keys():
diagonal_sets[tuple(new_key)].append(pauli_term)
else:
diagonal_sets[tuple(new_key)] = diagonal_sets[key]
diagonal_sets[tuple(new_key)].append(pauli_term)
# a lot of the ugliness comes from the fact that
# list(PauliTerm._ops.items()) is not the appropriate input for
# Pauliterm.from_list()
for key in list(diagonal_sets.keys()):
pauli_from_key = PauliTerm.from_list(
list(map(lambda x: tuple(reversed(x)), key)))
if diagonal_basis_commutes(pauli_term, pauli_from_key):
updated_pauli_set = diagonal_sets[key] + [pauli_term]
diagonalizing_term = get_diagonalizing_basis(updated_pauli_set)
if len(diagonalizing_term) > len(key):
del diagonal_sets[key]
new_key = tuple(sorted(diagonalizing_term._ops.items(),
key=lambda x: x[0]))
diagonal_sets[new_key] = updated_pauli_set
else:
diagonal_sets[key] = updated_pauli_set
return diagonal_sets
# made it through all keys and sets so need to make a new set
else:
# always need to sort because new pauli term functionality
new_key = tuple(sorted(pauli_term._ops.items(), key=lambda x: x[0]))
diagonal_sets[new_key] = [pauli_term]
return diagonal_sets


def commuting_sets_by_zbasis(pauli_sums):
"""
Computes commuting sets based on terms having the same diagonal basis
Following the technique outlined in the appendix of arXiv:1704.05018.
:param pauli_sums: PauliSum object to group
:return: dictionary where key value pair is a tuple corresponding to the
basis and a list of PauliTerms associated with that basis.
"""
diagonal_sets = {}
for term in pauli_sums:
diagonal_sets = _max_key_overlap(term, diagonal_sets)

diagonal_sets[hash_ptp] = [pauli_term]
return diagonal_sets


Expand Down Expand Up @@ -127,30 +173,6 @@ def commuting_sets_by_indices(pauli_sums, commutation_check):
return group_inds


def commuting_sets_by_zbasis(pauli_sums):
"""
Computes commuting sets based on terms having the same diagonal basis
Following the technique outlined in the appendix of arXiv:1704.05018.
:param pauli_sums: PauliSum object to group
:return: dictionary where key value pair is a tuple corresponding to the
basis and a list of PauliTerms associated with that basis.
"""
active_qubits = []
for term in pauli_sums:
active_qubits += list(term.get_qubits())
# get unique indices and put in order from least to greatest
# NOTE: translation layer to physical qubits is likely to be needed
active_qubits = sorted(list(set(active_qubits)))

diagonal_sets = {}
for term in pauli_sums:
diagonal_sets = _max_key_overlap(term, diagonal_sets, active_qubits)

return diagonal_sets


def commuting_sets_trivial(pauli_sum):
"""
Group a pauli term into commuting sets using trivial check
Expand Down
58 changes: 51 additions & 7 deletions grove/tests/measurements/test_term_grouping.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
from grove.measurements.term_grouping import (check_trivial_commutation,
commuting_sets_by_indices,
commuting_sets_by_zbasis,
commuting_sets_trivial)
commuting_sets_trivial,
get_diagonalizing_basis,
diagonal_basis_commutes)


def test_check_trivial_commutation_type():
Expand Down Expand Up @@ -114,6 +116,19 @@ def test_commuting_sets_4():
assert actual == desired


def test_term_grouping_weird_term():
term1 = PauliTerm.from_list([('X', 1), ('Z', 2), ('Y', 3), ('Y', 5),
('Z', 6), ('X', 7)],
coefficient=0.012870253243021476)

term2 = PauliTerm.from_list([('Z', 0), ('Z', 6)],
coefficient=0.13131672212575296)

term_dictionary = commuting_sets_by_zbasis(term1 + term2)
true_term_key = ((0, 'Z'), (1, 'X'), (2, 'Z'), (3, 'Y'), (5, 'Y'), (6, 'Z'), (7, 'X'))
assert list(term_dictionary.keys())[0] == true_term_key


def test_term_grouping():
"""
Test clumping terms into terms that share the same diagonal basis
Expand All @@ -124,8 +139,9 @@ def test_term_grouping():
zz_term = sZ(0) * sZ(1)
h2_hamiltonian = zz_term + z2_term + z1_term + x_term
clumped_terms = commuting_sets_by_zbasis(h2_hamiltonian)
true_set = {('X', 'X'): set([x_term.id()]),
('Z', 'Z'): set([z1_term.id(), z2_term.id(), zz_term.id()])}
true_set = {((0, 'X'), (1, 'X')): set([x_term.id()]),
((0, 'Z'), (1, 'Z')): set([z1_term.id(), z2_term.id(), zz_term.id()])}

for key, value in clumped_terms.items():
assert set(map(lambda x: x.id(), clumped_terms[key])) == true_set[key]

Expand All @@ -140,10 +156,38 @@ def test_term_grouping():
pauli_sum = zzzz_terms + xzxz_terms + xxxx_terms + yyyy_terms
clumped_terms = commuting_sets_by_zbasis(pauli_sum)

true_set = {('Z', 'Z', 'Z', 'Z'): set(map(lambda x: x.id(), zzzz_terms)),
('X', 'Z', 'X', 'Z'): set(map(lambda x: x.id(), xzxz_terms)),
('X', 'X', 'X', 'X'): set(map(lambda x: x.id(), xxxx_terms)),
('Y', 'Y', 'Y', 'Y'): set(map(lambda x: x.id(), yyyy_terms))}
true_set = {((1, 'Z'), (2, 'Z'), (3, 'Z'), (4, 'Z')): set(map(lambda x: x.id(), zzzz_terms)),
((1, 'X'), (2, 'Z'), (3, 'X'), (4, 'Z')): set(map(lambda x: x.id(), xzxz_terms)),
((1, 'X'), (2, 'X'), (3, 'X'), (4, 'X')): set(map(lambda x: x.id(), xxxx_terms)),
((1, 'Y'), (2, 'Y'), (3, 'Y'), (4, 'Y')): set(map(lambda x: x.id(), yyyy_terms))}
for key, value in clumped_terms.items():
assert set(map(lambda x: x.id(), clumped_terms[key])) == true_set[key]


def test_get_diagonal_basis():
xxxx_terms = sX(1) * sX(2) + sX(2) + sX(3) * sX(4) + sX(4) + \
sX(1) * sX(3) * sX(4) + sX(1) * sX(4) + sX(1) * sX(2) * sX(3)
true_term = sX(1) * sX(2) * sX(3) * sX(4)
assert get_diagonalizing_basis(xxxx_terms.terms) == true_term

zzzz_terms = sZ(1) * sZ(2) + sZ(3) * sZ(4) + \
sZ(1) * sZ(3) + sZ(1) * sZ(3) * sZ(4)
assert get_diagonalizing_basis(zzzz_terms.terms) == sZ(1) * sZ(2) * \
sZ(3) * sZ(4)


def test_diagonal_basis_commutation():
x_term = sX(0) * sX(1)
z1_term = sZ(1)
z2_term = sZ(0)
zz_term = sZ(0) * sZ(1)
assert not diagonal_basis_commutes(x_term, z1_term)
assert not diagonal_basis_commutes(zz_term, x_term)

assert diagonal_basis_commutes(z1_term, z2_term)
assert diagonal_basis_commutes(zz_term, z2_term)
assert diagonal_basis_commutes(zz_term, z1_term)
assert diagonal_basis_commutes(zz_term, sI(1))
assert diagonal_basis_commutes(zz_term, sI(2))
assert diagonal_basis_commutes(zz_term, sX(5) * sY(7))

0 comments on commit c1603ec

Please sign in to comment.