Skip to content

Commit

Permalink
Add support for allocating qubits in decompose to cirq.unitary (#6112)
Browse files Browse the repository at this point in the history
* Add support for allocating qubits in decompose to cirq.unitary

* fixed apply_unitaries

* fix mypy

* refactored tests

* addressing comments

* added sample_gates_test.py

* Improved sample_gates.py implementation and unitary_protocol tests. Also added docstrings

* fixed lint

* retrigger checks

---------

Co-authored-by: Tanuj Khattar <tanujkhattar@google.com>
  • Loading branch information
NoureldinYosri and tanujkhattar committed Jun 5, 2023
1 parent 99e8a13 commit 9177708
Show file tree
Hide file tree
Showing 7 changed files with 303 additions and 22 deletions.
74 changes: 60 additions & 14 deletions cirq-core/cirq/protocols/apply_unitary_protocol.py
Expand Up @@ -133,6 +133,33 @@ def default(
state = qis.one_hot(index=(0,) * num_qubits, shape=qid_shape, dtype=np.complex128)
return ApplyUnitaryArgs(state, np.empty_like(state), range(num_qubits))

@classmethod
def for_unitary(
cls, num_qubits: Optional[int] = None, *, qid_shape: Optional[Tuple[int, ...]] = None
) -> 'ApplyUnitaryArgs':
"""A default instance corresponding to an identity matrix.
Specify exactly one argument.
Args:
num_qubits: The number of qubits to make space for in the state.
qid_shape: A tuple representing the number of quantum levels of each
qubit the identity matrix applies to. `qid_shape` is (2, 2, 2) for
a three-qubit identity operation tensor.
Raises:
TypeError: If exactly neither `num_qubits` or `qid_shape` is provided or
both are provided.
"""
if (num_qubits is None) == (qid_shape is None):
raise TypeError('Specify exactly one of num_qubits or qid_shape.')
if num_qubits is not None:
qid_shape = (2,) * num_qubits
qid_shape = cast(Tuple[int, ...], qid_shape) # Satisfy mypy
num_qubits = len(qid_shape)
state = qis.eye_tensor(qid_shape, dtype=np.complex128)
return ApplyUnitaryArgs(state, np.empty_like(state), range(num_qubits))

def with_axes_transposed_to_start(self) -> 'ApplyUnitaryArgs':
"""Returns a transposed view of the same arguments.
Expand Down Expand Up @@ -409,19 +436,7 @@ def _strat_apply_unitary_from_apply_unitary(
return _incorporate_result_into_target(args, sub_args, sub_result)


def _strat_apply_unitary_from_unitary(
unitary_value: Any, args: ApplyUnitaryArgs
) -> Optional[np.ndarray]:
# Check for magic method.
method = getattr(unitary_value, '_unitary_', None)
if method is None:
return NotImplemented

# Attempt to get the unitary matrix.
matrix = method()
if matrix is NotImplemented or matrix is None:
return matrix

def _apply_unitary_from_matrix(matrix: np.ndarray, unitary_value: Any, args: ApplyUnitaryArgs):
if args.slices is None:
val_qid_shape = qid_shape_protocol.qid_shape(unitary_value, default=(2,) * len(args.axes))
slices = tuple(slice(0, size) for size in val_qid_shape)
Expand Down Expand Up @@ -450,11 +465,42 @@ def _strat_apply_unitary_from_unitary(
return _incorporate_result_into_target(args, sub_args, sub_result)


def _strat_apply_unitary_from_unitary(
unitary_value: Any, args: ApplyUnitaryArgs
) -> Optional[np.ndarray]:
# Check for magic method.
method = getattr(unitary_value, '_unitary_', None)
if method is None:
return NotImplemented

# Attempt to get the unitary matrix.
matrix = method()
if matrix is NotImplemented or matrix is None:
return matrix

return _apply_unitary_from_matrix(matrix, unitary_value, args)


def _strat_apply_unitary_from_decompose(val: Any, args: ApplyUnitaryArgs) -> Optional[np.ndarray]:
operations, qubits, _ = _try_decompose_into_operations_and_qubits(val)
if operations is None:
return NotImplemented
return apply_unitaries(operations, qubits, args, None)
all_qubits = frozenset([q for op in operations for q in op.qubits])
ancilla = tuple(sorted(all_qubits.difference(qubits)))
if not len(ancilla):
return apply_unitaries(operations, qubits, args, None)
ordered_qubits = ancilla + tuple(qubits)
all_qid_shapes = qid_shape_protocol.qid_shape(ordered_qubits)
result = apply_unitaries(
operations, ordered_qubits, ApplyUnitaryArgs.for_unitary(qid_shape=all_qid_shapes), None
)
if result is None or result is NotImplemented:
return result
result = result.reshape((np.prod(all_qid_shapes, dtype=np.int64), -1))
val_qid_shape = qid_shape_protocol.qid_shape(qubits)
state_vec_length = np.prod(val_qid_shape, dtype=np.int64)
result = result[:state_vec_length, :state_vec_length]
return _apply_unitary_from_matrix(result, val, args)


def apply_unitaries(
Expand Down
50 changes: 50 additions & 0 deletions cirq-core/cirq/protocols/apply_unitary_protocol_test.py
Expand Up @@ -717,3 +717,53 @@ def test_cast_to_complex():
np.ComplexWarning, match='Casting complex values to real discards the imaginary part'
):
cirq.apply_unitary(y0, args)


class NotDecomposableGate(cirq.Gate):
def num_qubits(self):
return 1


class DecomposableGate(cirq.Gate):
def __init__(self, sub_gate: cirq.Gate, allocate_ancilla: bool) -> None:
super().__init__()
self._sub_gate = sub_gate
self._allocate_ancilla = allocate_ancilla

def num_qubits(self):
return 1

def _decompose_(self, qubits):
if self._allocate_ancilla:
yield cirq.Z(cirq.NamedQubit('DecomposableGateQubit'))
yield self._sub_gate(qubits[0])


def test_strat_apply_unitary_from_decompose():
state = np.eye(2, dtype=np.complex128)
args = cirq.ApplyUnitaryArgs(
target_tensor=state, available_buffer=np.zeros_like(state), axes=(0,)
)
np.testing.assert_allclose(
cirq.apply_unitaries(
[DecomposableGate(cirq.X, False)(cirq.LineQubit(0))], [cirq.LineQubit(0)], args
),
[[0, 1], [1, 0]],
)

with pytest.raises(TypeError):
_ = cirq.apply_unitaries(
[DecomposableGate(NotDecomposableGate(), True)(cirq.LineQubit(0))],
[cirq.LineQubit(0)],
args,
)


def test_unitary_construction():
with pytest.raises(TypeError):
_ = cirq.ApplyUnitaryArgs.for_unitary()

np.testing.assert_allclose(
cirq.ApplyUnitaryArgs.for_unitary(num_qubits=3).target_tensor,
cirq.eye_tensor((2,) * 3, dtype=np.complex128),
)
24 changes: 16 additions & 8 deletions cirq-core/cirq/protocols/unitary_protocol.py
Expand Up @@ -17,7 +17,6 @@
import numpy as np
from typing_extensions import Protocol

from cirq import qis
from cirq._doc import doc_private
from cirq.protocols import qid_shape_protocol
from cirq.protocols.apply_unitary_protocol import ApplyUnitaryArgs, apply_unitaries
Expand Down Expand Up @@ -162,9 +161,7 @@ def _strat_unitary_from_apply_unitary(val: Any) -> Optional[np.ndarray]:
return NotImplemented

# Apply unitary effect to an identity matrix.
state = qis.eye_tensor(val_qid_shape, dtype=np.complex128)
buffer = np.empty_like(state)
result = method(ApplyUnitaryArgs(state, buffer, range(len(val_qid_shape))))
result = method(ApplyUnitaryArgs.for_unitary(qid_shape=val_qid_shape))

if result is NotImplemented or result is None:
return result
Expand All @@ -179,15 +176,26 @@ def _strat_unitary_from_decompose(val: Any) -> Optional[np.ndarray]:
if operations is None:
return NotImplemented

all_qubits = frozenset(q for op in operations for q in op.qubits)
work_qubits = frozenset(qubits)
ancillas = tuple(sorted(all_qubits.difference(work_qubits)))

ordered_qubits = ancillas + tuple(qubits)
val_qid_shape = qid_shape_protocol.qid_shape(ancillas) + val_qid_shape

# Apply sub-operations' unitary effects to an identity matrix.
state = qis.eye_tensor(val_qid_shape, dtype=np.complex128)
buffer = np.empty_like(state)
result = apply_unitaries(
operations, qubits, ApplyUnitaryArgs(state, buffer, range(len(val_qid_shape))), None
operations, ordered_qubits, ApplyUnitaryArgs.for_unitary(qid_shape=val_qid_shape), None
)

# Package result.
if result is None:
return None

state_len = np.prod(val_qid_shape, dtype=np.int64)
return result.reshape((state_len, state_len))
result = result.reshape((state_len, state_len))
# Assuming borrowable qubits are restored to their original state and
# clean qubits restord to the zero state then the desired unitary is
# the upper left square.
work_state_len = np.prod(val_qid_shape[len(ancillas) :], dtype=np.int64)
return result[:work_state_len, :work_state_len]
37 changes: 37 additions & 0 deletions cirq-core/cirq/protocols/unitary_protocol_test.py
Expand Up @@ -17,6 +17,7 @@
import pytest

import cirq
from cirq import testing

m0: np.ndarray = np.array([])
# yapf: disable
Expand Down Expand Up @@ -188,6 +189,42 @@ def test_has_unitary():
assert not cirq.has_unitary(FullyImplemented(False))


def _test_gate_that_allocates_qubits(gate):
from cirq.protocols.unitary_protocol import _strat_unitary_from_decompose

op = gate.on(*cirq.LineQubit.range(cirq.num_qubits(gate)))
moment = cirq.Moment(op)
circuit = cirq.FrozenCircuit(op)
circuit_op = cirq.CircuitOperation(circuit)
for val in [gate, op, moment, circuit, circuit_op]:
unitary_from_strat = _strat_unitary_from_decompose(val)
assert unitary_from_strat is not None
np.testing.assert_allclose(unitary_from_strat, gate.narrow_unitary())


@pytest.mark.parametrize('theta', np.linspace(0, 2 * np.pi, 10))
@pytest.mark.parametrize('phase_state', [0, 1])
@pytest.mark.parametrize('target_bitsize', [1, 2, 3])
@pytest.mark.parametrize('ancilla_bitsize', [1, 4])
def test_decompose_gate_that_allocates_clean_qubits(
theta: float, phase_state: int, target_bitsize: int, ancilla_bitsize: int
):

gate = testing.PhaseUsingCleanAncilla(theta, phase_state, target_bitsize, ancilla_bitsize)
_test_gate_that_allocates_qubits(gate)


@pytest.mark.parametrize('phase_state', [0, 1])
@pytest.mark.parametrize('target_bitsize', [1, 2, 3])
@pytest.mark.parametrize('ancilla_bitsize', [1, 4])
def test_decompose_gate_that_allocates_dirty_qubits(
phase_state: int, target_bitsize: int, ancilla_bitsize: int
):

gate = testing.PhaseUsingDirtyAncilla(phase_state, target_bitsize, ancilla_bitsize)
_test_gate_that_allocates_qubits(gate)


def test_decompose_and_get_unitary():
from cirq.protocols.unitary_protocol import _strat_unitary_from_decompose

Expand Down
2 changes: 2 additions & 0 deletions cirq-core/cirq/testing/__init__.py
Expand Up @@ -107,3 +107,5 @@
)

from cirq.testing.sample_circuits import nonoptimal_toffoli_circuit

from cirq.testing.sample_gates import PhaseUsingCleanAncilla, PhaseUsingDirtyAncilla
79 changes: 79 additions & 0 deletions cirq-core/cirq/testing/sample_gates.py
@@ -0,0 +1,79 @@
# Copyright 2023 The Cirq Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import dataclasses

import cirq
import numpy as np
from cirq import ops, qis


def _matrix_for_phasing_state(num_qubits, phase_state, phase):
matrix = qis.eye_tensor((2,) * num_qubits, dtype=np.complex128)
matrix = matrix.reshape((2**num_qubits, 2**num_qubits))
matrix[phase_state, phase_state] = phase
print(num_qubits, phase_state, phase)
print(matrix)
return matrix


@dataclasses.dataclass(frozen=True)
class PhaseUsingCleanAncilla(ops.Gate):
r"""Phases the state $|phase_state>$ by $\exp(1j * \pi * \theta)$ using one clean ancilla."""

theta: float
phase_state: int = 1
target_bitsize: int = 1
ancilla_bitsize: int = 1

def _num_qubits_(self):
return self.target_bitsize

def _decompose_(self, qubits):
anc = ops.NamedQubit.range(self.ancilla_bitsize, prefix="anc")
cv = [int(x) for x in f'{self.phase_state:0{self.target_bitsize}b}']
cnot_ladder = [cirq.CNOT(anc[i - 1], anc[i]) for i in range(1, self.ancilla_bitsize)]

yield ops.X(anc[0]).controlled_by(*qubits, control_values=cv)
yield [cnot_ladder, ops.Z(anc[-1]) ** self.theta, reversed(cnot_ladder)]
yield ops.X(anc[0]).controlled_by(*qubits, control_values=cv)

def narrow_unitary(self) -> np.ndarray:
"""Narrowed unitary corresponding to the unitary effect applied on target qubits."""
phase = np.exp(1j * np.pi * self.theta)
return _matrix_for_phasing_state(self.target_bitsize, self.phase_state, phase)


@dataclasses.dataclass(frozen=True)
class PhaseUsingDirtyAncilla(ops.Gate):
r"""Phases the state $|phase_state>$ by -1 using one dirty ancilla."""

phase_state: int = 1
target_bitsize: int = 1
ancilla_bitsize: int = 1

def _num_qubits_(self):
return self.target_bitsize

def _decompose_(self, qubits):
anc = ops.NamedQubit.range(self.ancilla_bitsize, prefix="anc")
cv = [int(x) for x in f'{self.phase_state:0{self.target_bitsize}b}']
cnot_ladder = [cirq.CNOT(anc[i - 1], anc[i]) for i in range(1, self.ancilla_bitsize)]
yield ops.X(anc[0]).controlled_by(*qubits, control_values=cv)
yield [cnot_ladder, ops.Z(anc[-1]), reversed(cnot_ladder)]
yield ops.X(anc[0]).controlled_by(*qubits, control_values=cv)
yield [cnot_ladder, ops.Z(anc[-1]), reversed(cnot_ladder)]

def narrow_unitary(self) -> np.ndarray:
"""Narrowed unitary corresponding to the unitary effect applied on target qubits."""
return _matrix_for_phasing_state(self.target_bitsize, self.phase_state, -1)
59 changes: 59 additions & 0 deletions cirq-core/cirq/testing/sample_gates_test.py
@@ -0,0 +1,59 @@
# Copyright 2023 The Cirq Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest

import numpy as np
from cirq.testing import sample_gates
import cirq


@pytest.mark.parametrize('theta', np.linspace(0, 2 * np.pi, 20))
def test_phase_using_clean_ancilla(theta: float):
g = sample_gates.PhaseUsingCleanAncilla(theta)
q = cirq.LineQubit(0)
qubit_order = cirq.QubitOrder.explicit([q], fallback=cirq.QubitOrder.DEFAULT)
decomposed_unitary = cirq.Circuit(cirq.decompose_once(g.on(q))).unitary(qubit_order=qubit_order)
phase = np.exp(1j * np.pi * theta)
np.testing.assert_allclose(g.narrow_unitary(), np.array([[1, 0], [0, phase]]))
np.testing.assert_allclose(
decomposed_unitary,
# fmt: off
np.array(
[
[1 , 0 , 0 , 0],
[0 , phase, 0 , 0],
[0 , 0 , phase, 0],
[0 , 0 , 0 , 1],
]
),
# fmt: on
)


@pytest.mark.parametrize(
'target_bitsize, phase_state', [(1, 0), (1, 1), (2, 0), (2, 1), (2, 2), (2, 3)]
)
@pytest.mark.parametrize('ancilla_bitsize', [1, 4])
def test_phase_using_dirty_ancilla(target_bitsize, phase_state, ancilla_bitsize):
g = sample_gates.PhaseUsingDirtyAncilla(phase_state, target_bitsize, ancilla_bitsize)
q = cirq.LineQubit.range(target_bitsize)
qubit_order = cirq.QubitOrder.explicit(q, fallback=cirq.QubitOrder.DEFAULT)
decomposed_circuit = cirq.Circuit(cirq.decompose_once(g.on(*q)))
decomposed_unitary = decomposed_circuit.unitary(qubit_order=qubit_order)
phase_matrix = np.eye(2**target_bitsize)
phase_matrix[phase_state, phase_state] = -1
np.testing.assert_allclose(g.narrow_unitary(), phase_matrix)
np.testing.assert_allclose(
decomposed_unitary, np.kron(phase_matrix, np.eye(2**ancilla_bitsize)), atol=1e-5
)

0 comments on commit 9177708

Please sign in to comment.