Skip to content

Commit c425c8a

Browse files
authored
Add JSON serialization for Calibration metrics (#3508)
* Add JSON serialization for Calibration metrics - Add json_dict and from_json_dict for serialization of Calibration objects using the underlying proto representation. - Add a valid repr for Calibration objects.
1 parent 0693f79 commit c425c8a

File tree

6 files changed

+157
-5
lines changed

6 files changed

+157
-5
lines changed

cirq/google/engine/calibration.py

Lines changed: 55 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
from typing import Any, Dict, Iterator, Optional, Tuple, TYPE_CHECKING
2020

21+
import google.protobuf.json_format as json_format
2122
from cirq import devices, vis
2223
from cirq.google.api import v2
2324

@@ -45,14 +46,25 @@ class Calibration(abc.Mapping):
4546
4647
`calibration['t1']`
4748
49+
This class can be instantiated either from a `MetricsSnapshot` proto
50+
or from a dictionary of metric values.
51+
4852
Attributes:
4953
timestamp: The time that this calibration was run, in milliseconds since
5054
the epoch.
5155
"""
5256

53-
def __init__(self, calibration: v2.metrics_pb2.MetricsSnapshot) -> None:
57+
def __init__(self,
58+
calibration: v2.metrics_pb2.MetricsSnapshot = v2.metrics_pb2.
59+
MetricsSnapshot(),
60+
metrics: Optional[
61+
Dict[str, Dict[Tuple['cirq.GridQubit', ...], Any]]] = None
62+
) -> None:
5463
self.timestamp = calibration.timestamp_ms
55-
self._metric_dict = self._compute_metric_dict(calibration.metrics)
64+
if metrics is None:
65+
self._metric_dict = self._compute_metric_dict(calibration.metrics)
66+
else:
67+
self._metric_dict = metrics
5668

5769
def _compute_metric_dict(
5870
self, metrics: v2.metrics_pb2.MetricsSnapshot
@@ -103,8 +115,47 @@ def __len__(self) -> int:
103115
return len(self._metric_dict)
104116

105117
def __str__(self) -> str:
106-
107-
return 'Calibration(keys={})'.format(list(sorted(self.keys())))
118+
return f'Calibration(keys={list(sorted(self.keys()))})'
119+
120+
def __repr__(self) -> str:
121+
return ('cirq.google.Calibration(metrics='
122+
f'{repr(dict(self._metric_dict))})')
123+
124+
def to_proto(self) -> v2.metrics_pb2.MetricsSnapshot:
125+
"""Reconstruct the protobuf message represented by this class."""
126+
proto = v2.metrics_pb2.MetricsSnapshot()
127+
for key in self._metric_dict:
128+
for target, value_list in self._metric_dict[key].items():
129+
current_metric = proto.metrics.add()
130+
current_metric.name = key
131+
current_metric.targets.extend(
132+
[v2.qubit_to_proto_id(q) for q in target])
133+
for value in value_list:
134+
current_value = current_metric.values.add()
135+
if isinstance(value, float):
136+
current_value.double_val = value
137+
elif isinstance(value, int):
138+
current_value.int64_val = value
139+
elif isinstance(value, str):
140+
current_value.str_val = value
141+
else:
142+
raise ValueError(f'Unsupported metric value {value}. '
143+
'Must be int, float, or str to '
144+
'convert to proto.')
145+
return proto
146+
147+
@classmethod
148+
def _from_json_dict_(cls, metrics: str, **kwargs) -> 'Calibration':
149+
"""Magic method for the JSON serialization protocol."""
150+
metric_proto = v2.metrics_pb2.MetricsSnapshot()
151+
return cls(json_format.ParseDict(metrics, metric_proto))
152+
153+
def _json_dict_(self) -> Dict[str, Any]:
154+
"""Magic method for the JSON serialization protocol."""
155+
return {
156+
'cirq_type': 'Calibration',
157+
'metrics': json_format.MessageToDict(self.to_proto())
158+
}
108159

109160
def timestamp_str(self,
110161
tz: Optional[datetime.tzinfo] = None,

cirq/google/engine/calibration_test.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,11 @@ def test_calibration_str():
9494
"'xeb'])")
9595

9696

97+
def test_calibration_repr():
98+
calibration = cg.Calibration(_CALIBRATION_DATA)
99+
cirq.testing.assert_equivalent_repr(calibration)
100+
101+
97102
def test_calibration_timestamp_str():
98103
calibration = cg.Calibration(_CALIBRATION_DATA)
99104
assert (calibration.timestamp_str(
@@ -103,6 +108,17 @@ def test_calibration_timestamp_str():
103108
hours=1))) == '2019-07-08 01:00:00.021021+01:00')
104109

105110

111+
def test_to_proto():
112+
calibration = cg.Calibration(_CALIBRATION_DATA)
113+
assert calibration == cg.Calibration(calibration.to_proto())
114+
invalid_value = cg.Calibration(
115+
metrics={'metric': {
116+
(cirq.GridQubit(1, 1),): [1.1, {}]
117+
}})
118+
with pytest.raises(ValueError, match='Unsupported metric value'):
119+
invalid_value.to_proto()
120+
121+
106122
def test_calibration_heatmap():
107123
calibration = cg.Calibration(_CALIBRATION_DATA)
108124

cirq/protocols/json_serialization.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ def two_qubit_matrix_gate(matrix):
8686
'CCXPowGate': cirq.CCXPowGate,
8787
'CCZPowGate': cirq.CCZPowGate,
8888
'CNotPowGate': cirq.CNotPowGate,
89+
'Calibration': cirq.google.Calibration,
8990
'CalibrationLayer': cirq.google.CalibrationLayer,
9091
'CalibrationResult': cirq.google.CalibrationResult,
9192
'CalibrationTag': cirq.google.CalibrationTag,

cirq/protocols/json_serialization_test.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,6 @@ def test_mutually_exclusive_blacklist():
276276
NOT_YET_SERIALIZABLE = [
277277
'AsymmetricDepolarizingChannel',
278278
'AxisAngleDecomposition',
279-
'Calibration',
280279
'CalibrationLayer',
281280
'CalibrationResult',
282281
'CircuitDag',
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
{
2+
"cirq_type": "Calibration",
3+
"metrics": {
4+
"metrics": [
5+
{
6+
"name": "xeb",
7+
"targets": [
8+
"0_0",
9+
"0_1"
10+
],
11+
"values": [
12+
{
13+
"doubleVal": 0.9999
14+
}
15+
]
16+
},
17+
{
18+
"name": "xeb",
19+
"targets": [
20+
"0_0",
21+
"1_0"
22+
],
23+
"values": [
24+
{
25+
"doubleVal": 0.9998
26+
}
27+
]
28+
},
29+
{
30+
"name": "t1",
31+
"targets": [
32+
"0_0"
33+
],
34+
"values": [
35+
{
36+
"int64Val": "321"
37+
}
38+
]
39+
},
40+
{
41+
"name": "t1",
42+
"targets": [
43+
"0_1"
44+
],
45+
"values": [
46+
{
47+
"int64Val": "911"
48+
}
49+
]
50+
},
51+
{
52+
"name": "t1",
53+
"targets": [
54+
"1_0"
55+
],
56+
"values": [
57+
{
58+
"int64Val": "505"
59+
}
60+
]
61+
},
62+
{
63+
"name": "globalMetric",
64+
"values": [
65+
{
66+
"strVal": "abcd"
67+
}
68+
]
69+
}
70+
]
71+
}
72+
}
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
cirq.google.Calibration(
2+
metrics={
3+
'xeb': {
4+
(cirq.GridQubit(0, 0), cirq.GridQubit(0, 1)): [0.9999],
5+
(cirq.GridQubit(0, 0), cirq.GridQubit(1, 0)): [0.9998]
6+
},
7+
't1': {
8+
(cirq.GridQubit(0, 0),): [321],
9+
(cirq.GridQubit(0, 1),): [911],
10+
(cirq.GridQubit(1, 0),): [505]},
11+
'globalMetric': {(): ['abcd']}
12+
}
13+
)

0 commit comments

Comments
 (0)