Skip to content

Commit

Permalink
Adds cirq.num_cnots_required and cirq.to_special (#2892)
Browse files Browse the repository at this point in the history
Adds cirq.num_two_qubit_gates_required and cirq.to_special. 

- `cirq.num_cnots_required`: Based on simple linear algebra users can calculate the minimum required two-qubit gates (CZ, CNOT) to implement a two-qubit unitary. 
- `cirq.to_special`: converts a unitary to a special unitary

Context: this is the first PR breaking up #2873.

For an overview of the high level design decisions of the whole project see: https://drive.google.com/open?id=1SDEtttIAaTwfV9AUs7XAxeW3AUd0qLoY.
  • Loading branch information
balopat committed Sep 29, 2020
1 parent bba0153 commit fb164bb
Show file tree
Hide file tree
Showing 7 changed files with 119 additions and 1 deletion.
2 changes: 2 additions & 0 deletions cirq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@
match_global_phase,
matrix_commutes,
matrix_from_basis_coefficients,
num_cnots_required,
partial_trace,
partial_trace_of_state_vector_as_mixture,
PAULI_BASIS,
Expand All @@ -156,6 +157,7 @@
sub_state_vector,
targeted_conjugate_about,
targeted_left_multiply,
to_special,
unitary_eig,
wavefunction_partial_trace_as_mixture,
)
Expand Down
2 changes: 2 additions & 0 deletions cirq/linalg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
KakDecomposition,
kron_factor_4x4_to_2x2s,
map_eigenvalues,
num_cnots_required,
unitary_eig,
scatter_plot_normalized_kak_interaction_coefficients,
so4_to_magic_su2s,
Expand Down Expand Up @@ -89,5 +90,6 @@
sub_state_vector,
targeted_conjugate_about,
targeted_left_multiply,
to_special,
wavefunction_partial_trace_as_mixture,
)
55 changes: 54 additions & 1 deletion cirq/linalg/decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

from cirq import value, protocols
from cirq._compat import proper_repr
from cirq.linalg import combinators, diagonalize, predicates
from cirq.linalg import combinators, diagonalize, predicates, transformations

if TYPE_CHECKING:
import cirq
Expand All @@ -36,8 +36,16 @@
[0, 1j, 1, 0],
[0, 1j, -1, 0],
[1, 0, 0, -1j]]) * np.sqrt(0.5)

MAGIC_CONJ_T = np.conj(MAGIC.T)

# yapf: disable
YY = np.array([[0, 0, 0, -1],
[0, 0, 1, 0],
[0, 1, 0, 0],
[-1, 0, 0, 0]])
# yapf: enable


def _phase_matrix(angle: float) -> np.ndarray:
return np.diag([1, np.exp(1j * angle)])
Expand Down Expand Up @@ -992,3 +1000,48 @@ def _canonicalize_kak_vector(k_vec: np.ndarray, atol: float) -> np.ndarray:
k_vec[need_diff, 2] *= -1

return k_vec


def num_cnots_required(u: np.ndarray, atol: float = 1e-8) -> int:
"""Returns the min number of CNOT/CZ gates required by a two-qubit unitary.
See Proposition III.1, III.2, III.3 in Shende et al. “Recognizing Small-
Circuit Structure in Two-Qubit Operators and Timing Hamiltonians to Compute
Controlled-Not Gates”. https://arxiv.org/abs/quant-ph/0308045
Args:
u: a two-qubit unitary
Returns:
the number of CNOT or CZ gates required to implement the unitary
"""
if u.shape != (4, 4):
raise ValueError(f"Expected unitary of shape (4,4), instead "
f"got {u.shape}")
g = _gamma(transformations.to_special(u))
# see Fadeev-LeVerrier formula
a3 = -np.trace(g)
# no need to check a2 = 6, as a3 = +-4 only happens if the eigenvalues are
# either all +1 or -1, which unambiguously implies that a2 = 6
if np.abs(a3 - 4) < atol or np.abs(a3 + 4) < atol:
return 0
# see Fadeev-LeVerrier formula
a2 = (a3 * a3 - np.trace(g @ g)) / 2
if np.abs(a3) < atol and np.abs(a2 - 2) < atol:
return 1
if np.abs(a3.imag) < atol:
return 2
return 3


def _gamma(u: np.ndarray) -> np.ndarray:
"""Gamma function to convert u to the magic basis.
See Definition IV.1 in Shende et al. "Minimal Universal Two-Qubit CNOT-based
Circuits." https://arxiv.org/abs/quant-ph/0308033
Args:
u: a member of SU(4)
Returns:
u @ yy @ u.T @ yy, where yy = Y ⊗ Y
"""
return u @ YY @ u.T @ YY
35 changes: 35 additions & 0 deletions cirq/linalg/decompositions_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,3 +741,38 @@ def test_kak_decompose(unitary: np.ndarray):
np.testing.assert_allclose(cirq.unitary(circuit), unitary, atol=1e-8)
assert len(circuit) == 5
assert len(list(circuit.all_operations())) == 8


def test_num_two_qubit_gates_required():
for i in range(4):
assert cirq.num_cnots_required(
_two_qubit_circuit_with_cnots(i).unitary()) == i

assert cirq.num_cnots_required(np.eye(4)) == 0


def test_num_two_qubit_gates_required_invalid():
with pytest.raises(ValueError, match="(4,4)"):
cirq.num_cnots_required(np.array([[1]]))


def _two_qubit_circuit_with_cnots(num_cnots=3, a=None, b=None):
random.seed(32123)
if a is None or b is None:
a, b = cirq.LineQubit.range(2)

def random_one_qubit_gate():
return cirq.PhasedXPowGate(phase_exponent=random.random(),
exponent=random.random())

def one_cz():
return [
cirq.CZ.on(a, b),
random_one_qubit_gate().on(a),
random_one_qubit_gate().on(b),
]

return cirq.Circuit([
random_one_qubit_gate().on(a),
random_one_qubit_gate().on(b), [one_cz() for _ in range(num_cnots)]
])
17 changes: 17 additions & 0 deletions cirq/linalg/transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,3 +504,20 @@ def sub_state_vector(state_vector: np.ndarray,
@deprecated(deadline='v0.10.0', fix='Use `cirq.sub_state_vector` instead.')
def subwavefunction(*args, **kwargs):
return sub_state_vector(*args, **kwargs)


def to_special(u: np.ndarray) -> np.ndarray:
"""Converts a unitary matrix to a special unitary matrix.
All unitary matrices u have |det(u)| = 1.
Also for all d dimensional unitary matrix u, and scalar s:
det(u * s) = det(u) * s^(d)
To find a special unitary matrix from u:
u * det(u)^{-1/d}
Args:
u: the unitary matrix
Returns:
the special unitary matrix
"""
return u * (np.linalg.det(u)**(-1 / len(u)))
7 changes: 7 additions & 0 deletions cirq/linalg/transformations_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,3 +671,10 @@ def test_deprecated():
# pylint: disable=unexpected-keyword-arg,no-value-for-parameter
_ = cirq.partial_trace_of_state_vector_as_mixture(wavefunction=a,
keep_indices=[0])


def test_to_special():
u = cirq.testing.random_unitary(4)
su = cirq.to_special(u)
assert not cirq.is_special_unitary(u)
assert cirq.is_special_unitary(su)
2 changes: 2 additions & 0 deletions rtd_docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -671,6 +671,7 @@ Algebra and Representation
cirq.match_global_phase
cirq.matrix_commutes
cirq.matrix_from_basis_coefficients
cirq.num_cnots_required
cirq.partial_trace
cirq.partial_trace_of_state_vector_as_mixture
cirq.reflection_matrix_pow
Expand All @@ -679,6 +680,7 @@ Algebra and Representation
cirq.sub_state_vector
cirq.targeted_conjugate_about
cirq.targeted_left_multiply
cirq.to_special
cirq.unitary_eig
cirq.AxisAngleDecomposition
cirq.Duration
Expand Down

0 comments on commit fb164bb

Please sign in to comment.