Skip to content

Commit

Permalink
Speed up wavefunction sampling (#2460)
Browse files Browse the repository at this point in the history
Gives ~1.5x speedup by using vectorized operations.

Benchmark script:
```
import timeit
import numpy as np
import cirq


def random_superposition(dim: int, seed: int) -> np.ndarray:
    random_state = np.random.RandomState(seed)
    state_vector = random_state.randn(dim).astype(complex)
    state_vector += 1j * random_state.randn(dim)
    state_vector /= np.linalg.norm(state_vector)
    return state_vector


for n in range(8, 22, 2):
    state = random_superposition(1 << n, seed=1234)
    t0 = timeit.default_timer()
    _ = cirq.sample_state_vector(state, range(n), repetitions=1000)
    t1 = timeit.default_timer()
    print(f"n={n}: {t1-t0} secs")
```

Output before:
```
n=8: 0.01701263594441116 secs
n=10: 0.03182271297555417 secs
n=12: 0.101432592025958 secs
n=14: 0.47042544500436634 secs
n=16: 1.6910279580624774 secs
n=18: 6.5820552540244535 secs
n=20: 27.832146479049698 secs
```

Output after:
```
n=8: 0.015078706084750593 secs
n=10: 0.02267816790845245 secs
n=12: 0.06910240300931036 secs
n=14: 0.28539576393086463 secs
n=16: 1.0683796659577638 secs
n=18: 4.231483687995933 secs
n=20: 17.637282963027246 secs
```
  • Loading branch information
kevinsung authored and CirqBot committed Oct 31, 2019
1 parent c266da8 commit ef86e7a
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions cirq/sim/wave_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,21 +563,23 @@ def measure_state_vector(


def _probs(state: np.ndarray, indices: List[int],
qid_shape: Tuple[int, ...]) -> List[float]:
qid_shape: Tuple[int, ...]) -> np.ndarray:
"""Returns the probabilities for a measurement on the given indices."""
# Tensor of squared amplitudes, shaped a rank [2, 2, .., 2] tensor.
tensor = np.reshape(state, qid_shape)

# Calculate the probabilities for measuring the particular results.
meas_shape = tuple(qid_shape[i] for i in indices)
probs = [
np.linalg.norm(tensor[linalg.slice_for_qubits_equal_to(
indices, big_endian_qureg_value=b, qid_shape=qid_shape)])**2
probs = np.abs([
tensor[linalg.slice_for_qubits_equal_to(indices,
big_endian_qureg_value=b,
qid_shape=qid_shape)]
for b in range(np.prod(meas_shape, dtype=int))
]
])**2
probs = np.sum(probs, axis=tuple(range(1, len(probs.shape))))

# To deal with rounding issues, ensure that the probabilities sum to 1.
probs /= sum(probs) # type: ignore
probs /= np.sum(probs)
return probs


Expand Down

0 comments on commit ef86e7a

Please sign in to comment.