Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds cirq.num_cnots_required and cirq.to_special #2892

Merged
merged 23 commits into from
Sep 29, 2020
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions cirq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@
match_global_phase,
matrix_commutes,
matrix_from_basis_coefficients,
num_two_qubit_gates_required,
partial_trace,
partial_trace_of_state_vector_as_mixture,
PAULI_BASIS,
Expand All @@ -156,6 +157,7 @@
sub_state_vector,
targeted_conjugate_about,
targeted_left_multiply,
to_special,
unitary_eig,
wavefunction_partial_trace_as_mixture,
)
Expand Down
2 changes: 2 additions & 0 deletions cirq/linalg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
KakDecomposition,
kron_factor_4x4_to_2x2s,
map_eigenvalues,
num_two_qubit_gates_required,
unitary_eig,
scatter_plot_normalized_kak_interaction_coefficients,
so4_to_magic_su2s,
Expand Down Expand Up @@ -89,5 +90,6 @@
sub_state_vector,
targeted_conjugate_about,
targeted_left_multiply,
to_special,
wavefunction_partial_trace_as_mixture,
)
56 changes: 55 additions & 1 deletion cirq/linalg/decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

from cirq import value, protocols
from cirq._compat import proper_repr
from cirq.linalg import combinators, diagonalize, predicates
from cirq.linalg import combinators, diagonalize, predicates, transformations

if TYPE_CHECKING:
import cirq
Expand All @@ -36,8 +36,16 @@
[0, 1j, 1, 0],
[0, 1j, -1, 0],
[1, 0, 0, -1j]]) * np.sqrt(0.5)

MAGIC_CONJ_T = np.conj(MAGIC.T)

# yapf: disable
YY = np.array([[0, 0, 0, -1],
[0, 0, 1, 0],
[0, 1, 0, 0],
[-1, 0, 0, 0]])
# yapf: enable


def _phase_matrix(angle: float) -> np.ndarray:
return np.diag([1, np.exp(1j * angle)])
Expand Down Expand Up @@ -992,3 +1000,49 @@ def _canonicalize_kak_vector(k_vec: np.ndarray, atol: float) -> np.ndarray:
k_vec[need_diff, 2] *= -1

return k_vec


def num_two_qubit_gates_required(u: np.ndarray, atol: float = 1e-8) -> int:
balopat marked this conversation as resolved.
Show resolved Hide resolved
"""Returns the min number of 2-qubit gates required by a two-qubit unitary.

See Proposition III.1, III.2, III.3 in Shende et al. “Recognizing Small-
Circuit Structure in Two-Qubit Operators and Timing Hamiltonians to Compute
Controlled-Not Gates”. https://arxiv.org/abs/quant-ph/0308045
Args:
balopat marked this conversation as resolved.
Show resolved Hide resolved
u: a two-qubit unitary
Returns:
the number of two-qubit gates (CNOT or CZ) required to implement
the unitary
"""
if u.shape != (4, 4):
raise ValueError(f"Expected unitary of shape (4,4), instead "
f"got {u.shape}")
g = _gamma(transformations.to_special(u))
# see Fadeev-LeVerrier formula
a3 = -np.trace(g)
# no need to check a2 = 6, as a3 = +-4 only happens if the eigenvalues are
# either all +1 or -1, which unambiguously implies that a2 = 6
if np.abs(a3 - 4) < atol or np.abs(a3 + 4) < atol:
return 0
# see Fadeev-LeVerrier formula
a2 = (a3 * a3 - np.trace(g @ g)) / 2
if np.abs(a3) < atol and np.abs(a2 - 2) < atol:
return 1
if np.abs(a3.imag) < atol:
return 2
return 3


def _gamma(u: np.ndarray) -> np.ndarray:
"""Gamma function to convert u to the magic basis.

See Definition IV.1 in Shende et al. "Minimal Universal Two-Qubit CNOT-based
Circuits." https://arxiv.org/abs/quant-ph/0308033
balopat marked this conversation as resolved.
Show resolved Hide resolved

Args:
u: a member of SU(4)
Returns:
u @ yy @ u.T @ yy, where yy = Y ⊗ Y
"""

balopat marked this conversation as resolved.
Show resolved Hide resolved
return u @ YY @ u.T @ YY
35 changes: 35 additions & 0 deletions cirq/linalg/decompositions_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,3 +741,38 @@ def test_kak_decompose(unitary: np.ndarray):
np.testing.assert_allclose(cirq.unitary(circuit), unitary, atol=1e-8)
assert len(circuit) == 5
assert len(list(circuit.all_operations())) == 8


def test_num_two_qubit_gates_required():
for i in range(4):
assert cirq.num_two_qubit_gates_required(
_two_qubit_circuit_with_cnots(i).unitary()) == i

assert cirq.num_two_qubit_gates_required(np.eye(4)) == 0


def test_num_two_qubit_gates_required_invalid():
with pytest.raises(ValueError, match="(4,4)"):
cirq.num_two_qubit_gates_required(np.array([[1]]))


def _two_qubit_circuit_with_cnots(num_cnots=3, a=None, b=None):
random.seed(32123)
if a is None or b is None:
a, b = cirq.LineQubit.range(2)

def random_one_qubit_gate():
return cirq.PhasedXPowGate(phase_exponent=random.random(),
exponent=random.random())

def one_cz():
return [
random_one_qubit_gate().on(a),
random_one_qubit_gate().on(b),
cirq.CZ.on(a, b)
balopat marked this conversation as resolved.
Show resolved Hide resolved
]

return cirq.Circuit([
random_one_qubit_gate().on(a),
random_one_qubit_gate().on(b), [one_cz() for _ in range(num_cnots)]
])
17 changes: 17 additions & 0 deletions cirq/linalg/transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,3 +504,20 @@ def sub_state_vector(state_vector: np.ndarray,
@deprecated(deadline='v0.10.0', fix='Use `cirq.sub_state_vector` instead.')
def subwavefunction(*args, **kwargs):
return sub_state_vector(*args, **kwargs)


def to_special(u: np.ndarray) -> np.ndarray:
"""Converts a unitary matrix to a special unitary matrix.

All unitary matrices u have |det(u)| = 1.
Also for all d dimensional unitary matrix u, and scalar s:
det(u * s) = det(u) * s^(d)
To find a special unitary matrix from u:
u * det(u)^{-1/d}

Args:
u: the unitary matrix
Returns:
the special unitary matrix
"""
return u * (np.linalg.det(u)**(-1 / len(u)))
7 changes: 7 additions & 0 deletions cirq/linalg/transformations_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,3 +671,10 @@ def test_deprecated():
# pylint: disable=unexpected-keyword-arg,no-value-for-parameter
_ = cirq.partial_trace_of_state_vector_as_mixture(wavefunction=a,
keep_indices=[0])


def test_to_special():
u = cirq.testing.random_unitary(4)
su = cirq.to_special(u)
assert not cirq.is_special_unitary(u)
assert cirq.is_special_unitary(su)
2 changes: 2 additions & 0 deletions rtd_docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -671,6 +671,7 @@ Algebra and Representation
cirq.match_global_phase
cirq.matrix_commutes
cirq.matrix_from_basis_coefficients
cirq.num_two_qubit_gates_required
cirq.partial_trace
cirq.partial_trace_of_state_vector_as_mixture
cirq.reflection_matrix_pow
Expand All @@ -679,6 +680,7 @@ Algebra and Representation
cirq.sub_state_vector
cirq.targeted_conjugate_about
cirq.targeted_left_multiply
cirq.to_special
cirq.unitary_eig
cirq.AxisAngleDecomposition
cirq.Duration
Expand Down