Skip to content

Commit

Permalink
Adds PauliMeasurementGate (#4444)
Browse files Browse the repository at this point in the history
* Add PauliMeasurementGate

* Fix lint errors

* Address comments

* fix lint

* change measure_* names to measure_single_paulistring and measure_paulistring_terms
  • Loading branch information
tanujkhattar committed Sep 7, 2021
1 parent 9349802 commit 676431d
Show file tree
Hide file tree
Showing 9 changed files with 475 additions and 1 deletion.
3 changes: 3 additions & 0 deletions cirq-core/cirq/__init__.py
Expand Up @@ -232,6 +232,8 @@
MixedUnitaryChannel,
measure,
measure_each,
measure_paulistring_terms,
measure_single_paulistring,
MeasurementGate,
Moment,
MutableDensePauliString,
Expand All @@ -247,6 +249,7 @@
PAULI_GATE_LIKE,
PAULI_STRING_LIKE,
PauliInteractionGate,
PauliMeasurementGate,
PauliString,
PauliStringGateOperation,
PauliStringPhasor,
Expand Down
1 change: 1 addition & 0 deletions cirq-core/cirq/json_resolver_cache.py
Expand Up @@ -114,6 +114,7 @@ def two_qubit_matrix_gate(matrix):
'ParamResolver': cirq.ParamResolver,
'ParallelGateOperation': cirq.ParallelGateOperation,
'ParallelGate': cirq.ParallelGate,
'PauliMeasurementGate': cirq.PauliMeasurementGate,
'PauliString': cirq.PauliString,
'PhaseDampingChannel': cirq.PhaseDampingChannel,
'PhaseFlipChannel': cirq.PhaseFlipChannel,
Expand Down
6 changes: 6 additions & 0 deletions cirq-core/cirq/ops/__init__.py
Expand Up @@ -140,6 +140,10 @@
PauliSumExponential,
)

from cirq.ops.pauli_measurement_gate import (
PauliMeasurementGate,
)

from cirq.ops.parallel_gate import ParallelGate, parallel_gate_op

from cirq.ops.parallel_gate_operation import (
Expand Down Expand Up @@ -169,6 +173,8 @@
from cirq.ops.measure_util import (
measure,
measure_each,
measure_paulistring_terms,
measure_single_paulistring,
)

from cirq.ops.measurement_gate import (
Expand Down
55 changes: 54 additions & 1 deletion cirq-core/cirq/ops/measure_util.py
Expand Up @@ -17,8 +17,9 @@
import numpy as np

from cirq import protocols, value
from cirq.ops import raw_types
from cirq.ops import raw_types, pauli_string
from cirq.ops.measurement_gate import MeasurementGate
from cirq.ops.pauli_measurement_gate import PauliMeasurementGate

if TYPE_CHECKING:
import cirq
Expand All @@ -28,6 +29,58 @@ def _default_measurement_key(qubits: Iterable[raw_types.Qid]) -> str:
return ','.join(str(q) for q in qubits)


def measure_single_paulistring(
pauli_observable: pauli_string.PauliString,
key: Optional[Union[str, value.MeasurementKey]] = None,
) -> raw_types.Operation:
"""Returns a single PauliMeasurementGate which measures the pauli observable
Args:
pauli_observable: The `cirq.PauliString` observable to measure.
key: Optional `str` or `cirq.MeasurementKey` that gate should use.
If none provided, it defaults to a comma-separated list of the
target qubits' str values.
Returns:
An operation measuring the pauli observable.
Raises:
ValueError: if the observable is not an instance of PauliString.
"""
if not isinstance(pauli_observable, pauli_string.PauliString):
raise ValueError(
f'Pauli observable {pauli_observable} should be an instance of cirq.PauliString.'
)
if key is None:
key = _default_measurement_key(pauli_observable)
return PauliMeasurementGate(pauli_observable.values(), key).on(*pauli_observable.keys())


def measure_paulistring_terms(
pauli_basis: pauli_string.PauliString, key_func: Callable[[raw_types.Qid], str] = str
) -> List[raw_types.Operation]:
"""Returns a list of operations individually measuring qubits in the pauli basis.
Args:
pauli_basis: The `cirq.PauliString` basis in which each qubit should
be measured.
key_func: Determines the key of the measurements of each qubit. Takes
the qubit and returns the key for that qubit. Defaults to str.
Returns:
A list of operations individually measuring the given qubits in the
specified pauli basis.
Raises:
ValueError: if `pauli_basis` is not an instance of `cirq.PauliString`.
"""
if not isinstance(pauli_basis, pauli_string.PauliString):
raise ValueError(
f'Pauli observable {pauli_basis} should be an instance of cirq.PauliString.'
)
return [PauliMeasurementGate([pauli_basis[q]], key=key_func(q)).on(q) for q in pauli_basis]


def measure(
*target: 'cirq.Qid',
key: Optional[Union[str, value.MeasurementKey]] = None,
Expand Down
36 changes: 36 additions & 0 deletions cirq-core/cirq/ops/measure_util_test.py
Expand Up @@ -60,3 +60,39 @@ def test_measure_each():
cirq.measure(a, key='a!'),
cirq.measure(b, key='b!'),
]


def test_measure_single_paulistring():
# Correct application
q = cirq.LineQubit.range(3)
ps = cirq.X(q[0]) * cirq.Y(q[1]) * cirq.Z(q[2])
assert cirq.measure_single_paulistring(ps, key='a') == cirq.PauliMeasurementGate(
ps.values(), key='a'
).on(*ps.keys())

# Empty application
with pytest.raises(ValueError, match='should be an instance of cirq.PauliString'):
_ = cirq.measure_single_paulistring(cirq.I(q[0]) * cirq.I(q[1]))

# Wrong type
with pytest.raises(ValueError, match='should be an instance of cirq.PauliString'):
_ = cirq.measure_single_paulistring(q)


def test_measure_paulistring_terms():
# Correct application
q = cirq.LineQubit.range(3)
ps = cirq.X(q[0]) * cirq.Y(q[1]) * cirq.Z(q[2])
assert cirq.measure_paulistring_terms(ps) == [
cirq.PauliMeasurementGate([cirq.X], key=str(q[0])).on(q[0]),
cirq.PauliMeasurementGate([cirq.Y], key=str(q[1])).on(q[1]),
cirq.PauliMeasurementGate([cirq.Z], key=str(q[2])).on(q[2]),
]

# Empty application
with pytest.raises(ValueError, match='should be an instance of cirq.PauliString'):
_ = cirq.measure_paulistring_terms(cirq.I(q[0]) * cirq.I(q[1]))

# Wrong type
with pytest.raises(ValueError, match='should be an instance of cirq.PauliString'):
_ = cirq.measure_paulistring_terms(q)
156 changes: 156 additions & 0 deletions cirq-core/cirq/ops/pauli_measurement_gate.py
@@ -0,0 +1,156 @@
# Copyright 2021 The Cirq Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Dict, Iterable, Tuple, Sequence, TYPE_CHECKING, Union


from cirq import protocols, value
from cirq.ops import (
raw_types,
measurement_gate,
op_tree,
dense_pauli_string,
pauli_gates,
pauli_string_phasor,
)

if TYPE_CHECKING:
import cirq


@value.value_equality
class PauliMeasurementGate(raw_types.Gate):
"""A gate that measures a Pauli observable.
PauliMeasurementGate contains a key used to identify results of measurement
and a list of Paulis which denote the observable to be measured.
"""

def __init__(
self,
observable: Iterable['cirq.Pauli'],
key: Union[str, value.MeasurementKey] = '',
) -> None:
"""Inits PauliMeasurementGate.
Args:
observable: Pauli observable to measure. Any `Iterable[cirq.Pauli]`
is a valid Pauli observable, including `cirq.DensePauliString`
instances, which do not contain any identity gates.
key: The string key of the measurement.
Raises:
ValueError: If the observable is empty.
"""
if not observable:
raise ValueError(f'Pauli observable {observable} is empty.')
if not all(isinstance(p, pauli_gates.Pauli) for p in observable):
raise ValueError(f'Pauli observable {observable} must be Iterable[`cirq.Pauli`].')
self._observable = tuple(observable)
self.key = key # type: ignore

@property
def key(self) -> str:
return str(self.mkey)

@key.setter
def key(self, key: Union[str, value.MeasurementKey]) -> None:
if isinstance(key, str):
key = value.MeasurementKey(name=key)
self.mkey = key

def _qid_shape_(self) -> Tuple[int, ...]:
return (2,) * len(self._observable)

def with_key(self, key: Union[str, value.MeasurementKey]) -> 'PauliMeasurementGate':
"""Creates a pauli measurement gate with a new key but otherwise identical."""
if key == self.key:
return self
return PauliMeasurementGate(self._observable, key=key)

def _with_key_path_(self, path: Tuple[str, ...]) -> 'PauliMeasurementGate':
return self.with_key(self.mkey._with_key_path_(path))

def _with_measurement_key_mapping_(self, key_map: Dict[str, str]) -> 'PauliMeasurementGate':
return self.with_key(protocols.with_measurement_key_mapping(self.mkey, key_map))

def with_observable(self, observable: Iterable['cirq.Pauli']) -> 'PauliMeasurementGate':
"""Creates a pauli measurement gate with the new observable and same key."""
if tuple(observable) == self._observable:
return self
return PauliMeasurementGate(observable, key=self.key)

def _is_measurement_(self) -> bool:
return True

def _measurement_key_name_(self) -> str:
return self.key

def observable(self) -> 'cirq.DensePauliString':
"""Pauli observable which should be measured by the gate."""
return dense_pauli_string.DensePauliString(self._observable)

def _decompose_(
self, qubits: Tuple['cirq.Qid', ...]
) -> 'protocols.decompose_protocol.DecomposeResult':
any_qubit = qubits[0]
to_z_ops = op_tree.freeze_op_tree(self.observable().on(*qubits).to_z_basis_ops())
xor_decomp = tuple(pauli_string_phasor.xor_nonlocal_decompose(qubits, any_qubit))
yield to_z_ops
yield xor_decomp
yield measurement_gate.MeasurementGate(1, self.mkey).on(any_qubit)
yield protocols.inverse(xor_decomp)
yield protocols.inverse(to_z_ops)

def _circuit_diagram_info_(
self, args: 'cirq.CircuitDiagramInfoArgs'
) -> 'cirq.CircuitDiagramInfo':
symbols = [f'M({g})' for g in self._observable]

# Mention the measurement key.
if not args.known_qubits or self.key != _default_measurement_key(args.known_qubits):
symbols[0] += f"('{self.key}')"

return protocols.CircuitDiagramInfo(tuple(symbols))

def _op_repr_(self, qubits: Sequence['cirq.Qid']) -> str:
args = [repr(self.observable().on(*qubits))]
if self.key != _default_measurement_key(qubits):
args.append(f'key={self.mkey!r}')
arg_list = ', '.join(args)
return f'cirq.measure_single_paulistring({arg_list})'

def __repr__(self) -> str:
return f'cirq.PauliMeasurementGate(' f'{self._observable!r}, ' f'{self.mkey!r})'

def _value_equality_values_(self) -> Any:
return self.key, self._observable

def _json_dict_(self) -> Dict[str, Any]:
return {
'cirq_type': self.__class__.__name__,
'observable': self._observable,
'key': self.key,
}

@classmethod
def _from_json_dict_(cls, observable, key, **kwargs) -> 'PauliMeasurementGate':
return cls(
observable=observable,
key=value.MeasurementKey.parse_serialized(key),
)


def _default_measurement_key(qubits: Iterable[raw_types.Qid]) -> str:
return ','.join(str(q) for q in qubits)

0 comments on commit 676431d

Please sign in to comment.