Skip to content

Commit

Permalink
Support empty invert_mask in measument gate deserialization (#6224)
Browse files Browse the repository at this point in the history
  • Loading branch information
NoureldinYosri committed Aug 2, 2023
1 parent c510fff commit 701538c
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 13 deletions.
28 changes: 15 additions & 13 deletions cirq-google/cirq_google/serialization/circuit_serializer.py
Expand Up @@ -14,7 +14,7 @@

"""Support for serializing and deserializing cirq_google.api.v2 protos."""

from typing import cast, Any, Dict, List, Optional
from typing import Any, Dict, List, Optional
import sympy

import cirq
Expand Down Expand Up @@ -546,20 +546,22 @@ def _deserialize_gate_op(
arg_function_language=arg_function_language,
required_arg_name=None,
)
invert_mask = cast(
List[bool],
arg_func_langs.arg_from_proto(
operation_proto.measurementgate.invert_mask,
arg_function_language=arg_function_language,
required_arg_name=None,
),
parsed_invert_mask = arg_func_langs.arg_from_proto(
operation_proto.measurementgate.invert_mask,
arg_function_language=arg_function_language,
required_arg_name=None,
)
if isinstance(invert_mask, list) and isinstance(key, str):
op = cirq.MeasurementGate(
num_qubits=len(qubits), key=key, invert_mask=tuple(invert_mask)
)(*qubits)
if (isinstance(parsed_invert_mask, list) or parsed_invert_mask is None) and isinstance(
key, str
):
invert_mask: tuple[bool, ...] = ()
if parsed_invert_mask is not None:
invert_mask = tuple(bool(x) for x in parsed_invert_mask)
op = cirq.MeasurementGate(num_qubits=len(qubits), key=key, invert_mask=invert_mask)(
*qubits
)
else:
raise ValueError(f'Incorrect types for measurement gate {invert_mask} {key}')
raise ValueError(f'Incorrect types for measurement gate {parsed_invert_mask} {key}')

elif which_gate_type == 'waitgate':
total_nanos = arg_func_langs.float_arg_from_proto(
Expand Down
Expand Up @@ -660,3 +660,11 @@ def test_no_constants_table():

with pytest.raises(ValueError, match='Proto has references to constants table'):
serializer._deserialize_gate_op(op)


def test_measurement_gate_deserialize() -> None:
q = cirq.NamedQubit('q')
circuit = cirq.Circuit(cirq.X(q) ** 0.5, cirq.measure(q))
msg = cg.CIRCUIT_SERIALIZER.serialize(circuit)

assert cg.CIRCUIT_SERIALIZER.deserialize(msg) == circuit

0 comments on commit 701538c

Please sign in to comment.