Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Base class for quantum states #5065

Merged
merged 50 commits into from Mar 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
00b7d45
Extract BufferedDensityMatrix from ActOnDensityMatrixArgs
daxfohl Feb 10, 2022
69574e7
state vector
daxfohl Feb 11, 2022
efa5e78
clean up code
daxfohl Feb 13, 2022
160cb4d
clean up code
daxfohl Feb 13, 2022
3097da4
clean up code
daxfohl Feb 13, 2022
6a323d6
format
daxfohl Feb 13, 2022
107c8e2
docs
daxfohl Feb 13, 2022
52c7546
test
daxfohl Feb 13, 2022
6eca8d9
coverage
daxfohl Feb 13, 2022
1a5b3f3
Merge branch 'master' into quantumstate
daxfohl Feb 13, 2022
a8b0ac4
Merge branch 'quantumstate' into quantumstate-sv
daxfohl Feb 13, 2022
1cf0323
Merge branch 'master' into quantumstate
daxfohl Feb 23, 2022
97ad2be
improve state vector
daxfohl Feb 24, 2022
9e623cd
improve state vector
daxfohl Feb 24, 2022
eabee77
replace deleted functions
daxfohl Feb 24, 2022
57177f7
replace deleted functions
daxfohl Feb 24, 2022
074ad2b
replace deleted functions
daxfohl Feb 24, 2022
f1fa7df
replace deleted functions
daxfohl Feb 24, 2022
a794315
Merge branch 'quantumstate-sv' into quantumstatem
daxfohl Feb 24, 2022
b9f6679
lint
daxfohl Feb 24, 2022
d7100e2
mps quantum state
daxfohl Feb 24, 2022
9165ab8
mps quantum state
daxfohl Feb 24, 2022
ee30b09
mps quantum state
daxfohl Feb 24, 2022
713cc94
mps quantum state
daxfohl Feb 24, 2022
4ce8a7b
mps quantum state
daxfohl Feb 24, 2022
4803e68
mps quantum state
daxfohl Feb 24, 2022
2ba1b14
mps quantum state
daxfohl Feb 24, 2022
9df8ef3
mps quantum state
daxfohl Feb 24, 2022
95fa8be
quantum base state
daxfohl Feb 25, 2022
c2a5f40
fix up clifford
daxfohl Feb 25, 2022
2dbd53b
mypy
daxfohl Feb 25, 2022
56d0b5e
format
daxfohl Feb 25, 2022
d783087
default sample
daxfohl Feb 25, 2022
bc4e04b
coverage
daxfohl Feb 25, 2022
b43bdb1
Merge branch 'quantumstate' into quantumstate-base
daxfohl Feb 25, 2022
180bda5
remove dupe functions
daxfohl Feb 25, 2022
d8e58ec
lint
daxfohl Feb 25, 2022
115a396
Merge branch 'master' into quantumstate
daxfohl Feb 25, 2022
fe19260
Add ActOnStabilizerArgs to should_not_be_serialized
daxfohl Feb 25, 2022
066ca17
Merge branch 'quantumstate' into quantumstate-base2
daxfohl Feb 25, 2022
7ad298b
fix merge errors
daxfohl Feb 25, 2022
9b2f629
Merge branch 'quantumstate' into quantumstate-base
daxfohl Feb 25, 2022
0815848
coverage
daxfohl Mar 9, 2022
7a7d838
Update tests
daxfohl Mar 13, 2022
d3e33cf
mypy
daxfohl Mar 13, 2022
762dc41
cover
daxfohl Mar 13, 2022
02e6c96
Merge branch 'master' into quantumstatebasem
daxfohl Mar 21, 2022
ac6ff53
Merge branch 'master' into quantumstate-base
CirqBot Mar 25, 2022
9852ff7
Merge branch 'master' into quantumstate-base
daxfohl Mar 26, 2022
0b78979
Merge branch 'master' into quantumstate-base
CirqBot Mar 28, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions cirq-core/cirq/__init__.py
Expand Up @@ -433,6 +433,7 @@
operation_to_superoperator,
QUANTUM_STATE_LIKE,
QuantumState,
QuantumStateRepresentation,
quantum_state,
STATE_VECTOR_LIKE,
StabilizerState,
Expand Down
40 changes: 14 additions & 26 deletions cirq-core/cirq/contrib/quimb/mps_simulator.py
Expand Up @@ -24,7 +24,7 @@
import numpy as np
import quimb.tensor as qtn

from cirq import devices, protocols, value
from cirq import devices, protocols, qis, value
from cirq._compat import deprecated
from cirq.sim import simulator_base
from cirq.sim.act_on_args import ActOnArgs
Expand Down Expand Up @@ -220,7 +220,7 @@ def _simulator_state(self):


@value.value_equality
class _MPSHandler:
class _MPSHandler(qis.QuantumStateRepresentation):
"""Quantum state of the MPS simulation."""

def __init__(
Expand Down Expand Up @@ -604,21 +604,24 @@ def __init__(
Raises:
ValueError: If the grouping does not cover the qubits.
"""
qubit_map = {q: i for i, q in enumerate(qubits)}
final_grouping = qubit_map if grouping is None else grouping
if final_grouping.keys() != qubit_map.keys():
raise ValueError('Grouping must cover exactly the qubits.')
state = _MPSHandler.create(
initial_state=initial_state,
qid_shape=tuple(q.dimension for q in qubits),
simulation_options=simulation_options,
grouping={qubit_map[k]: v for k, v in final_grouping.items()},
)
super().__init__(
state=state,
prng=prng,
qubits=qubits,
log_of_measurement_results=log_of_measurement_results,
classical_data=classical_data,
)
final_grouping = self.qubit_map if grouping is None else grouping
if final_grouping.keys() != self.qubit_map.keys():
raise ValueError('Grouping must cover exactly the qubits.')
self._state = _MPSHandler.create(
initial_state=initial_state,
qid_shape=tuple(q.dimension for q in qubits),
simulation_options=simulation_options,
grouping={self.qubit_map[k]: v for k, v in final_grouping.items()},
)
self._state: _MPSHandler = state

def i_str(self, i: int) -> str:
# Returns the index name for the i'th qid.
Expand All @@ -636,9 +639,6 @@ def __str__(self) -> str:
def _value_equality_values_(self) -> Any:
return self.qubits, self._state

def _on_copy(self, target: 'MPSState', deep_copy_buffers: bool = True):
target._state = self._state.copy(deep_copy_buffers)

def state_vector(self) -> np.ndarray:
"""Returns the full state vector.

Expand Down Expand Up @@ -709,15 +709,3 @@ def perform_measurement(
tolerance specified in simulation options.
"""
return self._state._measure(self.get_axes(qubits), prng, collapse_state_vector)

def _perform_measurement(self, qubits: Sequence['cirq.Qid']) -> List[int]:
"""Measures the axes specified by the simulator."""
return self._state.measure(self.get_axes(qubits), self.prng)

def sample(
self,
qubits: Sequence['cirq.Qid'],
repetitions: int = 1,
seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None,
) -> np.ndarray:
return self._state.sample(self.get_axes(qubits), repetitions, seed)
2 changes: 1 addition & 1 deletion cirq-core/cirq/qis/__init__.py
Expand Up @@ -25,7 +25,7 @@
superoperator_to_kraus,
)

from cirq.qis.clifford_tableau import CliffordTableau, StabilizerState
from cirq.qis.clifford_tableau import CliffordTableau, QuantumStateRepresentation, StabilizerState

from cirq.qis.measures import (
entanglement_fidelity,
Expand Down
93 changes: 89 additions & 4 deletions cirq-core/cirq/qis/clifford_tableau.py
Expand Up @@ -13,17 +13,97 @@
# limitations under the License.

import abc
from typing import Any, Dict, List, TYPE_CHECKING
from typing import Any, Dict, List, Sequence, Tuple, TYPE_CHECKING, TypeVar
import numpy as np

from cirq import protocols
from cirq import protocols, value
from cirq.value import big_endian_int_to_digits, linear_dict

if TYPE_CHECKING:
import cirq

TSelf = TypeVar('TSelf', bound='QuantumStateRepresentation')

class StabilizerState(metaclass=abc.ABCMeta):

class QuantumStateRepresentation(metaclass=abc.ABCMeta):
@abc.abstractmethod
def copy(self: TSelf, deep_copy_buffers: bool = True) -> TSelf:
"""Creates a copy of the object.
Args:
deep_copy_buffers: If True, buffers will also be deep-copied.
Otherwise the copy will share a reference to the original object's
buffers.
Returns:
A copied instance.
"""

@abc.abstractmethod
def measure(
self, axes: Sequence[int], seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None
) -> List[int]:
"""Measures the state.

Args:
axes: The axes to measure.
seed: The random number seed to use.
Returns:
The measurements in order.
"""

def sample(
self,
axes: Sequence[int],
repetitions: int = 1,
seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None,
) -> np.ndarray:
"""Samples the state. Subclasses can override with more performant method.

Args:
axes: The axes to sample.
repetitions: The number of samples to make.
seed: The random number seed to use.
Returns:
The samples in order.
"""
prng = value.parse_random_state(seed)
measurements = []
for _ in range(repetitions):
state = self.copy()
measurements.append(state.measure(axes, prng))
return np.array(measurements, dtype=bool)

def kron(self: TSelf, other: TSelf) -> TSelf:
"""Joins two state spaces together."""
raise NotImplementedError()

def factor(
self: TSelf, axes: Sequence[int], *, validate=True, atol=1e-07
) -> Tuple[TSelf, TSelf]:
"""Splits two state spaces after a measurement or reset."""
raise NotImplementedError()

def reindex(self: TSelf, axes: Sequence[int]) -> TSelf:
"""Physically reindexes the state by the new basis.
Args:
axes: The desired axis order.
Returns:
The state with qubit order transposed and underlying representation
updated.
"""
raise NotImplementedError()

@property
def supports_factor(self) -> bool:
"""Subclasses that allow factorization should override this."""
return False

@property
def can_represent_mixed_states(self) -> bool:
"""Subclasses that can represent mixed states should override this."""
return False


class StabilizerState(QuantumStateRepresentation, metaclass=abc.ABCMeta):
"""Interface for quantum stabilizer state representations.

This interface is used for CliffordTableau and StabilizerChForm quantum
Expand Down Expand Up @@ -222,7 +302,7 @@ def __eq__(self, other):
def __copy__(self) -> 'CliffordTableau':
return self.copy()

def copy(self) -> 'CliffordTableau':
def copy(self, deep_copy_buffers: bool = True) -> 'CliffordTableau':
state = CliffordTableau(self.n)
state.rs = self.rs.copy()
state.xs = self.xs.copy()
Expand Down Expand Up @@ -578,3 +658,8 @@ def apply_cx(

def apply_global_phase(self, coefficient: linear_dict.Scalar):
pass

def measure(
self, axes: Sequence[int], seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None
) -> List[int]:
return [self._measure(axis, seed) for axis in axes]
49 changes: 38 additions & 11 deletions cirq-core/cirq/sim/act_on_args.py
Expand Up @@ -12,9 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Objects and methods for acting efficiently on a state tensor."""
import abc
import copy
import inspect
import warnings
from typing import (
Any,
cast,
Expand All @@ -28,7 +28,6 @@
TYPE_CHECKING,
Tuple,
)
import warnings

import numpy as np

Expand Down Expand Up @@ -59,6 +58,7 @@ def __init__(
log_of_measurement_results: Optional[Dict[str, List[int]]] = None,
ignore_measurement_results: bool = False,
classical_data: Optional['cirq.ClassicalDataStore'] = None,
state: Optional['cirq.QuantumStateRepresentation'] = None,
):
"""Inits ActOnArgs.

Expand All @@ -76,6 +76,7 @@ def __init__(
simulators that can represent mixed states.
classical_data: The shared classical data container for this
simulation.
state: The underlying quantum state of the simulation.
"""
if prng is None:
prng = cast(np.random.RandomState, np.random)
Expand All @@ -90,6 +91,7 @@ def __init__(
}
)
self._ignore_measurement_results = ignore_measurement_results
self._state = state

@property
def prng(self) -> np.random.RandomState:
Expand Down Expand Up @@ -148,10 +150,21 @@ def measure(self, qubits: Sequence['cirq.Qid'], key: str, invert_mask: Sequence[
def get_axes(self, qubits: Sequence['cirq.Qid']) -> List[int]:
return [self.qubit_map[q] for q in qubits]

@abc.abstractmethod
def _perform_measurement(self, qubits: Sequence['cirq.Qid']) -> List[int]:
"""Child classes that perform measurements should implement this with
the implementation."""
"""Delegates the call to measure the density matrix."""
if self._state is not None:
return self._state.measure(self.get_axes(qubits), self.prng)
raise NotImplementedError()

def sample(
self,
qubits: Sequence['cirq.Qid'],
repetitions: int = 1,
seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None,
) -> np.ndarray:
if self._state is not None:
return self._state.sample(self.get_axes(qubits), repetitions, seed)
raise NotImplementedError()

def copy(self: TSelf, deep_copy_buffers: bool = True) -> TSelf:
"""Creates a copy of the object.
Expand All @@ -165,6 +178,10 @@ def copy(self: TSelf, deep_copy_buffers: bool = True) -> TSelf:
A copied instance.
"""
args = copy.copy(self)
args._classical_data = self._classical_data.copy()
if self._state is not None:
args._state = self._state.copy(deep_copy_buffers=deep_copy_buffers)
return args
if 'deep_copy_buffers' in inspect.signature(self._on_copy).parameters:
self._on_copy(args, deep_copy_buffers)
else:
Expand All @@ -176,7 +193,6 @@ def copy(self: TSelf, deep_copy_buffers: bool = True) -> TSelf:
DeprecationWarning,
)
self._on_copy(args)
args._classical_data = self._classical_data.copy()
return args

def _on_copy(self: TSelf, args: TSelf, deep_copy_buffers: bool = True):
Expand All @@ -190,7 +206,10 @@ def create_merged_state(self: TSelf) -> TSelf:
def kronecker_product(self: TSelf, other: TSelf, *, inplace=False) -> TSelf:
"""Joins two state spaces together."""
args = self if inplace else copy.copy(self)
self._on_kronecker_product(other, args)
if self._state is not None and other._state is not None:
args._state = self._state.kron(other._state)
else:
self._on_kronecker_product(other, args)
args._set_qubits(self.qubits + other.qubits)
return args

Expand Down Expand Up @@ -225,15 +244,20 @@ def factor(
"""Splits two state spaces after a measurement or reset."""
extracted = copy.copy(self)
remainder = self if inplace else copy.copy(self)
self._on_factor(qubits, extracted, remainder, validate, atol)
if self._state is not None:
e, r = self._state.factor(self.get_axes(qubits), validate=validate, atol=atol)
extracted._state = e
remainder._state = r
else:
self._on_factor(qubits, extracted, remainder, validate, atol)
extracted._set_qubits(qubits)
remainder._set_qubits([q for q in self.qubits if q not in qubits])
return extracted, remainder

@property
def allows_factoring(self):
"""Subclasses that allow factorization should override this."""
return False
return self._state.supports_factor if self._state is not None else False

def _on_factor(
self: TSelf,
Expand Down Expand Up @@ -265,7 +289,10 @@ def transpose_to_qubit_order(
if len(self.qubits) != len(qubits) or set(qubits) != set(self.qubits):
raise ValueError(f'Qubits do not match. Existing: {self.qubits}, provided: {qubits}')
args = self if inplace else copy.copy(self)
self._on_transpose_to_qubit_order(qubits, args)
if self._state is not None:
args._state = self._state.reindex(self.get_axes(qubits))
else:
self._on_transpose_to_qubit_order(qubits, args)
args._set_qubits(qubits)
return args

Expand Down Expand Up @@ -356,7 +383,7 @@ def __iter__(self) -> Iterator[Optional['cirq.Qid']]:

@property
def can_represent_mixed_states(self) -> bool:
return False
return self._state.can_represent_mixed_states if self._state is not None else False


def strat_act_on_from_apply_decompose(
Expand Down
15 changes: 0 additions & 15 deletions cirq-core/cirq/sim/act_on_args_container_test.py
Expand Up @@ -41,25 +41,10 @@ def _act_on_fallback_(
) -> bool:
return True

def _on_copy(self, args):
pass

def _on_kronecker_product(self, other, target):
pass

def _on_transpose_to_qubit_order(self, qubits, target):
pass

def _on_factor(self, qubits, extracted, remainder, validate=True, atol=1e-07):
pass

@property
def allows_factoring(self):
return True

def sample(self, qubits, repetitions=1, seed=None):
pass


q0, q1, q2 = qs3 = cirq.LineQubit.range(3)
qs2 = cirq.LineQubit.range(2)
Expand Down