Skip to content

Commit

Permalink
Define cirq.STATE_VECTOR_LIKE (#2376)
Browse files Browse the repository at this point in the history
- Add list-of-qid-values to the supported STATE_VECTOR_LIKE types
- Add tensor-of-amplitudes to the supported STATE_VECTOR_LIKE types
- Add non-numpy sequences to supported STATE_VECTOR_LIKE types
- Infer values-vs-tensor-vs-vector using shape, fail when ambiguous (can only occur for non-standard shapes involving qudits of dimension 1)
- Make `num_qubits` argument of `to_valid_state_vector` optional (if `qid_shape` specified)
  • Loading branch information
Strilanc authored and CirqBot committed Nov 6, 2019
1 parent efb3afe commit 58a794f
Show file tree
Hide file tree
Showing 11 changed files with 235 additions and 54 deletions.
1 change: 1 addition & 0 deletions cirq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,7 @@
SimulationTrialResult,
Simulator,
SparseSimulatorStep,
STATE_VECTOR_LIKE,
StateVectorMixin,
StepResult,
to_valid_density_matrix,
Expand Down
29 changes: 16 additions & 13 deletions cirq/circuits/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -1418,7 +1418,7 @@ def unitary(self,

def final_wavefunction(
self,
initial_state: Union[int, np.ndarray] = 0,
initial_state: 'cirq.STATE_VECTOR_LIKE' = 0,
qubit_order: 'cirq.QubitOrderOrList' = ops.QubitOrder.DEFAULT,
qubits_that_should_be_present: Iterable['cirq.Qid'] = (),
ignore_terminal_measurements: bool = True,
Expand All @@ -1440,12 +1440,18 @@ def final_wavefunction(
way.
Args:
initial_state: The input state for the circuit. This can be an int
or a vector. When this is an int, it refers to a computational
initial_state: The input state for the circuit. This can be a list
of qudit values, a big endian int encoding the qudit values,
a vector of amplitudes, or a tensor of amplitudes.
When this is an int, it refers to a computational
basis state (e.g. 5 means initialize to ``|5⟩ = |...000101⟩``).
If this is a state vector, it directly specifies the initial
state's amplitudes. The vector must be a flat numpy array with a
type that can be converted to np.complex128.
If this is a vector of amplitudes (a flat numpy array of the
correct length for the system) or a tensor of amplitudes (a
numpy array whose shape equals this circuit's `qid_shape`), it
directly specifies the initial state's amplitudes. The vector
type must be convertible to the given `dtype` argument.
qubit_order: Determines how qubits are ordered when passing matrices
into np.kron.
qubits_that_should_be_present: Qubits that may or may not appear
Expand Down Expand Up @@ -1485,13 +1491,10 @@ def final_wavefunction(
qid_shape = self.qid_shape(qubit_order=qs)
state_len = np.product(qid_shape, dtype=int)

if isinstance(initial_state, int):
state = np.zeros(state_len, dtype=dtype)
state[initial_state] = 1
else:
state = initial_state.astype(dtype)
state.shape = qid_shape

from cirq import sim
state = sim.to_valid_state_vector(initial_state,
qid_shape=qid_shape,
dtype=dtype).reshape(qid_shape)
result = _apply_unitary_circuit(self, state, qs, dtype)
return result.reshape((state_len,))

Expand Down
7 changes: 5 additions & 2 deletions cirq/linalg/combinators.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,22 @@
from cirq._doc import document


def kron(*factors: Union[np.ndarray, complex, float]) -> np.ndarray:
def kron(*factors: Union[np.ndarray, complex, float],
shape_len: int = 2) -> np.ndarray:
"""Computes the kronecker product of a sequence of values.
A *args version of lambda args: functools.reduce(np.kron, args).
Args:
*factors: The matrices, tensors, and/or scalars to combine together
using np.kron.
shape_len: The expected number of dimensions in the output. Mainly
determines the behavior of the empty kron product.
Returns:
The kronecker product of all the inputs.
"""
product = np.eye(1)
product = np.ones(shape=(1,) * shape_len)
for m in factors:
product = np.kron(product, m)
return np.array(product)
Expand Down
8 changes: 8 additions & 0 deletions cirq/linalg/combinators_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,14 @@ def test_dot():


def test_kron_multiplies_sizes():
assert cirq.kron(np.array([1, 2])).shape == (1, 2)
assert cirq.kron(np.array([1, 2]), shape_len=1).shape == (2,)
assert cirq.kron(np.array([1, 2]), np.array([3, 4, 5]),
shape_len=1).shape == (6,)
assert cirq.kron(shape_len=0).shape == ()
assert cirq.kron(shape_len=1).shape == (1,)
assert cirq.kron(shape_len=2).shape == (1, 1)

assert np.allclose(cirq.kron(1j, np.array([2, 3])), np.array([2j, 3j]))
assert np.allclose(cirq.kron(), np.eye(1))
assert np.allclose(cirq.kron(np.eye(1)), np.eye(1))
Expand Down
6 changes: 3 additions & 3 deletions cirq/ops/pauli_string_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,7 +558,7 @@ def test_to_z_basis_ops():
q4: cirq.Z, q5: cirq.Z})
circuit = cirq.Circuit(pauli_string.to_z_basis_ops())

initial_state = cirq.kron(x0, x1, y0, y1, z0, z1)
initial_state = cirq.kron(x0, x1, y0, y1, z0, z1, shape_len=1)
z_basis_state = circuit.final_wavefunction(initial_state)

expected_state = np.zeros(2 ** 6)
Expand Down Expand Up @@ -962,7 +962,7 @@ def test_pauli_string_expectation_from_wavefunction_pure_state():
x0z1 = cirq.PauliString({qubits[0]: cirq.X, qubits[1]: cirq.Z})
x3 = cirq.PauliString({qubits[3]: cirq.X})

for state in [wf, wf.reshape(2, 2, 2, 2)]:
for state in [wf, wf.reshape((2, 2, 2, 2))]:
np.testing.assert_allclose(
z0z1.expectation_from_wavefunction(state, q_map), -1)
np.testing.assert_allclose(
Expand Down Expand Up @@ -1206,7 +1206,7 @@ def test_pauli_string_expectation_from_density_matrix_pure_state():
x0z1 = cirq.PauliString({qubits[0]: cirq.X, qubits[1]: cirq.Z})
x3 = cirq.PauliString({qubits[3]: cirq.X})

for state in [rho, rho.reshape(2, 2, 2, 2, 2, 2, 2, 2)]:
for state in [rho, rho.reshape((2, 2, 2, 2, 2, 2, 2, 2))]:
np.testing.assert_allclose(
z0z1.expectation_from_density_matrix(state, q_map), -1)
np.testing.assert_allclose(
Expand Down
1 change: 1 addition & 0 deletions cirq/protocols/json_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,7 @@ def test_fail_to_resolve():
'ParamResolverOrSimilarType',
'PauliSumLike',
'QubitOrderOrList',
'STATE_VECTOR_LIKE',
'Sweepable',
'TParamVal',
'ParamDictType',
Expand Down
1 change: 1 addition & 0 deletions cirq/sim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
dirac_notation,
measure_state_vector,
sample_state_vector,
STATE_VECTOR_LIKE,
StateVectorMixin,
to_valid_state_vector,
validate_normalized_state,
Expand Down
6 changes: 3 additions & 3 deletions cirq/sim/density_matrix_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,8 @@ def test_to_valid_density_matrix_from_state():


def test_to_valid_density_matrix_from_state_invalid_state():
with pytest.raises(ValueError, match="2 qubits"):
cirq.to_valid_density_matrix(np.array([1, 0]), num_qubits=2)
with pytest.raises(ValueError, match="shape was neither"):
cirq.to_valid_density_matrix(np.array([1, 0, 0]), num_qubits=2)


def test_to_valid_density_matrix_from_computational_basis():
Expand All @@ -224,7 +224,7 @@ def test_to_valid_density_matrix_from_computational_basis():


def test_to_valid_density_matrix_from_state_invalid_computational_basis():
with pytest.raises(ValueError, match="positive"):
with pytest.raises(ValueError, match="out of range"):
cirq.to_valid_density_matrix(-1, num_qubits=2)


Expand Down
14 changes: 9 additions & 5 deletions cirq/sim/sparse_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,17 @@

import collections

from typing import Dict, Iterator, List, Optional, Tuple, Type, Union
from typing import Dict, Iterator, List, Optional, Tuple, Type, Union, \
TYPE_CHECKING

import numpy as np

from cirq import circuits, linalg, ops, protocols, study
from cirq.sim import simulator, wave_function, wave_function_simulator

if TYPE_CHECKING:
import cirq


class _FlipGate(ops.SingleQubitGate):
"""A unitary gate that flips the |0> state with another state.
Expand Down Expand Up @@ -213,7 +217,7 @@ def _simulator_iterator(
circuit: circuits.Circuit,
param_resolver: study.ParamResolver,
qubit_order: ops.QubitOrderOrList,
initial_state: Union[int, np.ndarray],
initial_state: 'cirq.STATE_VECTOR_LIKE',
) -> Iterator:
"""See definition in `cirq.SimulatesIntermediateState`.
Expand All @@ -236,8 +240,8 @@ def _base_iterator(
self,
circuit: circuits.Circuit,
qubit_order: ops.QubitOrderOrList,
initial_state: Union[int, np.ndarray],
perform_measurements: bool=True,
initial_state: 'cirq.STATE_VECTOR_LIKE',
perform_measurements: bool = True,
) -> Iterator:
qubits = ops.QubitOrder.as_qubit_order(qubit_order).order_for(
circuit.all_qubits())
Expand Down Expand Up @@ -430,7 +434,7 @@ def state_vector(self):
"""
return self._simulator_state().state_vector

def set_state_vector(self, state: Union[int, np.ndarray]):
def set_state_vector(self, state: 'cirq.STATE_VECTOR_LIKE'):
update_state = wave_function.to_valid_state_vector(
state,
len(self.qubit_map),
Expand Down

0 comments on commit 58a794f

Please sign in to comment.