Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 14 additions & 12 deletions cirq-google/cirq_google/api/v2/sweeps.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,36 +30,38 @@
from cirq.study import sweeps


def _build_sweep_const(value: Any, use_float64: bool = False) -> run_context_pb2.ConstValue:
def _add_sweep_const(
sweep: run_context_pb2.SingleSweep, value: Any, use_float64: bool = False
) -> None:
"""Build the sweep const message from a value."""
if isinstance(value, float):
# comparing to float is ~5x than testing numbers.Real
# if modifying the below, also modify the block below for numbers.Real
if use_float64:
return run_context_pb2.ConstValue(double_value=value)
sweep.const_value.double_value = value
else:
# Note: A loss of precision for floating-point numbers may occur here.
return run_context_pb2.ConstValue(float_value=value)
sweep.const_value.float_value = value
elif isinstance(value, int):
# comparing to int is ~5x than testing numbers.Integral
# if modifying the below, also modify the block below for numbers.Integral
return run_context_pb2.ConstValue(int_value=value)
sweep.const_value.int_value = value
elif value is None:
return run_context_pb2.ConstValue(is_none=True)
sweep.const_value.is_none = True
elif isinstance(value, str):
return run_context_pb2.ConstValue(string_value=value)
sweep.const_value.string_value = value
elif isinstance(value, numbers.Integral):
# more general than isinstance(int) but also slower
return run_context_pb2.ConstValue(int_value=int(value))
sweep.const_value.int_value = int(value)
elif isinstance(value, numbers.Real):
# more general than isinstance(float) but also slower
if use_float64:
return run_context_pb2.ConstValue(double_value=float(value))
sweep.const_value.double_value = float(value) # pragma: no cover
else:
# Note: A loss of precision for floating-point numbers may occur here.
return run_context_pb2.ConstValue(float_value=float(value))
sweep.const_value.float_value = float(value)
elif isinstance(value, tunits.Value):
return run_context_pb2.ConstValue(with_unit_value=value.to_proto())
value.to_proto(sweep.const_value.with_unit_value)
else:
raise ValueError(
f"Unsupported type for serializing const sweep: {value=} and {type(value)=}"
Expand Down Expand Up @@ -190,7 +192,7 @@ def sweep_to_proto(
sweep = cast(cirq.Points, sweep_transformer(sweep))
out.single_sweep.parameter_key = sweep.key
if len(sweep.points) == 1:
out.single_sweep.const_value.MergeFrom(_build_sweep_const(sweep.points[0], use_float64))
_add_sweep_const(out.single_sweep, sweep.points[0], use_float64)
else:
if isinstance(sweep.points[0], tunits.Value):
unit = sweep.points[0].unit
Expand Down Expand Up @@ -402,7 +404,7 @@ def sweepable_to_proto(
for key, val in sweepable.items():
single_sweep = zip_proto.sweeps.add().single_sweep
single_sweep.parameter_key = key
single_sweep.const_value.MergeFrom(_build_sweep_const(val, use_float64))
_add_sweep_const(single_sweep, val, use_float64)
return out
if isinstance(sweepable, Iterable):
for sweepable_element in sweepable:
Expand Down
6 changes: 4 additions & 2 deletions cirq-google/cirq_google/api/v2/sweeps_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,9 @@ def test_sweep_to_proto_linspace():

@pytest.mark.parametrize("val", [None, 1, 1.5, 's'])
def test_build_recover_const(val):
val2 = v2.sweeps._recover_sweep_const(v2.sweeps._build_sweep_const(val))
sweep = v2.run_context_pb2.SingleSweep()
v2.sweeps._add_sweep_const(sweep, val)
val2 = v2.sweeps._recover_sweep_const(sweep.const_value)
if isinstance(val, float):
assert math.isclose(val, val2) # avoid the floating precision issue.
else:
Expand All @@ -179,7 +181,7 @@ def test_build_covert_const_double():

def test_build_const_unsupported_type():
with pytest.raises(ValueError, match='Unsupported type for serializing const sweep'):
v2.sweeps._build_sweep_const((1, 2))
v2.sweeps._add_sweep_const(v2.run_context_pb2.SingleSweep(), (1, 2))


def test_list_sweep_bad_expression():
Expand Down