diff --git a/cirq-google/cirq_google/devices/grid_device.py b/cirq-google/cirq_google/devices/grid_device.py index 1560ebfa83e..73f65484f33 100644 --- a/cirq-google/cirq_google/devices/grid_device.py +++ b/cirq-google/cirq_google/devices/grid_device.py @@ -30,6 +30,7 @@ ) import re import warnings +from dataclasses import dataclass import cirq from cirq_google import ops @@ -40,14 +41,7 @@ # Gate family constants used in various parts of GridDevice logic. -_SYC_GATE_FAMILY = cirq.GateFamily(ops.SYC) -_SQRT_ISWAP_GATE_FAMILY = cirq.GateFamily(cirq.SQRT_ISWAP) -_SQRT_ISWAP_INV_GATE_FAMILY = cirq.GateFamily(cirq.SQRT_ISWAP_INV) -_CZ_GATE_FAMILY = cirq.GateFamily(cirq.CZ) _PHASED_XZ_GATE_FAMILY = cirq.GateFamily(cirq.PhasedXZGate) -_VIRTUAL_ZPOW_GATE_FAMILY = cirq.GateFamily(cirq.ZPowGate, tags_to_ignore=[ops.PhysicalZTag()]) -_PHYSICAL_ZPOW_GATE_FAMILY = cirq.GateFamily(cirq.ZPowGate, tags_to_accept=[ops.PhysicalZTag()]) -_COUPLER_PULSE_GATE_FAMILY = cirq.GateFamily(experimental_ops.CouplerPulse) _MEASUREMENT_GATE_FAMILY = cirq.GateFamily(cirq.MeasurementGate) _WAIT_GATE_FAMILY = cirq.GateFamily(cirq.WaitGate) @@ -74,6 +68,86 @@ _VARIADIC_GATE_FAMILIES = [_MEASUREMENT_GATE_FAMILY, _WAIT_GATE_FAMILY] +GateOrFamily = Union[Type[cirq.Gate], cirq.Gate, cirq.GateFamily] + + +@dataclass +class _GateRepresentations: + """Contains equivalent representations of a gate in both DeviceSpecification and GridDevice. + + Attributes: + gate_spec_name: The name of gate type in `GateSpecification`. + deserialized_forms: Gate representations to be included when the corresponding + `GateSpecification` gate type is deserialized into gatesets and gate durations. + serializable_forms: GateFamilies used to check whether a given gate can be serialized to the + gate type in this _GateRepresentation. + """ + + gate_spec_name: str + deserialized_forms: List[GateOrFamily] + serializable_forms: List[cirq.GateFamily] + + +"""Valid gates for a GridDevice.""" +_GATES: List[_GateRepresentations] = [ + _GateRepresentations( + gate_spec_name='syc', + deserialized_forms=[_SYC_FSIM_GATE_FAMILY], + serializable_forms=[_SYC_FSIM_GATE_FAMILY, cirq.GateFamily(ops.SYC)], + ), + _GateRepresentations( + gate_spec_name='sqrt_iswap', + deserialized_forms=[_SQRT_ISWAP_FSIM_GATE_FAMILY], + serializable_forms=[_SQRT_ISWAP_FSIM_GATE_FAMILY, cirq.GateFamily(cirq.SQRT_ISWAP)], + ), + _GateRepresentations( + gate_spec_name='sqrt_iswap_inv', + deserialized_forms=[_SQRT_ISWAP_INV_FSIM_GATE_FAMILY], + serializable_forms=[_SQRT_ISWAP_INV_FSIM_GATE_FAMILY, cirq.GateFamily(cirq.SQRT_ISWAP_INV)], + ), + _GateRepresentations( + gate_spec_name='cz', + deserialized_forms=[_CZ_FSIM_GATE_FAMILY], + serializable_forms=[_CZ_FSIM_GATE_FAMILY, cirq.GateFamily(cirq.CZ)], + ), + _GateRepresentations( + gate_spec_name='phased_xz', + deserialized_forms=[cirq.PhasedXZGate, cirq.XPowGate, cirq.YPowGate, cirq.PhasedXPowGate], + serializable_forms=[ + cirq.GateFamily(cirq.PhasedXZGate), + cirq.GateFamily(cirq.XPowGate), + cirq.GateFamily(cirq.YPowGate), + cirq.GateFamily(cirq.PhasedXPowGate), + ], + ), + _GateRepresentations( + gate_spec_name='virtual_zpow', + deserialized_forms=[cirq.GateFamily(cirq.ZPowGate, tags_to_ignore=[ops.PhysicalZTag()])], + serializable_forms=[cirq.GateFamily(cirq.ZPowGate, tags_to_ignore=[ops.PhysicalZTag()])], + ), + _GateRepresentations( + gate_spec_name='physical_zpow', + deserialized_forms=[cirq.GateFamily(cirq.ZPowGate, tags_to_accept=[ops.PhysicalZTag()])], + serializable_forms=[cirq.GateFamily(cirq.ZPowGate, tags_to_accept=[ops.PhysicalZTag()])], + ), + _GateRepresentations( + gate_spec_name='coupler_pulse', + deserialized_forms=[experimental_ops.CouplerPulse], + serializable_forms=[cirq.GateFamily(experimental_ops.CouplerPulse)], + ), + _GateRepresentations( + gate_spec_name='meas', + deserialized_forms=[cirq.MeasurementGate], + serializable_forms=[cirq.GateFamily(cirq.MeasurementGate)], + ), + _GateRepresentations( + gate_spec_name='wait', + deserialized_forms=[cirq.WaitGate], + serializable_forms=[cirq.GateFamily(cirq.WaitGate)], + ), +] + + def _validate_device_specification(proto: v2.device_pb2.DeviceSpecification) -> None: """Raises a ValueError if the `DeviceSpecification` proto is invalid.""" @@ -93,7 +167,6 @@ def _validate_device_specification(proto: v2.device_pb2.DeviceSpecification) -> qubit_set.add(q_name) for target_set in proto.valid_targets: - # Check for unknown qubits in targets. for target in target_set.targets: for target_id in target.ids: @@ -120,41 +193,63 @@ def _validate_device_specification(proto: v2.device_pb2.DeviceSpecification) -> raise ValueError("Invalid DeviceSpecification: target_ordering cannot be ASYMMETRIC.") -def _build_gateset_and_gate_durations( +def _serialize_gateset_and_gate_durations( + out: v2.device_pb2.DeviceSpecification, + gateset: cirq.Gateset, + gate_durations: Mapping[cirq.GateFamily, cirq.Duration], +) -> v2.device_pb2.DeviceSpecification: + """Serializes the given gateset and gate durations to DeviceSpecification.""" + + gate_specs: Dict[str, v2.device_pb2.GateSpecification] = {} + for gate_family in gateset.gates: + gate_spec = v2.device_pb2.GateSpecification() + gate_rep = next( + (gr for gr in _GATES for gf in gr.serializable_forms if gf == gate_family), None + ) + if gate_rep is None: + raise ValueError(f'Unrecognized gate: {gate_family}.') + gate_name = gate_rep.gate_spec_name + + # Set gate + getattr(gate_spec, gate_name).SetInParent() + + # Set gate duration + gate_durations_picos = { + int(gate_durations[gf].total_picos()) + for gf in gate_rep.serializable_forms + if gf in gate_durations + } + if len(gate_durations_picos) > 1: + raise ValueError( + 'Multiple gate families in the following list exist in the gate duration dict, and ' + f'they are expected to have the same duration value: {gate_rep.serializable_forms}' + ) + elif len(gate_durations_picos) == 1: + gate_spec.gate_duration_picos = gate_durations_picos.pop() + + # GateSpecification dedup. Multiple gates or GateFamilies in the gateset could map to the + # same GateSpecification. + gate_specs[gate_name] = gate_spec + + # Sort by gate name to keep valid_gates stable. + out.valid_gates.extend(v for _, v in sorted(gate_specs.items())) + + return out + + +def _deserialize_gateset_and_gate_durations( proto: v2.device_pb2.DeviceSpecification, ) -> Tuple[cirq.Gateset, Mapping[cirq.GateFamily, cirq.Duration]]: - """Extracts gate set and gate duration information from the given DeviceSpecification proto.""" + """Deserializes gateset and gate duration from DeviceSpecification.""" - gates_list: List[Union[Type[cirq.Gate], cirq.Gate, cirq.GateFamily]] = [] + gates_list: List[GateOrFamily] = [] gate_durations: Dict[cirq.GateFamily, cirq.Duration] = {} - # TODO(#5050) Describe how to add/remove gates. - for gate_spec in proto.valid_gates: gate_name = gate_spec.WhichOneof('gate') - cirq_gates: List[Union[Type[cirq.Gate], cirq.Gate, cirq.GateFamily]] = [] - - if gate_name == 'syc': - cirq_gates = [_SYC_FSIM_GATE_FAMILY] - elif gate_name == 'sqrt_iswap': - cirq_gates = [_SQRT_ISWAP_FSIM_GATE_FAMILY] - elif gate_name == 'sqrt_iswap_inv': - cirq_gates = [_SQRT_ISWAP_INV_FSIM_GATE_FAMILY] - elif gate_name == 'cz': - cirq_gates = [_CZ_FSIM_GATE_FAMILY] - elif gate_name == 'phased_xz': - cirq_gates = [cirq.PhasedXZGate, cirq.XPowGate, cirq.YPowGate, cirq.PhasedXPowGate] - elif gate_name == 'virtual_zpow': - cirq_gates = [_VIRTUAL_ZPOW_GATE_FAMILY] - elif gate_name == 'physical_zpow': - cirq_gates = [_PHYSICAL_ZPOW_GATE_FAMILY] - elif gate_name == 'coupler_pulse': - cirq_gates = [experimental_ops.CouplerPulse] - elif gate_name == 'meas': - cirq_gates = [cirq.MeasurementGate] - elif gate_name == 'wait': - cirq_gates = [cirq.WaitGate] - else: + + gate_rep = next((gr for gr in _GATES if gr.gate_spec_name == gate_name), None) + if gate_rep is None: # coverage: ignore warnings.warn( f"The DeviceSpecification contains the gate '{gate_name}' which is not recognized" @@ -163,11 +258,8 @@ def _build_gateset_and_gate_durations( ) continue - gates_list.extend(cirq_gates) - - # TODO(#5050) Allow different gate representations of the same gate to be looked up in - # gate_durations. - for g in cirq_gates: + gates_list.extend(gate_rep.deserialized_forms) + for g in gate_rep.deserialized_forms: if not isinstance(g, cirq.GateFamily): g = cirq.GateFamily(g) gate_durations[g] = cirq.Duration(picos=gate_spec.gate_duration_picos) @@ -316,20 +408,21 @@ class GridDevice(cirq.Device): https://github.com/quantumlib/Cirq/blob/master/cirq-google/cirq_google/api/v2/device.proto ) is the main specification for device information surfaced by the Quantum Computing Service. - Thus, this class is should be instantiated using a `DeviceSpecification` proto via the + Thus, this class should typically be instantiated using a `DeviceSpecification` proto via the `from_proto()` class method. """ def __init__(self, metadata: cirq.GridDeviceMetadata): """Creates a GridDevice object. - This constructor typically should not be used directly. Use `from_proto()` instead. + This constructor should not be used directly outside the class implementation. Use + `from_proto()` instead. """ self._metadata = metadata @classmethod def from_proto(cls, proto: v2.device_pb2.DeviceSpecification) -> 'GridDevice': - """Create a `GridDevice` from a `DeviceSpecification` proto. + """Deserializes the `DeviceSpecification` to a `GridDevice`. Args: proto: The `DeviceSpecification` proto describing a Google device. @@ -357,7 +450,7 @@ def from_proto(cls, proto: v2.device_pb2.DeviceSpecification) -> 'GridDevice': if len(target.ids) == 2 and ts.target_ordering == v2.device_pb2.TargetSet.SYMMETRIC ] - gateset, gate_durations = _build_gateset_and_gate_durations(proto) + gateset, gate_durations = _deserialize_gateset_and_gate_durations(proto) try: metadata = cirq.GridDeviceMetadata( @@ -373,6 +466,107 @@ def from_proto(cls, proto: v2.device_pb2.DeviceSpecification) -> 'GridDevice': return GridDevice(metadata) + def to_proto( + self, out: Optional[v2.device_pb2.DeviceSpecification] = None + ) -> v2.device_pb2.DeviceSpecification: + """Serializes the GridDevice to a DeviceSpecification. + + Args: + out: Optional DeviceSpecification to be populated. Fields are populated in-place. + + Returns: + The populated DeviceSpecification if out is specified, or the newly created + DeviceSpecification. + """ + qubits = self._metadata.qubit_set + unordered_pairs = [tuple(pair_set) for pair_set in self._metadata.qubit_pairs] + pairs = sorted((q0, q1) if q0 <= q1 else (q1, q0) for q0, q1 in unordered_pairs) + gateset = self._metadata.gateset + gate_durations = self._metadata.gate_durations + + if out is None: + out = v2.device_pb2.DeviceSpecification() + + # If fields are already filled (i.e. as part of the old DeviceSpecification format), leave + # them as is. Fields populated in the new format do not conflict with how they were + # populated in the old format. + # TODO(#5050) remove empty checks below once deprecated fields in DeviceSpecification are + # removed. + + if not out.valid_qubits: + known_devices.populate_qubits_in_device_proto(qubits, out) + if not out.valid_targets: + known_devices.populate_qubit_pairs_in_device_proto(pairs, out) + _serialize_gateset_and_gate_durations( + out, gateset, {} if gate_durations is None else gate_durations + ) + _validate_device_specification(out) + + return out + + @classmethod + def _from_device_information( + cls, + *, + qubit_pairs: Collection[Tuple[cirq.GridQubit, cirq.GridQubit]], + gateset: cirq.Gateset, + gate_durations: Optional[Mapping['cirq.GateFamily', 'cirq.Duration']] = None, + ) -> 'GridDevice': + """Constructs a GridDevice using the device information provided. + + EXPERIMENTAL: this method may have changes which are not backward compatible in the future. + + This is a convenience method for constructing a GridDevice given partial gateset and + gate_duration information: for every distinct gate, only one representation needs to be in + gateset and gate_duration. The remaining representations will be automatically generated. + + For example, if the input gateset contains only `cirq.PhasedXZGate`, and the input + gate_durations is `{cirq.GateFamily(cirq.PhasedXZGate): cirq.Duration(picos=3)}`, + `GridDevice.metadata.gateset` will be + + ``` + cirq.Gateset(cirq.PhasedXZGate, cirq.XPowGate, cirq.YPowGate, cirq.PhasedXPowGate) + ``` + + and `GridDevice.metadata.gate_durations` will be + + ``` + { + cirq.GateFamily(cirq.PhasedXZGate): cirq.Duration(picos=3), + cirq.GateFamily(cirq.XPowGate): cirq.Duration(picos=3), + cirq.GateFamily(cirq.YPowGate): cirq.Duration(picos=3), + cirq.GateFamily(cirq.PhasedXPowGate): cirq.Duration(picos=3), + } + ``` + + This method reduces the complexity of constructing `GridDevice` on server side by requiring + only the bare essential device information. + + Args: + qubit_pairs: Collection of bidirectional qubit couplings available on the device. + gateset: The gate set supported by the device. + gate_durations: Optional mapping from gates supported by the device to their timing + estimates. Not every gate is required to have an associated duration. + out: If set, device information will be serialized into this DeviceSpecification. + + Raises: + ValueError: If a pair contains two identical qubits. + ValueError: If `gateset` contains invalid GridDevice gates. + ValueError: If `gate_durations` contains keys which are not in `gateset`. + ValueError: If multiple gate families in gate_durations can + represent a particular gate, but they have different durations. + """ + metadata = cirq.GridDeviceMetadata( + qubit_pairs=qubit_pairs, gateset=gateset, gate_durations=gate_durations + ) + incomplete_device = GridDevice(metadata) + # incomplete_device may have incomplete gateset and gate durations information, as described + # in the docstring. + # To generate the full gateset and gate durations, we rely on the device deserialization + # logic by first serializing then deserializing the fake device, to ensure that the + # resulting device is consistent with one that is deserialized from DeviceSpecification. + return GridDevice.from_proto(incomplete_device.to_proto()) + @property def metadata(self) -> cirq.GridDeviceMetadata: """Get metadata information for the device.""" @@ -456,104 +650,3 @@ def _from_json_dict_(cls, metadata, **kwargs): def _value_equality_values_(self): return self._metadata - - -def _set_gate_in_gate_spec( - gate_spec: v2.device_pb2.GateSpecification, gate_family: cirq.GateFamily -) -> None: - if gate_family == _SYC_GATE_FAMILY or gate_family == _SYC_FSIM_GATE_FAMILY: - gate_spec.syc.SetInParent() - elif gate_family == _SQRT_ISWAP_GATE_FAMILY or gate_family == _SQRT_ISWAP_FSIM_GATE_FAMILY: - gate_spec.sqrt_iswap.SetInParent() - elif ( - gate_family == _SQRT_ISWAP_INV_GATE_FAMILY - or gate_family == _SQRT_ISWAP_INV_FSIM_GATE_FAMILY - ): - gate_spec.sqrt_iswap_inv.SetInParent() - elif gate_family == _CZ_GATE_FAMILY or gate_family == _CZ_FSIM_GATE_FAMILY: - gate_spec.cz.SetInParent() - elif gate_family == _PHASED_XZ_GATE_FAMILY: - gate_spec.phased_xz.SetInParent() - elif gate_family == _VIRTUAL_ZPOW_GATE_FAMILY: - gate_spec.virtual_zpow.SetInParent() - elif gate_family == _PHYSICAL_ZPOW_GATE_FAMILY: - gate_spec.physical_zpow.SetInParent() - elif gate_family == _COUPLER_PULSE_GATE_FAMILY: - gate_spec.coupler_pulse.SetInParent() - elif gate_family == _MEASUREMENT_GATE_FAMILY: - gate_spec.meas.SetInParent() - elif gate_family == _WAIT_GATE_FAMILY: - gate_spec.wait.SetInParent() - else: - raise ValueError(f'Unrecognized gate {gate_family}.') - - -def _create_device_specification_proto( - *, - qubits: Collection[cirq.GridQubit], - pairs: Collection[Tuple[cirq.GridQubit, cirq.GridQubit]], - gateset: cirq.Gateset, - gate_durations: Optional[Mapping['cirq.GateFamily', 'cirq.Duration']] = None, - out: Optional[v2.device_pb2.DeviceSpecification] = None, -) -> v2.device_pb2.DeviceSpecification: - """Serializes the given device information into a DeviceSpecification proto. - - EXPERIMENTAL: DeviceSpecification serialization API may change. - - This function does not serialize a `GridDevice`. Instead, it only takes a subset of device - information sufficient to populate the `DeviceSpecification` proto. This reduces the complexity - of constructing `DeviceSpecification` and `GridDevice` on server side by requiring only the bare - essential device information. - - Args: - qubits: Collection of qubits available on the device. - pairs: Collection of bidirectional qubit couplings available on the device. - gateset: The gate set supported by the device. - gate_durations: Optional mapping from gates supported by the device to their timing - estimates. Not every gate is required to have an associated duration. - out: If set, device information will be serialized into this DeviceSpecification. - - Raises: - ValueError: If a qubit in `pairs` is not part of `qubits`. - ValueError: If a pair contains two identical qubits. - ValueError: If `gate_durations` contains keys which are not in `gateset`. - ValueError: If `gateset` contains a gate which is not recognized by DeviceSpecification. - """ - - if gate_durations is not None: - extra_gate_families = (gate_durations.keys() | gateset.gates) - gateset.gates - if extra_gate_families: - raise ValueError( - 'Gate durations contain keys which are not part of the gateset:' - f' {extra_gate_families}' - ) - - if out is None: - out = v2.device_pb2.DeviceSpecification() - - # If fields are already filled (i.e. as part of the old DeviceSpecification format), leave them - # as is. Fields populated in the new format do not conflict with how they were populated in the - # old format. - # TODO(#5050) remove empty checks below once deprecated fields in DeviceSpecification are - # removed. - - if len(out.valid_qubits) == 0: - known_devices.populate_qubits_in_device_proto(qubits, out) - - if len(out.valid_targets) == 0: - known_devices.populate_qubit_pairs_in_device_proto(pairs, out) - - gate_specs = [] - for gate_family in gateset.gates: - gate_spec = v2.device_pb2.GateSpecification() - _set_gate_in_gate_spec(gate_spec, gate_family) - if gate_durations is not None and gate_family in gate_durations: - gate_spec.gate_duration_picos = int(gate_durations[gate_family].total_picos()) - gate_specs.append(gate_spec) - - # Sort by gate name to keep valid_gates stable. - out.valid_gates.extend(sorted(gate_specs, key=lambda s: s.WhichOneof('gate'))) - - _validate_device_specification(out) - - return out diff --git a/cirq-google/cirq_google/devices/grid_device_test.py b/cirq-google/cirq_google/devices/grid_device_test.py index 029c6088edc..a7307ef2dbc 100644 --- a/cirq-google/cirq_google/devices/grid_device_test.py +++ b/cirq-google/cirq_google/devices/grid_device_test.py @@ -17,7 +17,6 @@ import unittest.mock as mock import pytest -from google.protobuf import text_format import cirq import cirq_google @@ -440,11 +439,28 @@ def test_grid_device_repr_pretty(cycle, func): printer.text.assert_called_once_with(func(device)) -def test_to_proto(): - device_info, expected_spec = _create_device_spec_with_horizontal_couplings() +def test_device_from_device_information_equals_device_from_proto(): + device_info, spec = _create_device_spec_with_horizontal_couplings() - # The set of gates in gate_durations are consistent with what's generated in + # The set of gates in gateset and gate durations are consistent with what's generated in # _create_device_spec_with_horizontal_couplings() + gateset = cirq.Gateset( + cirq_google.SYC, + cirq.SQRT_ISWAP, + cirq.SQRT_ISWAP_INV, + cirq.CZ, + cirq.ops.phased_x_z_gate.PhasedXZGate, + cirq.GateFamily( + cirq.ops.common_gates.ZPowGate, tags_to_ignore=[cirq_google.PhysicalZTag()] + ), + cirq.GateFamily( + cirq.ops.common_gates.ZPowGate, tags_to_accept=[cirq_google.PhysicalZTag()] + ), + cirq_google.experimental.ops.coupler_pulse.CouplerPulse, + cirq.ops.measurement_gate.MeasurementGate, + cirq.ops.wait_gate.WaitGate, + ) + base_duration = cirq.Duration(picos=1_000) gate_durations = { cirq.GateFamily(cirq_google.SYC): base_duration * 0, @@ -465,69 +481,54 @@ def test_to_proto(): cirq.GateFamily(cirq.ops.wait_gate.WaitGate): base_duration * 9, } - spec = grid_device._create_device_specification_proto( - qubits=device_info.grid_qubits, - pairs=device_info.qubit_pairs, - gateset=cirq.Gateset(*gate_durations.keys()), - gate_durations=gate_durations, + device_from_information = cirq_google.GridDevice._from_device_information( + qubit_pairs=device_info.qubit_pairs, gateset=gateset, gate_durations=gate_durations ) - assert text_format.MessageToString(spec) == text_format.MessageToString(expected_spec) + assert device_from_information == cirq_google.GridDevice.from_proto(spec) @pytest.mark.parametrize( - 'error_match, qubits, qubit_pairs, gateset, gate_durations', + 'error_match, qubit_pairs, gateset, gate_durations', [ ( - 'Gate durations contain keys which are not part of the gateset', - [cirq.GridQubit(0, 0)], - [], - cirq.Gateset(cirq.CZ), - {cirq.GateFamily(cirq.SQRT_ISWAP): 1_000}, + 'Self loop encountered in qubit', + [(cirq.GridQubit(0, 0), cirq.GridQubit(0, 0))], + cirq.Gateset(), + None, ), - ('not in the GridQubit form', [cirq.NamedQubit('q0_0')], [], cirq.Gateset(), None), ( - 'valid_targets contain .* which is not in valid_qubits', - [cirq.GridQubit(0, 0)], + 'Unrecognized gate', [(cirq.GridQubit(0, 0), cirq.GridQubit(0, 1))], - cirq.Gateset(), + cirq.Gateset(cirq.H), None, ), ( - 'has a target which contains repeated qubits', - [cirq.GridQubit(0, 0)], - [(cirq.GridQubit(0, 0), cirq.GridQubit(0, 0))], - cirq.Gateset(), - None, + 'Some gate_durations keys are not found in gateset', + [(cirq.GridQubit(0, 0), cirq.GridQubit(0, 1))], + cirq.Gateset(cirq.CZ), + {cirq.GateFamily(cirq.SQRT_ISWAP): cirq.Duration(picos=1_000)}, + ), + ( + 'Multiple gate families .* expected to have the same duration value', + [(cirq.GridQubit(0, 0), cirq.GridQubit(0, 1))], + cirq.Gateset(cirq.PhasedXZGate, cirq.XPowGate), + { + cirq.GateFamily(cirq.PhasedXZGate): cirq.Duration(picos=1_000), + cirq.GateFamily(cirq.XPowGate): cirq.Duration(picos=2_000), + }, ), - ('Unrecognized gate', [cirq.GridQubit(0, 0)], [], cirq.Gateset(cirq.H), None), ], ) -def test_to_proto_invalid_input(error_match, qubits, qubit_pairs, gateset, gate_durations): +def test_from_device_information_invalid_input(error_match, qubit_pairs, gateset, gate_durations): with pytest.raises(ValueError, match=error_match): - grid_device._create_device_specification_proto( - qubits=qubits, pairs=qubit_pairs, gateset=gateset, gate_durations=gate_durations + grid_device.GridDevice._from_device_information( + qubit_pairs=qubit_pairs, gateset=gateset, gate_durations=gate_durations ) -def test_to_proto_empty(): - spec = grid_device._create_device_specification_proto( - # Qubits are always expected to be set - qubits=[cirq.GridQubit(0, i) for i in range(5)], - pairs=[], - gateset=cirq.Gateset(), - gate_durations=None, - ) - device = cirq_google.GridDevice.from_proto(spec) - - assert len(device.metadata.qubit_set) == 5 - assert len(device.metadata.qubit_pairs) == 0 - assert device.metadata.gateset == cirq.Gateset() - assert device.metadata.gate_durations is None - - -def test_to_proto_fsim_gate_family(): - """Verifies that FSimGateFamilies are serialized correctly.""" +def test_from_device_information_fsim_gate_family(): + """Verifies that FSimGateFamilies are recognized correctly.""" gateset = cirq.Gateset( cirq_google.FSimGateFamily(gates_to_accept=[cirq_google.SYC]), @@ -536,11 +537,55 @@ def test_to_proto_fsim_gate_family(): cirq_google.FSimGateFamily(gates_to_accept=[cirq.CZ]), ) - spec = grid_device._create_device_specification_proto( - qubits=[cirq.GridQubit(0, 0)], pairs=(), gateset=gateset + device = grid_device.GridDevice._from_device_information( + qubit_pairs=[(cirq.GridQubit(0, 0), cirq.GridQubit(0, 1))], gateset=gateset + ) + + assert gateset.gates.issubset(device.metadata.gateset.gates) + + +def test_from_device_information_empty(): + device = grid_device.GridDevice._from_device_information( + qubit_pairs=[], gateset=cirq.Gateset(), gate_durations=None ) - assert any(gate_spec.HasField('syc') for gate_spec in spec.valid_gates) - assert any(gate_spec.HasField('sqrt_iswap') for gate_spec in spec.valid_gates) - assert any(gate_spec.HasField('sqrt_iswap_inv') for gate_spec in spec.valid_gates) - assert any(gate_spec.HasField('cz') for gate_spec in spec.valid_gates) + assert len(device.metadata.qubit_set) == 0 + assert len(device.metadata.qubit_pairs) == 0 + assert device.metadata.gateset == cirq.Gateset() + assert device.metadata.gate_durations is None + + +def test_to_proto(): + device_info, expected_spec = _create_device_spec_with_horizontal_couplings() + + # The set of gates in gate_durations are consistent with what's generated in + # _create_device_spec_with_horizontal_couplings() + base_duration = cirq.Duration(picos=1_000) + gate_durations = { + cirq.GateFamily(cirq_google.SYC): base_duration * 0, + cirq.GateFamily(cirq.SQRT_ISWAP): base_duration * 1, + cirq.GateFamily(cirq.SQRT_ISWAP_INV): base_duration * 2, + cirq.GateFamily(cirq.CZ): base_duration * 3, + cirq.GateFamily(cirq.ops.phased_x_z_gate.PhasedXZGate): base_duration * 4, + cirq.GateFamily( + cirq.ops.common_gates.ZPowGate, tags_to_ignore=[cirq_google.PhysicalZTag()] + ): base_duration + * 5, + cirq.GateFamily( + cirq.ops.common_gates.ZPowGate, tags_to_accept=[cirq_google.PhysicalZTag()] + ): base_duration + * 6, + cirq.GateFamily(cirq_google.experimental.ops.coupler_pulse.CouplerPulse): base_duration * 7, + cirq.GateFamily(cirq.ops.measurement_gate.MeasurementGate): base_duration * 8, + cirq.GateFamily(cirq.ops.wait_gate.WaitGate): base_duration * 9, + } + + spec = cirq_google.GridDevice._from_device_information( + qubit_pairs=device_info.qubit_pairs, + gateset=cirq.Gateset(*gate_durations.keys()), + gate_durations=gate_durations, + ).to_proto() + + assert cirq_google.GridDevice.from_proto(spec) == cirq_google.GridDevice.from_proto( + expected_spec + ) diff --git a/cirq-google/cirq_google/devices/known_devices.py b/cirq-google/cirq_google/devices/known_devices.py index c4f1b58b875..9ffe4d01336 100644 --- a/cirq-google/cirq_google/devices/known_devices.py +++ b/cirq-google/cirq_google/devices/known_devices.py @@ -57,7 +57,6 @@ def _create_grid_device_from_diagram( ascii_grid: str, gateset: cirq.Gateset, gate_durations: Optional[Dict['cirq.GateFamily', 'cirq.Duration']] = None, - out: Optional[device_pb2.DeviceSpecification] = None, ) -> grid_device.GridDevice: """Parse ASCIIart device layout into a GridDevice instance. @@ -80,10 +79,9 @@ def _create_grid_device_from_diagram( if neighbor > qubit and neighbor in qubit_set: pairs.append((qubit, cast(cirq.GridQubit, neighbor))) - device_specification = grid_device._create_device_specification_proto( - qubits=qubits, pairs=pairs, gateset=gateset, gate_durations=gate_durations, out=out + return grid_device.GridDevice._from_device_information( + qubit_pairs=pairs, gateset=gateset, gate_durations=gate_durations ) - return grid_device.GridDevice.from_proto(device_specification) def populate_qubits_in_device_proto(