Skip to content

Commit

Permalink
Update asteroid version (#2976)
Browse files Browse the repository at this point in the history
Previously we were blocked on upgrading asteroid due to a bug
pylint-dev/astroid#650
However the bug now says this is fixed, so verifying that our tests pass.

...and it turns out they don't pass.

Fixes:
* `ndarray.reshape` technically takes lists of arguments like `a.reshape(1, 3, 4)`.  However sometimes lint gets mad about this.  Fixed this and changed other uses in this test to use same pattern for consistency.
* Found a missing assertion right nearby and fixed this.  Needed to fix order bug in test.
* Didn't fix a lint error about the result of getattr not being callable.  I don't think there is any fix besides this as there are lots of bugs about pylint messing this up.  So added a disable.
  • Loading branch information
dabacon committed May 6, 2020
1 parent 78fc4c2 commit 2899c1f
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 20 deletions.
30 changes: 16 additions & 14 deletions cirq/linalg/transformations_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,7 @@ def test_subwavefunction():
def test_subwavefunction_bad_subset():
a = cirq.testing.random_superposition(4)
b = cirq.testing.random_superposition(8)
state = np.kron(a, b).reshape(2, 2, 2, 2, 2)
state = np.kron(a, b).reshape((2, 2, 2, 2, 2))

for q1 in range(5):
assert cirq.subwavefunction(state, [q1], default=None,
Expand All @@ -462,7 +462,7 @@ def test_subwavefunction_bad_subset():
def test_subwavefunction_non_kron():
a = np.array([1, 0, 0, 0, 0, 0, 0, 1]) / np.sqrt(2)
b = np.array([1, 1]) / np.sqrt(2)
state = np.kron(a, b).reshape(2, 2, 2, 2)
state = np.kron(a, b).reshape((2, 2, 2, 2))

for q1 in [0, 1, 2]:
assert cirq.subwavefunction(state, [q1], default=None,
Expand All @@ -487,7 +487,8 @@ def test_subwavefunction_invalid_inputs():

# State shape does not conform to input requirements.
with pytest.raises(ValueError, match='shaped'):
cirq.subwavefunction(np.arange(16).reshape(2, 4, 2), [1, 2], atol=1e-8)
cirq.subwavefunction(np.arange(16).reshape((2, 4, 2)), [1, 2],
atol=1e-8)
with pytest.raises(ValueError, match='shaped'):
cirq.subwavefunction(np.arange(16).reshape((16, 1)), [1, 2], atol=1e-8)

Expand All @@ -499,7 +500,7 @@ def test_subwavefunction_invalid_inputs():
with pytest.raises(ValueError, match='2, 2'):
cirq.subwavefunction(state, [1, 2, 2], atol=1e-8)

state = np.array([1, 0, 0, 0]).reshape(2, 2)
state = np.array([1, 0, 0, 0]).reshape((2, 2))
with pytest.raises(ValueError, match='invalid'):
cirq.subwavefunction(state, [5], atol=1e-8)
with pytest.raises(ValueError, match='invalid'):
Expand All @@ -512,7 +513,7 @@ def test_wavefunction_partial_trace_as_mixture_invalid_input():
cirq.wavefunction_partial_trace_as_mixture(np.arange(7), [1, 2],
atol=1e-8)

bad_shape = np.arange(16).reshape(2, 4, 2)
bad_shape = np.arange(16).reshape((2, 4, 2))
with pytest.raises(ValueError, match='shaped'):
cirq.wavefunction_partial_trace_as_mixture(bad_shape, [1], atol=1e-8)
bad_shape = np.arange(16).reshape((16, 1))
Expand All @@ -526,7 +527,7 @@ def test_wavefunction_partial_trace_as_mixture_invalid_input():
with pytest.raises(ValueError, match='2, 2'):
cirq.wavefunction_partial_trace_as_mixture(state, [1, 2, 2], atol=1e-8)

state = np.array([1, 0, 0, 0]).reshape(2, 2)
state = np.array([1, 0, 0, 0]).reshape((2, 2))
with pytest.raises(ValueError, match='invalid'):
cirq.wavefunction_partial_trace_as_mixture(state, [5], atol=1e-8)
with pytest.raises(ValueError, match='invalid'):
Expand Down Expand Up @@ -560,15 +561,15 @@ def test_wavefunction_partial_trace_as_mixture_pure_result():
assert mixtures_equal(
cirq.wavefunction_partial_trace_as_mixture(state, [0, 1, 2, 3, 4],
atol=1e-8),
((1.0, np.kron(a, b).reshape(2, 2, 2, 2, 2)),))
((1.0, np.kron(a, b).reshape((2, 2, 2, 2, 2))),))
assert mixtures_equal(
cirq.wavefunction_partial_trace_as_mixture(state, [0, 1, 5, 6, 7, 8],
atol=1e-8),
((1.0, np.kron(a, c).reshape(2, 2, 2, 2, 2, 2)),))
((1.0, np.kron(a, c).reshape((2, 2, 2, 2, 2, 2))),))
assert mixtures_equal(
cirq.wavefunction_partial_trace_as_mixture(state, [2, 3, 4, 5, 6, 7, 8],
atol=1e-8),
((1.0, np.kron(b, c).reshape(2, 2, 2, 2, 2, 2, 2)),))
((1.0, np.kron(b, c).reshape((2, 2, 2, 2, 2, 2, 2))),))

# Shapes of states in the output mixture conform to the input's shape.
state = state.reshape(2**9)
Expand Down Expand Up @@ -603,21 +604,22 @@ def test_wavefunction_partial_trace_as_mixture_mixed_result():
atol=1e-8)
assert mixtures_equal(mixture, truth)

state = np.array([0, 1, 1, 0, 1, 0, 0, 0]).reshape(2, 2, 2) / np.sqrt(3)
truth = ((2 / 3, np.array([1.0, 0.0])), (1 / 3, np.array([0.0, 1.0])))
state = np.array([0, 1, 1, 0, 1, 0, 0, 0]).reshape((2, 2, 2)) / np.sqrt(3)
truth = ((1 / 3, np.array([0.0, 1.0])), (2 / 3, np.array([1.0, 0.0])))
for q1 in [0, 1, 2]:
mixture = cirq.wavefunction_partial_trace_as_mixture(state, [q1],
atol=1e-8)
assert mixtures_equal(mixture, truth)

state = np.array([1, 0, 0, 0, 0, 0, 0, 1]).reshape(2, 2, 2) / np.sqrt(2)
state = np.array([1, 0, 0, 0, 0, 0, 0, 1]).reshape((2, 2, 2)) / np.sqrt(2)
truth = ((0.5, np.array([1, 0])), (0.5, np.array([0, 1])))
for q1 in [0, 1, 2]:
mixture = cirq.wavefunction_partial_trace_as_mixture(state, [q1],
atol=1e-8)
assert mixtures_equal(mixture, truth)

truth = ((0.5, np.array([1, 0, 0, 0]).reshape(2, 2)),
(0.5, np.array([0, 0, 0, 1]).reshape(2, 2)))
truth = ((0.5, np.array([1, 0, 0, 0]).reshape(
(2, 2))), (0.5, np.array([0, 0, 0, 1]).reshape((2, 2))))
for (q1, q2) in [(0, 1), (0, 2), (1, 2)]:
mixture = cirq.wavefunction_partial_trace_as_mixture(state, [q1, q2],
atol=1e-8)
Expand Down
2 changes: 2 additions & 0 deletions cirq/protocols/inverse_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@ def inverse(val: Any, default: Any = RaiseTypeErrorIfNotProvided) -> Any:

# Check if object defines an inverse via __pow__.
raiser = getattr(val, '__pow__', None)

# pylint: disable=not-callable
result = NotImplemented if raiser is None else raiser(-1)
if result is not NotImplemented:
return result
Expand Down
6 changes: 3 additions & 3 deletions cirq/sim/density_matrix_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ def test_measure_density_matrix_not_square():
with pytest.raises(ValueError, match='not square'):
cirq.measure_density_matrix(np.array([1, 0, 0]), [1])
with pytest.raises(ValueError, match='not square'):
cirq.measure_density_matrix(np.array([1, 0, 0, 0]).reshape(2, 1, 2),
cirq.measure_density_matrix(np.array([1, 0, 0, 0]).reshape((2, 1, 2)),
[1],
qid_shape=(2, 1))

Expand All @@ -321,8 +321,8 @@ def test_measure_density_matrix_higher_powers_of_two():

def test_measure_density_matrix_tensor_different_left_right_shape():
with pytest.raises(ValueError, match='not equal'):
cirq.measure_density_matrix(np.array([1, 0, 0, 0]).reshape(2, 2, 1, 1),
[1],
cirq.measure_density_matrix(np.array([1, 0, 0, 0]).reshape(
(2, 2, 1, 1)), [1],
qid_shape=(2, 1))


Expand Down
3 changes: 0 additions & 3 deletions dev_tools/conf/pip-list-dev-tools.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,3 @@ sphinx-markdown-tables
nbsphinx
ipython
ipykernel

# Need to pin pylint's parser to 2.1 instead of 2.2 until https://github.com/PyCQA/astroid/issues/650 is fixed
astroid~=2.1.0

0 comments on commit 2899c1f

Please sign in to comment.