Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SimulatorBase independent qubits optimization #4100

Merged
merged 97 commits into from
Jul 2, 2021
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
97 commits
Select commit Hold shift + click to select a range
293374d
split
daxfohl May 6, 2021
4faff08
Allow config param split_entangled_states
daxfohl May 6, 2021
8c202a4
default split to off
daxfohl May 7, 2021
54f1eea
ensure consistent_act_on circuits have a qubit.
daxfohl May 7, 2021
265d939
lint
daxfohl May 7, 2021
099ecc5
lint
daxfohl May 7, 2021
a698893
mps
daxfohl May 7, 2021
251eeb4
lint
daxfohl May 7, 2021
a255905
lint
daxfohl May 7, 2021
6f58842
run sparse by default
daxfohl May 7, 2021
b7254f9
fix tests
daxfohl May 7, 2021
79e77b2
fix tests
daxfohl May 7, 2021
4d301e5
fix tests
daxfohl May 7, 2021
58ed8fa
mps
daxfohl May 11, 2021
ba93627
Merge branch 'master' into split
daxfohl May 11, 2021
cba2940
tableau
daxfohl May 11, 2021
95683b7
test simulator
daxfohl May 11, 2021
e472b5a
test simulator
daxfohl May 11, 2021
c5e2485
Update simulator_base.py
daxfohl May 11, 2021
44791ae
Drop mps/join
daxfohl May 11, 2021
febd6ab
Merge branch 'split' of https://github.com/daxfohl/Cirq into split
daxfohl May 11, 2021
a77314a
Merge branch 'master' into split
daxfohl May 11, 2021
6613a05
Fix clifford extract
daxfohl May 12, 2021
10cffe3
Merge branch 'master' into split
daxfohl May 12, 2021
0c02c2b
lint
daxfohl May 12, 2021
462a842
remove split/join from ch-form
daxfohl May 14, 2021
ce08aae
Merge branch 'master' into split
daxfohl May 14, 2021
d56ebd3
Add default arg for zero qubit circuits
daxfohl May 26, 2021
0af64a5
Merge branch 'master' into split
daxfohl May 26, 2021
352eed4
Have last repetition reuse original state repr
daxfohl May 26, 2021
f51a811
Merge branch 'split' of https://github.com/daxfohl/Cirq into split
daxfohl May 26, 2021
766d4b8
Remove cast
daxfohl May 26, 2021
1e3a208
Split all pure initial states by default
daxfohl May 26, 2021
8629086
Detangle on reset channels
daxfohl Jun 3, 2021
f37d5a0
docstrings
daxfohl Jun 3, 2021
dd1fbf8
docstrings
daxfohl Jun 3, 2021
26402b0
docstrings
daxfohl Jun 3, 2021
ce93863
docstrings
daxfohl Jun 3, 2021
1e83ee6
Merge branch 'master' into split
daxfohl Jun 4, 2021
9828350
fix merge
daxfohl Jun 4, 2021
cfa3ada
lint
daxfohl Jun 4, 2021
fe16073
Add unit test for integer states
daxfohl Jun 4, 2021
b2fb89c
format
daxfohl Jun 4, 2021
184d0d9
Add tests for splitting and joining
daxfohl Jun 4, 2021
1bc5176
remove unnecessary qubits param
daxfohl Jun 5, 2021
f874bcd
Clean up default args
daxfohl Jun 5, 2021
e4c5fcc
Fix failing test
daxfohl Jun 5, 2021
bd2f726
Add ActOnArgsContainer
daxfohl Jun 5, 2021
534a59e
Add ActOnArgsContainer
daxfohl Jun 5, 2021
d701f5e
Clean up tests
daxfohl Jun 6, 2021
6f01a77
Clean up tests
daxfohl Jun 6, 2021
e42892c
Clean up tests
daxfohl Jun 6, 2021
1ad4cf8
format
daxfohl Jun 6, 2021
5420919
Fix tests and coverage
daxfohl Jun 6, 2021
8957e9b
Add OperationTarget interface
daxfohl Jun 6, 2021
50f99b6
Fix unit tests
daxfohl Jun 6, 2021
c372796
mypy, lint, mocks, coverage
daxfohl Jun 6, 2021
9f01c2d
coverage
daxfohl Jun 6, 2021
6af822c
Merge branch 'master' into split
daxfohl Jun 6, 2021
5dc90dd
add log to container
daxfohl Jun 7, 2021
0860077
fix logs
daxfohl Jun 7, 2021
3922452
dead code
daxfohl Jun 7, 2021
47a81e0
EmptyActOnArgs
daxfohl Jun 7, 2021
3066010
simplify dummyargs
daxfohl Jun 7, 2021
191a7e8
lint
daxfohl Jun 8, 2021
278bcd3
Add [] to actonargs
daxfohl Jun 10, 2021
f9b2080
Merge branch 'master' into split
daxfohl Jun 10, 2021
9958740
rename _create_act_on_arg
daxfohl Jun 10, 2021
63fe48d
coverage
daxfohl Jun 10, 2021
4719f5f
coverage
daxfohl Jun 10, 2021
2f3b6f6
Default sparse sim to split=false
daxfohl Jun 12, 2021
eae56b1
format
daxfohl Jun 12, 2021
2993f28
Default sparse sim to split=false
daxfohl Jun 12, 2021
8217b07
Default density matrix sim to split=false
daxfohl Jun 12, 2021
ec50741
lint
daxfohl Jun 12, 2021
6653a54
lint
daxfohl Jun 14, 2021
df55deb
lint
daxfohl Jun 14, 2021
50e6978
lint
daxfohl Jun 14, 2021
d834de1
Merge branch 'master' into split
daxfohl Jun 14, 2021
78a1cc1
address review comments
daxfohl Jun 18, 2021
bc64bc6
Merge branch 'master' into temp
daxfohl Jun 19, 2021
c0647f2
lint
daxfohl Jun 19, 2021
41ebd25
Defaults back to split=false
daxfohl Jun 19, 2021
8537aba
add error if setting state when split is enabled
daxfohl Jun 19, 2021
2ccf2ab
Unit tests
daxfohl Jun 19, 2021
dbe38f5
coverage
daxfohl Jun 19, 2021
1d2b325
coverage
daxfohl Jun 19, 2021
da040e6
coverage
daxfohl Jun 19, 2021
145d52c
docs
daxfohl Jun 19, 2021
06dfa48
seed
daxfohl Jun 21, 2021
dd4b5d0
Merge branch 'master' into split
daxfohl Jun 21, 2021
824c4e0
format
daxfohl Jun 21, 2021
ced5621
Merge branch 'master' into split
daxfohl Jun 22, 2021
5e0948e
rename core functions
daxfohl Jul 1, 2021
f545a44
Add optional validation to factor methods
daxfohl Jul 1, 2021
2d1e5ad
coverage
daxfohl Jul 2, 2021
0d587fb
Merge branch 'master' into split
95-martin-orion Jul 2, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
29 changes: 26 additions & 3 deletions cirq-core/cirq/contrib/quimb/mps_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import dataclasses
import math
from typing import Any, Dict, List, Optional, Sequence, Set, TYPE_CHECKING, Iterable, Union
from typing import Any, Dict, List, Optional, Sequence, Set, TYPE_CHECKING, Iterable, Union, Tuple

import numpy as np
import quimb.tensor as qtn
Expand Down Expand Up @@ -84,10 +84,11 @@ def __init__(
seed=seed,
)

def _create_act_on_args(
def _create_act_on_arg(
self,
initial_state: Union[int, 'MPSState'],
qubits: Sequence['cirq.Qid'],
logs: Dict[str, Any],
) -> 'MPSState':
"""Creates MPSState args for simulating the Circuit.

Expand All @@ -110,6 +111,7 @@ def _create_act_on_args(
simulation_options=self.simulation_options,
grouping=self.grouping,
initial_state=initial_state,
log_of_measurement_results=logs,
)

def _create_step_result(
Expand Down Expand Up @@ -220,7 +222,6 @@ def sample(
return np.array(measurements, dtype=int)


@value.value_equality
class MPSState(ActOnArgs):
"""A state of the MPS simulation."""

Expand Down Expand Up @@ -310,11 +311,33 @@ def copy(self) -> 'MPSState':
prng=self.prng,
simulation_options=self.simulation_options,
grouping=self.grouping,
log_of_measurement_results=self.log_of_measurement_results,
)
state.M = [x.copy() for x in self.M]
state.estimated_gate_error_list = self.estimated_gate_error_list
return state

def join(self, other: 'MPSState') -> 'MPSState':
# TODO MPS simulator currently does not enable split_untangled_states
# so this will never be called during simulation, and MPS gains nothing
# from running in split_untangled_states mode, so this is not necessary,
# however it may be useful if other use cases arise.
raise NotImplementedError()

def extract(self, qubits: Sequence['cirq.Qid']) -> Tuple['MPSState', 'MPSState']:
# TODO MPS simulator currently does not enable split_untangled_states
# so this will never be called during simulation, and MPS gains nothing
# from running in split_untangled_states mode, so this is not necessary,
# however it may be useful if other use cases arise.
raise NotImplementedError()

def reorder(self, qubits: Sequence['cirq.Qid']) -> 'MPSState':
# TODO MPS simulator currently does not enable split_untangled_states
# so this will never be called during simulation, and MPS gains nothing
# from running in split_untangled_states mode, so this is not necessary,
# however may be useful if other use cases arise.
raise NotImplementedError()

def state_vector(self) -> np.ndarray:
"""Returns the full state vector.

Expand Down
46 changes: 46 additions & 0 deletions cirq-core/cirq/linalg/transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,3 +500,49 @@ def to_special(u: np.ndarray) -> np.ndarray:
the special unitary matrix
"""
return u * (np.linalg.det(u) ** (-1 / len(u)))


daxfohl marked this conversation as resolved.
Show resolved Hide resolved
def merge_state_vectors(
daxfohl marked this conversation as resolved.
Show resolved Hide resolved
t1: np.ndarray,
t2: np.ndarray,
) -> np.ndarray:
return np.outer(t1, t2).reshape(t1.shape + t2.shape)


def merge_density_matrices(
t1: np.ndarray,
t2: np.ndarray,
) -> np.ndarray:
t = merge_state_vectors(t1, t2)
t1_len = len(t1.shape)
t1_dim = int(t1_len / 2)
t2_len = len(t2.shape)
t2_dim = int(t2_len / 2)
shape = t1.shape[:t1_dim] + t2.shape[:t2_dim]
return np.moveaxis(t, range(t1_len, t1_len + t2_dim), range(t1_dim, t1_dim + t2_dim)).reshape(
shape * 2
)


def split_state_vectors(
t: np.ndarray,
axes: Sequence[int],
) -> Tuple[np.ndarray, np.ndarray]:
n_axes = len(axes)
t = np.moveaxis(t, axes, range(n_axes))
pivot = np.unravel_index(np.abs(t).argmax(), t.shape)
slices1 = (slice(None),) * n_axes + pivot[n_axes:]
slices2 = pivot[:n_axes] + (slice(None),) * (t.ndim - n_axes)
extracted = t[slices1]
extracted = extracted / np.sum(abs(extracted) ** 2) ** 0.5
remainder = t[slices2]
remainder = remainder / np.sum(abs(remainder) ** 2) ** 0.5
return extracted, remainder


def split_density_matrices(
t: np.ndarray,
axes: Sequence[int],
) -> Tuple[np.ndarray, np.ndarray]:
axes = list(axes) + [i + int(t.ndim / 2) for i in axes]
return split_state_vectors(t, axes)
14 changes: 13 additions & 1 deletion cirq-core/cirq/sim/act_on_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
"""Objects and methods for acting efficiently on a state tensor."""
import abc
from typing import Any, Iterable, Dict, List, TypeVar, TYPE_CHECKING, Sequence
from typing import Any, Iterable, Dict, List, TypeVar, TYPE_CHECKING, Sequence, Tuple

import numpy as np

Expand Down Expand Up @@ -85,6 +85,18 @@ def _perform_measurement(self) -> List[int]:
def copy(self: TSelf) -> TSelf:
"""Creates a copy of the object."""

@abc.abstractmethod
def join(self: TSelf, other: TSelf) -> TSelf:
"""Joins two state spaces together."""

@abc.abstractmethod
def extract(self: TSelf, qubits: Sequence['cirq.Qid']) -> Tuple[TSelf, TSelf]:
"""Splits two state spaces after a measurement or reset."""

@abc.abstractmethod
def reorder(self: TSelf, qubits: Sequence['cirq.Qid']) -> TSelf:
"""Physically reindexes the state by the new basis."""


def strat_act_on_from_apply_decompose(
val: Any,
Expand Down
59 changes: 58 additions & 1 deletion cirq-core/cirq/sim/act_on_density_matrix_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from cirq import protocols, sim
from cirq.sim.act_on_args import ActOnArgs, strat_act_on_from_apply_decompose
from cirq.linalg import transformations as tf

if TYPE_CHECKING:
import cirq
Expand Down Expand Up @@ -107,7 +108,63 @@ def copy(self) -> 'cirq.ActOnDensityMatrixArgs':
axes=self.axes,
qid_shape=self.qid_shape,
prng=self.prng,
log_of_measurement_results=self.log_of_measurement_results.copy(),
log_of_measurement_results=self.log_of_measurement_results,
daxfohl marked this conversation as resolved.
Show resolved Hide resolved
)

def join(self, other: 'cirq.ActOnDensityMatrixArgs') -> 'cirq.ActOnDensityMatrixArgs':
target_tensor = tf.merge_density_matrices(self.target_tensor, other.target_tensor)
buffer = [np.empty_like(target_tensor) for _ in self.available_buffer]
return ActOnDensityMatrixArgs(
target_tensor=target_tensor,
available_buffer=buffer,
qubits=self.qubits + other.qubits,
axes=(),
qid_shape=target_tensor.shape[: int(target_tensor.ndim / 2)],
prng=self.prng,
log_of_measurement_results=self.log_of_measurement_results,
)

def extract(
self, qubits: Sequence['cirq.Qid']
) -> Tuple['cirq.ActOnDensityMatrixArgs', 'cirq.ActOnDensityMatrixArgs']:
axes = [self.qubit_map[q] for q in qubits]
extracted_tensor, remainder_tensor = tf.split_density_matrices(self.target_tensor, axes)
buffer = [np.empty_like(extracted_tensor) for _ in self.available_buffer]
extracted_args = ActOnDensityMatrixArgs(
target_tensor=extracted_tensor,
available_buffer=buffer,
qubits=qubits,
qid_shape=extracted_tensor.shape[: int(extracted_tensor.ndim / 2)],
axes=(),
prng=self.prng,
log_of_measurement_results=self.log_of_measurement_results,
)
buffer = [np.empty_like(remainder_tensor) for _ in self.available_buffer]
remainder_args = ActOnDensityMatrixArgs(
target_tensor=remainder_tensor,
available_buffer=buffer,
qubits=tuple(q for q in self.qubits if q not in qubits),
qid_shape=remainder_tensor.shape[: int(remainder_tensor.ndim / 2)],
axes=(),
prng=self.prng,
log_of_measurement_results=self.log_of_measurement_results,
)
return extracted_args, remainder_args

def reorder(self, qubits: Sequence['cirq.Qid']) -> 'cirq.ActOnDensityMatrixArgs':
assert len(qubits) == len(self.qubits)
axes = [self.qubit_map[q] for q in qubits]
axes = axes + [i + len(qubits) for i in axes]
new_tensor = np.moveaxis(self.target_tensor, axes, range(len(qubits) * 2))
buffer = [np.empty_like(new_tensor) for _ in self.available_buffer]
return ActOnDensityMatrixArgs(
target_tensor=new_tensor,
available_buffer=buffer,
qubits=qubits,
qid_shape=new_tensor.shape[: int(new_tensor.ndim / 2)],
axes=(),
prng=self.prng,
log_of_measurement_results=self.log_of_measurement_results,
)


Expand Down
54 changes: 53 additions & 1 deletion cirq-core/cirq/sim/act_on_state_vector_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from cirq import linalg, protocols, sim
from cirq.sim.act_on_args import ActOnArgs, strat_act_on_from_apply_decompose
from cirq.linalg import transformations as tf
daxfohl marked this conversation as resolved.
Show resolved Hide resolved

if TYPE_CHECKING:
import cirq
Expand Down Expand Up @@ -178,9 +179,60 @@ def copy(self) -> 'cirq.ActOnStateVectorArgs':
qubits=self.qubits,
axes=self.axes,
prng=self.prng,
log_of_measurement_results=self.log_of_measurement_results.copy(),
log_of_measurement_results=self.log_of_measurement_results,
)

def join(self, other: 'cirq.ActOnStateVectorArgs') -> 'cirq.ActOnStateVectorArgs':
target_tensor = tf.merge_state_vectors(self.target_tensor, other.target_tensor)
buffer = np.empty_like(target_tensor)
offset = len(self.target_tensor.shape)
axes = self.axes + tuple(a + offset for a in other.axes)
return ActOnStateVectorArgs(
target_tensor=target_tensor,
available_buffer=buffer,
qubits=self.qubits + other.qubits,
axes=axes,
prng=self.prng,
log_of_measurement_results=self.log_of_measurement_results,
)

def extract(
self, qubits: Sequence['cirq.Qid']
) -> Tuple['cirq.ActOnStateVectorArgs', 'cirq.ActOnStateVectorArgs']:
axes = [self.qubit_map[q] for q in qubits]
extracted_tensor, remainder_tensor = tf.split_state_vectors(self.target_tensor, axes)
extracted_args = ActOnStateVectorArgs(
target_tensor=extracted_tensor,
available_buffer=np.empty_like(extracted_tensor),
qubits=qubits,
axes=(),
prng=self.prng,
log_of_measurement_results=self.log_of_measurement_results,
)
remainder_args = ActOnStateVectorArgs(
target_tensor=remainder_tensor,
available_buffer=np.empty_like(remainder_tensor),
qubits=tuple(q for q in self.qubits if q not in qubits),
axes=(),
prng=self.prng,
log_of_measurement_results=self.log_of_measurement_results,
)
return extracted_args, remainder_args

def reorder(self, qubits: Sequence['cirq.Qid']) -> 'cirq.ActOnStateVectorArgs':
assert len(qubits) == len(self.qubits)
axes = [self.qubit_map[q] for q in qubits]
new_tensor = np.moveaxis(self.target_tensor, axes, range(len(qubits)))
new_args = ActOnStateVectorArgs(
target_tensor=new_tensor,
available_buffer=np.empty_like(new_tensor),
qubits=qubits,
axes=(),
prng=self.prng,
log_of_measurement_results=self.log_of_measurement_results,
)
return new_args


def _strat_act_on_state_vector_from_apply_unitary(
unitary_value: Any,
Expand Down
18 changes: 16 additions & 2 deletions cirq-core/cirq/sim/clifford/act_on_clifford_tableau_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"""A protocol for implementing high performance clifford tableau evolutions
for Clifford Simulator."""

from typing import Any, Dict, Iterable, TYPE_CHECKING, List, Sequence
from typing import Any, Dict, Iterable, TYPE_CHECKING, List, Sequence, Tuple

import numpy as np

Expand Down Expand Up @@ -87,9 +87,23 @@ def copy(self) -> 'cirq.ActOnCliffordTableauArgs':
qubits=self.qubits,
axes=self.axes,
prng=self.prng,
log_of_measurement_results=self.log_of_measurement_results.copy(),
log_of_measurement_results=self.log_of_measurement_results,
)

def join(self, other: 'ActOnCliffordTableauArgs') -> 'ActOnCliffordTableauArgs':
# Unnecessary for now but can be added later if there is a use case.
raise NotImplementedError()

def extract(
self, qubits: Sequence['cirq.Qid']
) -> Tuple['ActOnCliffordTableauArgs', 'ActOnCliffordTableauArgs']:
# Unnecessary for now but can be added later if there is a use case.
raise NotImplementedError()

def reorder(self, qubits: Sequence['cirq.Qid']) -> 'ActOnCliffordTableauArgs':
# Unnecessary for now but can be added later if there is a use case.
raise NotImplementedError()


def _strat_act_on_clifford_tableau_from_single_qubit_decompose(
val: Any, args: 'cirq.ActOnCliffordTableauArgs'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,5 +98,4 @@ def test_copy():
assert args.tableau == args1.tableau
assert args.axes == args1.axes
assert args.prng is args1.prng
assert args.log_of_measurement_results is not args1.log_of_measurement_results
assert args.log_of_measurement_results == args.log_of_measurement_results
assert args.log_of_measurement_results is args1.log_of_measurement_results
18 changes: 16 additions & 2 deletions cirq-core/cirq/sim/clifford/act_on_stabilizer_ch_form_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Dict, Iterable, TYPE_CHECKING, List, Sequence
from typing import Any, Dict, Iterable, TYPE_CHECKING, List, Sequence, Tuple

import numpy as np

Expand Down Expand Up @@ -82,9 +82,23 @@ def copy(self) -> 'cirq.ActOnStabilizerCHFormArgs':
qubits=self.qubits,
axes=self.axes,
prng=self.prng,
log_of_measurement_results=self.log_of_measurement_results.copy(),
log_of_measurement_results=self.log_of_measurement_results,
)

def join(self, other: 'cirq.ActOnStabilizerCHFormArgs') -> 'cirq.ActOnStabilizerCHFormArgs':
# Unnecessary for now but can be added later if there is a use case.
raise NotImplementedError()

def extract(
self, qubits: Sequence['cirq.Qid']
) -> Tuple['cirq.ActOnStabilizerCHFormArgs', 'cirq.ActOnStabilizerCHFormArgs']:
# Unnecessary for now but can be added later if there is a use case.
raise NotImplementedError()

def reorder(self, qubits: Sequence['cirq.Qid']) -> 'cirq.ActOnStabilizerCHFormArgs':
# Unnecessary for now but can be added later if there is a use case.
raise NotImplementedError()


def _strat_act_on_stabilizer_ch_form_from_single_qubit_decompose(
val: Any, args: 'cirq.ActOnStabilizerCHFormArgs'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,5 +122,4 @@ def test_copy():
np.testing.assert_equal(args.state.state_vector(), args1.state.state_vector())
assert args.axes == args1.axes
assert args.prng is args1.prng
assert args.log_of_measurement_results is not args1.log_of_measurement_results
assert args.log_of_measurement_results == args.log_of_measurement_results
assert args.log_of_measurement_results is args1.log_of_measurement_results