-
Notifications
You must be signed in to change notification settings - Fork 989
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Density Matrix Simulator ActOn migration (#3841)
Fixes #3825 Sparse and Clifford simulators both have a proper ActOnXXXStateArgs where simulator state is maintained and acted upon by various operations. This was first added in #3019 with the act_on protocol. DensityMatrix and MPS simulators were never updated with the new protocol, leaving the codebase inconsistent and the migration unfinished. This PR finishes the task for DensityMatrix. A separate PR for MPS simulator hopefully will follow. Aditionally, this PR introduces an `ActOnArgs` base class to eliminate the code duplication and type checking that was required in `MeasurementGate` (it had grown to four `if isinstance` blocks, each with duplicate bitshift/logging code), and replace it all with a single function call back into the `args` object.
- Loading branch information
Showing
14 changed files
with
395 additions
and
199 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
# Copyright 2021 The Cirq Developers | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""Objects and methods for acting efficiently on a state tensor.""" | ||
import abc | ||
from typing import Any, Iterable, Dict, List | ||
|
||
import numpy as np | ||
|
||
from cirq import protocols | ||
from cirq.protocols.decompose_protocol import _try_decompose_into_operations_and_qubits | ||
|
||
|
||
class ActOnArgs: | ||
"""State and context for an operation acting on a state tensor.""" | ||
|
||
def __init__( | ||
self, | ||
axes: Iterable[int], | ||
prng: np.random.RandomState, | ||
log_of_measurement_results: Dict[str, Any], | ||
): | ||
""" | ||
Args: | ||
axes: The indices of axes corresponding to the qubits that the | ||
operation is supposed to act upon. | ||
prng: The pseudo random number generator to use for probabilistic | ||
effects. | ||
log_of_measurement_results: A mutable object that measurements are | ||
being recorded into. Edit it easily by calling | ||
`ActOnStateVectorArgs.record_measurement_result`. | ||
""" | ||
self.axes = tuple(axes) | ||
self.prng = prng | ||
self.log_of_measurement_results = log_of_measurement_results | ||
|
||
def measure(self, key, invert_mask): | ||
"""Adds a measurement result to the log. | ||
Args: | ||
key: The key the measurement result should be logged under. Note | ||
that operations should only store results under keys they have | ||
declared in a `_measurement_keys_` method. | ||
invert_mask: The invert mask for the measurement. | ||
""" | ||
bits = self._perform_measurement() | ||
corrected = [bit ^ (bit < 2 and mask) for bit, mask in zip(bits, invert_mask)] | ||
if key in self.log_of_measurement_results: | ||
raise ValueError(f"Measurement already logged to key {key!r}") | ||
self.log_of_measurement_results[key] = corrected | ||
|
||
@abc.abstractmethod | ||
def _perform_measurement(self) -> List[int]: | ||
"""Child classes that perform measurements should implement this with | ||
the implementation.""" | ||
|
||
|
||
def strat_act_on_from_apply_decompose( | ||
val: Any, | ||
args: ActOnArgs, | ||
) -> bool: | ||
operations, qubits, _ = _try_decompose_into_operations_and_qubits(val) | ||
if operations is None: | ||
return NotImplemented | ||
assert len(qubits) == len(args.axes) | ||
qubit_map = {q: args.axes[i] for i, q in enumerate(qubits)} | ||
|
||
old_axes = args.axes | ||
try: | ||
for operation in operations: | ||
args.axes = tuple(qubit_map[q] for q in operation.qubits) | ||
protocols.act_on(operation, args) | ||
finally: | ||
args.axes = old_axes | ||
return True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
# Copyright 2021 The Cirq Developers | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from typing import List | ||
|
||
import numpy as np | ||
|
||
import cirq | ||
from cirq.sim import act_on_args | ||
|
||
|
||
def test_measurements(): | ||
class DummyArgs(cirq.ActOnArgs): | ||
def _perform_measurement(self) -> List[int]: | ||
return [5, 3] | ||
|
||
args = DummyArgs(axes=[], prng=np.random.RandomState(), log_of_measurement_results={}) | ||
args.measure("test", [1]) | ||
assert args.log_of_measurement_results["test"] == [5] | ||
|
||
|
||
def test_decompose(): | ||
class DummyArgs(cirq.ActOnArgs): | ||
def _act_on_fallback_(self, action, allow_decompose): | ||
return True | ||
|
||
class Composite(cirq.Gate): | ||
def num_qubits(self) -> int: | ||
return 1 | ||
|
||
def _decompose_(self, qubits): | ||
yield cirq.X(*qubits) | ||
|
||
args = DummyArgs(axes=[0], prng=np.random.RandomState(), log_of_measurement_results={}) | ||
assert act_on_args.strat_act_on_from_apply_decompose(Composite(), args) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
# Copyright 2021 The Cirq Developers | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""Objects and methods for acting efficiently on a density matrix.""" | ||
|
||
from typing import Any, Iterable, Dict, List, Tuple | ||
|
||
import numpy as np | ||
|
||
from cirq import protocols, sim | ||
from cirq.sim.act_on_args import ActOnArgs, strat_act_on_from_apply_decompose | ||
|
||
|
||
class ActOnDensityMatrixArgs(ActOnArgs): | ||
"""State and context for an operation acting on a density matrix. | ||
To act on this object, directly edit the `target_tensor` property, which is | ||
storing the density matrix of the quantum system with one axis per qubit. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
target_tensor: np.ndarray, | ||
available_buffer: List[np.ndarray], | ||
axes: Iterable[int], | ||
qid_shape: Tuple[int, ...], | ||
prng: np.random.RandomState, | ||
log_of_measurement_results: Dict[str, Any], | ||
): | ||
""" | ||
Args: | ||
target_tensor: The state vector to act on, stored as a numpy array | ||
with one dimension for each qubit in the system. Operations are | ||
expected to perform inplace edits of this object. | ||
available_buffer: A workspace with the same shape and dtype as | ||
`target_tensor`. Used by operations that cannot be applied to | ||
`target_tensor` inline, in order to avoid unnecessary | ||
allocations. | ||
axes: The indices of axes corresponding to the qubits that the | ||
operation is supposed to act upon. | ||
qid_shape: The shape of the target tensor. | ||
prng: The pseudo random number generator to use for probabilistic | ||
effects. | ||
log_of_measurement_results: A mutable object that measurements are | ||
being recorded into. Edit it easily by calling | ||
`ActOnStateVectorArgs.record_measurement_result`. | ||
""" | ||
super().__init__(axes, prng, log_of_measurement_results) | ||
self.target_tensor = target_tensor | ||
self.available_buffer = available_buffer | ||
self.qid_shape = qid_shape | ||
|
||
def _act_on_fallback_(self, action: Any, allow_decompose: bool): | ||
strats = [ | ||
_strat_apply_channel_to_state, | ||
] | ||
if allow_decompose: | ||
strats.append(strat_act_on_from_apply_decompose) # type: ignore | ||
|
||
# Try each strategy, stopping if one works. | ||
for strat in strats: | ||
result = strat(action, self) | ||
if result is False: | ||
break # coverage: ignore | ||
if result is True: | ||
return True | ||
assert result is NotImplemented, str(result) | ||
raise TypeError( | ||
"Can't simulate operations that don't implement " | ||
"SupportsUnitary, SupportsConsistentApplyUnitary, " | ||
"SupportsMixture, SupportsChannel or is a measurement: {!r}".format(action) | ||
) | ||
|
||
def _perform_measurement(self) -> List[int]: | ||
"""Delegates the call to measure the density matrix.""" | ||
bits, _ = sim.measure_density_matrix( | ||
self.target_tensor, | ||
self.axes, | ||
out=self.target_tensor, | ||
qid_shape=self.qid_shape, | ||
seed=self.prng, | ||
) | ||
return bits | ||
|
||
|
||
def _strat_apply_channel_to_state( | ||
action: Any, | ||
args: ActOnDensityMatrixArgs, | ||
) -> bool: | ||
"""Apply channel to state.""" | ||
result = protocols.apply_channel( | ||
action, | ||
args=protocols.ApplyChannelArgs( | ||
target_tensor=args.target_tensor, | ||
out_buffer=args.available_buffer[0], | ||
auxiliary_buffer0=args.available_buffer[1], | ||
auxiliary_buffer1=args.available_buffer[2], | ||
left_axes=args.axes, | ||
right_axes=[e + len(args.qid_shape) for e in args.axes], | ||
), | ||
default=None, | ||
) | ||
if result is None: | ||
return NotImplemented | ||
for i in range(3): | ||
if result is args.available_buffer[i]: | ||
args.available_buffer[i] = args.target_tensor | ||
args.target_tensor = result | ||
return True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
# Copyright 2021 The Cirq Developers | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import numpy as np | ||
import pytest | ||
|
||
import cirq | ||
|
||
|
||
def test_decomposed_fallback(): | ||
class Composite(cirq.Gate): | ||
def num_qubits(self) -> int: | ||
return 1 | ||
|
||
def _decompose_(self, qubits): | ||
yield cirq.X(*qubits) | ||
|
||
qid_shape = (2,) | ||
tensor = cirq.to_valid_density_matrix( | ||
0, len(qid_shape), qid_shape=qid_shape, dtype=np.complex64 | ||
) | ||
args = cirq.ActOnDensityMatrixArgs( | ||
target_tensor=tensor, | ||
available_buffer=[np.empty_like(tensor) for _ in range(3)], | ||
axes=[0], | ||
prng=np.random.RandomState(), | ||
log_of_measurement_results={}, | ||
qid_shape=qid_shape, | ||
) | ||
|
||
cirq.act_on(Composite(), args) | ||
np.testing.assert_allclose( | ||
args.target_tensor, cirq.one_hot(index=(1, 1), shape=(2, 2), dtype=np.complex64) | ||
) | ||
|
||
|
||
def test_cannot_act(): | ||
class NoDetails: | ||
pass | ||
|
||
qid_shape = (2,) | ||
tensor = cirq.to_valid_density_matrix( | ||
0, len(qid_shape), qid_shape=qid_shape, dtype=np.complex64 | ||
) | ||
args = cirq.ActOnDensityMatrixArgs( | ||
target_tensor=tensor, | ||
available_buffer=[np.empty_like(tensor) for _ in range(3)], | ||
axes=[0], | ||
prng=np.random.RandomState(), | ||
log_of_measurement_results={}, | ||
qid_shape=qid_shape, | ||
) | ||
with pytest.raises(TypeError, match="Can't simulate operations"): | ||
cirq.act_on(NoDetails(), args) |
Oops, something went wrong.