Skip to content

Commit

Permalink
Handle shape mismatch for cirq.dirac_notation (#6179)
Browse files Browse the repository at this point in the history
Throwing a value error when there is a shape mismatch between state_vector and qid_shape. It also throws an error when qid_shape is not mentioned and length of state_vector is not a power of 2.

Fixes #6165
  • Loading branch information
jeeva2812 committed Jul 7, 2023
1 parent 9dff011 commit a041ef8
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 1 deletion.
13 changes: 12 additions & 1 deletion cirq-core/cirq/qis/states.py
Expand Up @@ -714,10 +714,22 @@ def dirac_notation(
Returns:
A pretty string consisting of a sum of computational basis kets
and non-zero floats of the specified accuracy.
Raises:
ValueError: If there is a shape mismatch between state_vector and qid_shape.
Otherwise, when qid_shape is not mentioned and length of state_vector
is not a power of 2.
"""
if qid_shape is None:
qid_shape = (2,) * (len(state_vector).bit_length() - 1)

if len(state_vector) != np.prod(qid_shape, dtype=np.int64):
raise ValueError(
'state_vector has incorrect size. Expected {} but was {}.'.format(
np.prod(qid_shape, dtype=np.int64), len(state_vector)
)
)

digit_separator = '' if max(qid_shape, default=0) < 10 else ','
perm_list = [
digit_separator.join(seq)
Expand Down Expand Up @@ -821,7 +833,6 @@ def to_valid_state_vector(
def _qudit_values_to_state_tensor(
*, state_vector: np.ndarray, qid_shape: Tuple[int, ...], dtype: Optional['DTypeLike']
) -> np.ndarray:

for i in range(len(qid_shape)):
s = state_vector[i]
q = qid_shape[i]
Expand Down
7 changes: 7 additions & 0 deletions cirq-core/cirq/qis/states_test.py
Expand Up @@ -475,6 +475,13 @@ def test_dirac_notation_precision():
assert_dirac_notation_python([sqrt, sqrt], "0.707|0⟩ + 0.707|1⟩", decimals=3)


def test_dirac_notation_invalid():
with pytest.raises(ValueError, match='state_vector has incorrect size'):
_ = cirq.dirac_notation([0.0, 0.0, 1.0])
with pytest.raises(ValueError, match='state_vector has incorrect size'):
_ = cirq.dirac_notation([1.0, 1.0], qid_shape=(3,))


def test_to_valid_state_vector():
with pytest.raises(ValueError, match='Computational basis state is out of range'):
cirq.to_valid_state_vector(2, 1)
Expand Down

0 comments on commit a041ef8

Please sign in to comment.