Skip to content

Commit

Permalink
Speed up sampling in the case of measuring all qudits (#2463)
Browse files Browse the repository at this point in the history
When all qudits are being measured we can avoid looping over measurement outcomes and computing array indices which is very slow. At large qubit numbers this gives over 100x speedup.

Benchmark script:
```
import timeit
import cirq


for n in range(8, 22, 2):
    state = cirq.testing.random_superposition(1 << n, random_state=1234)
    t0 = timeit.default_timer()
    _ = cirq.sample_state_vector(state, range(n), repetitions=1000)
    t1 = timeit.default_timer()
    print(f"n={n}: {t1-t0} secs")
```
Output before:
```
n=8: 0.014439163962379098 secs
n=10: 0.024910022970288992 secs
n=12: 0.07276041503064334 secs
n=14: 0.3017184369964525 secs
n=16: 1.0823597890557721 secs
n=18: 4.320284072891809 secs
n=20: 17.91134182305541 secs
```
Output after:
```
n=8: 0.00962069199886173 secs
n=10: 0.01253505703061819 secs
n=12: 0.013009974965825677 secs
n=14: 0.0381400550249964 secs
n=16: 0.04483150108717382 secs
n=18: 0.06963291903957725 secs
n=20: 0.10559344501234591 secs
```
  • Loading branch information
kevinsung authored and CirqBot committed Oct 31, 2019
1 parent e54f0f1 commit 1ec8a68
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 18 deletions.
23 changes: 15 additions & 8 deletions cirq/sim/density_matrix_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,14 +268,21 @@ def _probs(density_matrix: np.ndarray, indices: List[int],
tensor = np.reshape(all_probs, qid_shape)

# Calculate the probabilities for measuring the particular results.
meas_shape = tuple(qid_shape[i] for i in indices)
probs = np.abs([
tensor[linalg.slice_for_qubits_equal_to(indices,
big_endian_qureg_value=b,
qid_shape=qid_shape)]
for b in range(np.prod(meas_shape, dtype=int))
])
probs = np.sum(probs, axis=tuple(range(1, len(probs.shape))))
if len(indices) == len(qid_shape):
# We're measuring every qudit, so no need for fancy indexing
probs = np.abs(tensor)
probs = np.transpose(probs, indices)
probs = np.reshape(probs, np.prod(probs.shape))
else:
# Fancy indexing required
meas_shape = tuple(qid_shape[i] for i in indices)
probs = np.abs([
tensor[linalg.slice_for_qubits_equal_to(indices,
big_endian_qureg_value=b,
qid_shape=qid_shape)]
for b in range(np.prod(meas_shape, dtype=int))
])
probs = np.sum(probs, axis=tuple(range(1, len(probs.shape))))

# To deal with rounding issues, ensure that the probabilities sum to 1.
probs /= np.sum(probs)
Expand Down
25 changes: 15 additions & 10 deletions cirq/sim/wave_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,18 +565,23 @@ def measure_state_vector(
def _probs(state: np.ndarray, indices: List[int],
qid_shape: Tuple[int, ...]) -> np.ndarray:
"""Returns the probabilities for a measurement on the given indices."""
# Tensor of squared amplitudes, shaped a rank [2, 2, .., 2] tensor.
tensor = np.reshape(state, qid_shape)

# Calculate the probabilities for measuring the particular results.
meas_shape = tuple(qid_shape[i] for i in indices)
probs = np.abs([
tensor[linalg.slice_for_qubits_equal_to(indices,
big_endian_qureg_value=b,
qid_shape=qid_shape)]
for b in range(np.prod(meas_shape, dtype=int))
])**2
probs = np.sum(probs, axis=tuple(range(1, len(probs.shape))))
if len(indices) == len(qid_shape):
# We're measuring every qudit, so no need for fancy indexing
probs = np.abs(tensor)**2
probs = np.transpose(probs, indices)
probs = np.reshape(probs, np.prod(probs.shape))
else:
# Fancy indexing required
meas_shape = tuple(qid_shape[i] for i in indices)
probs = np.abs([
tensor[linalg.slice_for_qubits_equal_to(indices,
big_endian_qureg_value=b,
qid_shape=qid_shape)]
for b in range(np.prod(meas_shape, dtype=int))
])**2
probs = np.sum(probs, axis=tuple(range(1, len(probs.shape))))

# To deal with rounding issues, ensure that the probabilities sum to 1.
probs /= np.sum(probs)
Expand Down

0 comments on commit 1ec8a68

Please sign in to comment.