Skip to content

Commit

Permalink
Create qis subpackage and move some existing code to qis.states and q…
Browse files Browse the repository at this point in the history
…is.measures (#2808)

Moves the following things to `linalg/states.py` from `sim/wave_function.py` and `sim/density_matrix_utils.py`:
```
STATE_VECTOR_LIKE
bloch_vector_from_state_vector
density_matrix_from_state_vector
dirac_notation
to_valid_state_vector
validate_normalized_state
to_valid_density_matrix
von_neumann_entropy
```
This may not be their final resting place (see #2797 (comment)) but it certainly makes more sense than before, given that we have a file called `states.py`. Also, it has the benefit of removing some circular dependency import hacks.
  • Loading branch information
kevinsung committed Apr 29, 2020
1 parent e9b881a commit 33194ad
Show file tree
Hide file tree
Showing 27 changed files with 1,358 additions and 1,130 deletions.
26 changes: 15 additions & 11 deletions cirq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
protocols,
value,
linalg,
qis,
ops,
devices,
study,
Expand Down Expand Up @@ -119,9 +120,7 @@
diagonalize_real_symmetric_matrix,
dot,
expand_matrix_in_orthogonal_basis,
fidelity,
hilbert_schmidt_inner_product,
eye_tensor,
is_diagonal,
is_hermitian,
is_normal,
Expand All @@ -141,7 +140,6 @@
map_eigenvalues,
match_global_phase,
matrix_from_basis_coefficients,
one_hot,
partial_trace,
PAULI_BASIS,
scatter_plot_normalized_kak_interaction_coefficients,
Expand Down Expand Up @@ -298,21 +296,32 @@
two_qubit_matrix_to_operations,
)

from cirq.sim import (
from cirq.qis import (
bloch_vector_from_state_vector,
density_matrix_from_state_vector,
dirac_notation,
eye_tensor,
fidelity,
one_hot,
STATE_VECTOR_LIKE,
to_valid_density_matrix,
to_valid_state_vector,
validate_normalized_state,
von_neumann_entropy,
)

from cirq.sim import (
StabilizerStateChForm,
CIRCUIT_LIKE,
CliffordSimulator,
CliffordState,
CliffordSimulatorStepResult,
CliffordTableau,
CliffordTrialResult,
density_matrix_from_state_vector,
DensityMatrixSimulator,
DensityMatrixSimulatorState,
DensityMatrixStepResult,
DensityMatrixTrialResult,
dirac_notation,
measure_density_matrix,
measure_state_vector,
final_density_matrix,
Expand All @@ -329,13 +338,8 @@
SimulationTrialResult,
Simulator,
SparseSimulatorStep,
STATE_VECTOR_LIKE,
StateVectorMixin,
StepResult,
to_valid_density_matrix,
to_valid_state_vector,
validate_normalized_state,
von_neumann_entropy,
WaveFunctionSimulatorState,
WaveFunctionStepResult,
WaveFunctionTrialResult,
Expand Down
32 changes: 32 additions & 0 deletions cirq/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import functools
import warnings
from typing import Any, Callable, Optional, Dict, Tuple
from types import ModuleType

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -180,3 +181,34 @@ def decorated_func(*args, **kwargs) -> Any:
return decorated_func

return decorator


def wrap_module(module: ModuleType,
deprecated_attributes: Dict[str, Tuple[str, str]]):
"""Wrap a module with deprecated attributes.
Args:
module: The module to wrap.
deprecated_attributes: A dictionary from attribute name to pair of
strings, where the first string gives the version that the attribute
will be removed in, and the second string describes what the user
should do instead of accessing this deprecated attribute.
Returns:
Wrapped module with deprecated attributes.
"""

class Wrapped(ModuleType):

def __getattr__(self, name):
if name in deprecated_attributes:
deadline, fix = deprecated_attributes[name]
warnings.warn(
f'{name} was used but is deprecated.\n'
f'It will be removed in cirq {deadline}.\n'
f'{fix}\n',
DeprecationWarning,
stacklevel=2)
return getattr(module, name)

return Wrapped(module.__name__)
7 changes: 3 additions & 4 deletions cirq/circuits/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
import re
import numpy as np

from cirq import devices, linalg, ops, protocols
from cirq import devices, ops, protocols, qis
from cirq.circuits._bucket_priority_queue import BucketPriorityQueue
from cirq.circuits.insert_strategy import InsertStrategy
from cirq.circuits.text_diagram_drawer import TextDiagramDrawer
Expand Down Expand Up @@ -1493,7 +1493,7 @@ def unitary(self,
qid_shape = self.qid_shape(qubit_order=qs)
side_len = np.product(qid_shape, dtype=int)

state = linalg.eye_tensor(qid_shape, dtype=dtype)
state = qis.eye_tensor(qid_shape, dtype=dtype)

result = _apply_unitary_circuit(self, state, qs, dtype)
return result.reshape((side_len, side_len))
Expand Down Expand Up @@ -1573,8 +1573,7 @@ def final_wavefunction(
qid_shape = self.qid_shape(qubit_order=qs)
state_len = np.product(qid_shape, dtype=int)

from cirq import sim
state = sim.to_valid_state_vector(initial_state,
state = qis.to_valid_state_vector(initial_state,
qid_shape=qid_shape,
dtype=dtype).reshape(qid_shape)
result = _apply_unitary_circuit(self, state, qs, dtype)
Expand Down
3 changes: 0 additions & 3 deletions cirq/linalg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,6 @@
diagonalize_real_symmetric_matrix,
)

from cirq.linalg.distance_measures import (
fidelity,)

from cirq.linalg.operator_spaces import (
expand_matrix_in_orthogonal_basis,
hilbert_schmidt_inner_product,
Expand Down
4 changes: 4 additions & 0 deletions cirq/linalg/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@

import numpy as np

from cirq._compat import deprecated


@deprecated(deadline='v0.9', fix='Use cirq.one_hot instead.')
def one_hot(*,
index: Union[None, int, Sequence[int]] = None,
shape: Union[int, Sequence[int]],
Expand All @@ -42,6 +45,7 @@ def one_hot(*,
return result


@deprecated(deadline='v0.9', fix='Use cirq.eye_tensor instead.')
def eye_tensor(
half_shape: Tuple[int, ...],
*, # Force keyword args
Expand Down
54 changes: 12 additions & 42 deletions cirq/linalg/states_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,49 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np

import cirq
from cirq._compat_test import capture_logging


def test_one_hot():
result = cirq.one_hot(shape=4, dtype=np.int32)
assert result.dtype == np.int32
np.testing.assert_array_equal(result, [1, 0, 0, 0])

np.testing.assert_array_equal(
cirq.one_hot(shape=[2, 3], dtype=np.complex64), [[1, 0, 0], [0, 0, 0]])

np.testing.assert_array_equal(
cirq.one_hot(shape=[2, 3], dtype=np.complex64, index=(0, 2)),
[[0, 0, 1], [0, 0, 0]])

np.testing.assert_array_equal(
cirq.one_hot(shape=5, dtype=np.complex128, index=3), [0, 0, 0, 1, 0])

def test_deprecated():
with capture_logging() as log:
_ = cirq.linalg.eye_tensor((1,), dtype=float)
assert len(log) == 1
assert "cirq.eye_tensor" in log[0].getMessage()
assert "deprecated" in log[0].getMessage()

def test_eye_tensor():
assert np.all(cirq.eye_tensor((), dtype=int) == np.array(1))
assert np.all(cirq.eye_tensor((1,), dtype=int) == np.array([[1]]))
assert np.all(cirq.eye_tensor((2,), dtype=int) == np.array([
[1, 0],
[0, 1]])) # yapf: disable
assert np.all(cirq.eye_tensor((2, 2), dtype=int) == np.array([
[[[1, 0], [0, 0]],
[[0, 1], [0, 0]]],
[[[0, 0], [1, 0]],
[[0, 0], [0, 1]]]])) # yapf: disable
assert np.all(cirq.eye_tensor((2, 3), dtype=int) == np.array([
[[[1, 0, 0], [0, 0, 0]],
[[0, 1, 0], [0, 0, 0]],
[[0, 0, 1], [0, 0, 0]]],
[[[0, 0, 0], [1, 0, 0]],
[[0, 0, 0], [0, 1, 0]],
[[0, 0, 0], [0, 0, 1]]]])) # yapf: disable
assert np.all(cirq.eye_tensor((3, 2), dtype=int) == np.array([
[[[1, 0], [0, 0], [0, 0]],
[[0, 1], [0, 0], [0, 0]]],
[[[0, 0], [1, 0], [0, 0]],
[[0, 0], [0, 1], [0, 0]]],
[[[0, 0], [0, 0], [1, 0]],
[[0, 0], [0, 0], [0, 1]]]])) # yapf: disable
with capture_logging() as log:
_ = cirq.linalg.one_hot(shape=(1,), dtype=float)
assert len(log) == 1
assert "cirq.one_hot" in log[0].getMessage()
assert "deprecated" in log[0].getMessage()
4 changes: 2 additions & 2 deletions cirq/ops/controlled_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import itertools
import numpy as np

from cirq import protocols, linalg, value
from cirq import protocols, qis, value
from cirq.ops import raw_types, gate_operation, controlled_gate
from cirq.type_workarounds import NotImplementedType

Expand Down Expand Up @@ -129,7 +129,7 @@ def _unitary_(self) -> Union[np.ndarray, NotImplementedType]:
return NotImplemented
qid_shape = protocols.qid_shape(self)
sub_n = len(qid_shape) - len(self.controls)
tensor = linalg.eye_tensor(qid_shape, dtype=sub_matrix.dtype)
tensor = qis.eye_tensor(qid_shape, dtype=sub_matrix.dtype)
sub_tensor = sub_matrix.reshape(qid_shape[len(self.controls):] * 2)
for control_vals in itertools.product(*self.control_values):
active = (*(v for v in control_vals), *(slice(None),) * sub_n) * 2
Expand Down
22 changes: 9 additions & 13 deletions cirq/ops/linear_combinations.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import numpy as np

from cirq import protocols, value
from cirq import protocols, qis, value
from cirq._doc import document
from cirq.linalg import operator_spaces
from cirq.ops import identity, raw_types, pauli_gates, pauli_string
Expand Down Expand Up @@ -392,12 +392,10 @@ def expectation_from_wavefunction(self,
"with shape `(2 ** n,)` or `(2, ..., 2)`.")

if check_preconditions:
# HACK: avoid circular import
from cirq.sim.wave_function import validate_normalized_state
validate_normalized_state(state=state,
qid_shape=(2,) * num_qubits,
dtype=state.dtype,
atol=atol)
qis.validate_normalized_state(state=state,
qid_shape=(2,) * num_qubits,
dtype=state.dtype,
atol=atol)
return sum(
p._expectation_from_wavefunction_no_validation(state, qubit_map)
for p in self)
Expand Down Expand Up @@ -445,14 +443,12 @@ def expectation_from_density_matrix(self,
"with shape `(2 ** n, 2 ** n)` or `(2, ..., 2)`.")

if check_preconditions:
# HACK: avoid circular import
from cirq.sim.density_matrix_utils import to_valid_density_matrix
# Do not enforce reshaping if the state all axes are dimension 2.
_ = to_valid_density_matrix(density_matrix_rep=state.reshape(
_ = qis.to_valid_density_matrix(density_matrix_rep=state.reshape(
dim, dim),
num_qubits=num_qubits,
dtype=state.dtype,
atol=atol)
num_qubits=num_qubits,
dtype=state.dtype,
atol=atol)
return sum(
p._expectation_from_density_matrix_no_validation(state, qubit_map)
for p in self)
Expand Down
22 changes: 9 additions & 13 deletions cirq/ops/pauli_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

import numpy as np

from cirq import value, protocols, linalg
from cirq import value, protocols, linalg, qis
from cirq._doc import document
from cirq.ops import (
clifford_gate,
Expand Down Expand Up @@ -399,12 +399,10 @@ def expectation_from_wavefunction(self,

_validate_qubit_mapping(qubit_map, self.qubits, num_qubits)
if check_preconditions:
# HACK: avoid circular import
from cirq.sim.wave_function import validate_normalized_state
validate_normalized_state(state=state,
qid_shape=(2,) * num_qubits,
dtype=state.dtype,
atol=atol)
qis.validate_normalized_state(state=state,
qid_shape=(2,) * num_qubits,
dtype=state.dtype,
atol=atol)
return self._expectation_from_wavefunction_no_validation(
state, qubit_map)

Expand Down Expand Up @@ -499,14 +497,12 @@ def expectation_from_density_matrix(self,

_validate_qubit_mapping(qubit_map, self.qubits, num_qubits)
if check_preconditions:
# HACK: avoid circular import
from cirq.sim.density_matrix_utils import to_valid_density_matrix
# Do not enforce reshaping if the state all axes are dimension 2.
_ = to_valid_density_matrix(density_matrix_rep=state.reshape(
_ = qis.to_valid_density_matrix(density_matrix_rep=state.reshape(
dim, dim),
num_qubits=num_qubits,
dtype=state.dtype,
atol=atol)
num_qubits=num_qubits,
dtype=state.dtype,
atol=atol)
return self._expectation_from_density_matrix_no_validation(
state, qubit_map)

Expand Down
8 changes: 4 additions & 4 deletions cirq/protocols/apply_unitary_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import numpy as np
from typing_extensions import Protocol

from cirq import linalg
from cirq import linalg, qis
from cirq.protocols import qid_shape_protocol
from cirq.protocols.decompose_protocol import (
_try_decompose_into_operations_and_qubits,)
Expand Down Expand Up @@ -109,9 +109,9 @@ def default(num_qubits: Optional[int] = None,
qid_shape = (2,) * num_qubits
qid_shape = cast(Tuple[int, ...], qid_shape) # Satisfy mypy
num_qubits = len(qid_shape)
state = linalg.one_hot(index=(0,) * num_qubits,
shape=qid_shape,
dtype=np.complex128)
state = qis.one_hot(index=(0,) * num_qubits,
shape=qid_shape,
dtype=np.complex128)
return ApplyUnitaryArgs(state, np.empty_like(state), range(num_qubits))

def with_axes_transposed_to_start(self) -> 'ApplyUnitaryArgs':
Expand Down
4 changes: 2 additions & 2 deletions cirq/protocols/has_unitary_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from cirq.protocols.apply_unitary_protocol import ApplyUnitaryArgs
from cirq.protocols.decompose_protocol import (
_try_decompose_into_operations_and_qubits,)
from cirq import linalg
from cirq import qis

if TYPE_CHECKING:
import cirq
Expand Down Expand Up @@ -150,7 +150,7 @@ def _strat_has_unitary_from_apply_unitary(val: Any) -> Optional[bool]:
val_qid_shape = qid_shape_protocol.qid_shape(val, None)
if val_qid_shape is None:
return None
state = linalg.one_hot(shape=val_qid_shape, dtype=np.complex64)
state = qis.one_hot(shape=val_qid_shape, dtype=np.complex64)
buffer = np.empty_like(state)
result = method(ApplyUnitaryArgs(state, buffer, range(len(val_qid_shape))))
if result is NotImplemented:
Expand Down

0 comments on commit 33194ad

Please sign in to comment.