-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtf_stack.py
186 lines (142 loc) · 6.19 KB
/
tf_stack.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functions used to extract and analyze stacks. Faster than Python libs."""
# pylint: disable=g-bad-name
import collections
import inspect
import threading
# TODO(b/138203821): change to from ...util import ... once the bug is fixed.
from tensorflow.python.util import _tf_stack
# Generally such lookups should be done using `threading.local()`. See
# https://blogs.gnome.org/jamesh/2008/06/11/tls-python/ for a detailed
# explanation of why. However the transform stacks are expected to be empty
# when a thread is joined, so reusing the key does not introduce a correctness
# issue. Moreover, get_ident is faster than storing and retrieving a unique
# key in a thread local store.
_get_thread_key = threading.get_ident
# TODO(mdan): Move these to C++ as well.
# Moving to C++ can further avoid extra copies made by get_effective_map.
_source_mapper_stacks = collections.defaultdict(lambda: [SentinelMapper()])
_source_filter_stacks = collections.defaultdict(lambda: [SentinelFilter()])
class StackTraceTransform(object):
"""Base class for stack trace transformation functions."""
_stack_dict = None # Subclasses should override
_thread_key = None
def __enter__(self):
# Any given instance is assumed to be used by a single thread, which reduces
# expensive thread local lookups.
if self._thread_key is None:
self._thread_key = _get_thread_key()
else:
assert self._thread_key == _get_thread_key(), 'Shared across threads?'
stack = self._stack_dict[self._thread_key]
self.parent = stack[-1]
stack.append(self)
self.update()
return self
def __exit__(self, unused_type, unused_value, unused_traceback):
top = self._stack_dict[self._thread_key].pop()
assert top is self, 'Concurrent access?'
def update(self):
raise NotImplementedError('subclasses need to override this')
class StackTraceMapper(StackTraceTransform):
"""Allows remapping traceback information to different source code."""
_stack_dict = _source_mapper_stacks
def __init__(self):
self.internal_map = _tf_stack.PyBindSourceMap()
def update(self):
self.internal_map.update_to(tuple(self.get_effective_source_map().items()))
def get_effective_source_map(self):
"""Returns a map (filename, lineno) -> (filename, lineno, function_name)."""
raise NotImplementedError('subclasses need to override this')
EMPTY_DICT = {}
class SentinelMapper(StackTraceMapper):
def get_effective_source_map(self):
return EMPTY_DICT
class StackTraceFilter(StackTraceTransform):
"""Allows filtering traceback information by removing superfluous frames."""
_stack_dict = _source_filter_stacks
def __init__(self):
self.internal_set = _tf_stack.PyBindFileSet()
def update(self):
self.internal_set.update_to(set(self.get_filtered_filenames()))
def get_filtered_filenames(self):
raise NotImplementedError('subclasses need to override this')
EMPTY_SET = frozenset()
class SentinelFilter(StackTraceFilter):
def get_filtered_filenames(self):
return EMPTY_SET
class CurrentModuleFilter(StackTraceFilter):
"""Filters stack frames from the module where this is used (best effort)."""
def __init__(self):
super().__init__()
filter_filename = None
outer_f = None
f = inspect.currentframe()
try:
if f is not None:
# The current frame is __init__. The first outer frame should be the
# caller.
outer_f = f.f_back
if outer_f is not None:
filter_filename = inspect.getsourcefile(outer_f)
self._filename = filter_filename
# This may be called repeatedly: once on entry by the superclass, then by
# each child context manager.
self._cached_set = None
finally:
# Avoid reference cycles, see:
# https://docs.python.org/3.7/library/inspect.html#the-interpreter-stack
del f
del outer_f
def get_filtered_filenames(self):
if self._cached_set is not None:
return self._cached_set
filtered_filenames = frozenset((self._filename,))
if self.parent is not None:
filtered_filenames |= self.parent.get_filtered_filenames()
self._cached_set = filtered_filenames
return filtered_filenames
def extract_stack():
"""An eager-friendly alternative to traceback.extract_stack.
Returns:
A list-like FrameSummary containing StackFrame-like objects, which are
namedtuple-like objects with the following fields: filename, lineno, name,
line, meant to masquerade as traceback.FrameSummary objects.
"""
# N.B ExtractStack in tf_stack.cc will drop this frame prior to
# traversing the stack.
# TODO(cheshire): Remove this function, use extract_stack_for_op or Python
# traceback module.
thread_key = _get_thread_key()
return _tf_stack.extract_stack(
_source_mapper_stacks[thread_key][-1].internal_map,
_source_filter_stacks[thread_key][-1].internal_set)
# TODO(mdan): Revisit these - a single location is almost always sufficient.
def extract_stack_for_op(c_op, stacklevel=1):
"""Attaches the current stack trace to `c_op`.
Args:
c_op: a TF_Operation object.
stacklevel: An integer for ignoring Python wrapper stack frames.
The default value of 1 ignores this function from the frame.
"""
# N.B ExtractStack in tf_stack.cc will drop this frame prior to
# traversing the stack.
thread_key = _get_thread_key()
_tf_stack.extract_stack_for_op(
_source_mapper_stacks[thread_key][-1].internal_map,
_source_filter_stacks[thread_key][-1].internal_set, c_op, stacklevel)
StackSummary = _tf_stack.StackTraceWrapper
FrameSummary = _tf_stack.StackFrame