Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert "Support qubits in MeasurementKey and update JSON serialization of keys" #4277

Merged
merged 3 commits into from
Jun 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 0 additions & 1 deletion cirq-core/cirq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,6 @@
canonicalize_half_turns,
chosen_angle_to_canonical_half_turns,
chosen_angle_to_half_turns,
default_measurement_key_str,
Duration,
DURATION_LIKE,
GenericMetaImplementAnyOneOf,
Expand Down
12 changes: 2 additions & 10 deletions cirq-core/cirq/circuits/circuit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1367,16 +1367,8 @@ def test_findall_operations_with_gate(circuit_cls):
(3, cirq.CZ(a, b), cirq.CZ),
]
assert list(c.findall_operations_with_gate_type(cirq.MeasurementGate)) == [
(
4,
cirq.MeasurementGate(1).on(a),
cirq.MeasurementGate(1, key=cirq.MeasurementKey(qubits=(a,))),
),
(
4,
cirq.MeasurementGate(1).on(b),
cirq.MeasurementGate(1, key=cirq.MeasurementKey(qubits=(b,))),
),
(4, cirq.MeasurementGate(1, key='a').on(a), cirq.MeasurementGate(1, key='a')),
(4, cirq.MeasurementGate(1, key='b').on(b), cirq.MeasurementGate(1, key='b')),
]


Expand Down
4 changes: 0 additions & 4 deletions cirq-core/cirq/ops/gate_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,10 +123,6 @@ def __str__(self) -> str:
def _json_dict_(self) -> Dict[str, Any]:
return protocols.obj_to_dict_helper(self, ['gate', 'qubits'])

@classmethod
def _from_json_dict_(cls, gate, qubits, **kwargs):
return gate.on(*qubits)

def _group_interchangeable_qubits(
self,
) -> Tuple[Union['cirq.Qid', Tuple[int, FrozenSet['cirq.Qid']]], ...]:
Expand Down
18 changes: 11 additions & 7 deletions cirq-core/cirq/ops/measure_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Callable, List, Tuple, TYPE_CHECKING, Union
from typing import Callable, Iterable, List, Optional, Tuple, TYPE_CHECKING, Union

import numpy as np

Expand All @@ -24,9 +24,13 @@
import cirq


def _default_measurement_key(qubits: Iterable[raw_types.Qid]) -> str:
return ','.join(str(q) for q in qubits)


def measure(
*target: 'cirq.Qid',
key: Union[str, value.MeasurementKey] = '',
key: Optional[Union[str, value.MeasurementKey]] = None,
invert_mask: Tuple[bool, ...] = (),
) -> raw_types.Operation:
"""Returns a single MeasurementGate applied to all the given qubits.
Expand All @@ -35,7 +39,7 @@ def measure(

Args:
*target: The qubits that the measurement gate should measure.
key: The string key of the measurement. If this is empty, defaults
key: The string key of the measurement. If this is None, it defaults
to a comma-separated list of the target qubits' str values.
invert_mask: A list of Truthy or Falsey values indicating whether
the corresponding qubits should be flipped. None indicates no
Expand All @@ -56,12 +60,14 @@ def measure(
elif not isinstance(qubit, raw_types.Qid):
raise ValueError('measure() was called with type different than Qid.')

if key is None:
key = _default_measurement_key(target)
qid_shape = protocols.qid_shape(target)
return MeasurementGate(len(target), key, invert_mask, qid_shape).on(*target)


def measure_each(
*qubits: 'cirq.Qid', key_func: Callable[[raw_types.Qid], str] = lambda x: ''
*qubits: 'cirq.Qid', key_func: Callable[[raw_types.Qid], str] = str
) -> List[raw_types.Operation]:
"""Returns a list of operations individually measuring the given qubits.

Expand All @@ -70,9 +76,7 @@ def measure_each(
Args:
*qubits: The qubits to measure.
key_func: Determines the key of the measurements of each qubit. Takes
the qubit and returns the key for that qubit. Defaults to empty string
key which would result in a comma-separated list of the target qubits'
str values.
the qubit and returns the key for that qubit. Defaults to str.

Returns:
A list of operations individually measuring the given qubits.
Expand Down
18 changes: 4 additions & 14 deletions cirq-core/cirq/ops/measure_util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,22 +26,12 @@ def test_measure_qubits():
with pytest.raises(ValueError, match='empty set of qubits'):
_ = cirq.measure()

assert cirq.measure(a) == cirq.MeasurementGate(num_qubits=1, key='').on(a)
assert cirq.measure(a) != cirq.measure(a, key='a')
assert cirq.measure(a) != cirq.MeasurementGate(num_qubits=1, key='').on(b)
assert cirq.measure(a) != cirq.MeasurementGate(num_qubits=1, key='a').on(a)
assert cirq.measure(a) == cirq.MeasurementGate(
num_qubits=1, key=cirq.MeasurementKey(qubits=(a,))
).on(a)
assert cirq.measurement_key(cirq.measure(a)) == 'a'
assert cirq.measure(a, b) == cirq.MeasurementGate(num_qubits=2, key='').on(a, b)
assert cirq.measurement_key(cirq.measure(a, b)) == 'a,b'
assert cirq.measure(b, a) == cirq.MeasurementGate(num_qubits=2, key='').on(b, a)
assert cirq.measurement_key(cirq.measure(b, a)) == 'b,a'
assert cirq.measure(a) == cirq.MeasurementGate(num_qubits=1, key='a').on(a)
assert cirq.measure(a, b) == cirq.MeasurementGate(num_qubits=2, key='a,b').on(a, b)
assert cirq.measure(b, a) == cirq.MeasurementGate(num_qubits=2, key='b,a').on(b, a)
assert cirq.measure(a, key='b') == cirq.MeasurementGate(num_qubits=1, key='b').on(a)
assert cirq.measure(a, key='b') != cirq.MeasurementGate(num_qubits=1, key='b').on(b)
assert cirq.measure(a, invert_mask=(True,)) == cirq.MeasurementGate(
num_qubits=1, invert_mask=(True,)
num_qubits=1, key='a', invert_mask=(True,)
).on(a)
assert cirq.measure(*cirq.LineQid.for_qid_shape((1, 2, 3)), key='a') == cirq.MeasurementGate(
num_qubits=3, key='a', qid_shape=(1, 2, 3)
Expand Down
42 changes: 16 additions & 26 deletions cirq-core/cirq/ops/measurement_gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Dict, Optional, Tuple, Sequence, TYPE_CHECKING, Union
from typing import Any, Dict, Iterable, Optional, Tuple, Sequence, TYPE_CHECKING, Union

import numpy as np

from cirq import protocols, value
from cirq.ops import raw_types, gate_operation
from cirq.ops import raw_types

if TYPE_CHECKING:
import cirq
Expand Down Expand Up @@ -80,21 +80,12 @@ def key(self, key: Union[str, value.MeasurementKey]):
else:
self.mkey = value.MeasurementKey(name=key)

def on(self, *qubits: raw_types.Qid) -> raw_types.Operation:
"""Returns an application of this gate to the given qubits.

Args:
*qubits: The collection of qubits to potentially apply the gate to.
"""
maybe_rekeyed_gate = self.with_key(self.mkey.with_qubits(qubits))
return gate_operation.GateOperation(maybe_rekeyed_gate, list(qubits))

def _qid_shape_(self) -> Tuple[int, ...]:
return self._qid_shape

def with_key(self, key: Union[str, value.MeasurementKey]) -> 'MeasurementGate':
"""Creates a measurement gate with a new key but otherwise identical."""
if isinstance(key, value.MeasurementKey) and key == self.mkey:
if key == self.key:
return self
return MeasurementGate(
self.num_qubits(), key=key, invert_mask=self.invert_mask, qid_shape=self._qid_shape
Expand All @@ -114,7 +105,7 @@ def with_bits_flipped(self, *bit_positions: int) -> 'MeasurementGate':
for b in bit_positions:
new_mask[b] = not new_mask[b]
return MeasurementGate(
self.num_qubits(), key=self.mkey, invert_mask=tuple(new_mask), qid_shape=self._qid_shape
self.num_qubits(), key=self.key, invert_mask=tuple(new_mask), qid_shape=self._qid_shape
)

def full_invert_mask(self):
Expand Down Expand Up @@ -161,8 +152,8 @@ def _circuit_diagram_info_(
if b:
symbols[i] = '!M'

# Mention the measurement key if it is non-trivial or there are no known qubits.
if self.mkey.name or self.mkey.path or not args.known_qubits:
# Mention the measurement key.
if not args.known_qubits or self.key != _default_measurement_key(args.known_qubits):
symbols[0] += f"('{self.key}')"

return protocols.CircuitDiagramInfo(tuple(symbols))
Expand Down Expand Up @@ -200,13 +191,8 @@ def _quil_(

def _op_repr_(self, qubits: Sequence['cirq.Qid']) -> str:
args = list(repr(q) for q in qubits)
if self.mkey.name or self.mkey.path:
if self.mkey == self.mkey.name:
args.append(f'key={self.mkey.name!r}')
else:
# Remove qubits from the `MeasurementKey` representation since we already have
# qubits from the op.
args.append(f'key={self.mkey.with_qubits(tuple())!r}')
if self.key != _default_measurement_key(qubits):
args.append(f'key={self.key!r}')
if self.invert_mask:
args.append(f'invert_mask={self.invert_mask!r}')
arg_list = ', '.join(args)
Expand All @@ -219,13 +205,13 @@ def __repr__(self):
return (
f'cirq.MeasurementGate('
f'{self.num_qubits()!r}, '
f'{self.mkey.name if self.mkey == self.mkey.name else self.mkey!r}, '
f'{self.key!r}, '
f'{self.invert_mask}'
f'{qid_shape_arg})'
)

def _value_equality_values_(self) -> Any:
return self.mkey, self.invert_mask, self._qid_shape
return self.key, self.invert_mask, self._qid_shape

def _json_dict_(self) -> Dict[str, Any]:
other = {}
Expand All @@ -234,7 +220,7 @@ def _json_dict_(self) -> Dict[str, Any]:
return {
'cirq_type': self.__class__.__name__,
'num_qubits': len(self._qid_shape),
'key': self.mkey,
'key': self.key,
'invert_mask': self.invert_mask,
**other,
}
Expand All @@ -243,7 +229,7 @@ def _json_dict_(self) -> Dict[str, Any]:
def _from_json_dict_(cls, num_qubits, key, invert_mask, qid_shape=None, **kwargs):
return cls(
num_qubits=num_qubits,
key=value.MeasurementKey.parse_serialized(key) if isinstance(key, str) else key,
key=value.MeasurementKey.parse_serialized(key),
invert_mask=tuple(invert_mask),
qid_shape=None if qid_shape is None else tuple(qid_shape),
)
Expand All @@ -254,3 +240,7 @@ def _has_stabilizer_effect_(self) -> Optional[bool]:
def _act_on_(self, args: 'cirq.ActOnArgs', qubits: Sequence['cirq.Qid']) -> bool:
args.measure(qubits, self.key, self.full_invert_mask())
return True


def _default_measurement_key(qubits: Iterable[raw_types.Qid]) -> str:
return ','.join(str(q) for q in qubits)
18 changes: 9 additions & 9 deletions cirq-core/cirq/ops/measurement_gate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,15 @@
import cirq


def test_eval_repr():
# Basic safeguard against repr-inequality.
op = cirq.GateOperation(
gate=cirq.MeasurementGate(1, cirq.MeasurementKey(path=(), name='q0_1_0'), ()),
qubits=[cirq.GridQubit(0, 1)],
)
cirq.testing.assert_equivalent_repr(op)


@pytest.mark.parametrize('num_qubits', [1, 2, 4])
def test_measure_init(num_qubits):
assert cirq.MeasurementGate(num_qubits).num_qubits() == num_qubits
Expand Down Expand Up @@ -266,15 +275,6 @@ def test_op_repr():
"key='out', "
"invert_mask=(False, True))"
)
assert repr(
cirq.measure(
a, b, key=cirq.MeasurementKey(name='out', path=('a', 'b')), invert_mask=(False, True)
)
) == (
"cirq.measure(cirq.LineQubit(0), cirq.LineQubit(1), "
"key=cirq.MeasurementKey(path=('a', 'b'), name='out'), "
"invert_mask=(False, True))"
)


def test_act_on_state_vector():
Expand Down
28 changes: 1 addition & 27 deletions cirq-core/cirq/protocols/json_test_data/Circuit.json
Original file line number Diff line number Diff line change
Expand Up @@ -85,33 +85,7 @@
"gate": {
"cirq_type": "MeasurementGate",
"num_qubits": 5,
"key": {
"cirq_type": "MeasurementKey",
"name": "",
"path": [],
"qubits": [
{
"cirq_type": "LineQubit",
"x": 0
},
{
"cirq_type": "LineQubit",
"x": 1
},
{
"cirq_type": "LineQubit",
"x": 2
},
{
"cirq_type": "LineQubit",
"x": 3
},
{
"cirq_type": "LineQubit",
"x": 4
}
]
},
"key": "0,1,2,3,4",
"invert_mask": []
},
"qubits": [
Expand Down