Skip to content

Commit

Permalink
Fix most numpy type errors in cirq/linalg (#4000)
Browse files Browse the repository at this point in the history
Using `check/mypy --next | grep cirq/linalg` this fixes all the
problems.

#3767
  • Loading branch information
mpharrigan committed Apr 12, 2021
1 parent 845836a commit 9c710fb
Show file tree
Hide file tree
Showing 8 changed files with 59 additions and 43 deletions.
23 changes: 13 additions & 10 deletions cirq/linalg/combinators.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,15 @@
"""Utility methods for combining matrices."""

import functools
from typing import Union, Type
from typing import Union, TYPE_CHECKING

import numpy as np

from cirq._doc import document

if TYPE_CHECKING:
from numpy.typing import DTypeLike, ArrayLike


def kron(*factors: Union[np.ndarray, complex, float], shape_len: int = 2) -> np.ndarray:
"""Computes the kronecker product of a sequence of values.
Expand Down Expand Up @@ -104,7 +107,7 @@ def kron_with_controls(*factors: Union[np.ndarray, complex, float]) -> np.ndarra
return product


def dot(*values: Union[float, complex, np.ndarray]) -> Union[float, complex, np.ndarray]:
def dot(*values: 'ArrayLike') -> np.ndarray:
"""Computes the dot/matrix product of a sequence of values.
Performs the computation in serial order without regard to the matrix
Expand All @@ -117,20 +120,20 @@ def dot(*values: Union[float, complex, np.ndarray]) -> Union[float, complex, np.
Returns:
The resulting value or matrix.
"""
if len(values) == 0:
raise ValueError("cirq.dot must be called with arguments")

if len(values) == 1:
# note: it's important that we copy input arrays.
return np.array(values[0])

if len(values) <= 1:
if len(values) == 0:
raise ValueError("cirq.dot must be called with arguments")
if isinstance(values[0], np.ndarray):
return np.array(values[0])
return values[0]
result = values[0]
result = np.asarray(values[0])
for value in values[1:]:
result = np.dot(result, value)
return result


def _merge_dtypes(dtype1: Type[np.number], dtype2: Type[np.number]) -> Type[np.number]:
def _merge_dtypes(dtype1: 'DTypeLike', dtype2: 'DTypeLike') -> np.dtype:
return (np.zeros(0, dtype1) + np.zeros(0, dtype2)).dtype


Expand Down
10 changes: 6 additions & 4 deletions cirq/linalg/decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def _group_similar(items: List[T], comparer: Callable[[T, T], bool]) -> List[Lis

def unitary_eig(
matrix: np.ndarray, check_preconditions: bool = True, atol: float = 1e-8
) -> Tuple[np.array, np.ndarray]:
) -> Tuple[np.ndarray, np.ndarray]:
"""Gives the guaranteed unitary eigendecomposition of a normal matrix.
All hermitian and unitary matrices are normal matrices. This method was
Expand Down Expand Up @@ -337,7 +337,7 @@ def _unitary_(self) -> np.ndarray:

def __str__(self) -> str:
axis_terms = '+'.join(
'{:.3g}*{}'.format(e, a) if e < 0.9999 else a
f'{e:.3g}*{a}' if e < 0.9999 else a
for e, a in zip(self.axis, ['X', 'Y', 'Z'])
if abs(e) >= 1e-8
).replace('+-', '-')
Expand Down Expand Up @@ -648,11 +648,13 @@ def coord_transform(

# parse input and extract KAK vector
if not isinstance(interactions, np.ndarray):
interactions = [
interactions_extracted: List[np.ndarray] = [
a if isinstance(a, np.ndarray) else protocols.unitary(a) for a in interactions
]
else:
interactions_extracted = [interactions]

points = kak_vector(interactions) * 4 / np.pi
points = kak_vector(interactions_extracted) * 4 / np.pi

ax.scatter(*coord_transform(points), **kwargs)
ax.set_xlim(0, +1)
Expand Down
3 changes: 3 additions & 0 deletions cirq/linalg/decompositions_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,6 +568,9 @@ def test_scatter_plot_normalized_kak_interaction_coefficients():
)
assert ax2 is ax

ax3 = cirq.scatter_plot_normalized_kak_interaction_coefficients(data[1], ax=ax)
assert ax3 is ax


def _vector_kron(first: np.ndarray, second: np.ndarray) -> np.ndarray:
"""Vectorized implementation of kron for square matrices."""
Expand Down
12 changes: 6 additions & 6 deletions cirq/linalg/diagonalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,9 +183,9 @@ def bidiagonalize_real_matrix_pair_with_symmetric_products(
raise ValueError('mat1 must be real.')
if np.any(np.imag(mat2) != 0):
raise ValueError('mat2 must be real.')
if not predicates.is_hermitian(mat1.dot(mat2.T), rtol=rtol, atol=atol):
if not predicates.is_hermitian(np.dot(mat1, mat2.T), rtol=rtol, atol=atol):
raise ValueError('mat1 @ mat2.T must be symmetric.')
if not predicates.is_hermitian(mat1.T.dot(mat2), rtol=rtol, atol=atol):
if not predicates.is_hermitian(np.dot(mat1.T, mat2), rtol=rtol, atol=atol):
raise ValueError('mat1.T @ mat2 must be symmetric.')

# Use SVD to bi-diagonalize the first matrix.
Expand All @@ -200,7 +200,7 @@ def bidiagonalize_real_matrix_pair_with_symmetric_products(
base_diag = base_diag[:rank, :rank]

# Try diagonalizing the second matrix with the same factors as the first.
semi_corrected = base_left.T.dot(np.real(mat2)).dot(base_right.T)
semi_corrected = combinators.dot(base_left.T, np.real(mat2), base_right.T)

# Fix up the part of the second matrix's diagonalization that's matched
# against non-zero diagonal entries in the first matrix's diagonalization
Expand All @@ -218,15 +218,15 @@ def bidiagonalize_real_matrix_pair_with_symmetric_products(
# Merge the fixup factors into the initial diagonalization.
left_adjust = combinators.block_diag(overlap_adjust, extra_left_adjust)
right_adjust = combinators.block_diag(overlap_adjust.T, extra_right_adjust)
left = left_adjust.T.dot(base_left.T)
right = base_right.T.dot(right_adjust.T)
left = np.dot(left_adjust.T, base_left.T)
right = np.dot(base_right.T, right_adjust.T)

return left, right


def bidiagonalize_unitary_with_special_orthogonals(
mat: np.ndarray, *, rtol: float = 1e-5, atol: float = 1e-8, check_preconditions: bool = True
) -> Tuple[np.ndarray, np.array, np.ndarray]:
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Finds orthogonal matrices L, R such that L @ matrix @ R is diagonal.
Args:
Expand Down
11 changes: 5 additions & 6 deletions cirq/linalg/operator_spaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.

"""Utilities for manipulating linear operators as elements of vector space."""

from typing import Dict, Tuple

import numpy as np
Expand All @@ -32,7 +31,7 @@

def kron_bases(*bases: Dict[str, np.ndarray], repeat: int = 1) -> Dict[str, np.ndarray]:
"""Creates tensor product of bases."""
product_basis = {'': 1}
product_basis = {'': np.ones(1)}
for basis in bases * repeat:
product_basis = {
name1 + name2: np.kron(matrix1, matrix2)
Expand Down Expand Up @@ -98,14 +97,14 @@ def pow_pauli_combination(
if exponent == 0:
return 1, 0, 0, 0

v = np.sqrt(ax * ax + ay * ay + az * az)
s = np.power(ai + v, exponent)
t = np.power(ai - v, exponent)
v = np.sqrt(ax * ax + ay * ay + az * az).item()
s = (ai + v) ** exponent
t = (ai - v) ** exponent

ci = (s + t) / 2
if s == t:
# v is near zero, only one term in binomial expansion survives
cxyz = exponent * np.power(ai, exponent - 1)
cxyz = exponent * ai ** (exponent - 1)
else:
# v is non-zero, account for all terms of binomial expansion
cxyz = (s - t) / 2
Expand Down
4 changes: 2 additions & 2 deletions cirq/linalg/predicates.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from cirq import value


def is_diagonal(matrix: np.ndarray, *, atol: float = 1e-8) -> bool:
def is_diagonal(matrix: np.ndarray, *, atol: float = 1e-8) -> np.bool_:
"""Determines if a matrix is a approximately diagonal.
A matrix is diagonal if i!=j implies m[i,j]==0.
Expand Down Expand Up @@ -72,7 +72,7 @@ def is_orthogonal(matrix: np.ndarray, *, rtol: float = 1e-5, atol: float = 1e-8)
"""
return (
matrix.shape[0] == matrix.shape[1]
and np.all(np.imag(matrix) == 0)
and np.all(np.imag(matrix) == 0).item()
and np.allclose(matrix.dot(matrix.T), np.eye(matrix.shape[0]), rtol=rtol, atol=atol)
)

Expand Down
11 changes: 6 additions & 5 deletions cirq/linalg/tolerance.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,15 @@

"""Utility for testing approximate equality of matrices and scalars within
tolerances."""
from typing import Union, Iterable
from typing import Union, Iterable, TYPE_CHECKING

import numpy as np

if TYPE_CHECKING:
from numpy.typing import ArrayLike

def all_near_zero(
a: Union[float, complex, Iterable[float], np.ndarray], *, atol: float = 1e-8
) -> bool:

def all_near_zero(a: 'ArrayLike', *, atol: float = 1e-8) -> np.bool_:
"""Checks if the tensor's elements are all near zero.
Args:
Expand All @@ -33,7 +34,7 @@ def all_near_zero(

def all_near_zero_mod(
a: Union[float, complex, Iterable[float], np.ndarray], period: float, *, atol: float = 1e-8
) -> bool:
) -> np.bool_:
"""Checks if the tensor's elements are all near multiples of the period.
Args:
Expand Down
28 changes: 18 additions & 10 deletions cirq/linalg/transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

"""Utility methods for transforming matrices or vectors."""

from typing import Tuple, Optional, Sequence, List, Union, TypeVar
from typing import Tuple, Optional, Sequence, List, Union

import numpy as np

Expand All @@ -26,9 +26,7 @@
# of type np.ndarray to ensure the method has the correct type signature in that
# case. It is checked for using `is`, so it won't have a false positive if the
# user provides a different np.array([]) value.
RaiseValueErrorIfNotProvided = np.array([]) # type: np.ndarray

TDefault = TypeVar('TDefault')
RaiseValueErrorIfNotProvided: np.ndarray = np.array([])


def reflection_matrix_pow(reflection_matrix: np.ndarray, exponent: float):
Expand Down Expand Up @@ -326,6 +324,10 @@ def partial_trace(tensor: np.ndarray, keep_indices: List[int]) -> np.ndarray:
return np.einsum(tensor, left_indices + right_indices)


class EntangledStateError(ValueError):
"""Raised when a product state is expected, but an entangled state is provided."""


def partial_trace_of_state_vector_as_mixture(
state_vector: np.ndarray, keep_indices: List[int], *, atol: Union[int, float] = 1e-8
) -> Tuple[Tuple[float, np.ndarray], ...]:
Expand Down Expand Up @@ -357,9 +359,13 @@ def partial_trace_of_state_vector_as_mixture(
"""

# Attempt to do efficient state factoring.
state = sub_state_vector(state_vector, keep_indices, default=None, atol=atol)
if state is not None:
try:
state = sub_state_vector(
state_vector, keep_indices, default=RaiseValueErrorIfNotProvided, atol=atol
)
return ((1.0, state),)
except EntangledStateError:
pass

# Fall back to a (non-unique) mixture representation.
keep_dims = 1 << len(keep_indices)
Expand All @@ -382,7 +388,7 @@ def sub_state_vector(
state_vector: np.ndarray,
keep_indices: List[int],
*,
default: TDefault = RaiseValueErrorIfNotProvided,
default: np.ndarray = RaiseValueErrorIfNotProvided,
atol: Union[int, float] = 1e-8,
) -> np.ndarray:
r"""Attempts to factor a state vector into two parts and return one of them.
Expand Down Expand Up @@ -424,8 +430,10 @@ def sub_state_vector(
Raises:
ValueError: if the `state_vector` is not of the correct shape or the
indices are not a valid subset of the input `state_vector`'s indices, or
the result of factoring is not a pure state.
indices are not a valid subset of the input `state_vector`'s indices
EntangledStateError: If the result of factoring is not a pure state and
`default` is not provided.
"""

if not np.log2(state_vector.size).is_integer():
Expand Down Expand Up @@ -471,7 +479,7 @@ def sub_state_vector(
if default is not RaiseValueErrorIfNotProvided:
return default

raise ValueError(
raise EntangledStateError(
"Input state vector could not be factored into pure state over "
"indices {}".format(keep_indices)
)
Expand Down

0 comments on commit 9c710fb

Please sign in to comment.