diff --git a/cirq-google/cirq_google/api/v2/sweeps.py b/cirq-google/cirq_google/api/v2/sweeps.py index 5ce24719fcc..001c49d3de2 100644 --- a/cirq-google/cirq_google/api/v2/sweeps.py +++ b/cirq-google/cirq_google/api/v2/sweeps.py @@ -30,36 +30,38 @@ from cirq.study import sweeps -def _build_sweep_const(value: Any, use_float64: bool = False) -> run_context_pb2.ConstValue: +def _add_sweep_const( + sweep: run_context_pb2.SingleSweep, value: Any, use_float64: bool = False +) -> None: """Build the sweep const message from a value.""" if isinstance(value, float): # comparing to float is ~5x than testing numbers.Real # if modifying the below, also modify the block below for numbers.Real if use_float64: - return run_context_pb2.ConstValue(double_value=value) + sweep.const_value.double_value = value else: # Note: A loss of precision for floating-point numbers may occur here. - return run_context_pb2.ConstValue(float_value=value) + sweep.const_value.float_value = value elif isinstance(value, int): # comparing to int is ~5x than testing numbers.Integral # if modifying the below, also modify the block below for numbers.Integral - return run_context_pb2.ConstValue(int_value=value) + sweep.const_value.int_value = value elif value is None: - return run_context_pb2.ConstValue(is_none=True) + sweep.const_value.is_none = True elif isinstance(value, str): - return run_context_pb2.ConstValue(string_value=value) + sweep.const_value.string_value = value elif isinstance(value, numbers.Integral): # more general than isinstance(int) but also slower - return run_context_pb2.ConstValue(int_value=int(value)) + sweep.const_value.int_value = int(value) elif isinstance(value, numbers.Real): # more general than isinstance(float) but also slower if use_float64: - return run_context_pb2.ConstValue(double_value=float(value)) + sweep.const_value.double_value = float(value) # pragma: no cover else: # Note: A loss of precision for floating-point numbers may occur here. - return run_context_pb2.ConstValue(float_value=float(value)) + sweep.const_value.float_value = float(value) elif isinstance(value, tunits.Value): - return run_context_pb2.ConstValue(with_unit_value=value.to_proto()) + value.to_proto(sweep.const_value.with_unit_value) else: raise ValueError( f"Unsupported type for serializing const sweep: {value=} and {type(value)=}" @@ -190,7 +192,7 @@ def sweep_to_proto( sweep = cast(cirq.Points, sweep_transformer(sweep)) out.single_sweep.parameter_key = sweep.key if len(sweep.points) == 1: - out.single_sweep.const_value.MergeFrom(_build_sweep_const(sweep.points[0], use_float64)) + _add_sweep_const(out.single_sweep, sweep.points[0], use_float64) else: if isinstance(sweep.points[0], tunits.Value): unit = sweep.points[0].unit @@ -402,7 +404,7 @@ def sweepable_to_proto( for key, val in sweepable.items(): single_sweep = zip_proto.sweeps.add().single_sweep single_sweep.parameter_key = key - single_sweep.const_value.MergeFrom(_build_sweep_const(val, use_float64)) + _add_sweep_const(single_sweep, val, use_float64) return out if isinstance(sweepable, Iterable): for sweepable_element in sweepable: diff --git a/cirq-google/cirq_google/api/v2/sweeps_test.py b/cirq-google/cirq_google/api/v2/sweeps_test.py index 3c42a2b56b8..86258051f3a 100644 --- a/cirq-google/cirq_google/api/v2/sweeps_test.py +++ b/cirq-google/cirq_google/api/v2/sweeps_test.py @@ -164,7 +164,9 @@ def test_sweep_to_proto_linspace(): @pytest.mark.parametrize("val", [None, 1, 1.5, 's']) def test_build_recover_const(val): - val2 = v2.sweeps._recover_sweep_const(v2.sweeps._build_sweep_const(val)) + sweep = v2.run_context_pb2.SingleSweep() + v2.sweeps._add_sweep_const(sweep, val) + val2 = v2.sweeps._recover_sweep_const(sweep.const_value) if isinstance(val, float): assert math.isclose(val, val2) # avoid the floating precision issue. else: @@ -179,7 +181,7 @@ def test_build_covert_const_double(): def test_build_const_unsupported_type(): with pytest.raises(ValueError, match='Unsupported type for serializing const sweep'): - v2.sweeps._build_sweep_const((1, 2)) + v2.sweeps._add_sweep_const(v2.run_context_pb2.SingleSweep(), (1, 2)) def test_list_sweep_bad_expression():