-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathpublic_api.py
147 lines (122 loc) · 5.28 KB
/
public_api.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Visitor restricting traversal to only the public tensorflow API."""
import re
from tensorflow.python.util import tf_inspect
class PublicAPIVisitor:
"""Visitor to use with `traverse` to visit exactly the public TF API."""
def __init__(self, visitor):
"""Constructor.
`visitor` should be a callable suitable as a visitor for `traverse`. It will
be called only for members of the public TensorFlow API.
Args:
visitor: A visitor to call for the public API.
"""
self._visitor = visitor
self._root_name = 'tf'
# Modules/classes we want to suppress entirely.
self._private_map = {
'tf': [
'compiler',
'core',
# TODO(scottzhu): See b/227410870 for more details. Currently
# dtensor API is exposed under tf.experimental.dtensor, but in the
# meantime, we have tensorflow/dtensor directory which will be treat
# as a python package. We want to avoid step into the
# tensorflow/dtensor directory when visit the API.
# When the tf.dtensor becomes the public API, it will actually pick
# up from tf.compat.v2.dtensor as priority and hide the
# tensorflow/dtensor package.
'security',
'dtensor',
'python',
'tsl', # TODO(tlongeri): Remove after TSL is moved out of TF.
],
# Some implementations have this internal module that we shouldn't
# expose.
'tf.flags': ['cpp_flags'],
}
# Modules/classes we do not want to descend into if we hit them. Usually,
# system modules exposed through platforms for compatibility reasons.
# Each entry maps a module path to a name to ignore in traversal.
self._do_not_descend_map = {
'tf': [
'examples',
'flags', # Don't add flags
# TODO(drpng): This can be removed once sealed off.
'platform',
# TODO(drpng): This can be removed once sealed.
'pywrap_tensorflow',
# TODO(drpng): This can be removed once sealed.
'user_ops',
'tools',
'tensorboard',
],
## Everything below here is legitimate.
# It'll stay, but it's not officially part of the API.
'tf.app': ['flags'],
# Imported for compatibility between py2/3.
'tf.test': ['mock'],
}
@property
def private_map(self):
"""A map from parents to symbols that should not be included at all.
This map can be edited, but it should not be edited once traversal has
begun.
Returns:
The map marking symbols to not include.
"""
return self._private_map
@property
def do_not_descend_map(self):
"""A map from parents to symbols that should not be descended into.
This map can be edited, but it should not be edited once traversal has
begun.
Returns:
The map marking symbols to not explore.
"""
return self._do_not_descend_map
def set_root_name(self, root_name):
"""Override the default root name of 'tf'."""
self._root_name = root_name
def _is_private(self, path, name, obj=None):
"""Return whether a name is private."""
# TODO(wicke): Find out what names to exclude.
del obj # Unused.
return ((path in self._private_map and name in self._private_map[path]) or
(name.startswith('_') and not re.match('__.*__$', name) or
name in ['__base__', '__class__', '__next_in_mro__']))
def _do_not_descend(self, path, name):
"""Safely queries if a specific fully qualified name should be excluded."""
return (path in self._do_not_descend_map and
name in self._do_not_descend_map[path])
def __call__(self, path, parent, children):
"""Visitor interface, see `traverse` for details."""
# Avoid long waits in cases of pretty unambiguous failure.
if tf_inspect.ismodule(parent) and len(path.split('.')) > 10:
raise RuntimeError('Modules nested too deep:\n%s.%s\n\nThis is likely a '
'problem with an accidental public import.' %
(self._root_name, path))
# Includes self._root_name
full_path = '.'.join([self._root_name, path]) if path else self._root_name
# Remove things that are not visible.
for name, child in list(children):
if self._is_private(full_path, name, child):
children.remove((name, child))
self._visitor(path, parent, children)
# Remove things that are visible, but which should not be descended into.
for name, child in list(children):
if self._do_not_descend(full_path, name):
children.remove((name, child))