-
Notifications
You must be signed in to change notification settings - Fork 975
/
_compat.py
647 lines (525 loc) · 24.2 KB
/
_compat.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
# Copyright 2018 The Cirq Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Workarounds for compatibility issues between versions and libraries."""
import functools
import importlib
import os
import re
import sys
import traceback
import warnings
from types import ModuleType
from typing import Any, Callable, Optional, Dict, Tuple, Type, Set
import numpy as np
import pandas as pd
import sympy
def proper_repr(value: Any) -> str:
"""Overrides sympy and numpy returning repr strings that don't parse."""
if isinstance(value, sympy.Basic):
result = sympy.srepr(value)
# HACK: work around https://github.com/sympy/sympy/issues/16074
# (only handles a few cases)
fixed_tokens = ['Symbol', 'pi', 'Mul', 'Pow', 'Add', 'Mod', 'Integer', 'Float', 'Rational']
for token in fixed_tokens:
result = result.replace(token, 'sympy.' + token)
return result
if isinstance(value, np.ndarray):
if np.issubdtype(value.dtype, np.datetime64):
return f'np.array({value.tolist()!r}, dtype=np.{value.dtype!r})'
return f'np.array({value.tolist()!r}, dtype=np.{value.dtype})'
if isinstance(value, pd.MultiIndex):
return f'pd.MultiIndex.from_tuples({repr(list(value))}, names={repr(list(value.names))})'
if isinstance(value, pd.Index):
return (
f'pd.Index({repr(list(value))}, '
f'name={repr(value.name)}, '
f'dtype={repr(str(value.dtype))})'
)
if isinstance(value, pd.DataFrame):
cols = [value[col].tolist() for col in value.columns]
rows = list(zip(*cols))
return (
f'pd.DataFrame('
f'\n columns={proper_repr(value.columns)}, '
f'\n index={proper_repr(value.index)}, '
f'\n data={repr(rows)}'
f'\n)'
)
return repr(value)
def proper_eq(a: Any, b: Any) -> bool:
"""Compares objects for equality, working around __eq__ not always working.
For example, in numpy a == b broadcasts and returns an array instead of
doing what np.array_equal(a, b) does. This method uses np.array_equal(a, b)
when dealing with numpy arrays.
"""
if type(a) == type(b):
if isinstance(a, np.ndarray):
return np.array_equal(a, b)
if isinstance(a, (pd.DataFrame, pd.Index, pd.MultiIndex)):
return a.equals(b)
if isinstance(a, (tuple, list)):
return len(a) == len(b) and all(proper_eq(x, y) for x, y in zip(a, b))
return a == b
def _warn_or_error(msg):
from cirq.testing.deprecation import ALLOW_DEPRECATION_IN_TEST
deprecation_allowed = ALLOW_DEPRECATION_IN_TEST in os.environ
if _called_from_test() and not deprecation_allowed:
for filename, line_number, function_name, text in reversed(traceback.extract_stack()):
if (
not _is_internal(filename)
and not filename.endswith(os.path.join("cirq", "_compat.py"))
and "_test.py" in filename
):
break
raise ValueError(
f"During testing using Cirq deprecated functionality is not allowed: {msg}"
f"Update to non-deprecated functionality, or alternatively, you can quiet "
f"this error by removing the CIRQ_TESTING environment variable "
f"temporarily with `@mock.patch.dict(os.environ, clear='CIRQ_TESTING')`.\n"
f"In case the usage of deprecated cirq is intentional, use "
f"`with cirq.testing.assert_deprecated(...):` around this line:\n"
f"{filename}:{line_number}: in {function_name}\n"
f"\t{text}"
)
# we have to dynamically count the non-internal frames
# due to the potentially multiple nested module wrappers
stack_level = 1
for filename, _, _, _ in reversed(traceback.extract_stack()):
if not _is_internal(filename) and "_compat.py" not in filename:
break
if "_compat.py" in filename:
stack_level += 1
warnings.warn(
msg,
DeprecationWarning,
stacklevel=stack_level,
)
def _validate_deadline(deadline: str):
DEADLINE_REGEX = r"^v(\d)+\.(\d)+$"
assert re.match(DEADLINE_REGEX, deadline), "deadline should match vX.Y"
def deprecated(
*, deadline: str, fix: str, name: Optional[str] = None
) -> Callable[[Callable], Callable]:
"""Marks a function as deprecated.
Args:
deadline: The version where the function will be deleted. It should be a minor version
(e.g. "v0.7").
fix: A complete sentence describing what the user should be using
instead of this particular function (e.g. "Use cos instead.")
name: How to refer to the function.
Defaults to `func.__qualname__`.
Returns:
A decorator that decorates functions with a deprecation warning.
"""
_validate_deadline(deadline)
def decorator(func: Callable) -> Callable:
@functools.wraps(func)
def decorated_func(*args, **kwargs) -> Any:
qualname = func.__qualname__ if name is None else name
_warn_or_error(
f'{qualname} was used but is deprecated.\n'
f'It will be removed in cirq {deadline}.\n'
f'{fix}\n'
)
return func(*args, **kwargs)
decorated_func.__doc__ = (
f'THIS FUNCTION IS DEPRECATED.\n\n'
f'IT WILL BE REMOVED IN `cirq {deadline}`.\n\n'
f'{fix}\n\n'
f'{decorated_func.__doc__ or ""}'
)
return decorated_func
return decorator
def deprecated_class(
*, deadline: str, fix: str, name: Optional[str] = None
) -> Callable[[Type], Type]:
"""Marks a class as deprecated.
Args:
deadline: The version where the function will be deleted. It should be a minor version
(e.g. "v0.7").
fix: A complete sentence describing what the user should be using
instead of this particular function (e.g. "Use cos instead.")
name: How to refer to the class.
Defaults to `class.__qualname__`.
Returns:
A decorator that decorates classes with a deprecation warning.
"""
_validate_deadline(deadline)
def decorator(clazz: Type) -> Type:
clazz_new = clazz.__new__
def patched_new(cls, *args, **kwargs):
qualname = clazz.__qualname__ if name is None else name
_warn_or_error(
f'{qualname} was used but is deprecated.\n'
f'It will be removed in cirq {deadline}.\n'
f'{fix}\n'
)
return clazz_new(cls)
setattr(clazz, '__new__', patched_new)
clazz.__doc__ = (
f'THIS CLASS IS DEPRECATED.\n\n'
f'IT WILL BE REMOVED IN `cirq {deadline}`.\n\n'
f'{fix}\n\n'
f'{clazz.__doc__ or ""}'
)
return clazz
return decorator
def deprecated_parameter(
*,
deadline: str,
fix: str,
func_name: Optional[str] = None,
parameter_desc: str,
match: Callable[[Tuple[Any, ...], Dict[str, Any]], bool],
rewrite: Optional[
Callable[[Tuple[Any, ...], Dict[str, Any]], Tuple[Tuple[Any, ...], Dict[str, Any]]]
] = None,
) -> Callable[[Callable], Callable]:
"""Marks a function parameter as deprecated.
Also handles rewriting the deprecated parameter into the new signature.
Args:
deadline: The version where the function will be deleted. It should be a minor version
(e.g. "v0.7").
fix: A complete sentence describing what the user should be using
instead of this particular function (e.g. "Use cos instead.")
func_name: How to refer to the function.
Defaults to `func.__qualname__`.
parameter_desc: The name and type of the parameter being deprecated,
e.g. "janky_count" or "janky_count keyword" or
"positional janky_count".
match: A lambda that takes args, kwargs and determines if the
deprecated parameter is present or not. This determines whether or
not the deprecation warning is printed, and also whether or not
rewrite is called.
rewrite: Returns new args/kwargs that don't use the deprecated
parameter. Defaults to making no changes.
Returns:
A decorator that decorates functions with a parameter deprecation
warning.
"""
_validate_deadline(deadline)
def decorator(func: Callable) -> Callable:
@functools.wraps(func)
def decorated_func(*args, **kwargs) -> Any:
if match(args, kwargs):
if rewrite is not None:
args, kwargs = rewrite(args, kwargs)
qualname = func.__qualname__ if func_name is None else func_name
_warn_or_error(
f'The {parameter_desc} parameter of {qualname} was '
f'used but is deprecated.\n'
f'It will be removed in cirq {deadline}.\n'
f'{fix}\n',
)
return func(*args, **kwargs)
return decorated_func
return decorator
def deprecate_attributes(module: ModuleType, deprecated_attributes: Dict[str, Tuple[str, str]]):
"""Wrap a module with deprecated attributes that give warnings.
Args:
module: The module to wrap.
deprecated_attributes: A dictionary from attribute name to a tuple of
strings, where the first string gives the version that the attribute
will be removed in, and the second string describes what the user
should do instead of accessing this deprecated attribute.
Returns:
Wrapped module with deprecated attributes. Use of these attributes
will cause a warning for these deprecated attributes.
"""
for (deadline, _) in deprecated_attributes.values():
_validate_deadline(deadline)
class Wrapped(ModuleType):
__dict__ = module.__dict__
def __getattr__(self, name):
if name in deprecated_attributes:
deadline, fix = deprecated_attributes[name]
_warn_or_error(
f'{name} was used but is deprecated.\n'
f'It will be removed in cirq {deadline}.\n'
f'{fix}\n'
)
return getattr(module, name)
return Wrapped(module.__name__, module.__doc__)
class DeprecatedModuleLoader(importlib.abc.Loader):
"""A Loader for deprecated modules.
It wraps an existing Loader instance, to which it delegates the loading. On top of that
it ensures that the sys.modules cache has both the deprecated module's name and the
new module's name pointing to the same exact ModuleType instance.
Args:
loader: the loader to be wrapped
old_module_name: the deprecated module's fully qualified name
new_module_name: the new module's fully qualified name
"""
def __init__(self, loader: Any, old_module_name: str, new_module_name: str):
"""A module loader that uses an existing module loader and intercepts
the execution of a module.
"""
self.loader = loader
if hasattr(loader, 'exec_module'):
# mypy#2427
self.exec_module = self._wrap_exec_module(loader.exec_module) # type: ignore
# while this is rare and load_module was deprecated in 3.4
# in older environments this line makes them work as well
if hasattr(loader, 'load_module'):
# mypy#2427
self.load_module = self._wrap_load_module(loader.load_module) # type: ignore
if hasattr(loader, 'create_module'):
# mypy#2427
self.create_module = loader.create_module # type: ignore
self.old_module_name = old_module_name
self.new_module_name = new_module_name
def module_repr(self, module: ModuleType) -> str:
return self.loader.module_repr(module)
def _wrap_load_module(self, method: Any) -> Any:
def load_module(fullname: str) -> ModuleType:
assert fullname == self.old_module_name, (
f"DeprecatedModuleLoader for {self.old_module_name} was asked to "
f"load {fullname}"
)
if self.new_module_name in sys.modules:
sys.modules[self.old_module_name] = sys.modules[self.new_module_name]
return sys.modules[self.old_module_name]
method(self.new_module_name)
# https://docs.python.org/3.5/library/importlib.html#importlib.abc.Loader.load_module
assert self.new_module_name in sys.modules, (
f"Wrapped loader {self.loader} was "
f"expected to insert "
f"{self.new_module_name} in sys.modules "
f"but it did not."
)
sys.modules[self.old_module_name] = sys.modules[self.new_module_name]
return sys.modules[self.old_module_name]
return load_module
def _wrap_exec_module(self, method: Any) -> Any:
def exec_module(module: ModuleType) -> None:
assert module.__name__ == self.old_module_name, (
f"DeprecatedModuleLoader for {self.old_module_name} was asked to "
f"load {module.__name__}"
)
# check for new_module whether it was loaded
if self.new_module_name in sys.modules:
# found it - no need to load the module again
sys.modules[self.old_module_name] = sys.modules[self.new_module_name]
return
# now we know we have to initialize the module
sys.modules[self.old_module_name] = module
sys.modules[self.new_module_name] = module
try:
return method(module)
except BaseException:
# if there's an error, we atomically remove both
del sys.modules[self.new_module_name]
del sys.modules[self.old_module_name]
raise
return exec_module
def _is_internal(filename: str) -> bool:
"""Returns whether filename is internal to python.
This is similar to how the built-in warnings module differentiates frames from internal modules.
It is specific to CPython - see
https://github.com/python/cpython/blob/41ec17e45d54473d32f543396293256f1581e44d/Lib/warnings.py#L275.
"""
return 'importlib' in filename and '_bootstrap' in filename
_warned: Set[str] = set()
def _called_from_test() -> bool:
return 'CIRQ_TESTING' in os.environ
def _should_dedupe_module_deprecation() -> bool:
"""Whether module deprecation warnings should be deduped or not.
We should always dedupe when not called from test.
We should only dedupe during tests if forced.
"""
force_dedupe = "CIRQ_FORCE_DEDUPE_MODULE_DEPRECATION" in os.environ
return not _called_from_test() or force_dedupe
def _deduped_module_warn_or_error(old_module_name: str, new_module_name: str, deadline: str):
if _should_dedupe_module_deprecation() and old_module_name in _warned:
return
_warned.add(old_module_name)
_warn_or_error(
f"{old_module_name} was used but is deprecated.\n "
f"it will be removed in cirq {deadline}.\n "
f"Use {new_module_name} instead.\n",
)
class DeprecatedModuleFinder(importlib.abc.MetaPathFinder):
"""A module finder to handle deprecated module references.
It sends a deprecation warning when a deprecated module is asked to be found.
It is meant to be used as a wrapper around existing MetaPathFinder instances.
Args:
finder: the finder to wrap.
new_module_name: the new module's fully qualified name
old_module_name: the deprecated module's fully qualified name
deadline: the deprecation deadline
"""
def __init__(
self,
finder: Any,
new_module_name: str,
old_module_name: str,
deadline: str,
broken_module_exception: Optional[BaseException],
):
"""An aliasing module finder that uses an existing module finder to find a python
module spec and intercept the execution of matching modules.
"""
self.finder = finder
self.new_module_name = new_module_name
self.old_module_name = old_module_name
self.deadline = deadline
self.broken_module_exception = broken_module_exception
# to cater for metadata path finders
# https://docs.python.org/3/library/importlib.metadata.html#extending-the-search-algorithm
if hasattr(finder, "find_distributions"):
def find_distributions(context):
return self.finder.find_distributions(context)
self.find_distributions = find_distributions
if hasattr(finder, "invalidate_caches"):
def invalidate_caches() -> None:
return self.finder.invalidate_caches()
# mypy#2427
self.invalidate_caches = invalidate_caches # type: ignore
def find_spec(self, fullname: str, path: Any = None, target: Any = None) -> Any:
"""Finds the specification of a module.
This is an implementation of the importlib.abc.MetaPathFinder.find_spec method.
See https://docs.python.org/3/library/importlib.html#importlib.abc.MetaPathFinder.
Args:
fullname: name of the module.
path: if presented, this is the parent module's submodule search path.
target: When passed in, target is a module object that the finder may use to make a more
educated guess about what spec to return. We don't use it here, just pass it along
to the wrapped finder.
"""
if fullname != self.old_module_name and not fullname.startswith(self.old_module_name + "."):
# if we are not interested in it, then just pass through to the wrapped finder
return self.finder.find_spec(fullname, path, target)
if self.broken_module_exception is not None:
raise self.broken_module_exception
# warn for deprecation
_deduped_module_warn_or_error(self.old_module_name, self.new_module_name, self.deadline)
new_fullname = self.new_module_name + fullname[len(self.old_module_name) :]
# find the corresponding spec in the new structure
if fullname == self.old_module_name:
# this is the first time the deprecated module is being found
# which means that the new parent needs to be found first and under
# the new parent's path, we should be able to find the new name of
# the deprecated module
# this code is heavily inspired by importlib.util.find_spec
parent_name = new_fullname.rpartition('.')[0]
if parent_name:
parent = __import__(parent_name, fromlist=['__path__'])
# note that compared to importlib.util.find_spec we don't handle
# AttributeError here because it is not expected to happen in case
# of a DeprecatedModuleLoader - the new parent should exist and be
# a proper package
parent_path = parent.__path__
else:
parent_path = None
spec = self.finder.find_spec(new_fullname, parent_path, None)
else:
# we are finding a submodule of the parent of the deprecated module,
# which means that the parent was already found, and thus, `path` is
# correctly pointing to the module's parent in the new hierarchy
spec = self.finder.find_spec(
new_fullname,
path=path,
target=target,
)
# if the spec exists, return the DeprecatedModuleLoader that will do the loading as well
# as set the alias(es) in sys.modules as necessary
if spec is not None:
# change back the name to the deprecated module name
spec.name = fullname
# some loaders do a check to ensure the module's name is the same
# as the loader was created for
if getattr(spec.loader, "name", None) == new_fullname:
setattr(spec.loader, "name", fullname)
spec.loader = DeprecatedModuleLoader(spec.loader, fullname, new_fullname)
return spec
class _BrokenModule(ModuleType):
def __init__(self, name, exc):
self.exc = exc
super().__init__(name)
def __getattr__(self, name):
raise self.exc
class DeprecatedModuleImportError(ImportError):
pass
def deprecated_submodule(
*, new_module_name: str, old_parent: str, old_child: str, deadline: str, create_attribute: bool
):
"""Creates a deprecated module reference recursively for a module.
For `new_module_name` (e.g. cirq_google) creates an alias (e.g cirq.google) in Python's module
cache. It also recursively checks for the already imported submodules (e.g. cirq_google.api) and
creates the alias for them too (e.g. cirq.google.api). With this method it is possible to create
an alias that really looks like a module, e.g you can do things like
`from cirq.google import api` - which would be otherwise impossible.
Note that this method will execute `new_module_name` in order to ensure that it is in the module
cache.
Args:
new_module_name: absolute module name for the new module
old_parent: the current module that had the original submodule
old_child: the submodule that is being relocated
create_attribute: if True, the submodule will be added as a deprecated attribute to the
old_parent module
Returns:
None
"""
_validate_deadline(deadline)
old_module_name = f"{old_parent}.{old_child}"
broken_module_exception = None
if create_attribute:
try:
new_module = importlib.import_module(new_module_name)
_setup_deprecated_submodule_attribute(
new_module_name, old_parent, old_child, deadline, new_module
)
except ImportError as ex:
msg = (
f"{new_module_name} cannot be imported. The typical reasons are"
f" that\n 1.) {new_module_name} is not installed, or"
f"\n 2.) when developing Cirq, you don't have your PYTHONPATH "
f"setup. In this case run `source dev_tools/pypath`.\n\n You can "
f"check the detailed exception above for more details or run "
f"`import {new_module_name} to reproduce the issue."
)
broken_module_exception = DeprecatedModuleImportError(msg)
broken_module_exception.__cause__ = ex
_setup_deprecated_submodule_attribute(
new_module_name,
old_parent,
old_child,
deadline,
_BrokenModule(new_module_name, broken_module_exception),
)
def wrap(finder: Any) -> Any:
if not hasattr(finder, 'find_spec'):
return finder
return DeprecatedModuleFinder(
finder, new_module_name, old_module_name, deadline, broken_module_exception
)
sys.meta_path = [wrap(finder) for finder in sys.meta_path]
def _setup_deprecated_submodule_attribute(
new_module_name: str,
old_parent: str,
old_child: str,
deadline: str,
new_module: Optional[ModuleType],
):
parent_module = sys.modules[old_parent]
setattr(parent_module, old_child, new_module)
class Wrapped(ModuleType):
__dict__ = parent_module.__dict__
def __getattr__(self, name):
if name == old_child:
_deduped_module_warn_or_error(
f"{old_parent}.{old_child}", new_module_name, deadline
)
return getattr(parent_module, name)
sys.modules[old_parent] = Wrapped(parent_module.__name__, parent_module.__doc__)