From f3965e8281a71ac663cfbc80194ff970ade08fbd Mon Sep 17 00:00:00 2001 From: Doug Strain Date: Thu, 22 Jun 2023 09:56:01 -0700 Subject: [PATCH] Add Device Parameter metadata serialization to cirq_google (#6113) * Add Device Parameter metadata serialization to cirq_google - Allow DeviceParameter to be serialized as part of cirq_google. --- .../cirq_google/api/v2/run_context.proto | 20 +++++ .../cirq_google/api/v2/run_context_pb2.py | 76 +++++++++++++++++-- .../cirq_google/api/v2/run_context_pb2.pyi | 31 +++++++- cirq-google/cirq_google/api/v2/sweeps.py | 24 +++++- cirq-google/cirq_google/api/v2/sweeps_test.py | 24 +++++- 5 files changed, 162 insertions(+), 13 deletions(-) diff --git a/cirq-google/cirq_google/api/v2/run_context.proto b/cirq-google/cirq_google/api/v2/run_context.proto index fee314b1c78..e0cf0a27fa3 100644 --- a/cirq-google/cirq_google/api/v2/run_context.proto +++ b/cirq-google/cirq_google/api/v2/run_context.proto @@ -84,6 +84,22 @@ message SweepFunction { repeated Sweep sweeps = 2; } +message DeviceParameter { + + // Path to the parameter key + repeated string path = 1; + + // If the value is an array, the index of the array to change. + int64 idx = 2; + + // String representation of the units, if any. + // Examples: "GHz", "ns", etc. + string units = 3; + + // Note that the device parameter values will be populated + // by the sweep values themselves. +} + // A set of values to loop over for a particular parameter. message SingleSweep { @@ -98,6 +114,10 @@ message SingleSweep { // Uniformly-spaced sampling over a range. Linspace linspace = 3; } + + // Optional arguments for if this is a device parameter. + // (as opposed to a circuit symbol) + DeviceParameter parameter = 4; } // A list of explicit values. diff --git a/cirq-google/cirq_google/api/v2/run_context_pb2.py b/cirq-google/cirq_google/api/v2/run_context_pb2.py index d64e2eaa358..b235fecf876 100644 --- a/cirq-google/cirq_google/api/v2/run_context_pb2.py +++ b/cirq-google/cirq_google/api/v2/run_context_pb2.py @@ -19,7 +19,7 @@ syntax='proto3', serialized_options=b'\n\035com.google.cirq.google.api.v2B\017RunContextProtoP\001', create_key=_descriptor._internal_create_key, - serialized_pb=b'\n$cirq_google/api/v2/run_context.proto\x12\x12\x63irq.google.api.v2\"J\n\nRunContext\x12<\n\x10parameter_sweeps\x18\x01 \x03(\x0b\x32\".cirq.google.api.v2.ParameterSweep\"O\n\x0eParameterSweep\x12\x13\n\x0brepetitions\x18\x01 \x01(\x05\x12(\n\x05sweep\x18\x02 \x01(\x0b\x32\x19.cirq.google.api.v2.Sweep\"\x86\x01\n\x05Sweep\x12;\n\x0esweep_function\x18\x01 \x01(\x0b\x32!.cirq.google.api.v2.SweepFunctionH\x00\x12\x37\n\x0csingle_sweep\x18\x02 \x01(\x0b\x32\x1f.cirq.google.api.v2.SingleSweepH\x00\x42\x07\n\x05sweep\"\xc6\x01\n\rSweepFunction\x12\x45\n\rfunction_type\x18\x01 \x01(\x0e\x32..cirq.google.api.v2.SweepFunction.FunctionType\x12)\n\x06sweeps\x18\x02 \x03(\x0b\x32\x19.cirq.google.api.v2.Sweep\"C\n\x0c\x46unctionType\x12\x1d\n\x19\x46UNCTION_TYPE_UNSPECIFIED\x10\x00\x12\x0b\n\x07PRODUCT\x10\x01\x12\x07\n\x03ZIP\x10\x02\"\x8d\x01\n\x0bSingleSweep\x12\x15\n\rparameter_key\x18\x01 \x01(\t\x12,\n\x06points\x18\x02 \x01(\x0b\x32\x1a.cirq.google.api.v2.PointsH\x00\x12\x30\n\x08linspace\x18\x03 \x01(\x0b\x32\x1c.cirq.google.api.v2.LinspaceH\x00\x42\x07\n\x05sweep\"\x18\n\x06Points\x12\x0e\n\x06points\x18\x01 \x03(\x02\"G\n\x08Linspace\x12\x13\n\x0b\x66irst_point\x18\x01 \x01(\x02\x12\x12\n\nlast_point\x18\x02 \x01(\x02\x12\x12\n\nnum_points\x18\x03 \x01(\x03\x42\x32\n\x1d\x63om.google.cirq.google.api.v2B\x0fRunContextProtoP\x01\x62\x06proto3' + serialized_pb=b'\n$cirq_google/api/v2/run_context.proto\x12\x12\x63irq.google.api.v2\"J\n\nRunContext\x12<\n\x10parameter_sweeps\x18\x01 \x03(\x0b\x32\".cirq.google.api.v2.ParameterSweep\"O\n\x0eParameterSweep\x12\x13\n\x0brepetitions\x18\x01 \x01(\x05\x12(\n\x05sweep\x18\x02 \x01(\x0b\x32\x19.cirq.google.api.v2.Sweep\"\x86\x01\n\x05Sweep\x12;\n\x0esweep_function\x18\x01 \x01(\x0b\x32!.cirq.google.api.v2.SweepFunctionH\x00\x12\x37\n\x0csingle_sweep\x18\x02 \x01(\x0b\x32\x1f.cirq.google.api.v2.SingleSweepH\x00\x42\x07\n\x05sweep\"\xc6\x01\n\rSweepFunction\x12\x45\n\rfunction_type\x18\x01 \x01(\x0e\x32..cirq.google.api.v2.SweepFunction.FunctionType\x12)\n\x06sweeps\x18\x02 \x03(\x0b\x32\x19.cirq.google.api.v2.Sweep\"C\n\x0c\x46unctionType\x12\x1d\n\x19\x46UNCTION_TYPE_UNSPECIFIED\x10\x00\x12\x0b\n\x07PRODUCT\x10\x01\x12\x07\n\x03ZIP\x10\x02\";\n\x0f\x44\x65viceParameter\x12\x0c\n\x04path\x18\x01 \x03(\t\x12\x0b\n\x03idx\x18\x02 \x01(\x03\x12\r\n\x05units\x18\x03 \x01(\t\"\xc5\x01\n\x0bSingleSweep\x12\x15\n\rparameter_key\x18\x01 \x01(\t\x12,\n\x06points\x18\x02 \x01(\x0b\x32\x1a.cirq.google.api.v2.PointsH\x00\x12\x30\n\x08linspace\x18\x03 \x01(\x0b\x32\x1c.cirq.google.api.v2.LinspaceH\x00\x12\x36\n\tparameter\x18\x04 \x01(\x0b\x32#.cirq.google.api.v2.DeviceParameterB\x07\n\x05sweep\"\x18\n\x06Points\x12\x0e\n\x06points\x18\x01 \x03(\x02\"G\n\x08Linspace\x12\x13\n\x0b\x66irst_point\x18\x01 \x01(\x02\x12\x12\n\nlast_point\x18\x02 \x01(\x02\x12\x12\n\nnum_points\x18\x03 \x01(\x03\x42\x32\n\x1d\x63om.google.cirq.google.api.v2B\x0fRunContextProtoP\x01\x62\x06proto3' ) @@ -210,6 +210,52 @@ ) +_DEVICEPARAMETER = _descriptor.Descriptor( + name='DeviceParameter', + full_name='cirq.google.api.v2.DeviceParameter', + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name='path', full_name='cirq.google.api.v2.DeviceParameter.path', index=0, + number=1, type=9, cpp_type=9, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='idx', full_name='cirq.google.api.v2.DeviceParameter.idx', index=1, + number=2, type=3, cpp_type=2, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='units', full_name='cirq.google.api.v2.DeviceParameter.units', index=2, + number=3, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=555, + serialized_end=614, +) + + _SINGLESWEEP = _descriptor.Descriptor( name='SingleSweep', full_name='cirq.google.api.v2.SingleSweep', @@ -239,6 +285,13 @@ message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='parameter', full_name='cirq.google.api.v2.SingleSweep.parameter', index=3, + number=4, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), ], extensions=[ ], @@ -256,8 +309,8 @@ create_key=_descriptor._internal_create_key, fields=[]), ], - serialized_start=556, - serialized_end=697, + serialized_start=617, + serialized_end=814, ) @@ -288,8 +341,8 @@ extension_ranges=[], oneofs=[ ], - serialized_start=699, - serialized_end=723, + serialized_start=816, + serialized_end=840, ) @@ -334,8 +387,8 @@ extension_ranges=[], oneofs=[ ], - serialized_start=725, - serialized_end=796, + serialized_start=842, + serialized_end=913, ) _RUNCONTEXT.fields_by_name['parameter_sweeps'].message_type = _PARAMETERSWEEP @@ -353,6 +406,7 @@ _SWEEPFUNCTION_FUNCTIONTYPE.containing_type = _SWEEPFUNCTION _SINGLESWEEP.fields_by_name['points'].message_type = _POINTS _SINGLESWEEP.fields_by_name['linspace'].message_type = _LINSPACE +_SINGLESWEEP.fields_by_name['parameter'].message_type = _DEVICEPARAMETER _SINGLESWEEP.oneofs_by_name['sweep'].fields.append( _SINGLESWEEP.fields_by_name['points']) _SINGLESWEEP.fields_by_name['points'].containing_oneof = _SINGLESWEEP.oneofs_by_name['sweep'] @@ -363,6 +417,7 @@ DESCRIPTOR.message_types_by_name['ParameterSweep'] = _PARAMETERSWEEP DESCRIPTOR.message_types_by_name['Sweep'] = _SWEEP DESCRIPTOR.message_types_by_name['SweepFunction'] = _SWEEPFUNCTION +DESCRIPTOR.message_types_by_name['DeviceParameter'] = _DEVICEPARAMETER DESCRIPTOR.message_types_by_name['SingleSweep'] = _SINGLESWEEP DESCRIPTOR.message_types_by_name['Points'] = _POINTS DESCRIPTOR.message_types_by_name['Linspace'] = _LINSPACE @@ -396,6 +451,13 @@ }) _sym_db.RegisterMessage(SweepFunction) +DeviceParameter = _reflection.GeneratedProtocolMessageType('DeviceParameter', (_message.Message,), { + 'DESCRIPTOR' : _DEVICEPARAMETER, + '__module__' : 'cirq_google.api.v2.run_context_pb2' + # @@protoc_insertion_point(class_scope:cirq.google.api.v2.DeviceParameter) + }) +_sym_db.RegisterMessage(DeviceParameter) + SingleSweep = _reflection.GeneratedProtocolMessageType('SingleSweep', (_message.Message,), { 'DESCRIPTOR' : _SINGLESWEEP, '__module__' : 'cirq_google.api.v2.run_context_pb2' diff --git a/cirq-google/cirq_google/api/v2/run_context_pb2.pyi b/cirq-google/cirq_google/api/v2/run_context_pb2.pyi index c40d14d3a76..5483eabdaa3 100644 --- a/cirq-google/cirq_google/api/v2/run_context_pb2.pyi +++ b/cirq-google/cirq_google/api/v2/run_context_pb2.pyi @@ -124,6 +124,25 @@ class SweepFunction(google___protobuf___message___Message): else: def ClearField(self, field_name: typing_extensions___Literal[b"function_type",b"sweeps"]) -> None: ... +class DeviceParameter(google___protobuf___message___Message): + path = ... # type: google___protobuf___internal___containers___RepeatedScalarFieldContainer[typing___Text] + idx = ... # type: int + units = ... # type: typing___Text + + def __init__(self, + path : typing___Optional[typing___Iterable[typing___Text]] = None, + idx : typing___Optional[int] = None, + units : typing___Optional[typing___Text] = None, + ) -> None: ... + @classmethod + def FromString(cls, s: bytes) -> DeviceParameter: ... + def MergeFrom(self, other_msg: google___protobuf___message___Message) -> None: ... + def CopyFrom(self, other_msg: google___protobuf___message___Message) -> None: ... + if sys.version_info >= (3,): + def ClearField(self, field_name: typing_extensions___Literal[u"idx",u"path",u"units"]) -> None: ... + else: + def ClearField(self, field_name: typing_extensions___Literal[b"idx",b"path",b"units"]) -> None: ... + class SingleSweep(google___protobuf___message___Message): parameter_key = ... # type: typing___Text @@ -133,21 +152,25 @@ class SingleSweep(google___protobuf___message___Message): @property def linspace(self) -> Linspace: ... + @property + def parameter(self) -> DeviceParameter: ... + def __init__(self, parameter_key : typing___Optional[typing___Text] = None, points : typing___Optional[Points] = None, linspace : typing___Optional[Linspace] = None, + parameter : typing___Optional[DeviceParameter] = None, ) -> None: ... @classmethod def FromString(cls, s: bytes) -> SingleSweep: ... def MergeFrom(self, other_msg: google___protobuf___message___Message) -> None: ... def CopyFrom(self, other_msg: google___protobuf___message___Message) -> None: ... if sys.version_info >= (3,): - def HasField(self, field_name: typing_extensions___Literal[u"linspace",u"points",u"sweep"]) -> bool: ... - def ClearField(self, field_name: typing_extensions___Literal[u"linspace",u"parameter_key",u"points",u"sweep"]) -> None: ... + def HasField(self, field_name: typing_extensions___Literal[u"linspace",u"parameter",u"points",u"sweep"]) -> bool: ... + def ClearField(self, field_name: typing_extensions___Literal[u"linspace",u"parameter",u"parameter_key",u"points",u"sweep"]) -> None: ... else: - def HasField(self, field_name: typing_extensions___Literal[u"linspace",b"linspace",u"points",b"points",u"sweep",b"sweep"]) -> bool: ... - def ClearField(self, field_name: typing_extensions___Literal[b"linspace",b"parameter_key",b"points",b"sweep"]) -> None: ... + def HasField(self, field_name: typing_extensions___Literal[u"linspace",b"linspace",u"parameter",b"parameter",u"points",b"points",u"sweep",b"sweep"]) -> bool: ... + def ClearField(self, field_name: typing_extensions___Literal[b"linspace",b"parameter",b"parameter_key",b"points",b"sweep"]) -> None: ... def WhichOneof(self, oneof_group: typing_extensions___Literal[u"sweep",b"sweep"]) -> typing_extensions___Literal["points","linspace"]: ... class Points(google___protobuf___message___Message): diff --git a/cirq-google/cirq_google/api/v2/sweeps.py b/cirq-google/cirq_google/api/v2/sweeps.py index 83210e86315..af1195523a8 100644 --- a/cirq-google/cirq_google/api/v2/sweeps.py +++ b/cirq-google/cirq_google/api/v2/sweeps.py @@ -19,6 +19,7 @@ import cirq from cirq_google.api.v2 import batch_pb2 from cirq_google.api.v2 import run_context_pb2 +from cirq_google.study.device_parameter import DeviceParameter def sweep_to_proto( @@ -54,9 +55,23 @@ def sweep_to_proto( out.single_sweep.linspace.first_point = sweep.start out.single_sweep.linspace.last_point = sweep.stop out.single_sweep.linspace.num_points = sweep.length + # Use duck-typing to support google-internal Parameter objects + if sweep.metadata and getattr(sweep.metadata, 'path', None): + out.single_sweep.parameter.path.extend(sweep.metadata.path) + if sweep.metadata and getattr(sweep.metadata, 'idx', None): + out.single_sweep.parameter.idx = sweep.metadata.idx + if sweep.metadata and getattr(sweep.metadata, 'units', None): + out.single_sweep.parameter.units = sweep.metadata.units elif isinstance(sweep, cirq.Points) and not isinstance(sweep.key, sympy.Expr): out.single_sweep.parameter_key = sweep.key out.single_sweep.points.points.extend(sweep.points) + # Use duck-typing to support google-internal Parameter objects + if sweep.metadata and getattr(sweep.metadata, 'path', None): + out.single_sweep.parameter.path.extend(sweep.metadata.path) + if sweep.metadata and getattr(sweep.metadata, 'idx', None): + out.single_sweep.parameter.idx = sweep.metadata.idx + if sweep.metadata and getattr(sweep.metadata, 'units', None): + out.single_sweep.parameter.units = sweep.metadata.units elif isinstance(sweep, cirq.ListSweep): sweep_dict: Dict[str, List[float]] = {} for param_resolver in sweep: @@ -88,15 +103,22 @@ def sweep_from_proto(msg: run_context_pb2.Sweep) -> cirq.Sweep: raise ValueError(f'invalid sweep function type: {func_type}') if which == 'single_sweep': key = msg.single_sweep.parameter_key + if msg.single_sweep.HasField("parameter"): + metadata = DeviceParameter( + path=msg.single_sweep.parameter.path, idx=msg.single_sweep.parameter.idx + ) + else: + metadata = None if msg.single_sweep.WhichOneof('sweep') == 'linspace': return cirq.Linspace( key=key, start=msg.single_sweep.linspace.first_point, stop=msg.single_sweep.linspace.last_point, length=msg.single_sweep.linspace.num_points, + metadata=metadata, ) if msg.single_sweep.WhichOneof('sweep') == 'points': - return cirq.Points(key=key, points=msg.single_sweep.points.points) + return cirq.Points(key=key, points=msg.single_sweep.points.points, metadata=metadata) raise ValueError(f'single sweep type not set: {msg}') diff --git a/cirq-google/cirq_google/api/v2/sweeps_test.py b/cirq-google/cirq_google/api/v2/sweeps_test.py index d4d1c2b8cbe..97569da9ea9 100644 --- a/cirq-google/cirq_google/api/v2/sweeps_test.py +++ b/cirq-google/cirq_google/api/v2/sweeps_test.py @@ -19,6 +19,7 @@ import cirq from cirq.study import sweeps +from cirq_google.study import DeviceParameter from cirq_google.api import v2 @@ -41,7 +42,19 @@ def _values(self) -> Iterator[float]: [ cirq.UnitSweep, cirq.Linspace('a', 0, 10, 100), + cirq.Linspace( + 'a', + 0, + 10, + 100, + metadata=DeviceParameter(path=['path', 'to', 'parameter'], idx=2, units='ns'), + ), cirq.Points('b', [1, 1.5, 2, 2.5, 3]), + cirq.Points( + 'b', + [1, 1.5, 2, 2.5, 3], + metadata=DeviceParameter(path=['path', 'to', 'parameter'], idx=2, units='GHz'), + ), cirq.Linspace('a', 0, 1, 5) * cirq.Linspace('b', 0, 1, 5), cirq.Points('a', [1, 2, 3]) + cirq.Linspace('b', 0, 1, 3), ( @@ -62,7 +75,11 @@ def test_sweep_to_proto_roundtrip(sweep): def test_sweep_to_proto_linspace(): - proto = v2.sweep_to_proto(cirq.Linspace('foo', 0, 1, 20)) + proto = v2.sweep_to_proto( + cirq.Linspace( + 'foo', 0, 1, 20, metadata=DeviceParameter(path=['path', 'to', 'parameter'], idx=2) + ) + ) assert isinstance(proto, v2.run_context_pb2.Sweep) assert proto.HasField('single_sweep') assert proto.single_sweep.parameter_key == 'foo' @@ -70,6 +87,11 @@ def test_sweep_to_proto_linspace(): assert proto.single_sweep.linspace.first_point == 0 assert proto.single_sweep.linspace.last_point == 1 assert proto.single_sweep.linspace.num_points == 20 + assert proto.single_sweep.parameter.path == ['path', 'to', 'parameter'] + assert proto.single_sweep.parameter.idx == 2 + assert v2.sweep_from_proto(proto).metadata == DeviceParameter( + path=['path', 'to', 'parameter'], idx=2 + ) def test_list_sweep_bad_expression():