diff --git a/cirq-core/cirq/__init__.py b/cirq-core/cirq/__init__.py index 69cedaa26b2..0672592530e 100644 --- a/cirq-core/cirq/__init__.py +++ b/cirq-core/cirq/__init__.py @@ -362,6 +362,8 @@ decompose_multi_controlled_x, decompose_multi_controlled_rotation, decompose_two_qubit_interaction_into_four_fsim_gates, + defer_measurements, + dephase_measurements, drop_empty_moments, drop_negligible_operations, eject_phased_paulis, diff --git a/cirq-core/cirq/ops/kraus_channel.py b/cirq-core/cirq/ops/kraus_channel.py index 6dadbd5872b..cf3feb3a638 100644 --- a/cirq-core/cirq/ops/kraus_channel.py +++ b/cirq-core/cirq/ops/kraus_channel.py @@ -55,7 +55,7 @@ def __init__( self._key = key @staticmethod - def from_channel(channel: 'KrausChannel', key: Union[str, 'cirq.MeasurementKey', None] = None): + def from_channel(channel: 'cirq.Gate', key: Union[str, 'cirq.MeasurementKey', None] = None): """Creates a copy of a channel with the given measurement key.""" return KrausChannel(kraus_ops=list(protocols.kraus(channel)), key=key) diff --git a/cirq-core/cirq/ops/tags.py b/cirq-core/cirq/ops/tags.py index fec819a88e0..f51d2962b51 100644 --- a/cirq-core/cirq/ops/tags.py +++ b/cirq-core/cirq/ops/tags.py @@ -34,3 +34,6 @@ def __repr__(self) -> str: def _json_dict_(self) -> Dict[str, str]: return {} + + def __hash__(self): + return hash(VirtualTag) diff --git a/cirq-core/cirq/sim/mux.py b/cirq-core/cirq/sim/mux.py index 8c70ebea46e..2095abe3113 100644 --- a/cirq-core/cirq/sim/mux.py +++ b/cirq-core/cirq/sim/mux.py @@ -25,6 +25,7 @@ from cirq._doc import document from cirq.sim import sparse_simulator, density_matrix_simulator from cirq.sim.clifford import clifford_simulator +from cirq.transformers import measurement_transformers if TYPE_CHECKING: import cirq @@ -281,9 +282,10 @@ def final_density_matrix( dtype=dtype, noise=noise, seed=seed, - ignore_measurement_results=(ignore_measurement_results), ).simulate( - program=circuit_like, + program=measurement_transformers.dephase_measurements(circuit_like) + if ignore_measurement_results + else circuit_like, initial_state=initial_state, qubit_order=qubit_order, param_resolver=param_resolver, diff --git a/cirq-core/cirq/transformers/__init__.py b/cirq-core/cirq/transformers/__init__.py index d0a679f4aa4..0fa73cafc9f 100644 --- a/cirq-core/cirq/transformers/__init__.py +++ b/cirq-core/cirq/transformers/__init__.py @@ -53,6 +53,11 @@ from cirq.transformers.eject_z import eject_z +from cirq.transformers.measurement_transformers import ( + defer_measurements, + dephase_measurements, +) + from cirq.transformers.synchronize_terminal_measurements import synchronize_terminal_measurements from cirq.transformers.transformer_api import ( diff --git a/cirq-core/cirq/transformers/measurement_transformers.py b/cirq-core/cirq/transformers/measurement_transformers.py new file mode 100644 index 00000000000..1c1323e0242 --- /dev/null +++ b/cirq-core/cirq/transformers/measurement_transformers.py @@ -0,0 +1,177 @@ +# Copyright 2022 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, List, Optional, TYPE_CHECKING, Union + +from cirq import ops, protocols, value +from cirq.transformers import ( + transformer_api, + transformer_primitives, +) +from cirq.transformers.synchronize_terminal_measurements import find_terminal_measurements + +if TYPE_CHECKING: + import cirq + + +class _MeasurementQid(ops.Qid): + """A qubit that substitutes in for a deferred measurement. + + Exactly one qubit will be created per qubit in the measurement gate. + """ + + def __init__(self, key: Union[str, 'cirq.MeasurementKey'], qid: 'cirq.Qid'): + """Initializes the qubit. + + Args: + key: The key of the measurement gate being deferred. + qid: One qubit that is being measured. Each deferred measurement + should create one new _MeasurementQid per qubit being measured + by that gate. + """ + self._key = value.MeasurementKey.parse_serialized(key) if isinstance(key, str) else key + self._qid = qid + + @property + def dimension(self) -> int: + return self._qid.dimension + + def _comparison_key(self) -> Any: + return (str(self._key), self._qid._comparison_key()) + + def __str__(self) -> str: + return f"M('{self._key}', q={self._qid})" + + def __repr__(self) -> str: + return f'_MeasurementQid({self._key!r}, {self._qid!r})' + + +@transformer_api.transformer +def defer_measurements( + circuit: 'cirq.AbstractCircuit', *, context: Optional['cirq.TransformerContext'] = None +) -> 'cirq.Circuit': + """Implements the Deferred Measurement Principle. + + Uses the Deferred Measurement Principle to move all measurements to the + end of the circuit. All non-terminal measurements are changed to + conditional quantum gates onto ancilla qubits, and classically controlled + operations are transformed to quantum controls from those ancilla qubits. + Finally, measurements of all ancilla qubits are appended to the end of the + circuit. + + Optimizing deferred measurements is an area of active research, and future + iterations may contain optimizations that reduce the number of ancilla + qubits, so one should not depend on the exact shape of the output from this + function. Only the logical equivalence is guaranteed to remain unchanged. + Moment and subcircuit structure is not preserved. + + Args: + circuit: The circuit to transform. It will not be modified. + context: `cirq.TransformerContext` storing common configurable options + for transformers. + Returns: + A circuit with equivalent logic, but all measurements at the end of the + circuit. + Raises: + ValueError: If sympy-based classical conditions are used, or if + conditions based on multi-qubit measurements exist. (The latter of + these is planned to be implemented soon). + """ + + circuit = transformer_primitives.unroll_circuit_op(circuit, deep=True, tags_to_check=None) + terminal_measurements = {op for _, op in find_terminal_measurements(circuit)} + measurement_qubits: Dict['cirq.MeasurementKey', List['_MeasurementQid']] = {} + + def defer(op: 'cirq.Operation', _) -> 'cirq.OP_TREE': + if op in terminal_measurements: + return op + gate = op.gate + if isinstance(gate, ops.MeasurementGate): + key = value.MeasurementKey.parse_serialized(gate.key) + targets = [_MeasurementQid(key, q) for q in op.qubits] + measurement_qubits[key] = targets + cxs = [ops.CX(q, target) for q, target in zip(op.qubits, targets)] + xs = [ops.X(targets[i]) for i, b in enumerate(gate.full_invert_mask()) if b] + return cxs + xs + elif protocols.is_measurement(op): + return [defer(op, None) for op in protocols.decompose_once(op)] + elif op.classical_controls: + controls = [] + for c in op.classical_controls: + if isinstance(c, value.KeyCondition): + if c.key not in measurement_qubits: + raise ValueError(f'Deferred measurement for key={c.key} not found.') + qubits = measurement_qubits[c.key] + if len(qubits) != 1: + # TODO: Multi-qubit conditions require + # https://github.com/quantumlib/Cirq/issues/4512 + # Remember to update docstring above once this works. + raise ValueError('Only single qubit conditions are allowed.') + controls.extend(qubits) + else: + raise ValueError('Only KeyConditions are allowed.') + return op.without_classical_controls().controlled_by( + *controls, control_values=[tuple(range(1, q.dimension)) for q in controls] + ) + return op + + circuit = transformer_primitives.map_operations_and_unroll( + circuit=circuit, + map_func=defer, + tags_to_ignore=context.tags_to_ignore if context else (), + raise_if_add_qubits=False, + ).unfreeze() + for k, qubits in measurement_qubits.items(): + circuit.append(ops.measure(*qubits, key=k)) + return circuit + + +@transformer_api.transformer +def dephase_measurements( + circuit: 'cirq.AbstractCircuit', *, context: Optional['cirq.TransformerContext'] = None +) -> 'cirq.Circuit': + """Changes all measurements to a dephase operation. + + This transformer is useful when using a density matrix simulator, when + wishing to calculate the final density matrix of a circuit and not simulate + the measurements themselves. + + Args: + circuit: The circuit to transform. It will not be modified. + context: `cirq.TransformerContext` storing common configurable options + for transformers. + Returns: + A copy of the circuit, with dephase operations in place of all + measurements. + Raises: + ValueError: If the circuit contains classical controls. In this case, + it is required to change these to quantum controls via + `cirq.defer_measurements` first. Since deferral adds ancilla qubits + to the circuit, this is not done automatically, to prevent + surprises. + """ + + def dephase(op: 'cirq.Operation', _) -> 'cirq.OP_TREE': + gate = op.gate + if isinstance(gate, ops.MeasurementGate): + key = value.MeasurementKey.parse_serialized(gate.key) + return ops.KrausChannel.from_channel(ops.phase_damp(1), key=key).on_each(op.qubits) + elif isinstance(op, ops.ClassicallyControlledOperation): + raise ValueError('Use cirq.defer_measurements first to remove classical controls.') + return op + + ignored = () if context is None else context.tags_to_ignore + return transformer_primitives.map_operations( + circuit, dephase, deep=True, tags_to_ignore=ignored + ).unfreeze() diff --git a/cirq-core/cirq/transformers/measurement_transformers_test.py b/cirq-core/cirq/transformers/measurement_transformers_test.py new file mode 100644 index 00000000000..05a6f0a087d --- /dev/null +++ b/cirq-core/cirq/transformers/measurement_transformers_test.py @@ -0,0 +1,376 @@ +# Copyright 2022 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pytest +import sympy + +import cirq +from cirq.transformers.measurement_transformers import _MeasurementQid + + +def assert_equivalent_to_deferred(circuit: cirq.Circuit): + qubits = list(circuit.all_qubits()) + sim = cirq.Simulator() + num_qubits = len(qubits) + for i in range(2 ** num_qubits): + bits = cirq.big_endian_int_to_bits(i, bit_count=num_qubits) + modified = cirq.Circuit() + for j in range(num_qubits): + if bits[j]: + modified.append(cirq.X(qubits[j])) + modified.append(circuit) + deferred = cirq.defer_measurements(modified) + result = sim.simulate(modified) + result1 = sim.simulate(deferred) + np.testing.assert_equal(result.measurements, result1.measurements) + + +def assert_equivalent_to_dephased(circuit: cirq.Circuit): + qubits = list(circuit.all_qubits()) + sim = cirq.DensityMatrixSimulator(ignore_measurement_results=True) + num_qubits = len(qubits) + backwards = list(circuit.all_operations())[::-1] + for j in range(num_qubits): + backwards.append(cirq.H(qubits[j]) ** np.random.rand()) + modified = cirq.Circuit(backwards[::-1]) + for j in range(num_qubits): + modified.append(cirq.H(qubits[j]) ** np.random.rand()) + dephased = cirq.dephase_measurements(modified) + result = sim.simulate(modified) + result1 = sim.simulate(dephased) + np.testing.assert_almost_equal(result.final_density_matrix, result1.final_density_matrix) + + +def test_basic(): + q0, q1 = cirq.LineQubit.range(2) + circuit = cirq.Circuit( + cirq.measure(q0, key='a'), + cirq.X(q1).with_classical_controls('a'), + cirq.measure(q1, key='b'), + ) + assert_equivalent_to_deferred(circuit) + deferred = cirq.defer_measurements(circuit) + q_ma = _MeasurementQid('a', q0) + cirq.testing.assert_same_circuits( + deferred, + cirq.Circuit( + cirq.CX(q0, q_ma), + cirq.CX(q_ma, q1), + cirq.measure(q_ma, key='a'), + cirq.measure(q1, key='b'), + ), + ) + + +def test_nocompile_context(): + q0, q1 = cirq.LineQubit.range(2) + circuit = cirq.Circuit( + cirq.measure(q0, key='a').with_tags('nocompile'), + cirq.X(q1).with_classical_controls('a').with_tags('nocompile'), + cirq.measure(q1, key='b'), + ) + deferred = cirq.defer_measurements( + circuit, context=cirq.TransformerContext(tags_to_ignore=('nocompile',)) + ) + cirq.testing.assert_same_circuits(deferred, circuit) + + +def test_nocompile_context_leaves_invalid_circuit(): + q0, q1 = cirq.LineQubit.range(2) + circuit = cirq.Circuit( + cirq.measure(q0, key='a').with_tags('nocompile'), + cirq.X(q1).with_classical_controls('a'), + cirq.measure(q1, key='b'), + ) + with pytest.raises(ValueError, match='Deferred measurement for key=a not found'): + _ = cirq.defer_measurements( + circuit, context=cirq.TransformerContext(tags_to_ignore=('nocompile',)) + ) + + +def test_pauli(): + q0, q1 = cirq.LineQubit.range(2) + circuit = cirq.Circuit( + cirq.PauliMeasurementGate(cirq.DensePauliString('Y'), key='a').on(q0), + cirq.X(q1).with_classical_controls('a'), + cirq.measure(q1, key='b'), + ) + deferred = cirq.defer_measurements(circuit) + q_ma = _MeasurementQid('a', q0) + cirq.testing.assert_same_circuits( + cirq.unroll_circuit_op(deferred), + cirq.Circuit( + cirq.SingleQubitCliffordGate.X_sqrt(q0), + cirq.CX(q0, q_ma), + (cirq.SingleQubitCliffordGate.X_sqrt(q0) ** -1), + cirq.Moment(cirq.CX(q_ma, q1)), + cirq.measure(q_ma, key='a'), + cirq.measure(q1, key='b'), + ), + ) + + +def test_extra_measurements(): + q0, q1 = cirq.LineQubit.range(2) + circuit = cirq.Circuit( + cirq.measure(q0, key='a'), + cirq.measure(q0, key='b'), + cirq.X(q1).with_classical_controls('a'), + cirq.measure(q1, key='c'), + ) + assert_equivalent_to_deferred(circuit) + deferred = cirq.defer_measurements(circuit) + q_ma = _MeasurementQid('a', q0) + cirq.testing.assert_same_circuits( + deferred, + cirq.Circuit( + cirq.CX(q0, q_ma), + cirq.CX(q_ma, q1), + cirq.measure(q_ma, key='a'), + cirq.measure(q0, key='b'), + cirq.measure(q1, key='c'), + ), + ) + + +def test_extra_controlled_bits(): + q0, q1 = cirq.LineQubit.range(2) + circuit = cirq.Circuit( + cirq.measure(q0, key='a'), + cirq.CX(q0, q1).with_classical_controls('a'), + cirq.measure(q1, key='b'), + ) + assert_equivalent_to_deferred(circuit) + deferred = cirq.defer_measurements(circuit) + q_ma = _MeasurementQid('a', q0) + cirq.testing.assert_same_circuits( + deferred, + cirq.Circuit( + cirq.CX(q0, q_ma), + cirq.CCX(q_ma, q0, q1), + cirq.measure(q_ma, key='a'), + cirq.measure(q1, key='b'), + ), + ) + + +def test_extra_control_bits(): + q0, q1 = cirq.LineQubit.range(2) + circuit = cirq.Circuit( + cirq.measure(q0, key='a'), + cirq.measure(q0, key='b'), + cirq.X(q1).with_classical_controls('a', 'b'), + cirq.measure(q1, key='c'), + ) + assert_equivalent_to_deferred(circuit) + deferred = cirq.defer_measurements(circuit) + q_ma = _MeasurementQid('a', q0) + q_mb = _MeasurementQid('b', q0) + cirq.testing.assert_same_circuits( + deferred, + cirq.Circuit( + cirq.CX(q0, q_ma), + cirq.CX(q0, q_mb), + cirq.CCX(q_ma, q_mb, q1), + cirq.measure(q_ma, key='a'), + cirq.measure(q_mb, key='b'), + cirq.measure(q1, key='c'), + ), + ) + + +def test_subcircuit(): + q0, q1 = cirq.LineQubit.range(2) + circuit = cirq.Circuit( + cirq.CircuitOperation( + cirq.FrozenCircuit( + cirq.measure(q0, key='a'), + cirq.X(q1).with_classical_controls('a'), + cirq.measure(q1, key='b'), + ) + ) + ) + assert_equivalent_to_deferred(circuit) + deferred = cirq.defer_measurements(circuit) + q_m = _MeasurementQid('a', q0) + cirq.testing.assert_same_circuits( + deferred, + cirq.Circuit( + cirq.CX(q0, q_m), + cirq.CX(q_m, q1), + cirq.measure(q_m, key='a'), + cirq.measure(q1, key='b'), + ), + ) + + +def test_multi_qubit_measurements(): + q0, q1 = cirq.LineQubit.range(2) + circuit = cirq.Circuit( + cirq.measure(q0, q1, key='a'), + cirq.X(q0), + cirq.measure(q0, key='b'), + cirq.measure(q1, key='c'), + ) + assert_equivalent_to_deferred(circuit) + deferred = cirq.defer_measurements(circuit) + q_ma0 = _MeasurementQid('a', q0) + q_ma1 = _MeasurementQid('a', q1) + cirq.testing.assert_same_circuits( + deferred, + cirq.Circuit( + cirq.CX(q0, q_ma0), + cirq.CX(q1, q_ma1), + cirq.X(q0), + cirq.measure(q_ma0, q_ma1, key='a'), + cirq.measure(q0, key='b'), + cirq.measure(q1, key='c'), + ), + ) + + +def test_diagram(): + q0, q1, q2, q3 = cirq.LineQubit.range(4) + circuit = cirq.Circuit( + cirq.measure(q0, q2, key='a'), + cirq.measure(q1, q3, key='b'), + cirq.X(q0), + cirq.measure(q0, q1, q2, q3, key='c'), + ) + deferred = cirq.defer_measurements(circuit) + cirq.testing.assert_has_diagram( + deferred, + """ + ┌────┐ +0: ──────────────@───────X────────M('c')─── + │ │ +1: ──────────────┼─@──────────────M──────── + │ │ │ +2: ──────────────┼@┼──────────────M──────── + │││ │ +3: ──────────────┼┼┼@─────────────M──────── + ││││ +M('a', q=0): ────X┼┼┼────M('a')──────────── + │││ │ +M('a', q=2): ─────X┼┼────M───────────────── + ││ +M('b', q=1): ──────X┼────M('b')──────────── + │ │ +M('b', q=3): ───────X────M───────────────── + └────┘ +""", + use_unicode_characters=True, + ) + + +def test_repr(): + def test_repr(qid: _MeasurementQid): + cirq.testing.assert_equivalent_repr(qid, global_vals={'_MeasurementQid': _MeasurementQid}) + + test_repr(_MeasurementQid('a', cirq.LineQubit(0))) + test_repr(_MeasurementQid('a', cirq.NamedQubit('x'))) + test_repr(_MeasurementQid('a', cirq.NamedQid('x', 4))) + test_repr(_MeasurementQid('a', cirq.GridQubit(2, 3))) + test_repr(_MeasurementQid('0:1:a', cirq.LineQid(9, 4))) + + +def test_multi_qubit_control(): + q0, q1 = cirq.LineQubit.range(2) + circuit = cirq.Circuit( + cirq.measure(q0, q1, key='a'), + cirq.X(q1).with_classical_controls('a'), + ) + with pytest.raises(ValueError, match='Only single qubit conditions are allowed'): + _ = cirq.defer_measurements(circuit) + + +def test_sympy_control(): + q0, q1 = cirq.LineQubit.range(2) + circuit = cirq.Circuit( + cirq.measure(q0, q1, key='a'), + cirq.X(q1).with_classical_controls(sympy.Symbol('a')), + ) + with pytest.raises(ValueError, match='Only KeyConditions are allowed'): + _ = cirq.defer_measurements(circuit) + + +def test_dephase(): + q0, q1 = cirq.LineQubit.range(2) + circuit = cirq.Circuit( + cirq.CircuitOperation( + cirq.FrozenCircuit( + cirq.CX(q1, q0), + cirq.measure(q0, key='a'), + cirq.CX(q0, q1), + cirq.measure(q1, key='b'), + ) + ) + ) + assert_equivalent_to_dephased(circuit) + dephased = cirq.dephase_measurements(circuit) + cirq.testing.assert_same_circuits( + dephased, + cirq.Circuit( + cirq.CircuitOperation( + cirq.FrozenCircuit( + cirq.CX(q1, q0), + cirq.KrausChannel.from_channel(cirq.phase_damp(1), key='a')(q0), + cirq.CX(q0, q1), + cirq.KrausChannel.from_channel(cirq.phase_damp(1), key='b')(q1), + ) + ) + ), + ) + + +def test_dephase_classical_conditions(): + q0, q1 = cirq.LineQubit.range(2) + circuit = cirq.Circuit( + cirq.measure(q0, key='a'), + cirq.X(q1).with_classical_controls('a'), + cirq.measure(q1, key='b'), + ) + with pytest.raises(ValueError, match='defer_measurements first to remove classical controls'): + _ = cirq.dephase_measurements(circuit) + + +def test_dephase_nocompile_context(): + q0, q1 = cirq.LineQubit.range(2) + circuit = cirq.Circuit( + cirq.CircuitOperation( + cirq.FrozenCircuit( + cirq.CX(q1, q0), + cirq.measure(q0, key='a').with_tags('nocompile'), + cirq.CX(q0, q1), + cirq.measure(q1, key='b'), + ) + ) + ) + dephased = cirq.dephase_measurements( + circuit, context=cirq.TransformerContext(tags_to_ignore=('nocompile',)) + ) + cirq.testing.assert_same_circuits( + dephased, + cirq.Circuit( + cirq.CircuitOperation( + cirq.FrozenCircuit( + cirq.CX(q1, q0), + cirq.measure(q0, key='a').with_tags('nocompile'), + cirq.CX(q0, q1), + cirq.KrausChannel.from_channel(cirq.phase_damp(1), key='b')(q1), + ) + ) + ), + ) diff --git a/cirq-core/cirq/transformers/synchronize_terminal_measurements.py b/cirq-core/cirq/transformers/synchronize_terminal_measurements.py index 6e5fecbf663..f96f946009f 100644 --- a/cirq-core/cirq/transformers/synchronize_terminal_measurements.py +++ b/cirq-core/cirq/transformers/synchronize_terminal_measurements.py @@ -44,7 +44,6 @@ def find_terminal_measurements( moment = circuit[i] for q in open_qubits: op = moment.operation_at(q) - seen_control_keys |= protocols.control_keys(op) if ( op is not None and open_qubits.issuperset(op.qubits) @@ -53,6 +52,7 @@ def find_terminal_measurements( ): terminal_measurements.append((i, op)) open_qubits -= moment.qubits + seen_control_keys |= protocols.control_keys(moment) if not open_qubits: break return terminal_measurements