Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cirq-core/cirq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
ConstantQubitNoiseModel,
Device,
DeviceMetadata,
GateDurationTable,
GridDeviceMetadata,
GridQid,
GridQubit,
Expand Down
2 changes: 1 addition & 1 deletion cirq-core/cirq/devices/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from cirq.devices.device import Device, DeviceMetadata

from cirq.devices.grid_device_metadata import GridDeviceMetadata
from cirq.devices.grid_device_metadata import GateDurationTable, GridDeviceMetadata

from cirq.devices.grid_qubit import GridQid, GridQubit

Expand Down
51 changes: 47 additions & 4 deletions cirq-core/cirq/devices/grid_device_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
"""Metadata subtype for 2D Homogenous devices."""

from typing import TYPE_CHECKING, Optional, FrozenSet, Iterable, Tuple, Dict
from typing import TYPE_CHECKING, Optional, FrozenSet, Iterable, Tuple, Dict, Union, Type

import networkx as nx
from cirq import value
Expand All @@ -23,6 +23,48 @@
import cirq


class GateDurationTable:
"""A lookup table for gate durations.

This class allows different instances of a `cirq.GateFamily` to be used as a key for duration
lookup, as opposed to using the GateFamily as the key in a simple dictionary format.

For example:
>>> gdt = cirq.GateDurationTable({
cirq.GateFamily(cirq.XPowGate): cirq.Duration(nanos=1),
})
>>> gdt[cirq.X] # Prints `cirq.Duration(nanos=1)`
>>> gdt[cirq.X**0.25] # Prints `cirq.Duration(nanos=1)`
"""

def __init__(self, gate_durations: Dict['cirq.GateFamily', 'cirq.Duration']) -> None:
self.gate_durations = gate_durations

def __getitem__(
self, gate: Union['cirq.Gate', Type['cirq.Gate'], 'cirq.GateFamily']
) -> 'cirq.Duration':

found_duration = None

for gf, duration in self.gate_durations.items():
if gate in gf or gate == gf:
if found_duration is not None:
# TODO(verult) Include overlapping gatefamilies in error message
raise RuntimeError(
"The given gate matches multiple durations. This may be due to an overlap"
" of GateFamilies provided by the device."
)
found_duration = duration

if found_duration is None:
raise KeyError("Gate not found in gate duration table.")

return found_duration

def __str__(self):
return str(self.gate_durations)


@value.value_equality
class GridDeviceMetadata(device.DeviceMetadata):
"""Hardware metadata for homogenous 2d symmetric grid devices."""
Expand Down Expand Up @@ -110,8 +152,9 @@ def __init__(
f" gate_durations={gate_durations}"
f" gateset.gates={gateset.gates}"
)

self._gate_durations = gate_durations
self._gate_durations = GateDurationTable(gate_durations)
else:
self._gate_durations = None

@property
def qubit_pairs(self) -> FrozenSet[FrozenSet['cirq.Qid']]:
Expand All @@ -138,7 +181,7 @@ def compilation_target_gatesets(self) -> Tuple['cirq.CompilationTargetGateset',
return self._compilation_target_gatesets

@property
def gate_durations(self) -> Optional[Dict['cirq.GateFamily', 'cirq.Duration']]:
def gate_durations(self) -> Optional[GateDurationTable]:
"""Get a dictionary mapping from gateset to duration for gates."""
return self._gate_durations

Expand Down