/
log.py
350 lines (302 loc) · 12.1 KB
/
log.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
"""
Provides the main class for logging, :class:`Log`, and some helpers.
"""
from __future__ import annotations
import logging
import os
import sys
import io
import threading
from threading import RLock
import contextlib
import string
import typing
PY3 = sys.version_info[0] >= 3
class Stream:
"""
Simple stream wrapper, which provides :func:`write` and :func:`flush`.
"""
# noinspection PyShadowingNames
def __init__(self, log, lvl):
"""
:type log: logging.Logger
:type lvl: int
"""
self.buf = io.StringIO()
self.log = log
self.lvl = lvl
self.lock = RLock()
def write(self, msg):
"""
:param str msg:
"""
with self.lock:
if msg == "\n":
self.flush()
else:
self.buf.write(msg)
def flush(self):
"""
Flush, i.e. writes to the log.
"""
with self.lock:
self.buf.flush()
self.log.log(self.lvl, self.buf.getvalue())
self.buf.truncate(0)
# truncate does not change the current position.
# In Python 2.7, it incorrectly does. See: https://bugs.python.org/issue30250
self.buf.seek(0)
class Log:
"""
The main logging class.
"""
def __init__(self):
self.initialized = False
self.filename = None # type: typing.Optional[str]
self.v = None # type: typing.Optional[typing.List[logging.Logger]]
self.verbose = None # type: typing.Optional[typing.List[bool]]
self.v1 = None # type: typing.Optional[Stream]
self.v2 = None # type: typing.Optional[Stream]
self.v3 = None # type: typing.Optional[Stream]
self.v4 = None # type: typing.Optional[Stream]
self.v5 = None # type: typing.Optional[Stream]
self._printed_warning_history = set() # type: typing.Set[str]
def initialize(self, logs=None, verbosity=None, formatter=None, propagate=False):
"""
This resets and configures the "returnn" logger.
:param list[str|logging.Handler] logs: "stdout", "|<pipe-cmd>", "<filename>"|"<filename>$date<ext>".
"stdout" is always added when propagate=False.
:param list[int] verbosity: levels 0-5 for the log handlers
:param list[str] formatter: 'default', 'timed', 'raw' or 'verbose', for the log handlers
:param bool propagate:
"""
if formatter is None:
formatter = []
if verbosity is None:
verbosity = []
if logs is None:
logs = []
self.initialized = True
fmt = {
"default": logging.Formatter("%(message)s"),
"timed": logging.Formatter("%(asctime)s %(message)s", datefmt="%Y-%m-%d,%H:%M:%S.%MS"),
"raw": logging.Formatter("%(message)s"),
"verbose": logging.Formatter("%(levelname)s - %(asctime)s %(message)s", datefmt="%Y-%m-%d,%H:%M:%S.%MS"),
}
logger = logging.getLogger("returnn")
# Note on propagation:
# This is not so clear. By default, the root logger anyway does nothing.
# However, if you mix RETURNN with other code, which might setup the root logger
# (e.g. via logging.basicConfig(...)), then there is some root logger,
# and maybe we should also use it.
# But at this point here, we might not know about this
# -- maybe the user would call logging.basicConfig(...) at some later point.
# In any case, if there is a root logger and we would propagate,
# we should not add "stdout" here,
# although that might depend on the root logger level and handlers.
# For now, the default is to just disable propagation, to keep that separated
# and avoid any such problems.
logger.propagate = propagate
# Reset handler list, in case we have initialized some earlier (e.g. multiple log.initialize() calls).
logger.handlers = []
self.v = [logger] * 6 # no need for separate loggers, we do all via log levels
if "stdout" not in logs and not propagate:
logs.append("stdout")
if len(formatter) == 1:
# if only one format provided, use it for all logs
formatter = [formatter[0]] * len(logs)
# Define own level names. In reverse order, such that the name by default still has the default behavior.
logging.addLevelName(logging.DEBUG + 2, "DEBUG")
logging.addLevelName(logging.DEBUG + 1, "DEBUG")
logging.addLevelName(logging.DEBUG + 0, "DEBUG")
logging.addLevelName(logging.INFO + 1, "INFO")
logging.addLevelName(logging.INFO + 0, "INFO")
_VerbosityToLogLevel = {
0: logging.ERROR,
1: logging.INFO + 1,
2: logging.INFO,
3: logging.DEBUG + 2,
4: logging.DEBUG + 1,
5: logging.DEBUG,
}
self.verbose = [False] * 6
for i in range(len(logs)):
t = logs[i]
v = 3
if i < len(verbosity):
v = verbosity[i]
elif len(verbosity) == 1:
v = verbosity[0]
assert v <= 5, "invalid verbosity: " + str(v)
for j in range(v + 1):
self.verbose[j] = True
f = fmt["default"] if i >= len(formatter) or formatter[i] not in fmt else fmt[formatter[i]]
if isinstance(t, logging.Handler):
handler = t
elif t == "stdout":
handler = StdoutHandler()
elif t.startswith("|"): # pipe-format
proc_cmd = t[1:].strip()
from subprocess import Popen, PIPE
proc = Popen(proc_cmd, shell=True, stdin=PIPE)
handler = logging.StreamHandler(proc.stdin)
elif os.path.isdir(os.path.dirname(t)):
if "$" in t:
from returnn.util.basic import get_utc_start_time_filename_part
t = string.Template(t).substitute(date=get_utc_start_time_filename_part())
self.filename = t
handler = logging.FileHandler(t)
else:
assert False, "invalid log target %r" % t
handler.setLevel(_VerbosityToLogLevel[v])
handler.setFormatter(f)
logger.addHandler(handler)
logger.setLevel(logging.DEBUG)
if not logger.handlers:
logger.addHandler(logging.NullHandler())
self.v1 = Stream(self.v[1], _VerbosityToLogLevel[1])
self.v2 = Stream(self.v[2], _VerbosityToLogLevel[2])
self.v3 = Stream(self.v[3], _VerbosityToLogLevel[3])
self.v4 = Stream(self.v[4], _VerbosityToLogLevel[4])
self.v5 = Stream(self.v[5], _VerbosityToLogLevel[5])
def init_by_config(self, config):
"""
:param returnn.config.Config config:
"""
logs = config.list("log", [])
log_verbosity = config.int_list("log_verbosity", [])
log_format = config.list("log_format", [])
if config.typed_value("torch_distributed") is not None:
import returnn.torch.distributed
torch_distributed = returnn.torch.distributed.get_ctx(config=config)
new_logs = []
for fn in logs:
fn_prefix, fn_ext = os.path.splitext(fn)
fn_ext = ".torch-distrib-%i-%i%s" % (torch_distributed.rank(), torch_distributed.size(), fn_ext)
new_logs.append(fn_prefix + fn_ext)
logs = new_logs
elif config.is_true("use_horovod"):
assert config.bool("use_tensorflow", False) or config.value("backend", "").startswith("tensorflow")
import returnn.tf.horovod
hvd = returnn.tf.horovod.get_ctx(config=config)
new_logs = []
for fn in logs:
fn_prefix, fn_ext = os.path.splitext(fn)
fn_ext = ".horovod-%i-%i%s" % (hvd.rank(), hvd.size(), fn_ext)
new_logs.append(fn_prefix + fn_ext)
logs = new_logs
self.initialize(logs=logs, verbosity=log_verbosity, formatter=log_format)
def print_warning(self, text, prefix_text="WARNING:", extra_text=None):
"""
Write a warning to log.v2. Does not write repeated warnings.
:param str text:
:param str prefix_text:
:param str|None extra_text:
"""
if text in self._printed_warning_history:
return
self._printed_warning_history.add(text)
print(prefix_text, text, file=log.v2)
if extra_text:
print(extra_text, file=log.v2)
def print_deprecation_warning(self, text, behavior_version=None):
"""
Write a deprecation warning to log.v2. Does not write repeated warnings.
:param str text:
:param int|None behavior_version: if this deprecation is already covered by a behavior_version check
"""
if behavior_version:
behavior_text = "This will be disallowed with behavior_version %d." % behavior_version
else:
behavior_text = "This might be disallowed with a future behavior_version."
self.print_warning(text, prefix_text="DEPRECATION WARNING:", extra_text=behavior_text)
def flush(self):
"""
Flush all streams.
"""
for stream in [self.v1, self.v2, self.v3, self.v4, self.v5]:
if stream:
stream.flush()
log = Log()
class StdoutHandler(logging.StreamHandler):
"""
This class is like a StreamHandler using sys.stdout, but always uses
whatever sys.stdout is currently set to rather than the value of
sys.stdout at handler construction time.
Copied and adopted from logging._StderrHandler.
"""
@property
def stream(self):
"""
stream
"""
return sys.stdout
@stream.setter
def stream(self, stream):
pass # ignore
class StreamThreadLocal(threading.local):
"""
This will just buffer everything, thread-locally, and not forward it to any stream.
The idea is that multiple tasks will run in multiple threads and you want to catch all the logging/stdout
of each to not clutter the output, and also you want to keep it separate for each.
"""
def __init__(self):
self.buf = io.StringIO()
def write(self, msg):
"""
:param str msg:
"""
self.buf.write(msg)
def flush(self):
"""
Ignored.
"""
class StreamDummy:
"""
This will just discard any data.
"""
def write(self, msg):
"""
Ignored.
:param str msg:
"""
pass
def flush(self):
"""
Ignored.
"""
@contextlib.contextmanager
def wrap_log_streams(alternative_stream, also_sys_stdout=False, tf_log_verbosity=None):
"""
:param StreamThreadLocal|StreamDummy alternative_stream:
:param bool also_sys_stdout: wrap sys.stdout as well
:param int|str|None tf_log_verbosity: e.g. "WARNING"
:return: context manager which yields (original info stream v1, alternative_stream)
"""
v_attrib_keys = ["v%i" % i for i in range(6)] + ["error"]
# Store original values.
orig_v_list = log.v
orig_v_attribs = {key: getattr(log, key) for key in v_attrib_keys}
orig_stdout = sys.stdout
log.v = [alternative_stream] * len(orig_v_list)
for key in v_attrib_keys:
setattr(log, key, alternative_stream)
if also_sys_stdout:
sys.stdout = alternative_stream
orig_tf_log_verbosity = None
if tf_log_verbosity is not None:
import returnn.tf.compat as tf_compat
orig_tf_log_verbosity = tf_compat.v1.logging.get_verbosity()
tf_compat.v1.logging.set_verbosity(tf_log_verbosity)
try:
yield orig_v_attribs["v1"], alternative_stream
finally:
# Restore original values.
log.v = orig_v_list
for key, value in orig_v_attribs.items():
setattr(log, key, value)
if also_sys_stdout:
sys.stdout = orig_stdout
if tf_log_verbosity is not None:
import returnn.tf.compat as tf_compat
tf_compat.v1.logging.set_verbosity(orig_tf_log_verbosity)