Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemented State Preparation Gate #4482

Merged
merged 20 commits into from
Sep 16, 2021
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions cirq-core/cirq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,7 @@
PhasedXPowGate,
PhasedXZGate,
PhaseFlipChannel,
StatePreparationGate,
ProjectorString,
ProjectorSum,
RandomGateChannel,
Expand Down
1 change: 1 addition & 0 deletions cirq-core/cirq/json_resolver_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def two_qubit_matrix_gate(matrix):
'PhasedISwapPowGate': cirq.PhasedISwapPowGate,
'PhasedXPowGate': cirq.PhasedXPowGate,
'PhasedXZGate': cirq.PhasedXZGate,
'StatePreparationGate': cirq.StatePreparationGate,
'ProjectorString': cirq.ProjectorString,
'ProjectorSum': cirq.ProjectorSum,
'RandomGateChannel': cirq.RandomGateChannel,
Expand Down
2 changes: 2 additions & 0 deletions cirq-core/cirq/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,3 +299,5 @@
wait,
WaitGate,
)

from cirq.ops.state_preparation_gate import StatePreparationGate
125 changes: 125 additions & 0 deletions cirq-core/cirq/ops/state_preparation_gate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# Copyright 2021 The Cirq Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Quantum gates to prepare a given target state."""

from typing import Any, Sequence, Dict, List, Tuple, TYPE_CHECKING

import numpy as np

from cirq import protocols
from cirq.ops import raw_types
from cirq.ops.common_channels import ResetChannel
from cirq.ops.matrix_gates import MatrixGate
from cirq._compat import proper_repr

if TYPE_CHECKING:
import cirq


class StatePreparationGate(raw_types.Gate):
"""A unitary qubit gate which resets all qubits to the |0> state
and then prepares the target state."""
AnimeshSinha1309 marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, target_state: np.ndarray, name: str = "StatePreparation") -> None:
"""Initializes a State Preparation gate.

Args:
target_state: The state vector that this gate should prepare.
name: the name of the gate

Raises:
ValueError: if the array is not 1D, or does not have 2**n elements for some integer n.
"""
if len(target_state.shape) != 1:
raise ValueError('`target_state` must be a 1d numpy array.')

n = int(np.round(np.log2(target_state.shape[0] or 1)))
if 2 ** n != target_state.shape[0]:
raise ValueError(f'Matrix width ({target_state.shape[0]}) is not a power of 2')

self._state = target_state.astype(np.complex128) / np.linalg.norm(target_state)
self._num_qubits = n
self._name = name
self._qid_shape = (2,) * n

@staticmethod
def _has_unitary_() -> bool:
"""Checks and returns if the gate has a unitary representation.
It doesn't, since the resetting of the channels is a non-unitary operations,
it involves measurement."""
return False

def _json_dict_(self) -> Dict[str, Any]:
"""Converts the gate object into a serializable dictionary"""
return {
'cirq_type': self.__class__.__name__,
'target_state': self._state.tolist(),
}

@classmethod
def _from_json_dict_(cls, target_state, **kwargs):
"""Recreates the gate object from it's serialized form

Args:
target_state: the state to prepare using this gate
kwargs: other keyword arguments, ignored
"""
return cls(target_state=np.array(target_state))

def _num_qubits_(self):
return self._num_qubits

def _qid_shape_(self) -> Tuple[int, ...]:
return self._qid_shape

def _circuit_diagram_info_(
self, _args: 'cirq.CircuitDiagramInfoArgs'
) -> 'cirq.CircuitDiagramInfo':
"""Returns the information required to draw out the circuit diagram for this gate."""
symbols = (
[self._name]
95-martin-orion marked this conversation as resolved.
Show resolved Hide resolved
if self._num_qubits == 1
else [f'{self._name}[{i+1}]' for i in range(0, self._num_qubits)]
)
return protocols.CircuitDiagramInfo(wire_symbols=symbols)

def _get_unitary_transform(self):
AnimeshSinha1309 marked this conversation as resolved.
Show resolved Hide resolved
initial_basis = np.eye(2 ** self._num_qubits, dtype=np.complex128)
final_basis = [self._state]
for vector in initial_basis:
for new_basis_vector in final_basis:
vector -= np.conj(np.dot(new_basis_vector, vector)) * new_basis_vector
if not np.allclose(vector, 0):
vector /= np.linalg.norm(vector)
final_basis.append(vector)
final_basis = np.stack(final_basis[: initial_basis.shape[0]], axis=1)
return final_basis

def _decompose_(self, qubits: Sequence['cirq.Qid']) -> 'cirq.OP_TREE':
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is only covered as part of the simulation test. Could you add a separate test to verify that cirq.decompose(PrepareState(...)) behaves as expected?

"""Decompose the n-qubit diagonal gates into a Reset channel and a Matrix Gate."""
decomposed_circ: List[Any] = [ResetChannel(qubit.dimension).on(qubit) for qubit in qubits]
final_basis = self._get_unitary_transform()
decomposed_circ.append(MatrixGate(final_basis).on(*qubits))
return decomposed_circ

def __repr__(self) -> str:
return f'cirq.StatePreparationGate({proper_repr(self._state)})'

@property
def state(self):
return self._state

def __eq__(self, other):
return np.allclose(self._state, other.state)
102 changes: 102 additions & 0 deletions cirq-core/cirq/ops/state_preparation_gate_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# Copyright 2021 The Cirq Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import cirq
import pytest


@pytest.mark.parametrize(
'target_state',
np.array(
[
[1, 0, 0, 0],
[1, 0, 0, 1],
[3, 5, 2, 7],
[0.7823, 0.12323, 0.4312, 0.12321],
[23, 43, 12, 19],
[1j, 0, 0, 0],
[1j, 0, 0, 1j],
[1j, -1j, -1j, 1j],
[1 + 1j, 0, 0, 0],
[1 + 1j, 0, 1 + 1j, 0],
]
),
)
def test_state_prep_gate(target_state):
gate = cirq.StatePreparationGate(target_state)
qubits = cirq.LineQubit.range(2)
circuit = cirq.Circuit(
[
cirq.H(qubits[0]),
cirq.CNOT(qubits[0], qubits[1]),
gate(qubits[0], qubits[1]),
]
)
simulator = cirq.Simulator()
result = simulator.simulate(circuit, qubit_order=qubits).final_state_vector
assert np.allclose(result, target_state / np.linalg.norm(target_state))


def test_state_prep_gate_printing():
circuit = cirq.Circuit()
qubits = cirq.LineQubit.range(2)
gate = cirq.StatePreparationGate(np.array([1, 0, 0, 1]) / np.sqrt(2))
circuit.append(cirq.H(qubits[0]))
circuit.append(cirq.CNOT(qubits[0], qubits[1]))
circuit.append(gate(qubits[0], qubits[1]))
cirq.testing.assert_has_diagram(
circuit,
"""
0: ───H───@───StatePreparation[1]───
AnimeshSinha1309 marked this conversation as resolved.
Show resolved Hide resolved
│ │
1: ───────X───StatePreparation[2]───
""",
)


@pytest.mark.parametrize('name', ['Prep', 'S'])
def test_state_prep_gate_printing_with_name(name):
circuit = cirq.Circuit()
qubits = cirq.LineQubit.range(2)
gate = cirq.StatePreparationGate(np.array([1, 0, 0, 1]) / np.sqrt(2), name=name)
circuit.append(cirq.H(qubits[0]))
circuit.append(cirq.CNOT(qubits[0], qubits[1]))
circuit.append(gate(qubits[0], qubits[1]))
cirq.testing.assert_has_diagram(
circuit,
f"""
0: ───H───@───{name}[1]───
│ │
1: ───────X───{name}[2]───
""",
)


def test_gate_params():
state = np.array([1, 0, 0, 0], dtype=np.complex64)
gate = cirq.StatePreparationGate(state)
assert gate.num_qubits() == 2
assert not gate._has_unitary_()
assert (
repr(gate)
== 'cirq.StatePreparationGate(np.array([(1+0j), 0j, 0j, 0j], dtype=np.complex128))'
)


def test_gate_error_handling():
with pytest.raises(ValueError):
95-martin-orion marked this conversation as resolved.
Show resolved Hide resolved
cirq.StatePreparationGate(np.eye(2))
with pytest.raises(ValueError):
cirq.StatePreparationGate(np.ones(shape=5))
25 changes: 25 additions & 0 deletions cirq-core/cirq/protocols/json_test_data/StatePreparationGate.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
{
"cirq_type": "StatePreparationGate",
"target_state": [
{
"cirq_type": "complex",
"real": 1.0,
"imag": 0.0
},
{
"cirq_type": "complex",
"real": 0.0,
"imag": 0.0
},
{
"cirq_type": "complex",
"real": 0.0,
"imag": 0.0
},
{
"cirq_type": "complex",
"real": 0.0,
"imag": 0.0
}
]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
cirq.StatePreparationGate(np.array([(1+0j), 0j, 0j, 0j], dtype=np.complex128))