|
18 | 18 |
|
19 | 19 | from typing import Any, Dict, Iterator, Optional, Tuple, TYPE_CHECKING |
20 | 20 |
|
| 21 | +import google.protobuf.json_format as json_format |
21 | 22 | from cirq import devices, vis |
22 | 23 | from cirq.google.api import v2 |
23 | 24 |
|
@@ -45,14 +46,25 @@ class Calibration(abc.Mapping): |
45 | 46 |
|
46 | 47 | `calibration['t1']` |
47 | 48 |
|
| 49 | + This class can be instantiated either from a `MetricsSnapshot` proto |
| 50 | + or from a dictionary of metric values. |
| 51 | +
|
48 | 52 | Attributes: |
49 | 53 | timestamp: The time that this calibration was run, in milliseconds since |
50 | 54 | the epoch. |
51 | 55 | """ |
52 | 56 |
|
53 | | - def __init__(self, calibration: v2.metrics_pb2.MetricsSnapshot) -> None: |
| 57 | + def __init__(self, |
| 58 | + calibration: v2.metrics_pb2.MetricsSnapshot = v2.metrics_pb2. |
| 59 | + MetricsSnapshot(), |
| 60 | + metrics: Optional[ |
| 61 | + Dict[str, Dict[Tuple['cirq.GridQubit', ...], Any]]] = None |
| 62 | + ) -> None: |
54 | 63 | self.timestamp = calibration.timestamp_ms |
55 | | - self._metric_dict = self._compute_metric_dict(calibration.metrics) |
| 64 | + if metrics is None: |
| 65 | + self._metric_dict = self._compute_metric_dict(calibration.metrics) |
| 66 | + else: |
| 67 | + self._metric_dict = metrics |
56 | 68 |
|
57 | 69 | def _compute_metric_dict( |
58 | 70 | self, metrics: v2.metrics_pb2.MetricsSnapshot |
@@ -103,8 +115,47 @@ def __len__(self) -> int: |
103 | 115 | return len(self._metric_dict) |
104 | 116 |
|
105 | 117 | def __str__(self) -> str: |
106 | | - |
107 | | - return 'Calibration(keys={})'.format(list(sorted(self.keys()))) |
| 118 | + return f'Calibration(keys={list(sorted(self.keys()))})' |
| 119 | + |
| 120 | + def __repr__(self) -> str: |
| 121 | + return ('cirq.google.Calibration(metrics=' |
| 122 | + f'{repr(dict(self._metric_dict))})') |
| 123 | + |
| 124 | + def to_proto(self) -> v2.metrics_pb2.MetricsSnapshot: |
| 125 | + """Reconstruct the protobuf message represented by this class.""" |
| 126 | + proto = v2.metrics_pb2.MetricsSnapshot() |
| 127 | + for key in self._metric_dict: |
| 128 | + for target, value_list in self._metric_dict[key].items(): |
| 129 | + current_metric = proto.metrics.add() |
| 130 | + current_metric.name = key |
| 131 | + current_metric.targets.extend( |
| 132 | + [v2.qubit_to_proto_id(q) for q in target]) |
| 133 | + for value in value_list: |
| 134 | + current_value = current_metric.values.add() |
| 135 | + if isinstance(value, float): |
| 136 | + current_value.double_val = value |
| 137 | + elif isinstance(value, int): |
| 138 | + current_value.int64_val = value |
| 139 | + elif isinstance(value, str): |
| 140 | + current_value.str_val = value |
| 141 | + else: |
| 142 | + raise ValueError(f'Unsupported metric value {value}. ' |
| 143 | + 'Must be int, float, or str to ' |
| 144 | + 'convert to proto.') |
| 145 | + return proto |
| 146 | + |
| 147 | + @classmethod |
| 148 | + def _from_json_dict_(cls, metrics: str, **kwargs) -> 'Calibration': |
| 149 | + """Magic method for the JSON serialization protocol.""" |
| 150 | + metric_proto = v2.metrics_pb2.MetricsSnapshot() |
| 151 | + return cls(json_format.ParseDict(metrics, metric_proto)) |
| 152 | + |
| 153 | + def _json_dict_(self) -> Dict[str, Any]: |
| 154 | + """Magic method for the JSON serialization protocol.""" |
| 155 | + return { |
| 156 | + 'cirq_type': 'Calibration', |
| 157 | + 'metrics': json_format.MessageToDict(self.to_proto()) |
| 158 | + } |
108 | 159 |
|
109 | 160 | def timestamp_str(self, |
110 | 161 | tz: Optional[datetime.tzinfo] = None, |
|
0 commit comments