-
Notifications
You must be signed in to change notification settings - Fork 5.4k
/
py_obj_scanner.py
106 lines (82 loc) · 3.61 KB
/
py_obj_scanner.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import io
import sys
from typing import Any, Dict, Generic, List, Tuple, Type, TypeVar, Union
# For python < 3.8 we need to explicitly use pickle5 to support protocol 5
if sys.version_info < (3, 8):
try:
import pickle5 as pickle # noqa: F401
except ImportError:
import pickle # noqa: F401
else:
import pickle # noqa: F401
import ray
from ray.dag.base import DAGNodeBase
# Used in deserialization hooks to reference scanner instances.
_instances: Dict[int, "_PyObjScanner"] = {}
# Generic types for the scanner to transform from and to.
SourceType = TypeVar("SourceType")
TransformedType = TypeVar("TransformedType")
def _get_node(instance_id: int, node_index: int) -> SourceType:
"""Get the node instance.
Note: This function should be static and globally importable,
otherwise the serialization overhead would be very significant.
"""
return _instances[instance_id]._replace_index(node_index)
class _PyObjScanner(ray.cloudpickle.CloudPickler, Generic[SourceType, TransformedType]):
"""Utility to find and replace the `source_type` in Python objects.
`source_type` can either be a single type or a tuple of multiple types.
The caller must first call `find_nodes()`, then compute a replacement table and
pass it to `replace_nodes`.
This uses cloudpickle under the hood, so all sub-objects that are not `source_type`
must be serializable.
Args:
source_type: the type(s) of object to find and replace. Default to DAGNodeBase.
"""
def __init__(self, source_type: Union[Type, Tuple] = DAGNodeBase):
self.source_type = source_type
# Buffer to keep intermediate serialized state.
self._buf = io.BytesIO()
# List of top-level SourceType found during the serialization pass.
self._found = None
# List of other objects found during the serialization pass.
# This is used to store references to objects so they won't be
# serialized by cloudpickle.
self._objects = []
# Replacement table to consult during deserialization.
self._replace_table: Dict[SourceType, TransformedType] = None
_instances[id(self)] = self
super().__init__(self._buf)
def reducer_override(self, obj):
"""Hook for reducing objects.
Objects of `self.source_type` are saved to `self._found` and a global map so
they can later be replaced.
All other objects fall back to the default `CloudPickler` serialization.
"""
if isinstance(obj, self.source_type):
index = len(self._found)
self._found.append(obj)
return _get_node, (id(self), index)
return super().reducer_override(obj)
def find_nodes(self, obj: Any) -> List[SourceType]:
"""Find top-level DAGNodes."""
assert (
self._found is None
), "find_nodes cannot be called twice on the same PyObjScanner instance."
self._found = []
self._objects = []
self.dump(obj)
return self._found
def replace_nodes(self, table: Dict[SourceType, TransformedType]) -> Any:
"""Replace previously found DAGNodes per the given table."""
assert self._found is not None, "find_nodes must be called first"
self._replace_table = table
self._buf.seek(0)
return pickle.load(self._buf)
def _replace_index(self, i: int) -> SourceType:
return self._replace_table[self._found[i]]
def clear(self):
"""Clear the scanner from the _instances"""
if id(self) in _instances:
del _instances[id(self)]
def __del__(self):
self.clear()