diff --git a/cirq-google/cirq_google/serialization/circuit_serializer.py b/cirq-google/cirq_google/serialization/circuit_serializer.py index aaac880266a..16fd0396755 100644 --- a/cirq-google/cirq_google/serialization/circuit_serializer.py +++ b/cirq-google/cirq_google/serialization/circuit_serializer.py @@ -14,7 +14,7 @@ """Support for serializing and deserializing cirq_google.api.v2 protos.""" -from typing import cast, Any, Dict, List, Optional +from typing import Any, Dict, List, Optional import sympy import cirq @@ -546,20 +546,22 @@ def _deserialize_gate_op( arg_function_language=arg_function_language, required_arg_name=None, ) - invert_mask = cast( - List[bool], - arg_func_langs.arg_from_proto( - operation_proto.measurementgate.invert_mask, - arg_function_language=arg_function_language, - required_arg_name=None, - ), + parsed_invert_mask = arg_func_langs.arg_from_proto( + operation_proto.measurementgate.invert_mask, + arg_function_language=arg_function_language, + required_arg_name=None, ) - if isinstance(invert_mask, list) and isinstance(key, str): - op = cirq.MeasurementGate( - num_qubits=len(qubits), key=key, invert_mask=tuple(invert_mask) - )(*qubits) + if (isinstance(parsed_invert_mask, list) or parsed_invert_mask is None) and isinstance( + key, str + ): + invert_mask: tuple[bool, ...] = () + if parsed_invert_mask is not None: + invert_mask = tuple(bool(x) for x in parsed_invert_mask) + op = cirq.MeasurementGate(num_qubits=len(qubits), key=key, invert_mask=invert_mask)( + *qubits + ) else: - raise ValueError(f'Incorrect types for measurement gate {invert_mask} {key}') + raise ValueError(f'Incorrect types for measurement gate {parsed_invert_mask} {key}') elif which_gate_type == 'waitgate': total_nanos = arg_func_langs.float_arg_from_proto( diff --git a/cirq-google/cirq_google/serialization/circuit_serializer_test.py b/cirq-google/cirq_google/serialization/circuit_serializer_test.py index c8f99d57756..0db07c3e36d 100644 --- a/cirq-google/cirq_google/serialization/circuit_serializer_test.py +++ b/cirq-google/cirq_google/serialization/circuit_serializer_test.py @@ -660,3 +660,11 @@ def test_no_constants_table(): with pytest.raises(ValueError, match='Proto has references to constants table'): serializer._deserialize_gate_op(op) + + +def test_measurement_gate_deserialize() -> None: + q = cirq.NamedQubit('q') + circuit = cirq.Circuit(cirq.X(q) ** 0.5, cirq.measure(q)) + msg = cg.CIRCUIT_SERIALIZER.serialize(circuit) + + assert cg.CIRCUIT_SERIALIZER.deserialize(msg) == circuit