Skip to content

Commit

Permalink
Defensively copy state vector in simulator (#3115)
Browse files Browse the repository at this point in the history
  • Loading branch information
dabacon committed Jul 6, 2020
1 parent a610aab commit a06be71
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 3 deletions.
12 changes: 10 additions & 2 deletions cirq/sim/sparse_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ def _simulator_state(self
return state_vector_simulator.StateVectorSimulatorState(
qubit_map=self.qubit_map, state_vector=self._state_vector)

def state_vector(self):
def state_vector(self, copy: bool = True):
"""Return the state vector at this point in the computation.
The state is returned in the computational basis with these basis
Expand All @@ -315,8 +315,16 @@ def state_vector(self):
| 5 | 1 | 0 | 1 |
| 6 | 1 | 1 | 0 |
| 7 | 1 | 1 | 1 |
Args:
copy: If True, then the returned state is a copy of the state
vector. If False, then the state vector is not copied,
potentially saving memory. If one only needs to read derived
parameters from the state vector and store then using False
can speed up simulation by eliminating a memory copy.
"""
return self._simulator_state().state_vector
vector = self._simulator_state().state_vector
return vector.copy() if copy else vector

def set_state_vector(self, state: 'cirq.STATE_VECTOR_LIKE'):
update_state = qis.to_valid_state_vector(state,
Expand Down
44 changes: 44 additions & 0 deletions cirq/sim/sparse_simulator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
import random
from unittest import mock
import numpy as np
Expand Down Expand Up @@ -963,3 +964,46 @@ def test_separated_measurements():
])
sample = cirq.Simulator().sample(c, repetitions=10)
np.testing.assert_array_equal(sample['zero'].values, [0] * 10)


def test_state_vector_copy():
sim = cirq.Simulator()

class InplaceGate(cirq.SingleQubitGate):
"""A gate that modifies the target tensor in place, multiply by -1."""

def _apply_unitary_(self, args):
args.target_tensor *= -1.0
return args.target_tensor

q = cirq.LineQubit(0)
circuit = cirq.Circuit(InplaceGate()(q), InplaceGate()(q))

vectors = []
for step in sim.simulate_moment_steps(circuit):
vectors.append(step.state_vector(copy=True))
for x, y in itertools.combinations(vectors, 2):
assert not np.shares_memory(x, y)

# If the state vector is not copied, then applying second InplaceGate
# causes old state to be modified.
vectors = []
copy_of_vectors = []
for step in sim.simulate_moment_steps(circuit):
state_vector = step.state_vector(copy=False)
vectors.append(state_vector)
copy_of_vectors.append(state_vector.copy())
assert any(
not np.array_equal(x, y) for x, y in zip(vectors, copy_of_vectors))


def test_final_state_vector_is_not_last_object():
sim = cirq.Simulator()

q = cirq.LineQubit(0)
initial_state = np.array([1, 0], dtype=np.complex64)
circuit = cirq.Circuit(cirq.WaitGate(0)(q))
result = sim.simulate(circuit, initial_state=initial_state)
assert result.state_vector() is not initial_state
assert not np.shares_memory(result.state_vector(), initial_state)
np.testing.assert_equal(result.state_vector(), initial_state)
2 changes: 1 addition & 1 deletion cirq/sim/state_vector_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def state_vector(self):
| 6 | 1 | 1 | 0 |
| 7 | 1 | 1 | 1 |
"""
return self._final_simulator_state.state_vector
return self._final_simulator_state.state_vector.copy()

def _value_equality_values_(self):
measurements = {
Expand Down
12 changes: 12 additions & 0 deletions cirq/sim/state_vector_simulator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,18 @@ def test_state_vector_trial_result_qid_shape():
assert cirq.qid_shape(trial_result) == (3, 2)


def test_state_vector_trial_state_vector_is_copy():
final_state_vector = np.array([0, 1])
final_simulator_state = cirq.StateVectorSimulatorState(
qubit_map={cirq.NamedQubit('a'): 0}, state_vector=final_state_vector)
trial_result = cirq.StateVectorTrialResult(
params=cirq.ParamResolver({}),
measurements={},
final_simulator_state=final_simulator_state)
assert final_simulator_state.state_vector is final_state_vector
assert not trial_result.state_vector() is final_state_vector


def test_str_big():
qs = cirq.LineQubit.range(20)
result = cirq.StateVectorTrialResult(
Expand Down

0 comments on commit a06be71

Please sign in to comment.