Skip to content

Commit

Permalink
Support qubits in MeasurementKey and update JSON serialization of k…
Browse files Browse the repository at this point in the history
…eys (#4123)

* Phase 2b

* Fix inconsistencies in qubit keys

* Always compare mkeys

* Add test to compare cirq.measures
  • Loading branch information
smitsanghavi committed Jun 22, 2021
1 parent 7d9e603 commit 81436fe
Show file tree
Hide file tree
Showing 26 changed files with 889 additions and 74 deletions.
1 change: 1 addition & 0 deletions cirq-core/cirq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,7 @@
canonicalize_half_turns,
chosen_angle_to_canonical_half_turns,
chosen_angle_to_half_turns,
default_measurement_key_str,
Duration,
DURATION_LIKE,
GenericMetaImplementAnyOneOf,
Expand Down
12 changes: 10 additions & 2 deletions cirq-core/cirq/circuits/circuit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1366,8 +1366,16 @@ def test_findall_operations_with_gate(circuit_cls):
(3, cirq.CZ(a, b), cirq.CZ),
]
assert list(c.findall_operations_with_gate_type(cirq.MeasurementGate)) == [
(4, cirq.MeasurementGate(1, key='a').on(a), cirq.MeasurementGate(1, key='a')),
(4, cirq.MeasurementGate(1, key='b').on(b), cirq.MeasurementGate(1, key='b')),
(
4,
cirq.MeasurementGate(1).on(a),
cirq.MeasurementGate(1, key=cirq.MeasurementKey(qubits=(a,))),
),
(
4,
cirq.MeasurementGate(1).on(b),
cirq.MeasurementGate(1, key=cirq.MeasurementKey(qubits=(b,))),
),
]


Expand Down
4 changes: 4 additions & 0 deletions cirq-core/cirq/ops/gate_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,10 @@ def __str__(self) -> str:
def _json_dict_(self) -> Dict[str, Any]:
return protocols.obj_to_dict_helper(self, ['gate', 'qubits'])

@classmethod
def _from_json_dict_(cls, gate, qubits, **kwargs):
return gate.on(*qubits)

def _group_interchangeable_qubits(
self,
) -> Tuple[Union['cirq.Qid', Tuple[int, FrozenSet['cirq.Qid']]], ...]:
Expand Down
18 changes: 7 additions & 11 deletions cirq-core/cirq/ops/measure_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Callable, Iterable, List, Optional, Tuple, TYPE_CHECKING, Union
from typing import Callable, List, Tuple, TYPE_CHECKING, Union

import numpy as np

Expand All @@ -24,13 +24,9 @@
import cirq


def _default_measurement_key(qubits: Iterable[raw_types.Qid]) -> str:
return ','.join(str(q) for q in qubits)


def measure(
*target: 'cirq.Qid',
key: Optional[Union[str, value.MeasurementKey]] = None,
key: Union[str, value.MeasurementKey] = '',
invert_mask: Tuple[bool, ...] = (),
) -> raw_types.Operation:
"""Returns a single MeasurementGate applied to all the given qubits.
Expand All @@ -39,7 +35,7 @@ def measure(
Args:
*target: The qubits that the measurement gate should measure.
key: The string key of the measurement. If this is None, it defaults
key: The string key of the measurement. If this is empty, defaults
to a comma-separated list of the target qubits' str values.
invert_mask: A list of Truthy or Falsey values indicating whether
the corresponding qubits should be flipped. None indicates no
Expand All @@ -60,14 +56,12 @@ def measure(
elif not isinstance(qubit, raw_types.Qid):
raise ValueError('measure() was called with type different than Qid.')

if key is None:
key = _default_measurement_key(target)
qid_shape = protocols.qid_shape(target)
return MeasurementGate(len(target), key, invert_mask, qid_shape).on(*target)


def measure_each(
*qubits: 'cirq.Qid', key_func: Callable[[raw_types.Qid], str] = str
*qubits: 'cirq.Qid', key_func: Callable[[raw_types.Qid], str] = lambda x: ''
) -> List[raw_types.Operation]:
"""Returns a list of operations individually measuring the given qubits.
Expand All @@ -76,7 +70,9 @@ def measure_each(
Args:
*qubits: The qubits to measure.
key_func: Determines the key of the measurements of each qubit. Takes
the qubit and returns the key for that qubit. Defaults to str.
the qubit and returns the key for that qubit. Defaults to empty string
key which would result in a comma-separated list of the target qubits'
str values.
Returns:
A list of operations individually measuring the given qubits.
Expand Down
18 changes: 14 additions & 4 deletions cirq-core/cirq/ops/measure_util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,22 @@ def test_measure_qubits():
with pytest.raises(ValueError, match='empty set of qubits'):
_ = cirq.measure()

assert cirq.measure(a) == cirq.MeasurementGate(num_qubits=1, key='a').on(a)
assert cirq.measure(a, b) == cirq.MeasurementGate(num_qubits=2, key='a,b').on(a, b)
assert cirq.measure(b, a) == cirq.MeasurementGate(num_qubits=2, key='b,a').on(b, a)
assert cirq.measure(a) == cirq.MeasurementGate(num_qubits=1, key='').on(a)
assert cirq.measure(a) != cirq.measure(a, key='a')
assert cirq.measure(a) != cirq.MeasurementGate(num_qubits=1, key='').on(b)
assert cirq.measure(a) != cirq.MeasurementGate(num_qubits=1, key='a').on(a)
assert cirq.measure(a) == cirq.MeasurementGate(
num_qubits=1, key=cirq.MeasurementKey(qubits=(a,))
).on(a)
assert cirq.measurement_key(cirq.measure(a)) == 'a'
assert cirq.measure(a, b) == cirq.MeasurementGate(num_qubits=2, key='').on(a, b)
assert cirq.measurement_key(cirq.measure(a, b)) == 'a,b'
assert cirq.measure(b, a) == cirq.MeasurementGate(num_qubits=2, key='').on(b, a)
assert cirq.measurement_key(cirq.measure(b, a)) == 'b,a'
assert cirq.measure(a, key='b') == cirq.MeasurementGate(num_qubits=1, key='b').on(a)
assert cirq.measure(a, key='b') != cirq.MeasurementGate(num_qubits=1, key='b').on(b)
assert cirq.measure(a, invert_mask=(True,)) == cirq.MeasurementGate(
num_qubits=1, key='a', invert_mask=(True,)
num_qubits=1, invert_mask=(True,)
).on(a)
assert cirq.measure(*cirq.LineQid.for_qid_shape((1, 2, 3)), key='a') == cirq.MeasurementGate(
num_qubits=3, key='a', qid_shape=(1, 2, 3)
Expand Down
42 changes: 26 additions & 16 deletions cirq-core/cirq/ops/measurement_gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Dict, Iterable, Optional, Tuple, Sequence, TYPE_CHECKING, Union
from typing import Any, Dict, Optional, Tuple, Sequence, TYPE_CHECKING, Union

import numpy as np

from cirq import protocols, value
from cirq.ops import raw_types
from cirq.ops import raw_types, gate_operation

if TYPE_CHECKING:
import cirq
Expand Down Expand Up @@ -80,12 +80,21 @@ def key(self, key: Union[str, value.MeasurementKey]):
else:
self.mkey = value.MeasurementKey(name=key)

def on(self, *qubits: raw_types.Qid) -> raw_types.Operation:
"""Returns an application of this gate to the given qubits.
Args:
*qubits: The collection of qubits to potentially apply the gate to.
"""
maybe_rekeyed_gate = self.with_key(self.mkey.with_qubits(qubits))
return gate_operation.GateOperation(maybe_rekeyed_gate, list(qubits))

def _qid_shape_(self) -> Tuple[int, ...]:
return self._qid_shape

def with_key(self, key: Union[str, value.MeasurementKey]) -> 'MeasurementGate':
"""Creates a measurement gate with a new key but otherwise identical."""
if key == self.key:
if isinstance(key, value.MeasurementKey) and key == self.mkey:
return self
return MeasurementGate(
self.num_qubits(), key=key, invert_mask=self.invert_mask, qid_shape=self._qid_shape
Expand All @@ -105,7 +114,7 @@ def with_bits_flipped(self, *bit_positions: int) -> 'MeasurementGate':
for b in bit_positions:
new_mask[b] = not new_mask[b]
return MeasurementGate(
self.num_qubits(), key=self.key, invert_mask=tuple(new_mask), qid_shape=self._qid_shape
self.num_qubits(), key=self.mkey, invert_mask=tuple(new_mask), qid_shape=self._qid_shape
)

def full_invert_mask(self):
Expand Down Expand Up @@ -152,8 +161,8 @@ def _circuit_diagram_info_(
if b:
symbols[i] = '!M'

# Mention the measurement key.
if not args.known_qubits or self.key != _default_measurement_key(args.known_qubits):
# Mention the measurement key if it is non-trivial or there are no known qubits.
if self.mkey.name or self.mkey.path or not args.known_qubits:
symbols[0] += f"('{self.key}')"

return protocols.CircuitDiagramInfo(tuple(symbols))
Expand Down Expand Up @@ -191,8 +200,13 @@ def _quil_(

def _op_repr_(self, qubits: Sequence['cirq.Qid']) -> str:
args = list(repr(q) for q in qubits)
if self.key != _default_measurement_key(qubits):
args.append(f'key={self.key!r}')
if self.mkey.name or self.mkey.path:
if self.mkey == self.mkey.name:
args.append(f'key={self.mkey.name!r}')
else:
# Remove qubits from the `MeasurementKey` representation since we already have
# qubits from the op.
args.append(f'key={self.mkey.with_qubits(tuple())!r}')
if self.invert_mask:
args.append(f'invert_mask={self.invert_mask!r}')
arg_list = ', '.join(args)
Expand All @@ -205,13 +219,13 @@ def __repr__(self):
return (
f'cirq.MeasurementGate('
f'{self.num_qubits()!r}, '
f'{self.key!r}, '
f'{self.mkey.name if self.mkey == self.mkey.name else self.mkey!r}, '
f'{self.invert_mask}'
f'{qid_shape_arg})'
)

def _value_equality_values_(self) -> Any:
return self.key, self.invert_mask, self._qid_shape
return self.mkey, self.invert_mask, self._qid_shape

def _json_dict_(self) -> Dict[str, Any]:
other = {}
Expand All @@ -220,7 +234,7 @@ def _json_dict_(self) -> Dict[str, Any]:
return {
'cirq_type': self.__class__.__name__,
'num_qubits': len(self._qid_shape),
'key': self.key,
'key': self.mkey,
'invert_mask': self.invert_mask,
**other,
}
Expand All @@ -229,7 +243,7 @@ def _json_dict_(self) -> Dict[str, Any]:
def _from_json_dict_(cls, num_qubits, key, invert_mask, qid_shape=None, **kwargs):
return cls(
num_qubits=num_qubits,
key=value.MeasurementKey.parse_serialized(key),
key=value.MeasurementKey.parse_serialized(key) if isinstance(key, str) else key,
invert_mask=tuple(invert_mask),
qid_shape=None if qid_shape is None else tuple(qid_shape),
)
Expand All @@ -240,7 +254,3 @@ def _has_stabilizer_effect_(self) -> Optional[bool]:
def _act_on_(self, args: 'cirq.ActOnArgs', qubits: Sequence['cirq.Qid']) -> bool:
args.measure(qubits, self.key, self.full_invert_mask())
return True


def _default_measurement_key(qubits: Iterable[raw_types.Qid]) -> str:
return ','.join(str(q) for q in qubits)
9 changes: 9 additions & 0 deletions cirq-core/cirq/ops/measurement_gate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,15 @@ def test_op_repr():
"key='out', "
"invert_mask=(False, True))"
)
assert repr(
cirq.measure(
a, b, key=cirq.MeasurementKey(name='out', path=('a', 'b')), invert_mask=(False, True)
)
) == (
"cirq.measure(cirq.LineQubit(0), cirq.LineQubit(1), "
"key=cirq.MeasurementKey(path=('a', 'b'), name='out'), "
"invert_mask=(False, True))"
)


def test_act_on_state_vector():
Expand Down
28 changes: 27 additions & 1 deletion cirq-core/cirq/protocols/json_test_data/Circuit.json
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,33 @@
"gate": {
"cirq_type": "MeasurementGate",
"num_qubits": 5,
"key": "0,1,2,3,4",
"key": {
"cirq_type": "MeasurementKey",
"name": "",
"path": [],
"qubits": [
{
"cirq_type": "LineQubit",
"x": 0
},
{
"cirq_type": "LineQubit",
"x": 1
},
{
"cirq_type": "LineQubit",
"x": 2
},
{
"cirq_type": "LineQubit",
"x": 3
},
{
"cirq_type": "LineQubit",
"x": 4
}
]
},
"invert_mask": []
},
"qubits": [
Expand Down

0 comments on commit 81436fe

Please sign in to comment.