Skip to content

Commit

Permalink
Density Matrix Simulator ActOn migration (#3841)
Browse files Browse the repository at this point in the history
Fixes #3825

Sparse and Clifford simulators both have a proper ActOnXXXStateArgs where simulator state is maintained and acted upon by various operations. This was first added in #3019 with the act_on protocol.

DensityMatrix and MPS simulators were never updated with the new protocol, leaving the codebase inconsistent and the migration unfinished. This PR finishes the task for DensityMatrix. A separate PR for MPS simulator hopefully will follow.

Aditionally, this PR introduces an `ActOnArgs` base class to eliminate the code duplication and type checking that was required in `MeasurementGate` (it had grown to four `if isinstance` blocks, each with duplicate bitshift/logging code), and replace it all with a single function call back into the `args` object.
  • Loading branch information
daxfohl committed Mar 17, 2021
1 parent 8cf7825 commit 06e4892
Show file tree
Hide file tree
Showing 14 changed files with 395 additions and 199 deletions.
2 changes: 2 additions & 0 deletions cirq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,9 @@
)

from cirq.sim import (
ActOnArgs,
ActOnCliffordTableauArgs,
ActOnDensityMatrixArgs,
ActOnStabilizerCHFormArgs,
ActOnStateVectorArgs,
StabilizerStateChForm,
Expand Down
31 changes: 3 additions & 28 deletions cirq/ops/measurement_gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Dict, Iterable, Optional, Tuple, Sequence, TYPE_CHECKING
from typing import Any, Dict, Iterable, Optional, Tuple, Sequence, TYPE_CHECKING, List

import numpy as np

Expand Down Expand Up @@ -223,33 +223,8 @@ def _has_stabilizer_effect_(self) -> Optional[bool]:
def _act_on_(self, args: Any) -> bool:
from cirq import sim

if isinstance(args, sim.ActOnStateVectorArgs):

invert_mask = self.full_invert_mask()
bits, _ = sim.measure_state_vector(
args.target_tensor,
args.axes,
out=args.target_tensor,
qid_shape=args.target_tensor.shape,
seed=args.prng,
)
corrected = [bit ^ (bit < 2 and mask) for bit, mask in zip(bits, invert_mask)]
args.record_measurement_result(self.key, corrected)

return True

if isinstance(args, sim.clifford.ActOnCliffordTableauArgs):
invert_mask = self.full_invert_mask()
bits = [args.tableau._measure(q, args.prng) for q in args.axes]
corrected = [bit ^ (bit < 2 and mask) for bit, mask in zip(bits, invert_mask)]
args.record_measurement_result(self.key, corrected)
return True

if isinstance(args, sim.clifford.ActOnStabilizerCHFormArgs):
invert_mask = self.full_invert_mask()
bits = [args.state._measure(q, args.prng) for q in args.axes]
corrected = [bit ^ (bit < 2 and mask) for bit, mask in zip(bits, invert_mask)]
args.record_measurement_result(self.key, corrected)
if isinstance(args, sim.ActOnArgs):
args.measure(self.key, self.full_invert_mask())
return True

return NotImplemented
Expand Down
2 changes: 2 additions & 0 deletions cirq/protocols/json_test_data/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,9 @@
'Heatmap',
'TwoQubitInteractionHeatmap',
# Intermediate states with work buffers and unknown external prng guts.
'ActOnArgs',
'ActOnCliffordTableauArgs',
'ActOnDensityMatrixArgs',
'ActOnStabilizerCHFormArgs',
'ActOnStateVectorArgs',
'ApplyChannelArgs',
Expand Down
8 changes: 8 additions & 0 deletions cirq/sim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,14 @@
"""Base simulation classes and generic simulators."""
from typing import Tuple, Dict

from cirq.sim.act_on_args import (
ActOnArgs,
)

from cirq.sim.act_on_density_matrix_args import (
ActOnDensityMatrixArgs,
)

from cirq.sim.act_on_state_vector_args import (
ActOnStateVectorArgs,
)
Expand Down
85 changes: 85 additions & 0 deletions cirq/sim/act_on_args.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Copyright 2021 The Cirq Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Objects and methods for acting efficiently on a state tensor."""
import abc
from typing import Any, Iterable, Dict, List

import numpy as np

from cirq import protocols
from cirq.protocols.decompose_protocol import _try_decompose_into_operations_and_qubits


class ActOnArgs:
"""State and context for an operation acting on a state tensor."""

def __init__(
self,
axes: Iterable[int],
prng: np.random.RandomState,
log_of_measurement_results: Dict[str, Any],
):
"""
Args:
axes: The indices of axes corresponding to the qubits that the
operation is supposed to act upon.
prng: The pseudo random number generator to use for probabilistic
effects.
log_of_measurement_results: A mutable object that measurements are
being recorded into. Edit it easily by calling
`ActOnStateVectorArgs.record_measurement_result`.
"""
self.axes = tuple(axes)
self.prng = prng
self.log_of_measurement_results = log_of_measurement_results

def measure(self, key, invert_mask):
"""Adds a measurement result to the log.
Args:
key: The key the measurement result should be logged under. Note
that operations should only store results under keys they have
declared in a `_measurement_keys_` method.
invert_mask: The invert mask for the measurement.
"""
bits = self._perform_measurement()
corrected = [bit ^ (bit < 2 and mask) for bit, mask in zip(bits, invert_mask)]
if key in self.log_of_measurement_results:
raise ValueError(f"Measurement already logged to key {key!r}")
self.log_of_measurement_results[key] = corrected

@abc.abstractmethod
def _perform_measurement(self) -> List[int]:
"""Child classes that perform measurements should implement this with
the implementation."""


def strat_act_on_from_apply_decompose(
val: Any,
args: ActOnArgs,
) -> bool:
operations, qubits, _ = _try_decompose_into_operations_and_qubits(val)
if operations is None:
return NotImplemented
assert len(qubits) == len(args.axes)
qubit_map = {q: args.axes[i] for i, q in enumerate(qubits)}

old_axes = args.axes
try:
for operation in operations:
args.axes = tuple(qubit_map[q] for q in operation.qubits)
protocols.act_on(operation, args)
finally:
args.axes = old_axes
return True
45 changes: 45 additions & 0 deletions cirq/sim/act_on_args_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright 2021 The Cirq Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List

import numpy as np

import cirq
from cirq.sim import act_on_args


def test_measurements():
class DummyArgs(cirq.ActOnArgs):
def _perform_measurement(self) -> List[int]:
return [5, 3]

args = DummyArgs(axes=[], prng=np.random.RandomState(), log_of_measurement_results={})
args.measure("test", [1])
assert args.log_of_measurement_results["test"] == [5]


def test_decompose():
class DummyArgs(cirq.ActOnArgs):
def _act_on_fallback_(self, action, allow_decompose):
return True

class Composite(cirq.Gate):
def num_qubits(self) -> int:
return 1

def _decompose_(self, qubits):
yield cirq.X(*qubits)

args = DummyArgs(axes=[0], prng=np.random.RandomState(), log_of_measurement_results={})
assert act_on_args.strat_act_on_from_apply_decompose(Composite(), args)
119 changes: 119 additions & 0 deletions cirq/sim/act_on_density_matrix_args.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# Copyright 2021 The Cirq Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Objects and methods for acting efficiently on a density matrix."""

from typing import Any, Iterable, Dict, List, Tuple

import numpy as np

from cirq import protocols, sim
from cirq.sim.act_on_args import ActOnArgs, strat_act_on_from_apply_decompose


class ActOnDensityMatrixArgs(ActOnArgs):
"""State and context for an operation acting on a density matrix.
To act on this object, directly edit the `target_tensor` property, which is
storing the density matrix of the quantum system with one axis per qubit.
"""

def __init__(
self,
target_tensor: np.ndarray,
available_buffer: List[np.ndarray],
axes: Iterable[int],
qid_shape: Tuple[int, ...],
prng: np.random.RandomState,
log_of_measurement_results: Dict[str, Any],
):
"""
Args:
target_tensor: The state vector to act on, stored as a numpy array
with one dimension for each qubit in the system. Operations are
expected to perform inplace edits of this object.
available_buffer: A workspace with the same shape and dtype as
`target_tensor`. Used by operations that cannot be applied to
`target_tensor` inline, in order to avoid unnecessary
allocations.
axes: The indices of axes corresponding to the qubits that the
operation is supposed to act upon.
qid_shape: The shape of the target tensor.
prng: The pseudo random number generator to use for probabilistic
effects.
log_of_measurement_results: A mutable object that measurements are
being recorded into. Edit it easily by calling
`ActOnStateVectorArgs.record_measurement_result`.
"""
super().__init__(axes, prng, log_of_measurement_results)
self.target_tensor = target_tensor
self.available_buffer = available_buffer
self.qid_shape = qid_shape

def _act_on_fallback_(self, action: Any, allow_decompose: bool):
strats = [
_strat_apply_channel_to_state,
]
if allow_decompose:
strats.append(strat_act_on_from_apply_decompose) # type: ignore

# Try each strategy, stopping if one works.
for strat in strats:
result = strat(action, self)
if result is False:
break # coverage: ignore
if result is True:
return True
assert result is NotImplemented, str(result)
raise TypeError(
"Can't simulate operations that don't implement "
"SupportsUnitary, SupportsConsistentApplyUnitary, "
"SupportsMixture, SupportsChannel or is a measurement: {!r}".format(action)
)

def _perform_measurement(self) -> List[int]:
"""Delegates the call to measure the density matrix."""
bits, _ = sim.measure_density_matrix(
self.target_tensor,
self.axes,
out=self.target_tensor,
qid_shape=self.qid_shape,
seed=self.prng,
)
return bits


def _strat_apply_channel_to_state(
action: Any,
args: ActOnDensityMatrixArgs,
) -> bool:
"""Apply channel to state."""
result = protocols.apply_channel(
action,
args=protocols.ApplyChannelArgs(
target_tensor=args.target_tensor,
out_buffer=args.available_buffer[0],
auxiliary_buffer0=args.available_buffer[1],
auxiliary_buffer1=args.available_buffer[2],
left_axes=args.axes,
right_axes=[e + len(args.qid_shape) for e in args.axes],
),
default=None,
)
if result is None:
return NotImplemented
for i in range(3):
if result is args.available_buffer[i]:
args.available_buffer[i] = args.target_tensor
args.target_tensor = result
return True
65 changes: 65 additions & 0 deletions cirq/sim/act_on_density_matrix_args_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright 2021 The Cirq Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import pytest

import cirq


def test_decomposed_fallback():
class Composite(cirq.Gate):
def num_qubits(self) -> int:
return 1

def _decompose_(self, qubits):
yield cirq.X(*qubits)

qid_shape = (2,)
tensor = cirq.to_valid_density_matrix(
0, len(qid_shape), qid_shape=qid_shape, dtype=np.complex64
)
args = cirq.ActOnDensityMatrixArgs(
target_tensor=tensor,
available_buffer=[np.empty_like(tensor) for _ in range(3)],
axes=[0],
prng=np.random.RandomState(),
log_of_measurement_results={},
qid_shape=qid_shape,
)

cirq.act_on(Composite(), args)
np.testing.assert_allclose(
args.target_tensor, cirq.one_hot(index=(1, 1), shape=(2, 2), dtype=np.complex64)
)


def test_cannot_act():
class NoDetails:
pass

qid_shape = (2,)
tensor = cirq.to_valid_density_matrix(
0, len(qid_shape), qid_shape=qid_shape, dtype=np.complex64
)
args = cirq.ActOnDensityMatrixArgs(
target_tensor=tensor,
available_buffer=[np.empty_like(tensor) for _ in range(3)],
axes=[0],
prng=np.random.RandomState(),
log_of_measurement_results={},
qid_shape=qid_shape,
)
with pytest.raises(TypeError, match="Can't simulate operations"):
cirq.act_on(NoDetails(), args)

0 comments on commit 06e4892

Please sign in to comment.