Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions cirq-google/cirq_google/devices/serializable_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
"""Device object for converting from device specification protos"""

from typing import Any, Callable, cast, Dict, Iterable, Optional, List, Set, Tuple, Type
from typing import Any, Callable, cast, Dict, Iterable, List, Optional, Set, Tuple, Type, Union
import cirq
from cirq_google.serialization import serializable_gate_set
from cirq_google.api import v2
Expand Down Expand Up @@ -63,6 +63,9 @@ def __eq__(self, other):
return self.__dict__ == other.__dict__


_GateOrFrozenCircuitTypes = Union[Type[cirq.Gate], Type[cirq.FrozenCircuit]]


class SerializableDevice(cirq.Device):
"""Device object generated from a device specification proto.

Expand All @@ -79,7 +82,9 @@ class SerializableDevice(cirq.Device):
"""

def __init__(
self, qubits: List[cirq.Qid], gate_definitions: Dict[Type[cirq.Gate], List[_GateDefinition]]
self,
qubits: List[cirq.Qid],
gate_definitions: Dict[_GateOrFrozenCircuitTypes, List[_GateDefinition]],
):
"""Constructor for SerializableDevice using python objects.

Expand All @@ -93,6 +98,7 @@ def __init__(
"""
self.qubits = qubits
self.gate_definitions = gate_definitions
has_subcircuit_support: bool = cirq.FrozenCircuit in gate_definitions
self._metadata = cirq.GridDeviceMetadata(
qubit_pairs=[
(pair[0], pair[1])
Expand All @@ -103,12 +109,9 @@ def __init__(
if len(pair) == 2 and pair[0] < pair[1]
],
gateset=cirq.Gateset(
*[
g
for g in gate_definitions.keys()
if isinstance(g, (cirq.Gate, type(cirq.Gate)))
],
*(g for g in gate_definitions.keys() if issubclass(g, cirq.Gate)),
cirq.GlobalPhaseGate,
unroll_circuit_op=has_subcircuit_support,
),
gate_durations=None,
)
Expand Down Expand Up @@ -174,7 +177,7 @@ def from_proto(
)

# Loop through serializers and map gate_definitions to type
gates_by_type: Dict[Type[cirq.Gate], List[_GateDefinition]] = {}
gates_by_type: Dict[_GateOrFrozenCircuitTypes, List[_GateDefinition]] = {}
for gate_set in gate_sets:
for internal_type in gate_set.supported_internal_types():
for serializer in gate_set.serializers[internal_type]:
Expand Down
10 changes: 10 additions & 0 deletions cirq-google/cirq_google/devices/serializable_device_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,16 @@ def test_serializable_device_str_named_qubits():
assert device.__class__.__name__ in str(device)


def test_serializable_device_gate_definitions_filter():
"""Ignore items in gate_definitions dictionary with invalid keys."""
device = cg.SerializableDevice(
qubits=[cirq.NamedQubit('a'), cirq.NamedQubit('b')],
gate_definitions={cirq.FSimGate: [], cirq.NoiseModel: []},
)
# Two gates for cirq.FSimGate and the cirq.GlobalPhaseGate default
assert len(device.metadata.gateset.gates) == 2


def test_sycamore23_str():
assert (
str(cg.Sycamore23)
Expand Down
11 changes: 7 additions & 4 deletions cirq-google/cirq_google/serialization/serializable_gate_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
from cirq_google.serialization import serializer, op_deserializer, op_serializer, arg_func_langs


_GateOrFrozenCircuitTypes = Union[Type[cirq.Gate], Type[cirq.FrozenCircuit]]


class SerializableGateSet(serializer.Serializer):
"""A class for serializing and deserializing programs and operations.

Expand All @@ -47,7 +50,7 @@ def __init__(
forms of gates or circuits into Operations.
"""
super().__init__(gate_set_name)
self.serializers: Dict[Type, List[op_serializer.OpSerializer]] = {}
self.serializers: Dict[_GateOrFrozenCircuitTypes, List[op_serializer.OpSerializer]] = {}
for s in serializers:
self.serializers.setdefault(s.internal_type, []).append(s)
self.deserializers = {d.serialized_id: d for d in deserializers}
Expand Down Expand Up @@ -77,7 +80,7 @@ def with_added_types(
deserializers=[*self.deserializers.values(), *deserializers],
)

def supported_internal_types(self) -> Tuple:
def supported_internal_types(self) -> Tuple[_GateOrFrozenCircuitTypes, ...]:
return tuple(self.serializers.keys())

def is_supported(self, op_tree: cirq.OP_TREE) -> bool:
Expand Down Expand Up @@ -194,8 +197,8 @@ def serialize_gate_op(
if gate_type_mro in self.serializers:
# Check each serializer in turn, if serializer proto returns
# None, then skip.
for serializer in self.serializers[gate_type_mro]:
proto_msg = serializer.to_proto(
for mro_serializer in self.serializers[gate_type_mro]:
proto_msg = mro_serializer.to_proto(
op,
msg,
arg_function_language=arg_function_language,
Expand Down