-
Notifications
You must be signed in to change notification settings - Fork 975
/
json_serialization_test.py
818 lines (680 loc) · 26.7 KB
/
json_serialization_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
# Copyright 2019 The Cirq Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import dataclasses
import datetime
import importlib
import io
import json
import os
import pathlib
import sys
import warnings
from typing import Dict, List, Optional, Tuple
from unittest import mock
import numpy as np
import pandas as pd
import pytest
import sympy
import cirq
from cirq._compat import proper_eq
from cirq.protocols import json_serialization
from cirq.testing import assert_json_roundtrip_works
from cirq.testing.json import ModuleJsonTestSpec, spec_for
REPO_ROOT = pathlib.Path(__file__).parent.parent.parent.parent
@dataclasses.dataclass
class _ModuleDeprecation:
old_name: str
deprecation_assertion: contextlib.AbstractContextManager
# tested modules and their deprecation settings
TESTED_MODULES: Dict[str, Optional[_ModuleDeprecation]] = {
'cirq_aqt': _ModuleDeprecation(
old_name="cirq.aqt",
deprecation_assertion=cirq.testing.assert_deprecated(
"cirq.aqt", deadline="v0.14", count=None
),
),
'cirq_ionq': _ModuleDeprecation(
old_name="cirq.ionq",
deprecation_assertion=cirq.testing.assert_deprecated(
"cirq.ionq", deadline="v0.14", count=None
),
),
'cirq_google': _ModuleDeprecation(
old_name="cirq.google",
deprecation_assertion=cirq.testing.assert_deprecated(
"cirq.google", deadline="v0.14", count=None
),
),
'cirq_pasqal': _ModuleDeprecation(
old_name="cirq.pasqal",
deprecation_assertion=cirq.testing.assert_deprecated(
"cirq.pasqal", deadline="v0.14", count=None
),
),
'cirq_rigetti': None,
'cirq.protocols': None,
'non_existent_should_be_fine': None,
}
# pyQuil 3.0, necessary for cirq_rigetti module requires
# python >= 3.7
if sys.version_info < (3, 7): # pragma: no cover
del TESTED_MODULES['cirq_rigetti']
def _get_testspecs_for_modules():
modules = []
for m in TESTED_MODULES.keys():
try:
modules.append(spec_for(m))
except ModuleNotFoundError:
# for optional modules it is okay to skip
pass
return modules
MODULE_TEST_SPECS = _get_testspecs_for_modules()
def test_line_qubit_roundtrip():
q1 = cirq.LineQubit(12)
assert_json_roundtrip_works(
q1,
text_should_be="""{
"cirq_type": "LineQubit",
"x": 12
}""",
)
def test_gridqubit_roundtrip():
q = cirq.GridQubit(15, 18)
assert_json_roundtrip_works(
q,
text_should_be="""{
"cirq_type": "GridQubit",
"row": 15,
"col": 18
}""",
)
def test_op_roundtrip():
q = cirq.LineQubit(5)
op1 = cirq.rx(0.123).on(q)
assert_json_roundtrip_works(
op1,
text_should_be="""{
"cirq_type": "GateOperation",
"gate": {
"cirq_type": "Rx",
"rads": 0.123
},
"qubits": [
{
"cirq_type": "LineQubit",
"x": 5
}
]
}""",
)
def test_op_roundtrip_filename(tmpdir):
filename = f'{tmpdir}/op.json'
q = cirq.LineQubit(5)
op1 = cirq.rx(0.123).on(q)
cirq.to_json(op1, filename)
assert os.path.exists(filename)
op2 = cirq.read_json(filename)
assert op1 == op2
gzip_filename = f'{tmpdir}/op.gz'
cirq.to_json_gzip(op1, gzip_filename)
assert os.path.exists(gzip_filename)
op3 = cirq.read_json_gzip(gzip_filename)
assert op1 == op3
def test_op_roundtrip_file_obj(tmpdir):
filename = f'{tmpdir}/op.json'
q = cirq.LineQubit(5)
op1 = cirq.rx(0.123).on(q)
with open(filename, 'w+') as file:
cirq.to_json(op1, file)
assert os.path.exists(filename)
file.seek(0)
op2 = cirq.read_json(file)
assert op1 == op2
gzip_filename = f'{tmpdir}/op.gz'
with open(gzip_filename, 'w+b') as gzip_file:
cirq.to_json_gzip(op1, gzip_file)
assert os.path.exists(gzip_filename)
gzip_file.seek(0)
op3 = cirq.read_json_gzip(gzip_file)
assert op1 == op3
def test_fail_to_resolve():
buffer = io.StringIO()
buffer.write(
"""
{
"cirq_type": "MyCustomClass",
"data": [1, 2, 3]
}
"""
)
buffer.seek(0)
with pytest.raises(ValueError) as e:
cirq.read_json(buffer)
assert e.match("Could not resolve type 'MyCustomClass' during deserialization")
QUBITS = cirq.LineQubit.range(5)
Q0, Q1, Q2, Q3, Q4 = QUBITS
# TODO: Include cirq.rx in the Circuit test case file.
# Github issue: https://github.com/quantumlib/Cirq/issues/2014
# Note that even the following doesn't work because theta gets
# multiplied by 1/pi:
# cirq.Circuit(cirq.rx(sympy.Symbol('theta')).on(Q0)),
### MODULE CONSISTENCY tests
@pytest.mark.parametrize('mod_spec', MODULE_TEST_SPECS, ids=repr)
# during test setup deprecated submodules are inspected and trigger the
# deprecation error in testing. It is cleaner to just turn it off than to assert
# deprecation for each submodule.
@mock.patch.dict(os.environ, clear='CIRQ_TESTING')
def test_shouldnt_be_serialized_no_superfluous(mod_spec: ModuleJsonTestSpec):
# everything in the list should be ignored for a reason
names = set(mod_spec.get_all_names())
missing_names = set(mod_spec.should_not_be_serialized).difference(names)
assert len(missing_names) == 0, (
f"Defined as \"should't be serialized\", "
f"but missing from {mod_spec}: \n"
f"{missing_names}"
)
@pytest.mark.parametrize('mod_spec', MODULE_TEST_SPECS, ids=repr)
# during test setup deprecated submodules are inspected and trigger the
# deprecation error in testing. It is cleaner to just turn it off than to assert
# deprecation for each submodule.
@mock.patch.dict(os.environ, clear='CIRQ_TESTING')
def test_not_yet_serializable_no_superfluous(mod_spec: ModuleJsonTestSpec):
# everything in the list should be ignored for a reason
names = set(mod_spec.get_all_names())
missing_names = set(mod_spec.not_yet_serializable).difference(names)
assert len(missing_names) == 0, (
f"Defined as Not yet serializable, " f"but missing from {mod_spec}: \n" f"{missing_names}"
)
@pytest.mark.parametrize('mod_spec', MODULE_TEST_SPECS, ids=repr)
def test_mutually_exclusive_blacklist(mod_spec: ModuleJsonTestSpec):
common = set(mod_spec.should_not_be_serialized) & set(mod_spec.not_yet_serializable)
assert len(common) == 0, (
f"Defined in both {mod_spec.name} 'Not yet serializable' "
f" and 'Should not be serialized' lists: {common}"
)
@pytest.mark.parametrize('mod_spec', MODULE_TEST_SPECS, ids=repr)
def test_resolver_cache_vs_should_not_serialize(mod_spec: ModuleJsonTestSpec):
resolver_cache_types = set([n for (n, _) in mod_spec.get_resolver_cache_types()])
common = set(mod_spec.should_not_be_serialized) & resolver_cache_types
assert len(common) == 0, (
f"Defined in both {mod_spec.name} Resolver "
f"Cache and should not be serialized:"
f"{common}"
)
@pytest.mark.parametrize('mod_spec', MODULE_TEST_SPECS, ids=repr)
def test_resolver_cache_vs_not_yet_serializable(mod_spec: ModuleJsonTestSpec):
resolver_cache_types = set([n for (n, _) in mod_spec.get_resolver_cache_types()])
common = set(mod_spec.not_yet_serializable) & resolver_cache_types
assert len(common) == 0, (
f"Issue with the JSON config of {mod_spec.name}.\n"
f"Types are listed in both"
f" {mod_spec.name}.json_resolver_cache.py and in the 'not_yet_serializable' list in"
f" {mod_spec.test_data_path}/spec.py: "
f"\n {common}"
)
def test_builtins():
assert_json_roundtrip_works(True)
assert_json_roundtrip_works(1)
assert_json_roundtrip_works(1 + 2j)
assert_json_roundtrip_works(
{
'test': [123, 5.5],
'key2': 'asdf',
'3': None,
'0.0': [],
}
)
def test_numpy():
x = np.ones(1)[0]
assert_json_roundtrip_works(x.astype(np.int8))
assert_json_roundtrip_works(x.astype(np.int16))
assert_json_roundtrip_works(x.astype(np.int32))
assert_json_roundtrip_works(x.astype(np.int64))
assert_json_roundtrip_works(x.astype(np.uint8))
assert_json_roundtrip_works(x.astype(np.uint16))
assert_json_roundtrip_works(x.astype(np.uint32))
assert_json_roundtrip_works(x.astype(np.uint64))
assert_json_roundtrip_works(x.astype(np.float32))
assert_json_roundtrip_works(x.astype(np.float64))
assert_json_roundtrip_works(x.astype(np.complex64))
assert_json_roundtrip_works(x.astype(np.complex128))
assert_json_roundtrip_works(np.ones((11, 5)))
assert_json_roundtrip_works(np.arange(3))
def test_pandas():
assert_json_roundtrip_works(
pd.DataFrame(data=[[1, 2, 3], [4, 5, 6]], columns=['x', 'y', 'z'], index=[2, 5])
)
assert_json_roundtrip_works(pd.Index([1, 2, 3], name='test'))
assert_json_roundtrip_works(
pd.MultiIndex.from_tuples([(1, 2), (3, 4), (5, 6)], names=['alice', 'bob'])
)
assert_json_roundtrip_works(
pd.DataFrame(
index=pd.Index([1, 2, 3], name='test'),
data=[[11, 21.0], [12, 22.0], [13, 23.0]],
columns=['a', 'b'],
)
)
assert_json_roundtrip_works(
pd.DataFrame(
index=pd.MultiIndex.from_tuples([(1, 2), (2, 3), (3, 4)], names=['x', 'y']),
data=[[11, 21.0], [12, 22.0], [13, 23.0]],
columns=pd.Index(['a', 'b'], name='c'),
)
)
def test_sympy():
# Raw values.
assert_json_roundtrip_works(sympy.Symbol('theta'))
assert_json_roundtrip_works(sympy.Integer(5))
assert_json_roundtrip_works(sympy.Rational(2, 3))
assert_json_roundtrip_works(sympy.Float(1.1))
# Basic operations.
s = sympy.Symbol('s')
t = sympy.Symbol('t')
assert_json_roundtrip_works(t + s)
assert_json_roundtrip_works(t * s)
assert_json_roundtrip_works(t / s)
assert_json_roundtrip_works(t - s)
assert_json_roundtrip_works(t ** s)
# Linear combinations.
assert_json_roundtrip_works(t * 2)
assert_json_roundtrip_works(4 * t + 3 * s + 2)
assert_json_roundtrip_works(sympy.pi)
assert_json_roundtrip_works(sympy.E)
assert_json_roundtrip_works(sympy.EulerGamma)
class SBKImpl(cirq.SerializableByKey):
"""A test implementation of SerializableByKey."""
def __init__(
self,
name: str,
data_list: Optional[List] = None,
data_tuple: Optional[Tuple] = None,
data_dict: Optional[Dict] = None,
):
self.name = name
self.data_list = data_list or []
self.data_tuple = data_tuple or ()
self.data_dict = data_dict or {}
def __eq__(self, other):
if not isinstance(other, SBKImpl):
return False
return (
self.name == other.name
and self.data_list == other.data_list
and self.data_tuple == other.data_tuple
and self.data_dict == other.data_dict
)
def _json_dict_(self):
return {
"cirq_type": "SBKImpl",
"name": self.name,
"data_list": self.data_list,
"data_tuple": self.data_tuple,
"data_dict": self.data_dict,
}
@classmethod
def _from_json_dict_(cls, name, data_list, data_tuple, data_dict, **kwargs):
return cls(name, data_list, tuple(data_tuple), data_dict)
def test_context_serialization():
def custom_resolver(name):
if name == 'SBKImpl':
return SBKImpl
test_resolvers = [custom_resolver] + cirq.DEFAULT_RESOLVERS
sbki_empty = SBKImpl('sbki_empty')
assert_json_roundtrip_works(sbki_empty, resolvers=test_resolvers)
sbki_list = SBKImpl('sbki_list', data_list=[sbki_empty, sbki_empty])
assert_json_roundtrip_works(sbki_list, resolvers=test_resolvers)
sbki_tuple = SBKImpl('sbki_tuple', data_tuple=(sbki_list, sbki_list))
assert_json_roundtrip_works(sbki_tuple, resolvers=test_resolvers)
sbki_dict = SBKImpl('sbki_dict', data_dict={'a': sbki_tuple, 'b': sbki_tuple})
assert_json_roundtrip_works(sbki_dict, resolvers=test_resolvers)
sbki_json = str(cirq.to_json(sbki_dict))
# There should be exactly one context item for each previous SBKImpl.
assert sbki_json.count('"cirq_type": "_SerializedContext"') == 4
# There should be exactly two key items for each of sbki_(empty|list|tuple),
# plus one for the top-level sbki_dict.
assert sbki_json.count('"cirq_type": "_SerializedKey"') == 7
# The final object should be a _SerializedKey for sbki_dict.
final_obj_idx = sbki_json.rfind('{')
final_obj = sbki_json[final_obj_idx : sbki_json.find('}', final_obj_idx) + 1]
assert (
final_obj
== """{
"cirq_type": "_SerializedKey",
"key": 4
}"""
)
list_sbki = [sbki_dict]
assert_json_roundtrip_works(list_sbki, resolvers=test_resolvers)
dict_sbki = {'a': sbki_dict}
assert_json_roundtrip_works(dict_sbki, resolvers=test_resolvers)
assert sbki_list != json_serialization._SerializedKey(sbki_list)
# Serialization keys have unique suffixes.
sbki_other_list = SBKImpl('sbki_list', data_list=[sbki_list])
assert_json_roundtrip_works(sbki_other_list, resolvers=test_resolvers)
def test_internal_serializer_types():
sbki = SBKImpl('test_key')
key = 1
test_key = json_serialization._SerializedKey(key)
test_context = json_serialization._SerializedContext(sbki, 1)
test_serialization = json_serialization._ContextualSerialization(sbki)
key_json = test_key._json_dict_()
with pytest.raises(TypeError, match='_from_json_dict_'):
_ = json_serialization._SerializedKey._from_json_dict_(**key_json)
context_json = test_context._json_dict_()
with pytest.raises(TypeError, match='_from_json_dict_'):
_ = json_serialization._SerializedContext._from_json_dict_(**context_json)
serialization_json = test_serialization._json_dict_()
with pytest.raises(TypeError, match='_from_json_dict_'):
_ = json_serialization._ContextualSerialization._from_json_dict_(**serialization_json)
# during test setup deprecated submodules are inspected and trigger the
# deprecation error in testing. It is cleaner to just turn it off than to assert
# deprecation for each submodule.
@mock.patch.dict(os.environ, clear='CIRQ_TESTING')
def _list_public_classes_for_tested_modules():
# to remove DeprecationWarning noise during test collection
with warnings.catch_warnings():
warnings.simplefilter("ignore")
return [
(mod_spec, o, n)
for mod_spec in MODULE_TEST_SPECS
for (o, n) in mod_spec.find_classes_that_should_serialize()
]
@pytest.mark.parametrize(
'mod_spec,cirq_obj_name,cls',
_list_public_classes_for_tested_modules(),
)
def test_json_test_data_coverage(mod_spec: ModuleJsonTestSpec, cirq_obj_name: str, cls):
if cirq_obj_name == "SerializableByKey":
pytest.skip(
"SerializableByKey does not follow common serialization rules. "
"It is tested separately in test_context_serialization."
)
if cirq_obj_name in mod_spec.not_yet_serializable:
return pytest.xfail(reason="Not serializable (yet)")
test_data_path = mod_spec.test_data_path
rel_path = test_data_path.relative_to(REPO_ROOT)
mod_path = mod_spec.name.replace(".", "/")
rel_resolver_cache_path = f"{mod_path}/json_resolver_cache.py"
json_path = test_data_path / f'{cirq_obj_name}.json'
json_path2 = test_data_path / f'{cirq_obj_name}.json_inward'
deprecation_deadline = mod_spec.deprecated.get(cirq_obj_name)
if not json_path.exists() and not json_path2.exists():
# coverage: ignore
pytest.fail(
f"Hello intrepid developer. There is a new public or "
f"serializable object named '{cirq_obj_name}' in the module '{mod_spec.name}' "
f"that does not have associated test data.\n"
f"\n"
f"You must create the file\n"
f" {rel_path}/{cirq_obj_name}.json\n"
f"and the file\n"
f" {rel_path}/{cirq_obj_name}.repr\n"
f"in order to guarantee this public object is, and will "
f"remain, serializable.\n"
f"\n"
f"The content of the .repr file should be the string returned "
f"by `repr(obj)` where `obj` is a test {cirq_obj_name} value "
f"or list of such values. To get this to work you may need to "
f"implement a __repr__ method for {cirq_obj_name}. The repr "
f"must be a parsable python expression that evaluates to "
f"something equal to `obj`."
f"\n"
f"The content of the .json file should be the string returned "
f"by `cirq.to_json(obj)` where `obj` is the same object or "
f"list of test objects.\n"
f"To get this to work you likely need "
f"to add {cirq_obj_name} to the "
f"`_class_resolver_dictionary` method in "
f"the {rel_resolver_cache_path} source file. "
f"You may also need to add a _json_dict_ method to "
f"{cirq_obj_name}. In some cases you will also need to add a "
f"_from_json_dict_ class method to the {cirq_obj_name} class."
f"\n"
f"For more information on JSON serialization, please read the "
f"docstring for cirq.protocols.SupportsJSON. If this object or "
f"class is not appropriate for serialization, add its name to "
f"the `should_not_be_serialized` list in the TestSpec defined in the "
f"{rel_path}/spec.py source file."
)
repr_file = test_data_path / f'{cirq_obj_name}.repr'
if repr_file.exists() and cls is not None:
objs = _eval_repr_data_file(repr_file, deprecation_deadline=deprecation_deadline)
if not isinstance(objs, list):
objs = [objs]
for obj in objs:
assert type(obj) == cls, (
f"Value in {test_data_path}/{cirq_obj_name}.repr must be of "
f"exact type {cls}, or a list of instances of that type. But "
f"the value (or one of the list entries) had type "
f"{type(obj)}.\n"
f"\n"
f"If using a value of the wrong type is intended, move the "
f"value to {test_data_path}/{cirq_obj_name}.repr_inward\n"
f"\n"
f"Value with wrong type:\n{obj!r}."
)
def test_to_from_strings():
x_json_text = """{
"cirq_type": "_PauliX",
"exponent": 1.0,
"global_shift": 0.0
}"""
assert cirq.to_json(cirq.X) == x_json_text
assert cirq.read_json(json_text=x_json_text) == cirq.X
with pytest.raises(ValueError, match='specify ONE'):
cirq.read_json(io.StringIO(), json_text=x_json_text)
def test_to_from_json_gzip():
a, b = cirq.LineQubit.range(2)
test_circuit = cirq.Circuit(cirq.H(a), cirq.CX(a, b))
gzip_data = cirq.to_json_gzip(test_circuit)
unzip_circuit = cirq.read_json_gzip(gzip_raw=gzip_data)
assert test_circuit == unzip_circuit
with pytest.raises(ValueError):
_ = cirq.read_json_gzip(io.StringIO(), gzip_raw=gzip_data)
with pytest.raises(ValueError):
_ = cirq.read_json_gzip()
def _eval_repr_data_file(path: pathlib.Path, deprecation_deadline: Optional[str]):
content = path.read_text()
ctx_managers: List[contextlib.AbstractContextManager] = [contextlib.suppress()]
if deprecation_deadline:
# we ignore coverage here, because sometimes there are no deprecations at all in any of the
# modules
# coverage: ignore
ctx_managers = [cirq.testing.assert_deprecated(deadline=deprecation_deadline, count=None)]
for deprecation in TESTED_MODULES.values():
if deprecation is not None and deprecation.old_name in content:
ctx_managers.append(deprecation.deprecation_assertion)
imports = {
'cirq': cirq,
'pd': pd,
'sympy': sympy,
'np': np,
'datetime': datetime,
}
for m in TESTED_MODULES.keys():
try:
imports[m] = importlib.import_module(m)
except ImportError:
pass
with contextlib.ExitStack() as stack:
for ctx_manager in ctx_managers:
stack.enter_context(ctx_manager)
obj = eval(
content,
imports,
{},
)
return obj
def assert_repr_and_json_test_data_agree(
mod_spec: ModuleJsonTestSpec,
repr_path: pathlib.Path,
json_path: pathlib.Path,
inward_only: bool,
deprecation_deadline: Optional[str],
):
if not repr_path.exists() and not json_path.exists():
return
rel_repr_path = f'{repr_path.relative_to(REPO_ROOT)}'
rel_json_path = f'{json_path.relative_to(REPO_ROOT)}'
try:
json_from_file = json_path.read_text()
ctx_manager = (
cirq.testing.assert_deprecated(deadline=deprecation_deadline, count=None)
if deprecation_deadline
else contextlib.suppress()
)
with ctx_manager:
json_obj = cirq.read_json(json_text=json_from_file)
except ValueError as ex: # coverage: ignore
# coverage: ignore
if "Could not resolve type" in str(ex):
mod_path = mod_spec.name.replace(".", "/")
rel_resolver_cache_path = f"{mod_path}/json_resolver_cache.py"
# coverage: ignore
pytest.fail(
f"{rel_json_path} can't be parsed to JSON.\n"
f"Maybe an entry is missing from the "
f" `_class_resolver_dictionary` method in {rel_resolver_cache_path}?"
)
else:
raise ValueError(f"deprecation: {deprecation_deadline} - got error: {ex}")
except AssertionError as ex: # coverage: ignore
# coverage: ignore
raise ex
except Exception as ex: # coverage: ignore
# coverage: ignore
raise IOError(f'Failed to parse test json data from {rel_json_path}.') from ex
try:
repr_obj = _eval_repr_data_file(repr_path, deprecation_deadline)
except Exception as ex: # coverage: ignore
# coverage: ignore
raise IOError(f'Failed to parse test repr data from {rel_repr_path}.') from ex
assert proper_eq(json_obj, repr_obj), (
f'The json data from {rel_json_path} did not parse '
f'into an object equivalent to the repr data from {rel_repr_path}.\n'
f'\n'
f'json object: {json_obj!r}\n'
f'repr object: {repr_obj!r}\n'
)
if not inward_only:
json_from_cirq = cirq.to_json(repr_obj)
json_from_cirq_obj = json.loads(json_from_cirq)
json_from_file_obj = json.loads(json_from_file)
assert json_from_cirq_obj == json_from_file_obj, (
f'The json produced by cirq no longer agrees with the json in the '
f'{rel_json_path} test data file.\n'
f'\n'
f'You must either fix the cirq code to continue to produce the '
f'same output, or you must move the old test data to '
f'{rel_json_path}_inward and create a fresh {rel_json_path} file.\n'
f'\n'
f'test data json:\n'
f'{json_from_file}\n'
f'\n'
f'cirq produced json:\n'
f'{json_from_cirq}\n'
)
@pytest.mark.parametrize(
'mod_spec, abs_path',
[(m, abs_path) for m in MODULE_TEST_SPECS for abs_path in m.all_test_data_keys()],
)
def test_json_and_repr_data(mod_spec: ModuleJsonTestSpec, abs_path: str):
assert_repr_and_json_test_data_agree(
mod_spec=mod_spec,
repr_path=pathlib.Path(f'{abs_path}.repr'),
json_path=pathlib.Path(f'{abs_path}.json'),
inward_only=False,
deprecation_deadline=mod_spec.deprecated.get(os.path.basename(abs_path)),
)
assert_repr_and_json_test_data_agree(
mod_spec=mod_spec,
repr_path=pathlib.Path(f'{abs_path}.repr_inward'),
json_path=pathlib.Path(f'{abs_path}.json_inward'),
inward_only=True,
deprecation_deadline=mod_spec.deprecated.get(os.path.basename(abs_path)),
)
def test_pathlib_paths(tmpdir):
path = pathlib.Path(tmpdir) / 'op.json'
cirq.to_json(cirq.X, path)
assert cirq.read_json(path) == cirq.X
gzip_path = pathlib.Path(tmpdir) / 'op.gz'
cirq.to_json_gzip(cirq.X, gzip_path)
assert cirq.read_json_gzip(gzip_path) == cirq.X
def test_json_serializable_dataclass():
@cirq.json_serializable_dataclass
class MyDC:
q: cirq.LineQubit
desc: str
my_dc = MyDC(cirq.LineQubit(4), 'hi mom')
def custom_resolver(name):
if name == 'MyDC':
return MyDC
assert_json_roundtrip_works(
my_dc,
text_should_be="\n".join(
[
'{',
' "cirq_type": "MyDC",',
' "q": {',
' "cirq_type": "LineQubit",',
' "x": 4',
' },',
' "desc": "hi mom"',
'}',
]
),
resolvers=[custom_resolver] + cirq.DEFAULT_RESOLVERS,
)
def test_json_serializable_dataclass_parenthesis():
@cirq.json_serializable_dataclass()
class MyDC:
q: cirq.LineQubit
desc: str
def custom_resolver(name):
if name == 'MyDC':
return MyDC
my_dc = MyDC(cirq.LineQubit(4), 'hi mom')
assert_json_roundtrip_works(my_dc, resolvers=[custom_resolver] + cirq.DEFAULT_RESOLVERS)
def test_dataclass_json_dict():
@dataclasses.dataclass(frozen=True)
class MyDC:
q: cirq.LineQubit
desc: str
def _json_dict_(self):
return cirq.dataclass_json_dict(self)
def custom_resolver(name):
if name == 'MyDC':
return MyDC
my_dc = MyDC(cirq.LineQubit(4), 'hi mom')
assert_json_roundtrip_works(my_dc, resolvers=[custom_resolver, *cirq.DEFAULT_RESOLVERS])
def test_json_serializable_dataclass_namespace():
@cirq.json_serializable_dataclass(namespace='cirq.experiments')
class QuantumVolumeParams:
width: int
depth: int
circuit_i: int
qvp = QuantumVolumeParams(width=5, depth=5, circuit_i=0)
def custom_resolver(name):
if name == 'cirq.experiments.QuantumVolumeParams':
return QuantumVolumeParams
assert_json_roundtrip_works(qvp, resolvers=[custom_resolver] + cirq.DEFAULT_RESOLVERS)
def test_numpy_values():
assert (
cirq.to_json({'value': np.array(1)})
== """{
"value": 1
}"""
)