Skip to content

Commit

Permalink
Fix occasional bug in gate_product_tabulation (#2700)
Browse files Browse the repository at this point in the history
Addresses #2696
  • Loading branch information
dkafri authored and CirqBot committed Jan 23, 2020
1 parent 3d9d2df commit d470089
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 7 deletions.
38 changes: 31 additions & 7 deletions cirq/google/optimizers/two_qubit_gates/gate_compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,8 @@ def _tabulate_kak_vectors(
Returns:
The newly tabulated KAK vectors and the local unitaries used to generate
them.
them. This function also updates already_tabulated to include the
indices of these vectors (within kak_mesh).
"""
shapes = {pair[0].shape for pair in local_unitary_pairs}
shapes.update({pair[0].shape for pair in local_unitary_pairs})
Expand Down Expand Up @@ -340,27 +341,50 @@ def gate_product_tabulation(base_gate: np.ndarray,
base_gate_dag = base_gate.conj().T
for ind in missing_vec_inds:
missing_vec = mesh_points[ind]
# Unitary A we wish to solve for
missing_unitary = kak_vector_to_unitary(missing_vec)

# Products of the from base_gate^\dagger k A
products = np.einsum('ab,...bc,cd', base_gate_dag, u_locals,
missing_unitary)
# KAK vectors for these products
kaks = linalg.kak_vector(products, check_preconditions=False)
kaks = kaks[..., np.newaxis, :]

# Check if any of the product KAK vectors are close to a previously
# tabulated KAK vector
dists2 = np.sum((kaks - kak_vecs_single)**2, axis=-1)
min_dist_inds = np.unravel_index(dists2.argmin(), dists2.shape)
min_dist = np.sqrt(dists2[min_dist_inds])
if min_dist < tabulation_cutoff:
# If so, compute the single qubit unitary k_L such that
# base_gate^\dagger k A = kL base_gate k0 base_gate kR
# where k0 is the old (previously tabulated) single qubit unitary
# and k is one of the single qubit unitaries used above.
# Indices below are for k, k0 respectively
new_ind, old_ind = min_dist_inds

old_sq_cycle = sq_cycles_single[old_ind][0]
old_k = np.kron(*old_sq_cycle)
base_product = base_gate @ old_k @ base_gate
# Special case where the RHS is just base_gate (no single qubit
# gates yet applied). I.e. base_gate^\dagger k A ~ base_gate
# which implies base_gate^\dagger k A = k_L base_gate k_R
new_product = products[new_ind]
if old_ind == 0:
assert not sq_cycles_single[old_ind]
base_product = base_gate
_, kL, actual = _outer_locals_for_unitary(
new_product, base_product)
# Add to the enumeration
sq_cycles.append((kL,))
else: # typical case mentioned above
assert len(sq_cycles_single[old_ind]) == 1
old_sq_cycle = sq_cycles_single[old_ind][0]
old_k = np.kron(*old_sq_cycle)
base_product = base_gate @ old_k @ base_gate
_, kL, actual = _outer_locals_for_unitary(
new_product, base_product)
# Add to the enumeration
sq_cycles.append((old_sq_cycle, kL))

_, kL, actual = _outer_locals_for_unitary(new_product, base_product)
# Add to the enumeration
sq_cycles.append((old_sq_cycle, kL))
kak_vecs.append(
linalg.kak_vector(base_gate @ actual,
check_preconditions=False))
Expand Down
11 changes: 11 additions & 0 deletions cirq/google/optimizers/two_qubit_gates/gate_compilation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,14 @@ def test_gate_compilation_missing_points_raises_error():
0.4,
allow_missed_points=False,
random_state=_rng)


@pytest.mark.parametrize('seed', [0, 1])
def test_sycamore_gate_tabulation(seed):
base_gate = unitary(FSimGate(np.pi / 2, np.pi / 6))
tab = gate_product_tabulation(base_gate,
0.1,
sample_scaling=2,
random_state=np.random.RandomState(seed))
result = tab.compile_two_qubit_gate(base_gate)
assert result.success

0 comments on commit d470089

Please sign in to comment.