Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for allocating qubits in decompose to cirq.unitary #6112

Merged
merged 13 commits into from
Jun 5, 2023
53 changes: 41 additions & 12 deletions cirq-core/cirq/protocols/apply_unitary_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,12 @@ def subspace_index(
qid_shape=self.target_tensor.shape,
)

@classmethod
def for_unitary(cls, qid_shapes: Tuple[int, ...]) -> 'ApplyUnitaryArgs':
NoureldinYosri marked this conversation as resolved.
Show resolved Hide resolved
state = qis.eye_tensor(qid_shapes, dtype=np.complex128)
buffer = np.empty_like(state)
return ApplyUnitaryArgs(state, buffer, range(len(qid_shapes)))


class SupportsConsistentApplyUnitary(Protocol):
"""An object that can be efficiently left-multiplied into tensors."""
Expand Down Expand Up @@ -274,6 +280,10 @@ def _apply_unitary_(
"""


def _strat_apply_unitary_from_unitary_(val: Any, args: ApplyUnitaryArgs) -> Optional[np.ndarray]:
NoureldinYosri marked this conversation as resolved.
Show resolved Hide resolved
return _strat_apply_unitary_from_unitary(val, args, matrix=None)


def apply_unitary(
unitary_value: Any,
args: ApplyUnitaryArgs,
Expand Down Expand Up @@ -346,14 +356,14 @@ def apply_unitary(
if len(args.axes) <= 4:
strats = [
_strat_apply_unitary_from_apply_unitary,
_strat_apply_unitary_from_unitary,
_strat_apply_unitary_from_unitary_,
_strat_apply_unitary_from_decompose,
]
else:
strats = [
_strat_apply_unitary_from_apply_unitary,
_strat_apply_unitary_from_decompose,
_strat_apply_unitary_from_unitary,
_strat_apply_unitary_from_unitary_,
]
if not allow_decompose:
strats.remove(_strat_apply_unitary_from_decompose)
Expand Down Expand Up @@ -410,17 +420,18 @@ def _strat_apply_unitary_from_apply_unitary(


def _strat_apply_unitary_from_unitary(
unitary_value: Any, args: ApplyUnitaryArgs
unitary_value: Any, args: ApplyUnitaryArgs, matrix: Optional[np.ndarray] = None
) -> Optional[np.ndarray]:
# Check for magic method.
method = getattr(unitary_value, '_unitary_', None)
if method is None:
return NotImplemented
if matrix is None:
# Check for magic method.
method = getattr(unitary_value, '_unitary_', None)
if method is None:
return NotImplemented
NoureldinYosri marked this conversation as resolved.
Show resolved Hide resolved

# Attempt to get the unitary matrix.
matrix = method()
if matrix is NotImplemented or matrix is None:
return matrix
# Attempt to get the unitary matrix.
matrix = method()
if matrix is NotImplemented or matrix is None:
return matrix
NoureldinYosri marked this conversation as resolved.
Show resolved Hide resolved

if args.slices is None:
val_qid_shape = qid_shape_protocol.qid_shape(unitary_value, default=(2,) * len(args.axes))
Expand Down Expand Up @@ -454,7 +465,25 @@ def _strat_apply_unitary_from_decompose(val: Any, args: ApplyUnitaryArgs) -> Opt
operations, qubits, _ = _try_decompose_into_operations_and_qubits(val)
if operations is None:
return NotImplemented
return apply_unitaries(operations, qubits, args, None)
all_qubits = frozenset([q for op in operations for q in op.qubits])
ancilla = tuple(sorted(all_qubits.difference(qubits)))
if not len(ancilla):
return apply_unitaries(operations, qubits, args, None)
ordered_qubits = ancilla + tuple(qubits)
all_qid_shapes = qid_shape_protocol.qid_shape(ordered_qubits)
result = apply_unitaries(
operations,
ordered_qubits,
ApplyUnitaryArgs.for_unitary(qid_shape_protocol.qid_shape(ordered_qubits)),
NoureldinYosri marked this conversation as resolved.
Show resolved Hide resolved
None,
)
if result is None or result is NotImplemented:
return result
result = result.reshape((np.prod(all_qid_shapes, dtype=np.int64), -1))
val_qid_shape = qid_shape_protocol.qid_shape(qubits)
state_vec_length = np.prod(val_qid_shape, dtype=np.int64)
result = result[:state_vec_length, :state_vec_length]
return _strat_apply_unitary_from_unitary(val, args, matrix=result)
tanujkhattar marked this conversation as resolved.
Show resolved Hide resolved
NoureldinYosri marked this conversation as resolved.
Show resolved Hide resolved


def apply_unitaries(
Expand Down
21 changes: 21 additions & 0 deletions cirq-core/cirq/protocols/apply_unitary_protocol_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import cirq
from cirq.protocols.apply_unitary_protocol import _incorporate_result_into_target
from cirq import testing


def test_apply_unitary_presence_absence():
Expand Down Expand Up @@ -717,3 +718,23 @@ def test_cast_to_complex():
np.ComplexWarning, match='Casting complex values to real discards the imaginary part'
):
cirq.apply_unitary(y0, args)


def test_strat_apply_unitary_from_decompose():
state = np.eye(2, dtype=np.complex128)
args = cirq.ApplyUnitaryArgs(
target_tensor=state, available_buffer=np.zeros_like(state), axes=(0,)
)
np.testing.assert_allclose(
cirq.apply_unitaries(
[testing.DecomposableGate(cirq.X, False)(cirq.LineQubit(0))], [cirq.LineQubit(0)], args
),
[[0, 1], [1, 0]],
)

with pytest.raises(TypeError):
_ = cirq.apply_unitaries(
[testing.DecomposableGate(testing.NotDecomposableGate(), True)(cirq.LineQubit(0))],
[cirq.LineQubit(0)],
args,
)
24 changes: 16 additions & 8 deletions cirq-core/cirq/protocols/unitary_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import numpy as np
from typing_extensions import Protocol

from cirq import qis
from cirq._doc import doc_private
from cirq.protocols import qid_shape_protocol
from cirq.protocols.apply_unitary_protocol import ApplyUnitaryArgs, apply_unitaries
Expand Down Expand Up @@ -162,9 +161,7 @@ def _strat_unitary_from_apply_unitary(val: Any) -> Optional[np.ndarray]:
return NotImplemented

# Apply unitary effect to an identity matrix.
state = qis.eye_tensor(val_qid_shape, dtype=np.complex128)
buffer = np.empty_like(state)
result = method(ApplyUnitaryArgs(state, buffer, range(len(val_qid_shape))))
result = method(ApplyUnitaryArgs.for_unitary(val_qid_shape))

if result is NotImplemented or result is None:
return result
Expand All @@ -179,15 +176,26 @@ def _strat_unitary_from_decompose(val: Any) -> Optional[np.ndarray]:
if operations is None:
return NotImplemented

all_qubits = frozenset(q for op in operations for q in op.qubits)
work_qubits = frozenset(qubits)
ancillas = tuple(sorted(all_qubits.difference(work_qubits)))

ordered_qubits = ancillas + tuple(qubits)
val_qid_shape = qid_shape_protocol.qid_shape(ancillas) + val_qid_shape

# Apply sub-operations' unitary effects to an identity matrix.
state = qis.eye_tensor(val_qid_shape, dtype=np.complex128)
buffer = np.empty_like(state)
result = apply_unitaries(
operations, qubits, ApplyUnitaryArgs(state, buffer, range(len(val_qid_shape))), None
operations, ordered_qubits, ApplyUnitaryArgs.for_unitary(val_qid_shape), None
)

# Package result.
if result is None:
return None

state_len = np.prod(val_qid_shape, dtype=np.int64)
return result.reshape((state_len, state_len))
result = result.reshape((state_len, state_len))
# Assuming borrowable qubits are restored to their original state and
# clean qubits restord to the zero state then the desired unitary is
# the upper left square.
work_state_len = np.prod(val_qid_shape[len(ancillas) :], dtype=np.int64)
return result[:work_state_len, :work_state_len]
37 changes: 36 additions & 1 deletion cirq-core/cirq/protocols/unitary_protocol_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from typing import cast, Optional

import numpy as np
import pytest

import cirq
from cirq import testing

m0: np.ndarray = np.array([])
# yapf: disable
Expand Down Expand Up @@ -188,6 +189,31 @@ def test_has_unitary():
assert not cirq.has_unitary(FullyImplemented(False))


@pytest.mark.parametrize('theta', np.linspace(0, 2 * np.pi, 10))
def test_decompose_gate_that_allocates_qubits(theta: float):
from cirq.protocols.unitary_protocol import _strat_unitary_from_decompose

gate = testing.GateThatAllocatesAQubit(theta)
np.testing.assert_allclose(
cast(np.ndarray, _strat_unitary_from_decompose(gate)), gate.target_unitary()
)
np.testing.assert_allclose(
cast(np.ndarray, _strat_unitary_from_decompose(gate(a))), gate.target_unitary()
)


@pytest.mark.parametrize('theta', np.linspace(0, 2 * np.pi, 10))
@pytest.mark.parametrize('n', [*range(1, 6)])
def test_recusive_decomposition(n: int, theta: float):
from cirq.protocols.unitary_protocol import _strat_unitary_from_decompose

g1 = testing.GateThatDecomposesIntoNGates(n, cirq.H, theta)
g2 = testing.GateThatDecomposesIntoNGates(n, g1, theta)
np.testing.assert_allclose(
cast(np.ndarray, _strat_unitary_from_decompose(g2)), g2.target_unitary()
)


def test_decompose_and_get_unitary():
from cirq.protocols.unitary_protocol import _strat_unitary_from_decompose

Expand All @@ -201,6 +227,15 @@ def test_decompose_and_get_unitary():
np.testing.assert_allclose(_strat_unitary_from_decompose(DummyComposite()), np.eye(1))
np.testing.assert_allclose(_strat_unitary_from_decompose(OtherComposite()), m2)

np.testing.assert_allclose(
_strat_unitary_from_decompose(testing.GateThatAllocatesTwoQubits()),
testing.GateThatAllocatesTwoQubits.target_unitary(),
)
np.testing.assert_allclose(
_strat_unitary_from_decompose(testing.GateThatAllocatesTwoQubits().on(a, b)),
testing.GateThatAllocatesTwoQubits.target_unitary(),
)


def test_decomposed_has_unitary():
# Gates
Expand Down
8 changes: 8 additions & 0 deletions cirq-core/cirq/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,3 +107,11 @@
)

from cirq.testing.sample_circuits import nonoptimal_toffoli_circuit

from cirq.testing.sample_gates import (
DecomposableGate,
NotDecomposableGate,
GateThatAllocatesAQubit,
GateThatAllocatesTwoQubits,
GateThatDecomposesIntoNGates,
)
99 changes: 99 additions & 0 deletions cirq-core/cirq/testing/sample_gates.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Copyright 2023 The Cirq Developers
NoureldinYosri marked this conversation as resolved.
Show resolved Hide resolved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import numpy as np
from cirq import ops


class NotDecomposableGate(ops.Gate):
def num_qubits(self):
return 1


class DecomposableGate(ops.Gate):
def __init__(self, sub_gate: ops.Gate, allocate_ancilla: bool) -> None:
super().__init__()
self._sub_gate = sub_gate
self._allocate_ancilla = allocate_ancilla

def num_qubits(self):
return 1

def _decompose_(self, qubits):
if self._allocate_ancilla:
yield ops.Z(ops.NamedQubit('DecomposableGateQubit'))
yield self._sub_gate(qubits[0])


class GateThatAllocatesAQubit(ops.Gate):
def __init__(self, theta: float) -> None:
super().__init__()
self._theta = theta

def _num_qubits_(self):
return 1

def _decompose_(self, q):
anc = ops.NamedQubit("anc")
yield ops.CX(*q, anc)
yield (ops.Z**self._theta)(anc)
yield ops.CX(*q, anc)

def target_unitary(self) -> np.ndarray:
return np.array([[1, 0], [0, (-1 + 0j) ** self._theta]])


class GateThatAllocatesTwoQubits(ops.Gate):
def _num_qubits_(self):
return 2

def _decompose_(self, qs):
q0, q1 = qs
anc = ops.NamedQubit.range(2, prefix='two_ancillas_')

yield ops.X(anc[0])
yield ops.CX(q0, anc[0])
yield (ops.Y)(anc[0])
yield ops.CX(q0, anc[0])

yield ops.CX(q1, anc[1])
yield (ops.Z)(anc[1])
yield ops.CX(q1, anc[1])

@classmethod
def target_unitary(cls) -> np.ndarray:
# Unitary = (-j I_2) \otimes Z
return np.array([[-1j, 0, 0, 0], [0, 1j, 0, 0], [0, 0, 1j, 0], [0, 0, 0, -1j]])


class GateThatDecomposesIntoNGates(ops.Gate):
def __init__(self, n: int, sub_gate: ops.Gate, theta: float) -> None:
super().__init__()
self._n = n
self._subgate = sub_gate
self._name = str(sub_gate)
self._theta = theta

def _num_qubits_(self) -> int:
return self._n

def _decompose_(self, qs):
ancilla = ops.NamedQubit.range(self._n, prefix=self._name)
yield self._subgate.on_each(ancilla)
yield (ops.Z**self._theta).on_each(qs)
yield self._subgate.on_each(ancilla)

def target_unitary(self) -> np.ndarray:
U = np.array([[1, 0], [0, (-1 + 0j) ** self._theta]])
return functools.reduce(np.kron, [U] * self._n)