diff --git a/cirq-core/cirq/__init__.py b/cirq-core/cirq/__init__.py index a82f106526d..3aeace7669d 100644 --- a/cirq-core/cirq/__init__.py +++ b/cirq-core/cirq/__init__.py @@ -83,6 +83,7 @@ ConstantQubitNoiseModel, Device, DeviceMetadata, + GateDurationTable, GridDeviceMetadata, GridQid, GridQubit, diff --git a/cirq-core/cirq/devices/__init__.py b/cirq-core/cirq/devices/__init__.py index 7ff622946a1..f16b22e4b3f 100644 --- a/cirq-core/cirq/devices/__init__.py +++ b/cirq-core/cirq/devices/__init__.py @@ -16,7 +16,7 @@ from cirq.devices.device import Device, DeviceMetadata -from cirq.devices.grid_device_metadata import GridDeviceMetadata +from cirq.devices.grid_device_metadata import GateDurationTable, GridDeviceMetadata from cirq.devices.grid_qubit import GridQid, GridQubit diff --git a/cirq-core/cirq/devices/grid_device_metadata.py b/cirq-core/cirq/devices/grid_device_metadata.py index 00da05e27a9..7a64bd9f4c8 100644 --- a/cirq-core/cirq/devices/grid_device_metadata.py +++ b/cirq-core/cirq/devices/grid_device_metadata.py @@ -13,7 +13,7 @@ # limitations under the License. """Metadata subtype for 2D Homogenous devices.""" -from typing import TYPE_CHECKING, Optional, FrozenSet, Iterable, Tuple, Dict +from typing import TYPE_CHECKING, Optional, FrozenSet, Iterable, Tuple, Dict, Union, Type import networkx as nx from cirq import value @@ -23,6 +23,48 @@ import cirq +class GateDurationTable: + """A lookup table for gate durations. + + This class allows different instances of a `cirq.GateFamily` to be used as a key for duration + lookup, as opposed to using the GateFamily as the key in a simple dictionary format. + + For example: + >>> gdt = cirq.GateDurationTable({ + cirq.GateFamily(cirq.XPowGate): cirq.Duration(nanos=1), + }) + >>> gdt[cirq.X] # Prints `cirq.Duration(nanos=1)` + >>> gdt[cirq.X**0.25] # Prints `cirq.Duration(nanos=1)` + """ + + def __init__(self, gate_durations: Dict['cirq.GateFamily', 'cirq.Duration']) -> None: + self.gate_durations = gate_durations + + def __getitem__( + self, gate: Union['cirq.Gate', Type['cirq.Gate'], 'cirq.GateFamily'] + ) -> 'cirq.Duration': + + found_duration = None + + for gf, duration in self.gate_durations.items(): + if gate in gf or gate == gf: + if found_duration is not None: + # TODO(verult) Include overlapping gatefamilies in error message + raise RuntimeError( + "The given gate matches multiple durations. This may be due to an overlap" + " of GateFamilies provided by the device." + ) + found_duration = duration + + if found_duration is None: + raise KeyError("Gate not found in gate duration table.") + + return found_duration + + def __str__(self): + return str(self.gate_durations) + + @value.value_equality class GridDeviceMetadata(device.DeviceMetadata): """Hardware metadata for homogenous 2d symmetric grid devices.""" @@ -110,8 +152,9 @@ def __init__( f" gate_durations={gate_durations}" f" gateset.gates={gateset.gates}" ) - - self._gate_durations = gate_durations + self._gate_durations = GateDurationTable(gate_durations) + else: + self._gate_durations = None @property def qubit_pairs(self) -> FrozenSet[FrozenSet['cirq.Qid']]: @@ -138,7 +181,7 @@ def compilation_target_gatesets(self) -> Tuple['cirq.CompilationTargetGateset', return self._compilation_target_gatesets @property - def gate_durations(self) -> Optional[Dict['cirq.GateFamily', 'cirq.Duration']]: + def gate_durations(self) -> Optional[GateDurationTable]: """Get a dictionary mapping from gateset to duration for gates.""" return self._gate_durations