From 91777087f568503a9d652c15b123b95099187bf2 Mon Sep 17 00:00:00 2001 From: Noureldin Date: Mon, 5 Jun 2023 11:09:14 +0100 Subject: [PATCH] Add support for allocating qubits in decompose to cirq.unitary (#6112) * Add support for allocating qubits in decompose to cirq.unitary * fixed apply_unitaries * fix mypy * refactored tests * addressing comments * added sample_gates_test.py * Improved sample_gates.py implementation and unitary_protocol tests. Also added docstrings * fixed lint * retrigger checks --------- Co-authored-by: Tanuj Khattar --- .../cirq/protocols/apply_unitary_protocol.py | 74 +++++++++++++---- .../protocols/apply_unitary_protocol_test.py | 50 ++++++++++++ cirq-core/cirq/protocols/unitary_protocol.py | 24 ++++-- .../cirq/protocols/unitary_protocol_test.py | 37 +++++++++ cirq-core/cirq/testing/__init__.py | 2 + cirq-core/cirq/testing/sample_gates.py | 79 +++++++++++++++++++ cirq-core/cirq/testing/sample_gates_test.py | 59 ++++++++++++++ 7 files changed, 303 insertions(+), 22 deletions(-) create mode 100644 cirq-core/cirq/testing/sample_gates.py create mode 100644 cirq-core/cirq/testing/sample_gates_test.py diff --git a/cirq-core/cirq/protocols/apply_unitary_protocol.py b/cirq-core/cirq/protocols/apply_unitary_protocol.py index 61881eddbf6..e3dc092dca1 100644 --- a/cirq-core/cirq/protocols/apply_unitary_protocol.py +++ b/cirq-core/cirq/protocols/apply_unitary_protocol.py @@ -133,6 +133,33 @@ def default( state = qis.one_hot(index=(0,) * num_qubits, shape=qid_shape, dtype=np.complex128) return ApplyUnitaryArgs(state, np.empty_like(state), range(num_qubits)) + @classmethod + def for_unitary( + cls, num_qubits: Optional[int] = None, *, qid_shape: Optional[Tuple[int, ...]] = None + ) -> 'ApplyUnitaryArgs': + """A default instance corresponding to an identity matrix. + + Specify exactly one argument. + + Args: + num_qubits: The number of qubits to make space for in the state. + qid_shape: A tuple representing the number of quantum levels of each + qubit the identity matrix applies to. `qid_shape` is (2, 2, 2) for + a three-qubit identity operation tensor. + + Raises: + TypeError: If exactly neither `num_qubits` or `qid_shape` is provided or + both are provided. + """ + if (num_qubits is None) == (qid_shape is None): + raise TypeError('Specify exactly one of num_qubits or qid_shape.') + if num_qubits is not None: + qid_shape = (2,) * num_qubits + qid_shape = cast(Tuple[int, ...], qid_shape) # Satisfy mypy + num_qubits = len(qid_shape) + state = qis.eye_tensor(qid_shape, dtype=np.complex128) + return ApplyUnitaryArgs(state, np.empty_like(state), range(num_qubits)) + def with_axes_transposed_to_start(self) -> 'ApplyUnitaryArgs': """Returns a transposed view of the same arguments. @@ -409,19 +436,7 @@ def _strat_apply_unitary_from_apply_unitary( return _incorporate_result_into_target(args, sub_args, sub_result) -def _strat_apply_unitary_from_unitary( - unitary_value: Any, args: ApplyUnitaryArgs -) -> Optional[np.ndarray]: - # Check for magic method. - method = getattr(unitary_value, '_unitary_', None) - if method is None: - return NotImplemented - - # Attempt to get the unitary matrix. - matrix = method() - if matrix is NotImplemented or matrix is None: - return matrix - +def _apply_unitary_from_matrix(matrix: np.ndarray, unitary_value: Any, args: ApplyUnitaryArgs): if args.slices is None: val_qid_shape = qid_shape_protocol.qid_shape(unitary_value, default=(2,) * len(args.axes)) slices = tuple(slice(0, size) for size in val_qid_shape) @@ -450,11 +465,42 @@ def _strat_apply_unitary_from_unitary( return _incorporate_result_into_target(args, sub_args, sub_result) +def _strat_apply_unitary_from_unitary( + unitary_value: Any, args: ApplyUnitaryArgs +) -> Optional[np.ndarray]: + # Check for magic method. + method = getattr(unitary_value, '_unitary_', None) + if method is None: + return NotImplemented + + # Attempt to get the unitary matrix. + matrix = method() + if matrix is NotImplemented or matrix is None: + return matrix + + return _apply_unitary_from_matrix(matrix, unitary_value, args) + + def _strat_apply_unitary_from_decompose(val: Any, args: ApplyUnitaryArgs) -> Optional[np.ndarray]: operations, qubits, _ = _try_decompose_into_operations_and_qubits(val) if operations is None: return NotImplemented - return apply_unitaries(operations, qubits, args, None) + all_qubits = frozenset([q for op in operations for q in op.qubits]) + ancilla = tuple(sorted(all_qubits.difference(qubits))) + if not len(ancilla): + return apply_unitaries(operations, qubits, args, None) + ordered_qubits = ancilla + tuple(qubits) + all_qid_shapes = qid_shape_protocol.qid_shape(ordered_qubits) + result = apply_unitaries( + operations, ordered_qubits, ApplyUnitaryArgs.for_unitary(qid_shape=all_qid_shapes), None + ) + if result is None or result is NotImplemented: + return result + result = result.reshape((np.prod(all_qid_shapes, dtype=np.int64), -1)) + val_qid_shape = qid_shape_protocol.qid_shape(qubits) + state_vec_length = np.prod(val_qid_shape, dtype=np.int64) + result = result[:state_vec_length, :state_vec_length] + return _apply_unitary_from_matrix(result, val, args) def apply_unitaries( diff --git a/cirq-core/cirq/protocols/apply_unitary_protocol_test.py b/cirq-core/cirq/protocols/apply_unitary_protocol_test.py index 1b455c6bd69..1473d6fc117 100644 --- a/cirq-core/cirq/protocols/apply_unitary_protocol_test.py +++ b/cirq-core/cirq/protocols/apply_unitary_protocol_test.py @@ -717,3 +717,53 @@ def test_cast_to_complex(): np.ComplexWarning, match='Casting complex values to real discards the imaginary part' ): cirq.apply_unitary(y0, args) + + +class NotDecomposableGate(cirq.Gate): + def num_qubits(self): + return 1 + + +class DecomposableGate(cirq.Gate): + def __init__(self, sub_gate: cirq.Gate, allocate_ancilla: bool) -> None: + super().__init__() + self._sub_gate = sub_gate + self._allocate_ancilla = allocate_ancilla + + def num_qubits(self): + return 1 + + def _decompose_(self, qubits): + if self._allocate_ancilla: + yield cirq.Z(cirq.NamedQubit('DecomposableGateQubit')) + yield self._sub_gate(qubits[0]) + + +def test_strat_apply_unitary_from_decompose(): + state = np.eye(2, dtype=np.complex128) + args = cirq.ApplyUnitaryArgs( + target_tensor=state, available_buffer=np.zeros_like(state), axes=(0,) + ) + np.testing.assert_allclose( + cirq.apply_unitaries( + [DecomposableGate(cirq.X, False)(cirq.LineQubit(0))], [cirq.LineQubit(0)], args + ), + [[0, 1], [1, 0]], + ) + + with pytest.raises(TypeError): + _ = cirq.apply_unitaries( + [DecomposableGate(NotDecomposableGate(), True)(cirq.LineQubit(0))], + [cirq.LineQubit(0)], + args, + ) + + +def test_unitary_construction(): + with pytest.raises(TypeError): + _ = cirq.ApplyUnitaryArgs.for_unitary() + + np.testing.assert_allclose( + cirq.ApplyUnitaryArgs.for_unitary(num_qubits=3).target_tensor, + cirq.eye_tensor((2,) * 3, dtype=np.complex128), + ) diff --git a/cirq-core/cirq/protocols/unitary_protocol.py b/cirq-core/cirq/protocols/unitary_protocol.py index e5acd30e7df..4882dd96022 100644 --- a/cirq-core/cirq/protocols/unitary_protocol.py +++ b/cirq-core/cirq/protocols/unitary_protocol.py @@ -17,7 +17,6 @@ import numpy as np from typing_extensions import Protocol -from cirq import qis from cirq._doc import doc_private from cirq.protocols import qid_shape_protocol from cirq.protocols.apply_unitary_protocol import ApplyUnitaryArgs, apply_unitaries @@ -162,9 +161,7 @@ def _strat_unitary_from_apply_unitary(val: Any) -> Optional[np.ndarray]: return NotImplemented # Apply unitary effect to an identity matrix. - state = qis.eye_tensor(val_qid_shape, dtype=np.complex128) - buffer = np.empty_like(state) - result = method(ApplyUnitaryArgs(state, buffer, range(len(val_qid_shape)))) + result = method(ApplyUnitaryArgs.for_unitary(qid_shape=val_qid_shape)) if result is NotImplemented or result is None: return result @@ -179,15 +176,26 @@ def _strat_unitary_from_decompose(val: Any) -> Optional[np.ndarray]: if operations is None: return NotImplemented + all_qubits = frozenset(q for op in operations for q in op.qubits) + work_qubits = frozenset(qubits) + ancillas = tuple(sorted(all_qubits.difference(work_qubits))) + + ordered_qubits = ancillas + tuple(qubits) + val_qid_shape = qid_shape_protocol.qid_shape(ancillas) + val_qid_shape + # Apply sub-operations' unitary effects to an identity matrix. - state = qis.eye_tensor(val_qid_shape, dtype=np.complex128) - buffer = np.empty_like(state) result = apply_unitaries( - operations, qubits, ApplyUnitaryArgs(state, buffer, range(len(val_qid_shape))), None + operations, ordered_qubits, ApplyUnitaryArgs.for_unitary(qid_shape=val_qid_shape), None ) # Package result. if result is None: return None + state_len = np.prod(val_qid_shape, dtype=np.int64) - return result.reshape((state_len, state_len)) + result = result.reshape((state_len, state_len)) + # Assuming borrowable qubits are restored to their original state and + # clean qubits restord to the zero state then the desired unitary is + # the upper left square. + work_state_len = np.prod(val_qid_shape[len(ancillas) :], dtype=np.int64) + return result[:work_state_len, :work_state_len] diff --git a/cirq-core/cirq/protocols/unitary_protocol_test.py b/cirq-core/cirq/protocols/unitary_protocol_test.py index 46448e76d16..5d972c082ce 100644 --- a/cirq-core/cirq/protocols/unitary_protocol_test.py +++ b/cirq-core/cirq/protocols/unitary_protocol_test.py @@ -17,6 +17,7 @@ import pytest import cirq +from cirq import testing m0: np.ndarray = np.array([]) # yapf: disable @@ -188,6 +189,42 @@ def test_has_unitary(): assert not cirq.has_unitary(FullyImplemented(False)) +def _test_gate_that_allocates_qubits(gate): + from cirq.protocols.unitary_protocol import _strat_unitary_from_decompose + + op = gate.on(*cirq.LineQubit.range(cirq.num_qubits(gate))) + moment = cirq.Moment(op) + circuit = cirq.FrozenCircuit(op) + circuit_op = cirq.CircuitOperation(circuit) + for val in [gate, op, moment, circuit, circuit_op]: + unitary_from_strat = _strat_unitary_from_decompose(val) + assert unitary_from_strat is not None + np.testing.assert_allclose(unitary_from_strat, gate.narrow_unitary()) + + +@pytest.mark.parametrize('theta', np.linspace(0, 2 * np.pi, 10)) +@pytest.mark.parametrize('phase_state', [0, 1]) +@pytest.mark.parametrize('target_bitsize', [1, 2, 3]) +@pytest.mark.parametrize('ancilla_bitsize', [1, 4]) +def test_decompose_gate_that_allocates_clean_qubits( + theta: float, phase_state: int, target_bitsize: int, ancilla_bitsize: int +): + + gate = testing.PhaseUsingCleanAncilla(theta, phase_state, target_bitsize, ancilla_bitsize) + _test_gate_that_allocates_qubits(gate) + + +@pytest.mark.parametrize('phase_state', [0, 1]) +@pytest.mark.parametrize('target_bitsize', [1, 2, 3]) +@pytest.mark.parametrize('ancilla_bitsize', [1, 4]) +def test_decompose_gate_that_allocates_dirty_qubits( + phase_state: int, target_bitsize: int, ancilla_bitsize: int +): + + gate = testing.PhaseUsingDirtyAncilla(phase_state, target_bitsize, ancilla_bitsize) + _test_gate_that_allocates_qubits(gate) + + def test_decompose_and_get_unitary(): from cirq.protocols.unitary_protocol import _strat_unitary_from_decompose diff --git a/cirq-core/cirq/testing/__init__.py b/cirq-core/cirq/testing/__init__.py index 7e831b4d480..1c7ffaba28e 100644 --- a/cirq-core/cirq/testing/__init__.py +++ b/cirq-core/cirq/testing/__init__.py @@ -107,3 +107,5 @@ ) from cirq.testing.sample_circuits import nonoptimal_toffoli_circuit + +from cirq.testing.sample_gates import PhaseUsingCleanAncilla, PhaseUsingDirtyAncilla diff --git a/cirq-core/cirq/testing/sample_gates.py b/cirq-core/cirq/testing/sample_gates.py new file mode 100644 index 00000000000..c4bb2c9b95f --- /dev/null +++ b/cirq-core/cirq/testing/sample_gates.py @@ -0,0 +1,79 @@ +# Copyright 2023 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import dataclasses + +import cirq +import numpy as np +from cirq import ops, qis + + +def _matrix_for_phasing_state(num_qubits, phase_state, phase): + matrix = qis.eye_tensor((2,) * num_qubits, dtype=np.complex128) + matrix = matrix.reshape((2**num_qubits, 2**num_qubits)) + matrix[phase_state, phase_state] = phase + print(num_qubits, phase_state, phase) + print(matrix) + return matrix + + +@dataclasses.dataclass(frozen=True) +class PhaseUsingCleanAncilla(ops.Gate): + r"""Phases the state $|phase_state>$ by $\exp(1j * \pi * \theta)$ using one clean ancilla.""" + + theta: float + phase_state: int = 1 + target_bitsize: int = 1 + ancilla_bitsize: int = 1 + + def _num_qubits_(self): + return self.target_bitsize + + def _decompose_(self, qubits): + anc = ops.NamedQubit.range(self.ancilla_bitsize, prefix="anc") + cv = [int(x) for x in f'{self.phase_state:0{self.target_bitsize}b}'] + cnot_ladder = [cirq.CNOT(anc[i - 1], anc[i]) for i in range(1, self.ancilla_bitsize)] + + yield ops.X(anc[0]).controlled_by(*qubits, control_values=cv) + yield [cnot_ladder, ops.Z(anc[-1]) ** self.theta, reversed(cnot_ladder)] + yield ops.X(anc[0]).controlled_by(*qubits, control_values=cv) + + def narrow_unitary(self) -> np.ndarray: + """Narrowed unitary corresponding to the unitary effect applied on target qubits.""" + phase = np.exp(1j * np.pi * self.theta) + return _matrix_for_phasing_state(self.target_bitsize, self.phase_state, phase) + + +@dataclasses.dataclass(frozen=True) +class PhaseUsingDirtyAncilla(ops.Gate): + r"""Phases the state $|phase_state>$ by -1 using one dirty ancilla.""" + + phase_state: int = 1 + target_bitsize: int = 1 + ancilla_bitsize: int = 1 + + def _num_qubits_(self): + return self.target_bitsize + + def _decompose_(self, qubits): + anc = ops.NamedQubit.range(self.ancilla_bitsize, prefix="anc") + cv = [int(x) for x in f'{self.phase_state:0{self.target_bitsize}b}'] + cnot_ladder = [cirq.CNOT(anc[i - 1], anc[i]) for i in range(1, self.ancilla_bitsize)] + yield ops.X(anc[0]).controlled_by(*qubits, control_values=cv) + yield [cnot_ladder, ops.Z(anc[-1]), reversed(cnot_ladder)] + yield ops.X(anc[0]).controlled_by(*qubits, control_values=cv) + yield [cnot_ladder, ops.Z(anc[-1]), reversed(cnot_ladder)] + + def narrow_unitary(self) -> np.ndarray: + """Narrowed unitary corresponding to the unitary effect applied on target qubits.""" + return _matrix_for_phasing_state(self.target_bitsize, self.phase_state, -1) diff --git a/cirq-core/cirq/testing/sample_gates_test.py b/cirq-core/cirq/testing/sample_gates_test.py new file mode 100644 index 00000000000..848928c0e33 --- /dev/null +++ b/cirq-core/cirq/testing/sample_gates_test.py @@ -0,0 +1,59 @@ +# Copyright 2023 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest + +import numpy as np +from cirq.testing import sample_gates +import cirq + + +@pytest.mark.parametrize('theta', np.linspace(0, 2 * np.pi, 20)) +def test_phase_using_clean_ancilla(theta: float): + g = sample_gates.PhaseUsingCleanAncilla(theta) + q = cirq.LineQubit(0) + qubit_order = cirq.QubitOrder.explicit([q], fallback=cirq.QubitOrder.DEFAULT) + decomposed_unitary = cirq.Circuit(cirq.decompose_once(g.on(q))).unitary(qubit_order=qubit_order) + phase = np.exp(1j * np.pi * theta) + np.testing.assert_allclose(g.narrow_unitary(), np.array([[1, 0], [0, phase]])) + np.testing.assert_allclose( + decomposed_unitary, + # fmt: off + np.array( + [ + [1 , 0 , 0 , 0], + [0 , phase, 0 , 0], + [0 , 0 , phase, 0], + [0 , 0 , 0 , 1], + ] + ), + # fmt: on + ) + + +@pytest.mark.parametrize( + 'target_bitsize, phase_state', [(1, 0), (1, 1), (2, 0), (2, 1), (2, 2), (2, 3)] +) +@pytest.mark.parametrize('ancilla_bitsize', [1, 4]) +def test_phase_using_dirty_ancilla(target_bitsize, phase_state, ancilla_bitsize): + g = sample_gates.PhaseUsingDirtyAncilla(phase_state, target_bitsize, ancilla_bitsize) + q = cirq.LineQubit.range(target_bitsize) + qubit_order = cirq.QubitOrder.explicit(q, fallback=cirq.QubitOrder.DEFAULT) + decomposed_circuit = cirq.Circuit(cirq.decompose_once(g.on(*q))) + decomposed_unitary = decomposed_circuit.unitary(qubit_order=qubit_order) + phase_matrix = np.eye(2**target_bitsize) + phase_matrix[phase_state, phase_state] = -1 + np.testing.assert_allclose(g.narrow_unitary(), phase_matrix) + np.testing.assert_allclose( + decomposed_unitary, np.kron(phase_matrix, np.eye(2**ancilla_bitsize)), atol=1e-5 + )