Skip to content

Commit

Permalink
Create consistency check for unitary with ancilla (#6196)
Browse files Browse the repository at this point in the history
  • Loading branch information
NoureldinYosri committed Jul 14, 2023
1 parent 8e70f77 commit cb05a69
Show file tree
Hide file tree
Showing 7 changed files with 209 additions and 8 deletions.
2 changes: 2 additions & 0 deletions cirq-core/cirq/testing/__init__.py
Expand Up @@ -109,3 +109,5 @@
from cirq.testing.sample_circuits import nonoptimal_toffoli_circuit

from cirq.testing.sample_gates import PhaseUsingCleanAncilla, PhaseUsingDirtyAncilla

from cirq.testing.consistent_unitary import assert_unitary_is_consistent
13 changes: 11 additions & 2 deletions cirq-core/cirq/testing/consistent_decomposition.py
Expand Up @@ -40,8 +40,17 @@ def assert_decompose_is_consistent_with_unitary(val: Any, ignoring_global_phase:
# If there's no decomposition, it's vacuously consistent.
return

actual = circuits.Circuit(dec).unitary(qubit_order=qubits)

c = circuits.Circuit(dec)
if len(c.all_qubits().difference(qubits)):
# The decomposition contains ancilla qubits.
ancilla = tuple(c.all_qubits().difference(qubits))
qubit_order = ancilla + qubits
actual = c.unitary(qubit_order=qubit_order)
qid_shape = protocols.qid_shape(qubits)
vol = np.prod(qid_shape, dtype=np.int64)
actual = actual[:vol, :vol]
else:
actual = c.unitary(qubit_order=qubits)
if ignoring_global_phase:
lin_alg_utils.assert_allclose_up_to_global_phase(actual, expected, atol=1e-8)
else:
Expand Down
9 changes: 8 additions & 1 deletion cirq-core/cirq/testing/consistent_decomposition_test.py
Expand Up @@ -43,6 +43,14 @@ def test_assert_decompose_is_consistent_with_unitary():
GoodGateDecompose().on(cirq.NamedQubit('q'))
)

cirq.testing.assert_decompose_is_consistent_with_unitary(
cirq.testing.PhaseUsingCleanAncilla(theta=0.1, ancilla_bitsize=3)
)

cirq.testing.assert_decompose_is_consistent_with_unitary(
cirq.testing.PhaseUsingDirtyAncilla(phase_state=1, ancilla_bitsize=4)
)

with pytest.raises(AssertionError):
cirq.testing.assert_decompose_is_consistent_with_unitary(BadGateDecompose())

Expand Down Expand Up @@ -83,7 +91,6 @@ def _decompose_(self, qubits):


def test_assert_decompose_ends_at_default_gateset():

cirq.testing.assert_decompose_ends_at_default_gateset(GateDecomposesToDefaultGateset())
cirq.testing.assert_decompose_ends_at_default_gateset(
GateDecomposesToDefaultGateset().on(*cirq.LineQubit.range(2))
Expand Down
2 changes: 2 additions & 0 deletions cirq-core/cirq/testing/consistent_protocols.py
Expand Up @@ -37,6 +37,7 @@
from cirq.testing.consistent_specified_has_unitary import assert_specifies_has_unitary_if_unitary
from cirq.testing.equivalent_repr_eval import assert_equivalent_repr
from cirq.testing.consistent_controlled_gate_op import assert_controlled_and_controlled_by_identical
from cirq.testing.consistent_unitary import assert_unitary_is_consistent


def assert_implements_consistent_protocols(
Expand Down Expand Up @@ -153,6 +154,7 @@ def _assert_meets_standards_helper(
assert_qasm_is_consistent_with_unitary(val)
assert_has_consistent_trace_distance_bound(val)
assert_decompose_is_consistent_with_unitary(val, ignoring_global_phase=ignoring_global_phase)
assert_unitary_is_consistent(val, ignoring_global_phase=ignoring_global_phase)
if not ignore_decompose_to_default_gateset:
assert_decompose_ends_at_default_gateset(val)
assert_phase_by_is_consistent_with_unitary(val)
Expand Down
85 changes: 85 additions & 0 deletions cirq-core/cirq/testing/consistent_unitary.py
@@ -0,0 +1,85 @@
# Copyright 2023 The Cirq Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from typing import Any
import cirq
import numpy as np


def assert_unitary_is_consistent(val: Any, ignoring_global_phase: bool = False):
if not isinstance(val, (cirq.Operation, cirq.Gate)):
return

if not cirq.has_unitary(val):
return

# Ensure that `u` is a unitary.
u = cirq.unitary(val)
assert not (u is None or u is NotImplemented)
assert cirq.is_unitary(u)

if isinstance(val, cirq.Operation):
qubits = val.qubits
decomposition = cirq.decompose_once(val, default=None)
else:
qubits = tuple(cirq.LineQid.for_gate(val))
decomposition = cirq.decompose_once_with_qubits(val, qubits, default=None)

if decomposition is None or decomposition is NotImplemented:
return

c = cirq.Circuit(decomposition)
if len(c.all_qubits().difference(qubits)) == 0:
return

clean_qubits = tuple(q for q in c.all_qubits() if isinstance(q, cirq.ops.CleanQubit))
borrowable_qubits = tuple(q for q in c.all_qubits() if isinstance(q, cirq.ops.BorrowableQubit))
qubit_order = clean_qubits + borrowable_qubits + qubits

# Check that the decomposition uses all data qubits in addition to
# clean and/or borrowable qubits.
assert set(qubit_order) == c.all_qubits()

qid_shape = cirq.qid_shape(qubit_order)
full_unitary = cirq.apply_unitaries(
decomposition,
qubits=qubit_order,
args=cirq.ApplyUnitaryArgs.for_unitary(qid_shape=qid_shape),
default=None,
)
if full_unitary is None:
raise ValueError(f'apply_unitaries failed on the decomposition of {val}')
vol = np.prod(qid_shape, dtype=np.int64)
full_unitary = full_unitary.reshape((vol, vol))

vol = np.prod(cirq.qid_shape(borrowable_qubits + qubits), dtype=np.int64)

# Extract the submatrix acting on the |0..0> subspace of clean qubits.
# This submatirx must be a unitary.
clean_qubits_zero_subspace = full_unitary[:vol, :vol]

# If the borrowable qubits are restored to their initial state, then
# the decomposition's effect on it is the identity matrix.
# This means that the `clean_qubits_zero_subspace` must be I \otimes u.
# So checking that `clean_qubits_zero_subspace` is I \otimes u checks correctness
# for both clean and borrowable qubits at the same time.
expected = np.kron(np.eye(2 ** len(borrowable_qubits), dtype=np.complex128), u)

if ignoring_global_phase:
cirq.testing.assert_allclose_up_to_global_phase(
clean_qubits_zero_subspace, expected, atol=1e-8
)
else:
np.testing.assert_allclose(clean_qubits_zero_subspace, expected, atol=1e-8)
96 changes: 96 additions & 0 deletions cirq-core/cirq/testing/consistent_unitary_test.py
@@ -0,0 +1,96 @@
# Copyright 2023 The Cirq Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import cirq

import pytest
import numpy as np


class InconsistentGate(cirq.Gate):
def _num_qubits_(self) -> int:
return 1

def _unitary_(self) -> np.ndarray:
return np.eye(2, dtype=np.complex128)

def _decompose_with_context_(self, qubits, *, context):
(q,) = context.qubit_manager.qalloc(1)
yield cirq.X(q)
yield cirq.CNOT(q, qubits[0])


class FailsOnDecompostion(cirq.Gate):
def _num_qubits_(self) -> int:
return 1

def _unitary_(self) -> np.ndarray:
return np.eye(2, dtype=np.complex128)

def _has_unitary_(self) -> bool:
return True

def _decompose_with_context_(self, qubits, *, context):
(q,) = context.qubit_manager.qalloc(1)
yield cirq.X(q)
yield cirq.measure(qubits[0])


class CleanCorrectButBorrowableIncorrectGate(cirq.Gate):
"""Ancilla type determines if the decomposition is correct or not."""

def __init__(self, use_clean_ancilla: bool) -> None:
self.ancillas_are_clean = use_clean_ancilla

def _num_qubits_(self):
return 2

def _decompose_with_context_(self, qubits, *, context):
if self.ancillas_are_clean:
anc = context.qubit_manager.qalloc(1)
else:
anc = context.qubit_manager.qborrow(1)
yield cirq.CCNOT(*qubits, *anc)
yield cirq.Z(*anc)
yield cirq.CCNOT(*qubits, *anc)
context.qubit_manager.qfree(anc)


@pytest.mark.parametrize('ignore_phase', [False, True])
@pytest.mark.parametrize(
'g,is_consistent',
[
(cirq.testing.PhaseUsingCleanAncilla(theta=0.1, ancilla_bitsize=3), True),
(cirq.testing.PhaseUsingDirtyAncilla(phase_state=1, ancilla_bitsize=4), True),
(InconsistentGate(), False),
(CleanCorrectButBorrowableIncorrectGate(use_clean_ancilla=True), True),
(CleanCorrectButBorrowableIncorrectGate(use_clean_ancilla=False), False),
],
)
def test_assert_unitary_is_consistent(g, ignore_phase, is_consistent):
if is_consistent:
cirq.testing.assert_unitary_is_consistent(g, ignore_phase)
cirq.testing.assert_unitary_is_consistent(g.on(*cirq.LineQid.for_gate(g)), ignore_phase)
else:
with pytest.raises(AssertionError):
cirq.testing.assert_unitary_is_consistent(g, ignore_phase)
with pytest.raises(AssertionError):
cirq.testing.assert_unitary_is_consistent(g.on(*cirq.LineQid.for_gate(g)), ignore_phase)


def test_failed_decomposition():
with pytest.raises(ValueError):
cirq.testing.assert_unitary_is_consistent(FailsOnDecompostion())

_ = cirq.testing.assert_unitary_is_consistent(cirq.Circuit())
10 changes: 5 additions & 5 deletions cirq-core/cirq/testing/sample_gates.py
Expand Up @@ -15,7 +15,7 @@

import cirq
import numpy as np
from cirq import ops, qis
from cirq import ops, qis, protocols


def _matrix_for_phasing_state(num_qubits, phase_state, phase):
Expand All @@ -39,8 +39,8 @@ class PhaseUsingCleanAncilla(ops.Gate):
def _num_qubits_(self):
return self.target_bitsize

def _decompose_(self, qubits):
anc = ops.NamedQubit.range(self.ancilla_bitsize, prefix="anc")
def _decompose_with_context_(self, qubits, *, context: protocols.DecompositionContext):
anc = context.qubit_manager.qalloc(self.ancilla_bitsize)
cv = [int(x) for x in f'{self.phase_state:0{self.target_bitsize}b}']
cnot_ladder = [cirq.CNOT(anc[i - 1], anc[i]) for i in range(1, self.ancilla_bitsize)]

Expand All @@ -65,8 +65,8 @@ class PhaseUsingDirtyAncilla(ops.Gate):
def _num_qubits_(self):
return self.target_bitsize

def _decompose_(self, qubits):
anc = ops.NamedQubit.range(self.ancilla_bitsize, prefix="anc")
def _decompose_with_context_(self, qubits, *, context: protocols.DecompositionContext):
anc = context.qubit_manager.qalloc(self.ancilla_bitsize)
cv = [int(x) for x in f'{self.phase_state:0{self.target_bitsize}b}']
cnot_ladder = [cirq.CNOT(anc[i - 1], anc[i]) for i in range(1, self.ancilla_bitsize)]
yield ops.X(anc[0]).controlled_by(*qubits, control_values=cv)
Expand Down

0 comments on commit cb05a69

Please sign in to comment.