Skip to content

Commit

Permalink
Remove axes from ActOnArgs, pass qubits explicitly to act_on (#4089)
Browse files Browse the repository at this point in the history
* Remove axes from ActOnArgs, pass qubits explicitly

* split protocols

* require ActOnArgs to implement fallback

* lint

* Split the protocols

* Fix tests and coverage

* coverage

* format

* make param order consistent

* format

* add deprecation for axes

* v0.13

* lint

* readd axes with mypy ignore

* safe

* deprecate

* fix args len

* tests

* lint

* lint

* cover

* Change _act_on_qubits_ dunder back to _act_on_

* format

* unify act_on

* lint

* exception

* test

* format

* SupportsActOnQubits
  • Loading branch information
daxfohl committed Jun 18, 2021
1 parent 566ce2b commit 65711e1
Show file tree
Hide file tree
Showing 42 changed files with 870 additions and 347 deletions.
1 change: 1 addition & 0 deletions cirq-core/cirq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,7 @@
resolve_parameters_once,
SerializableByKey,
SupportsActOn,
SupportsActOnQubits,
SupportsApplyChannel,
SupportsApplyMixture,
SupportsApproximateEquality,
Expand Down
12 changes: 9 additions & 3 deletions cirq-core/cirq/contrib/quimb/mps_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import quimb.tensor as qtn

from cirq import devices, study, ops, protocols, value
from cirq._compat import deprecated_parameter
from cirq.sim import simulator, simulator_base
from cirq.sim.act_on_args import ActOnArgs

Expand Down Expand Up @@ -224,6 +225,12 @@ def sample(
class MPSState(ActOnArgs):
"""A state of the MPS simulation."""

@deprecated_parameter(
deadline='v0.13',
fix='No longer needed. `protocols.act_on` infers axes.',
parameter_desc='axes',
match=lambda args, kwargs: 'axes' in kwargs or len(args) > 6,
)
def __init__(
self,
qubits: Sequence['cirq.Qid'],
Expand Down Expand Up @@ -451,7 +458,7 @@ def apply_op(self, op: 'cirq.Operation', prng: np.random.RandomState):
raise ValueError('Can only handle 1 and 2 qubit operations')
return True

def _act_on_fallback_(self, op: Any, allow_decompose: bool):
def _act_on_fallback_(self, op: Any, qubits: Sequence['cirq.Qid'], allow_decompose: bool):
"""Delegates the action to self.apply_op"""
return self.apply_op(op, self.prng)

Expand Down Expand Up @@ -524,7 +531,6 @@ def perform_measurement(

return results

def _perform_measurement(self) -> List[int]:
def _perform_measurement(self, qubits: Sequence['cirq.Qid']) -> List[int]:
"""Measures the axes specified by the simulator."""
qubits = [self.qubits[key] for key in self.axes]
return self.perform_measurement(qubits, self.prng)
3 changes: 1 addition & 2 deletions cirq-core/cirq/contrib/quimb/mps_simulator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,8 +498,7 @@ def test_state_act_on_args_initializer():
s = ccq.mps_simulator.MPSState(
qubits=(cirq.LineQubit(0),),
prng=np.random.RandomState(0),
axes=[2],
log_of_measurement_results={'test': 4},
)
assert s.axes == (2,)
assert s.qubits == (cirq.LineQubit(0),)
assert s.log_of_measurement_results == {'test': 4}
17 changes: 9 additions & 8 deletions cirq-core/cirq/ops/common_channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,13 +316,13 @@ def __str__(self) -> str:
return f"depolarize(p={self._p})"
return f"depolarize(p={self._p},n_qubits={self._n_qubits})"

def _act_on_(self, args: Any) -> bool:
def _act_on_(self, args: 'cirq.ActOnArgs', qubits: Sequence['cirq.Qid']) -> bool:
from cirq.sim import clifford

if isinstance(args, clifford.ActOnCliffordTableauArgs):
if args.prng.random() < self._p:
gate = args.prng.choice([pauli_gates.X, pauli_gates.Y, pauli_gates.Z])
protocols.act_on(gate, args)
protocols.act_on(gate, args, qubits)
return True
return NotImplemented

Expand Down Expand Up @@ -720,29 +720,30 @@ def _qasm_(self, args: 'cirq.QasmArgs', qubits: Tuple['cirq.Qid', ...]) -> Optio
def _qid_shape_(self):
return (self._dimension,)

def _act_on_(self, args: Any):
def _act_on_(self, args: 'cirq.ActOnArgs', qubits: Sequence['cirq.Qid']):
from cirq import sim, ops

if isinstance(args, sim.ActOnStabilizerCHFormArgs):
(axe,) = args.axes
axe = args.qubit_map[qubits[0]]
if args.state._measure(axe, args.prng):
ops.X._act_on_(args)
ops.X._act_on_(args, qubits)
return True

if isinstance(args, sim.ActOnStateVectorArgs):
# Do a silent measurement.
axes = args.get_axes(qubits)
measurements, _ = sim.measure_state_vector(
args.target_tensor,
args.axes,
axes,
out=args.target_tensor,
qid_shape=args.target_tensor.shape,
)
result = measurements[0]

# Use measurement result to zero the qid.
if result:
zero = args.subspace_index(0)
other = args.subspace_index(result)
zero = args.subspace_index(axes, 0)
other = args.subspace_index(axes, result)
args.target_tensor[zero] = args.target_tensor[other]
args.target_tensor[other] = 0

Expand Down
11 changes: 6 additions & 5 deletions cirq-core/cirq/ops/common_channels_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import pytest

import cirq
from cirq.protocols.act_on_protocol_test import DummyActOnArgs

X = np.array([[0, 1], [1, 0]])
Y = np.array([[0, -1j], [1j, 0]])
Expand Down Expand Up @@ -478,26 +479,26 @@ def test_reset_channel_text_diagram():

def test_reset_act_on():
with pytest.raises(TypeError, match="Failed to act"):
cirq.act_on(cirq.ResetChannel(), object())
cirq.act_on(cirq.ResetChannel(), DummyActOnArgs(), qubits=())

args = cirq.ActOnStateVectorArgs(
target_tensor=cirq.one_hot(
index=(1, 1, 1, 1, 1), shape=(2, 2, 2, 2, 2), dtype=np.complex64
),
available_buffer=np.empty(shape=(2, 2, 2, 2, 2)),
axes=[1],
qubits=cirq.LineQubit.range(5),
prng=np.random.RandomState(),
log_of_measurement_results={},
)

cirq.act_on(cirq.ResetChannel(), args)
cirq.act_on(cirq.ResetChannel(), args, [cirq.LineQubit(1)])
assert args.log_of_measurement_results == {}
np.testing.assert_allclose(
args.target_tensor,
cirq.one_hot(index=(1, 0, 1, 1, 1), shape=(2, 2, 2, 2, 2), dtype=np.complex64),
)

cirq.act_on(cirq.ResetChannel(), args)
cirq.act_on(cirq.ResetChannel(), args, [cirq.LineQubit(1)])
assert args.log_of_measurement_results == {}
np.testing.assert_allclose(
args.target_tensor,
Expand Down Expand Up @@ -693,7 +694,7 @@ def test_bit_flip_channel_text_diagram():
def test_stabilizer_supports_depolarize():
with pytest.raises(TypeError, match="act_on"):
for _ in range(100):
cirq.act_on(cirq.depolarize(3 / 4), object())
cirq.act_on(cirq.depolarize(3 / 4), DummyActOnArgs(), qubits=())

q = cirq.LineQubit(0)
c = cirq.Circuit(cirq.depolarize(3 / 4).on(q), cirq.measure(q, key='m'))
Expand Down
52 changes: 26 additions & 26 deletions cirq-core/cirq/ops/common_gates.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,10 @@
"""


def _act_with_gates(args, *gates: 'cirq.SupportsActOn') -> None:
def _act_with_gates(args, qubits, *gates: 'cirq.SupportsActOnQubits') -> None:
"""Act on the given args with the given gates in order."""
for gate in gates:
assert gate._act_on_(args)
assert gate._act_on_(args, qubits)


def _pi(rads):
Expand Down Expand Up @@ -108,14 +108,14 @@ def _apply_unitary_(self, args: 'protocols.ApplyUnitaryArgs') -> Optional[np.nda
args.available_buffer *= p
return args.available_buffer

def _act_on_(self, args: Any):
def _act_on_(self, args: 'cirq.ActOnArgs', qubits: Sequence['cirq.Qid']):
from cirq.sim import clifford

if isinstance(args, clifford.ActOnCliffordTableauArgs):
if not protocols.has_stabilizer_effect(self):
return NotImplemented
tableau = args.tableau
q = args.axes[0]
q = args.qubit_map[qubits[0]]
effective_exponent = self._exponent % 2
if effective_exponent == 0.5:
tableau.xs[:, q] ^= tableau.zs[:, q]
Expand All @@ -130,7 +130,7 @@ def _act_on_(self, args: Any):
if isinstance(args, clifford.ActOnStabilizerCHFormArgs):
if not protocols.has_stabilizer_effect(self):
return NotImplemented
_act_with_gates(args, H, ZPowGate(exponent=self._exponent), H)
_act_with_gates(args, qubits, H, ZPowGate(exponent=self._exponent), H)
# Adjust the global phase based on the global_shift parameter.
args.state.omega *= np.exp(1j * np.pi * self.global_shift * self.exponent)
return True
Expand Down Expand Up @@ -360,14 +360,14 @@ def _apply_unitary_(self, args: 'protocols.ApplyUnitaryArgs') -> Optional[np.nda
args.available_buffer *= p
return args.available_buffer

def _act_on_(self, args: Any):
def _act_on_(self, args: 'cirq.ActOnArgs', qubits: Sequence['cirq.Qid']):
from cirq.sim import clifford

if isinstance(args, clifford.ActOnCliffordTableauArgs):
if not protocols.has_stabilizer_effect(self):
return NotImplemented
tableau = args.tableau
q = args.axes[0]
q = args.qubit_map[qubits[0]]
effective_exponent = self._exponent % 2
if effective_exponent == 0.5:
tableau.rs[:] ^= tableau.xs[:, q] & (~tableau.zs[:, q])
Expand All @@ -392,13 +392,13 @@ def _act_on_(self, args: Any):
state = args.state
Z = ZPowGate()
if effective_exponent == 0.5:
_act_with_gates(args, Z, H)
_act_with_gates(args, qubits, Z, H)
state.omega *= (1 + 1j) / (2 ** 0.5)
elif effective_exponent == 1:
_act_with_gates(args, Z, H, Z, H)
_act_with_gates(args, qubits, Z, H, Z, H)
state.omega *= 1j
elif effective_exponent == 1.5:
_act_with_gates(args, H, Z)
_act_with_gates(args, qubits, H, Z)
state.omega *= (1 - 1j) / (2 ** 0.5)
# Adjust the global phase based on the global_shift parameter.
args.state.omega *= np.exp(1j * np.pi * self.global_shift * self.exponent)
Expand Down Expand Up @@ -579,14 +579,14 @@ def _apply_unitary_(self, args: 'protocols.ApplyUnitaryArgs') -> Optional[np.nda
args.target_tensor *= p
return args.target_tensor

def _act_on_(self, args: Any):
def _act_on_(self, args: 'cirq.ActOnArgs', qubits: Sequence['cirq.Qid']):
from cirq.sim import clifford

if isinstance(args, clifford.ActOnCliffordTableauArgs):
if not protocols.has_stabilizer_effect(self):
return NotImplemented
tableau = args.tableau
q = args.axes[0]
q = args.qubit_map[qubits[0]]
effective_exponent = self._exponent % 2
if effective_exponent == 0.5:
tableau.rs[:] ^= tableau.xs[:, q] & tableau.zs[:, q]
Expand All @@ -601,7 +601,7 @@ def _act_on_(self, args: Any):
if isinstance(args, clifford.ActOnStabilizerCHFormArgs):
if not protocols.has_stabilizer_effect(self):
return NotImplemented
q = args.axes[0]
q = args.qubit_map[qubits[0]]
effective_exponent = self._exponent % 2
state = args.state
for _ in range(int(effective_exponent * 2)):
Expand Down Expand Up @@ -896,14 +896,14 @@ def _apply_unitary_(self, args: 'protocols.ApplyUnitaryArgs') -> Optional[np.nda
args.target_tensor *= np.sqrt(2) * p
return args.target_tensor

def _act_on_(self, args: Any):
def _act_on_(self, args: 'cirq.ActOnArgs', qubits: Sequence['cirq.Qid']):
from cirq.sim import clifford

if isinstance(args, clifford.ActOnCliffordTableauArgs):
if not protocols.has_stabilizer_effect(self):
return NotImplemented
tableau = args.tableau
q = args.axes[0]
q = args.qubit_map[qubits[0]]
if self._exponent % 2 == 1:
(tableau.xs[:, q], tableau.zs[:, q]) = (
tableau.zs[:, q].copy(),
Expand All @@ -915,7 +915,7 @@ def _act_on_(self, args: Any):
if isinstance(args, clifford.ActOnStabilizerCHFormArgs):
if not protocols.has_stabilizer_effect(self):
return NotImplemented
q = args.axes[0]
q = args.qubit_map[qubits[0]]
state = args.state
if self._exponent % 2 == 1:
# Prescription for H left multiplication
Expand Down Expand Up @@ -1059,15 +1059,15 @@ def _apply_unitary_(
args.target_tensor *= p
return args.target_tensor

def _act_on_(self, args: Any):
def _act_on_(self, args: 'cirq.ActOnArgs', qubits: Sequence['cirq.Qid']):
from cirq.sim import clifford

if isinstance(args, clifford.ActOnCliffordTableauArgs):
if not protocols.has_stabilizer_effect(self):
return NotImplemented
tableau = args.tableau
q1 = args.axes[0]
q2 = args.axes[1]
q1 = args.qubit_map[qubits[0]]
q2 = args.qubit_map[qubits[1]]
if self._exponent % 2 == 1:
(tableau.xs[:, q2], tableau.zs[:, q2]) = (
tableau.zs[:, q2].copy(),
Expand All @@ -1088,8 +1088,8 @@ def _act_on_(self, args: Any):
if isinstance(args, clifford.ActOnStabilizerCHFormArgs):
if not protocols.has_stabilizer_effect(self):
return NotImplemented
q1 = args.axes[0]
q2 = args.axes[1]
q1 = args.qubit_map[qubits[0]]
q2 = args.qubit_map[qubits[1]]
state = args.state
if self._exponent % 2 == 1:
# Prescription for CZ left multiplication.
Expand Down Expand Up @@ -1282,15 +1282,15 @@ def _apply_unitary_(self, args: 'protocols.ApplyUnitaryArgs') -> Optional[np.nda
args.target_tensor *= p
return args.target_tensor

def _act_on_(self, args: Any):
def _act_on_(self, args: 'cirq.ActOnArgs', qubits: Sequence['cirq.Qid']):
from cirq.sim import clifford

if isinstance(args, clifford.ActOnCliffordTableauArgs):
if not protocols.has_stabilizer_effect(self):
return NotImplemented
tableau = args.tableau
q1 = args.axes[0]
q2 = args.axes[1]
q1 = args.qubit_map[qubits[0]]
q2 = args.qubit_map[qubits[1]]
if self._exponent % 2 == 1:
tableau.rs[:] ^= (
tableau.xs[:, q1]
Expand All @@ -1304,8 +1304,8 @@ def _act_on_(self, args: Any):
if isinstance(args, clifford.ActOnStabilizerCHFormArgs):
if not protocols.has_stabilizer_effect(self):
return NotImplemented
q1 = args.axes[0]
q2 = args.axes[1]
q1 = args.qubit_map[qubits[0]]
q2 = args.qubit_map[qubits[1]]
state = args.state
if self._exponent % 2 == 1:
# Prescription for CX left multiplication.
Expand Down

0 comments on commit 65711e1

Please sign in to comment.