Skip to content

Commit

Permalink
Allow ActOnArgsContainer to support protocols.act_on (#4371)
Browse files Browse the repository at this point in the history
Fixes #4368

Previously ActOnArgsContainer was "the top" in the act_on stack. However this caused CircuitOperations to entangle all qubits they contain. Now ActOnArgsContainer is just another act_on target, and CircuitOperaiton.act_on pushes its sub-operations through the act_on protocol into the ActOnArgsContainer, so the container never entangles the whole subcircuit.

To note, this also represents something of a directional shift as we move toward classical flow control. Rather than the smarts living in the simulators, the smarts can live in the operations, such that the classical operations are given control of running their components. This approach isn't entirely new (stabilizer operations already use act_on to manage their own simulation within Clifford simulators), but specifically for classical control (of which subcircuits will be, themselves able to contain classical control operations), I anticipate this is the approach we will standardize on.
  • Loading branch information
daxfohl committed Aug 17, 2021
1 parent 9d615c5 commit cad0c33
Show file tree
Hide file tree
Showing 17 changed files with 240 additions and 50 deletions.
20 changes: 17 additions & 3 deletions cirq-core/cirq/circuits/circuit_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,17 @@
applied as part of a larger circuit, a CircuitOperation will execute all
component operations in order, including any nested CircuitOperations.
"""

from typing import TYPE_CHECKING, AbstractSet, Callable, Dict, List, Optional, Tuple, Union
from typing import (
TYPE_CHECKING,
AbstractSet,
Callable,
Dict,
List,
Optional,
Tuple,
Union,
Iterator,
)

import dataclasses
import numpy as np
Expand Down Expand Up @@ -245,9 +254,14 @@ def mapped_op(self, deep: bool = False) -> 'cirq.CircuitOperation':
"""As `mapped_circuit`, but wraps the result in a CircuitOperation."""
return CircuitOperation(circuit=self.mapped_circuit(deep=deep).freeze())

def _decompose_(self) -> 'cirq.OP_TREE':
def _decompose_(self) -> Iterator['cirq.Operation']:
return self.mapped_circuit(deep=False).all_operations()

def _act_on_(self, args: 'cirq.ActOnArgs') -> bool:
for op in self._decompose_():
protocols.act_on(op, args)
return True

# Methods for string representation of the operation.

def __repr__(self):
Expand Down
11 changes: 9 additions & 2 deletions cirq-core/cirq/contrib/quimb/mps_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,9 +446,16 @@ def apply_op(self, op: 'cirq.Operation', prng: np.random.RandomState):
raise ValueError('Can only handle 1 and 2 qubit operations')
return True

def _act_on_fallback_(self, op: Any, qubits: Sequence['cirq.Qid'], allow_decompose: bool):
def _act_on_fallback_(
self,
action: Union['cirq.Operation', 'cirq.Gate'],
qubits: Sequence['cirq.Qid'],
allow_decompose: bool = True,
) -> bool:
"""Delegates the action to self.apply_op"""
return self.apply_op(op, self.prng)
if isinstance(action, ops.Gate):
action = ops.GateOperation(action, qubits)
return self.apply_op(action, self.prng)

def estimation_stats(self):
"""Returns some statistics about the memory usage and quality of the approximation."""
Expand Down
14 changes: 14 additions & 0 deletions cirq-core/cirq/contrib/quimb/mps_simulator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,3 +504,17 @@ def test_state_act_on_args_initializer():
)
assert s.qubits == (cirq.LineQubit(0),)
assert s.log_of_measurement_results == {'test': 4}


def test_act_on_gate():
args = ccq.mps_simulator.MPSState(
qubits=cirq.LineQubit.range(3),
prng=np.random.RandomState(0),
log_of_measurement_results={},
)

cirq.act_on(cirq.X, args, [cirq.LineQubit(1)])
np.testing.assert_allclose(
args.state_vector().reshape((2, 2, 2)),
cirq.one_hot(index=(0, 1, 0), shape=(2, 2, 2), dtype=np.complex64),
)
6 changes: 5 additions & 1 deletion cirq-core/cirq/ops/measurement_gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,11 @@ def _from_json_dict_(cls, num_qubits, key, invert_mask, qid_shape=None, **kwargs
def _has_stabilizer_effect_(self) -> Optional[bool]:
return True

def _act_on_(self, args: 'cirq.ActOnArgs', qubits: Sequence['cirq.Qid']) -> bool:
def _act_on_(self, args: 'cirq.OperationTarget', qubits: Sequence['cirq.Qid']) -> bool:
from cirq.sim import ActOnArgs

if not isinstance(args, ActOnArgs):
return NotImplemented
args.measure(qubits, self.key, self.full_invert_mask())
return True

Expand Down
13 changes: 8 additions & 5 deletions cirq-core/cirq/protocols/act_on_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING, Union, Sequence, Any
from typing import TYPE_CHECKING, Union, Sequence

from typing_extensions import Protocol

Expand Down Expand Up @@ -91,8 +91,8 @@ def _act_on_(
# TODO(#3388) Add documentation for Raises.
# pylint: disable=missing-raises-doc
def act_on(
action: Union['cirq.Operation', Any],
args: 'cirq.ActOnArgs',
action: Union['cirq.Operation', 'cirq.Gate'],
args: 'cirq.OperationTarget',
qubits: Sequence['cirq.Qid'] = None,
*,
allow_decompose: bool = True,
Expand Down Expand Up @@ -134,7 +134,10 @@ def act_on(

# todo: change to an exception after `args.axes` is deprecated.
if not is_op and qubits is None:
qubits = [args.qubits[i] for i in args.axes]
from cirq.sim import ActOnArgs

if isinstance(args, ActOnArgs):
qubits = [args.qubits[i] for i in args.axes]

action_act_on = getattr(action, '_act_on_', None)
if action_act_on is not None:
Expand All @@ -149,7 +152,7 @@ def act_on(

arg_fallback = getattr(args, '_act_on_fallback_', None)
if arg_fallback is not None:
qubits = action.qubits if is_op else qubits
qubits = action.qubits if isinstance(action, ops.Operation) else qubits
result = arg_fallback(action, qubits=qubits, allow_decompose=allow_decompose)
if result is True:
return
Expand Down
16 changes: 13 additions & 3 deletions cirq-core/cirq/protocols/act_on_protocol_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Tuple
from typing import Any, Tuple, Union, Sequence

import numpy as np
import pytest
Expand All @@ -34,7 +34,12 @@ def _perform_measurement(self, qubits):
def copy(self):
return DummyActOnArgs(self.fallback_result, self.measurements.copy()) # coverage: ignore

def _act_on_fallback_(self, action, qubits, allow_decompose):
def _act_on_fallback_(
self,
action: Union['cirq.Operation', 'cirq.Gate'],
qubits: Sequence['cirq.Qid'],
allow_decompose: bool = True,
):
return self.fallback_result

def sample(self, qubits, repetitions=1, seed=None):
Expand Down Expand Up @@ -80,7 +85,12 @@ def _act_on_(self, args):

def test_act_on_args_axes_deprecation():
class Args(DummyActOnArgs):
def _act_on_fallback_(self, action, qubits, allow_decompose):
def _act_on_fallback_(
self,
action: Union['cirq.Operation', 'cirq.Gate'],
qubits: Sequence['cirq.Qid'] = None,
allow_decompose: bool = True,
) -> bool:
self.measurements.append(qubits)
return True

Expand Down
8 changes: 0 additions & 8 deletions cirq-core/cirq/sim/act_on_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,6 @@ def create_merged_state(self: TSelf) -> TSelf:
"""Creates a final merged state."""
return self

def apply_operation(self, op: 'cirq.Operation'):
"""Applies the operation to the state."""
protocols.act_on(op, self)

def kronecker_product(self: TSelf, other: TSelf) -> TSelf:
"""Joins two state spaces together."""
raise NotImplementedError()
Expand Down Expand Up @@ -214,10 +210,6 @@ def __len__(self) -> int:
def __iter__(self) -> Iterator[Optional['cirq.Qid']]:
return iter(self.qubits)

@abc.abstractmethod
def _act_on_fallback_(self, action: Any, qubits: Sequence['cirq.Qid'], allow_decompose: bool):
"""Handles the act_on protocol fallback implementation."""

@property # type: ignore
@deprecated(
deadline="v0.13",
Expand Down
30 changes: 18 additions & 12 deletions cirq-core/cirq/sim/act_on_args_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,12 @@
Any,
Tuple,
List,
Union,
)

import numpy as np

from cirq import ops
from cirq import ops, protocols
from cirq.sim.operation_target import OperationTarget
from cirq.sim.simulator import (
TActOnArgs,
Expand Down Expand Up @@ -79,29 +80,32 @@ def create_merged_state(self) -> TActOnArgs:
final_args = final_args.kronecker_product(args)
return final_args.transpose_to_qubit_order(self.qubits)

def apply_operation(
def _act_on_fallback_(
self,
op: 'cirq.Operation',
):
gate = op.gate
action: Union['cirq.Operation', 'cirq.Gate'],
qubits: Sequence['cirq.Qid'],
allow_decompose: bool = True,
) -> bool:
gate = action.gate if isinstance(action, ops.Operation) else action

if isinstance(gate, ops.IdentityGate):
return
return True

if isinstance(gate, ops.SwapPowGate) and gate.exponent % 2 == 1 and gate.global_shift == 0:
q0, q1 = op.qubits
q0, q1 = qubits
args0 = self.args[q0]
args1 = self.args[q1]
if args0 is args1:
args0.swap(q0, q1, inplace=True)
else:
self.args[q0] = args1.rename(q1, q0, inplace=True)
self.args[q1] = args0.rename(q0, q1, inplace=True)
return
return True

# Go through the op's qubits and join any disparate ActOnArgs states
# into a new combined state.
op_args_opt: Optional[TActOnArgs] = None
for q in op.qubits:
for q in qubits:
if op_args_opt is None:
op_args_opt = self.args[q]
elif q not in op_args_opt.qubits:
Expand All @@ -113,19 +117,21 @@ def apply_operation(
self.args[q] = op_args

# Act on the args with the operation
op_args.apply_operation(op)
act_on_qubits = qubits if isinstance(action, ops.Gate) else None
protocols.act_on(action, op_args, act_on_qubits, allow_decompose=allow_decompose)

# Decouple any measurements or resets
if self.split_untangled_states and isinstance(
op.gate, (ops.MeasurementGate, ops.ResetChannel)
gate, (ops.MeasurementGate, ops.ResetChannel)
):
for q in op.qubits:
for q in qubits:
q_args, op_args = op_args.factor((q,), validate=False)
self.args[q] = q_args

# (Backfill the args map with the new value)
for q in op_args.qubits:
self.args[q] = op_args
return True

def copy(self) -> 'ActOnArgsContainer[TActOnArgs]':
logs = self.log_of_measurement_results.copy()
Expand Down
81 changes: 78 additions & 3 deletions cirq-core/cirq/sim/act_on_args_container_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Dict, Any, Sequence, Tuple, Optional
from typing import List, Dict, Any, Sequence, Tuple, Optional, Union

import cirq

Expand All @@ -32,7 +32,12 @@ def copy(self) -> 'EmptyActOnArgs':
logs=self.log_of_measurement_results.copy(),
)

def _act_on_fallback_(self, action: Any, qubits: Sequence[cirq.Qid], allow_decompose: bool):
def _act_on_fallback_(
self,
action: Union['cirq.Operation', 'cirq.Gate'],
qubits: Sequence['cirq.Qid'],
allow_decompose: bool = True,
) -> bool:
return True

def kronecker_product(self, other: 'EmptyActOnArgs') -> 'EmptyActOnArgs':
Expand Down Expand Up @@ -68,7 +73,8 @@ def sample(self, qubits, repetitions=1, seed=None):
pass


q0, q1 = qs2 = cirq.LineQubit.range(2)
q0, q1, q2 = qs3 = cirq.LineQubit.range(3)
qs2 = cirq.LineQubit.range(2)


def create_container(
Expand Down Expand Up @@ -98,6 +104,25 @@ def test_entanglement_causes_join():
assert args[None] is not args[q0]


def test_subcircuit_entanglement_causes_join():
args = create_container(qs2)
assert len(set(args.values())) == 3
args.apply_operation(cirq.CircuitOperation(cirq.FrozenCircuit(cirq.CNOT(q0, q1))))
assert len(set(args.values())) == 2
assert args[q0] is args[q1]


def test_subcircuit_entanglement_causes_join_in_subset():
args = create_container(qs3)
assert len(set(args.values())) == 4
args.apply_operation(cirq.CircuitOperation(cirq.FrozenCircuit(cirq.CNOT(q0, q1))))
assert len(set(args.values())) == 3
assert args[q0] is args[q1]
args.apply_operation(cirq.CircuitOperation(cirq.FrozenCircuit(cirq.CNOT(q0, q2))))
assert len(set(args.values())) == 2
assert args[q0] is args[q1] is args[q2]


def test_identity_does_not_join():
args = create_container(qs2)
assert len(set(args.values())) == 3
Expand All @@ -107,6 +132,23 @@ def test_identity_does_not_join():
assert args[q0] is not args[None]


def test_identity_fallback_does_not_join():
args = create_container(qs2)
assert len(set(args.values())) == 3
args._act_on_fallback_(cirq.I, (q0, q1))
assert len(set(args.values())) == 3
assert args[q0] is not args[q1]
assert args[q0] is not args[None]


def test_subcircuit_identity_does_not_join():
args = create_container(qs2)
assert len(set(args.values())) == 3
args.apply_operation(cirq.CircuitOperation(cirq.FrozenCircuit(cirq.IdentityGate(2)(q0, q1))))
assert len(set(args.values())) == 3
assert args[q0] is not args[q1]


def test_measurement_causes_split():
args = create_container(qs2)
args.apply_operation(cirq.CNOT(q0, q1))
Expand All @@ -117,6 +159,30 @@ def test_measurement_causes_split():
assert args[q0] is not args[None]


def test_subcircuit_measurement_causes_split():
args = create_container(qs2)
args.apply_operation(cirq.CNOT(q0, q1))
assert len(set(args.values())) == 2
args.apply_operation(cirq.CircuitOperation(cirq.FrozenCircuit(cirq.measure(q0))))
assert len(set(args.values())) == 3
assert args[q0] is not args[q1]


def test_subcircuit_measurement_causes_split_in_subset():
args = create_container(qs3)
args.apply_operation(cirq.CNOT(q0, q1))
args.apply_operation(cirq.CNOT(q0, q2))
assert len(set(args.values())) == 2
args.apply_operation(cirq.CircuitOperation(cirq.FrozenCircuit(cirq.measure(q0))))
assert len(set(args.values())) == 3
assert args[q0] is not args[q1]
args.apply_operation(cirq.CircuitOperation(cirq.FrozenCircuit(cirq.measure(q1))))
assert len(set(args.values())) == 4
assert args[q0] is not args[q1]
assert args[q0] is not args[q2]
assert args[q1] is not args[q2]


def test_reset_causes_split():
args = create_container(qs2)
args.apply_operation(cirq.CNOT(q0, q1))
Expand Down Expand Up @@ -213,3 +279,12 @@ def test_swap_after_entangle_reorders():
assert len(set(args.values())) == 2
assert args[q0] is args[q1]
assert args[q0].qubits == (q1, q0)


def test_act_on_gate_does_not_join():
args = create_container(qs2)
assert len(set(args.values())) == 3
cirq.act_on(cirq.X, args, [q0])
assert len(set(args.values())) == 3
assert args[q0] is not args[q1]
assert args[q0] is not args[None]

0 comments on commit cad0c33

Please sign in to comment.