/
logging.py
723 lines (573 loc) · 24.4 KB
/
logging.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
""" Callbacks for printing, logging and log information."""
import sys
import time
from contextlib import suppress
from numbers import Number
from itertools import cycle
from pathlib import Path
import numpy as np
import tqdm
from tabulate import tabulate
from skorch.utils import Ansi
from skorch.dataset import get_len
from skorch.callbacks import Callback
__all__ = ['EpochTimer', 'NeptuneLogger', 'WandbLogger', 'PrintLog', 'ProgressBar',
'TensorBoard']
def filter_log_keys(keys, keys_ignored=None):
"""Filter out keys that are generally to be ignored.
This is used by several callbacks to filter out keys from history
that should not be logged.
Parameters
----------
keys : iterable of str
All keys.
keys_ignored : iterable of str or None (default=None)
If not None, collection of extra keys to be ignored.
"""
keys_ignored = keys_ignored or ()
for key in keys:
if not (
key == 'epoch' or
(key in keys_ignored) or
key.endswith('_best') or
key.endswith('_batch_count') or
key.startswith('event_')
):
yield key
class EpochTimer(Callback):
"""Measures the duration of each epoch and writes it to the
history with the name ``dur``.
"""
def __init__(self, **kwargs):
super(EpochTimer, self).__init__(**kwargs)
self.epoch_start_time_ = None
def on_epoch_begin(self, net, **kwargs):
self.epoch_start_time_ = time.time()
def on_epoch_end(self, net, **kwargs):
net.history.record('dur', time.time() - self.epoch_start_time_)
class NeptuneLogger(Callback):
"""Logs results from history to Neptune
Neptune is a lightweight experiment tracking tool.
You can read more about it here: https://neptune.ai
Use this callback to automatically log all interesting values from
your net's history to Neptune.
The best way to log additional information is to log directly to the
experiment object or subclass the ``on_*`` methods.
To monitor resource consumption install psutil
>>> pip install psutil
You can view example experiment logs here:
https://ui.neptune.ai/o/shared/org/skorch-integration/e/SKOR-13/charts
Examples
--------
>>> # Install neptune
>>> pip install neptune-client
>>> # Create a neptune experiment object
>>> import neptune
...
... # We are using api token for an anonymous user.
... # For your projects use the token associated with your neptune.ai account
>>> neptune.init(api_token='ANONYMOUS',
... project_qualified_name='shared/skorch-integration')
...
... experiment = neptune.create_experiment(
... name='skorch-basic-example',
... params={'max_epochs': 20,
... 'lr': 0.01},
... upload_source_files=['skorch_example.py'])
>>> # Create a neptune_logger callback
>>> neptune_logger = NeptuneLogger(experiment, close_after_train=False)
>>> # Pass a logger to net callbacks argument
>>> net = NeuralNetClassifier(
... ClassifierModule,
... max_epochs=20,
... lr=0.01,
... callbacks=[neptune_logger])
>>> # Log additional metrics after training has finished
>>> from sklearn.metrics import roc_auc_score
... y_pred = net.predict_proba(X)
... auc = roc_auc_score(y, y_pred[:, 1])
...
... neptune_logger.experiment.log_metric('roc_auc_score', auc)
>>> # log charts like ROC curve
... from scikitplot.metrics import plot_roc
... import matplotlib.pyplot as plt
...
... fig, ax = plt.subplots(figsize=(16, 12))
... plot_roc(y, y_pred, ax=ax)
... neptune_logger.experiment.log_image('roc_curve', fig)
>>> # log net object after training
... net.save_params(f_params='basic_model.pkl')
... neptune_logger.experiment.log_artifact('basic_model.pkl')
>>> # close experiment
... neptune_logger.experiment.stop()
Parameters
----------
experiment : neptune.experiments.Experiment
Instantiated ``Experiment`` class.
log_on_batch_end : bool (default=False)
Whether to log loss and other metrics on batch level.
close_after_train : bool (default=True)
Whether to close the ``Experiment`` object once training
finishes. Set this parameter to False if you want to continue
logging to the same Experiment or if you use it as a context
manager.
keys_ignored : str or list of str (default=None)
Key or list of keys that should not be logged to
Neptune. Note that in addition to the keys provided by the
user, keys such as those starting with 'event_' or ending on
'_best' are ignored by default.
Attributes
----------
first_batch_ : bool
Helper attribute that is set to True at initialization and changes
to False on first batch end. Can be used when we want to log things
exactly once.
.. _Neptune: https://www.neptune.ai
"""
def __init__(
self,
experiment,
log_on_batch_end=False,
close_after_train=True,
keys_ignored=None,
):
self.experiment = experiment
self.log_on_batch_end = log_on_batch_end
self.close_after_train = close_after_train
self.keys_ignored = keys_ignored
def initialize(self):
self.first_batch_ = True
keys_ignored = self.keys_ignored
if isinstance(keys_ignored, str):
keys_ignored = [keys_ignored]
self.keys_ignored_ = set(keys_ignored or [])
self.keys_ignored_.add('batches')
return self
def on_batch_end(self, net, **kwargs):
if self.log_on_batch_end:
batch_logs = net.history[-1]['batches'][-1]
for key in filter_log_keys(batch_logs.keys(), self.keys_ignored_):
self.experiment.log_metric(key, batch_logs[key])
self.first_batch_ = False
def on_epoch_end(self, net, **kwargs):
"""Automatically log values from the last history step."""
history = net.history
epoch_logs = history[-1]
epoch = epoch_logs['epoch']
for key in filter_log_keys(epoch_logs.keys(), self.keys_ignored_):
self.experiment.log_metric(key, x=epoch, y=epoch_logs[key])
def on_train_end(self, net, **kwargs):
if self.close_after_train:
self.experiment.stop()
class WandbLogger(Callback):
"""Logs best model and metrics to `Weights & Biases <https://docs.wandb.com/>`_
Use this callback to automatically log best trained model, all metrics from
your net's history, model topology and computer resources to Weights & Biases
after each epoch.
Every file saved in `wandb_run.dir` is automatically logged to W&B servers.
See `example run
<https://app.wandb.ai/borisd13/skorch/runs/s20or4ct/overview?workspace=user-borisd13>`_
Examples
--------
>>> # Install wandb
... pip install wandb
>>> import wandb
>>> from skorch.callbacks import WandbLogger
>>> # Create a wandb Run
... wandb_run = wandb.init()
>>> # Alternative: Create a wandb Run without having a W&B account
... wandb_run = wandb.init(anonymous="allow)
>>> # Log hyper-parameters (optional)
... wandb_run.config.update({"learning rate": 1e-3, "batch size": 32})
>>> net = NeuralNet(..., callbacks=[WandbLogger(wandb_run)])
>>> net.fit(X, y)
Parameters
----------
wandb_run : wandb.wandb_run.Run
wandb Run used to log data.
save_model : bool (default=True)
Whether to save a checkpoint of the best model and upload it
to your Run on W&B servers.
keys_ignored : str or list of str (default=None)
Key or list of keys that should not be logged to
tensorboard. Note that in addition to the keys provided by the
user, keys such as those starting with 'event_' or ending on
'_best' are ignored by default.
"""
def __init__(
self,
wandb_run,
save_model=True,
keys_ignored=None,
):
self.wandb_run = wandb_run
self.save_model = save_model
self.keys_ignored = keys_ignored
def initialize(self):
keys_ignored = self.keys_ignored
if isinstance(keys_ignored, str):
keys_ignored = [keys_ignored]
self.keys_ignored_ = set(keys_ignored or [])
self.keys_ignored_.add('batches')
return self
def on_train_begin(self, net, **kwargs):
"""Log model topology and add a hook for gradients"""
self.wandb_run.watch(net.module_)
def on_epoch_end(self, net, **kwargs):
"""Log values from the last history step and save best model"""
hist = net.history[-1]
keys_kept = filter_log_keys(hist, keys_ignored=self.keys_ignored_)
logged_vals = {k: hist[k] for k in keys_kept}
self.wandb_run.log(logged_vals)
# save best model
if self.save_model and hist['valid_loss_best']:
model_path = Path(self.wandb_run.dir) / 'best_model.pth'
with model_path.open('wb') as model_file:
net.save_params(f_params=model_file)
class PrintLog(Callback):
"""Print useful information from the model's history as a table.
By default, ``PrintLog`` prints everything from the history except
for ``'batches'``.
To determine the best loss, ``PrintLog`` looks for keys that end on
``'_best'`` and associates them with the corresponding loss. E.g.,
``'train_loss_best'`` will be matched with ``'train_loss'``. The
:class:`skorch.callbacks.EpochScoring` callback takes care of
creating those entries, which is why ``PrintLog`` works best in
conjunction with that callback.
``PrintLog`` treats keys with the ``'event_'`` prefix in a special
way. They are assumed to contain information about occasionally
occuring events. The ``False`` or ``None`` entries (indicating
that an event did not occur) are not printed, resulting in empty
cells in the table, and ``True`` entries are printed with ``+``
symbol. ``PrintLog`` groups all event columns together and pushes
them to the right, just before the ``'dur'`` column.
*Note*: ``PrintLog`` will not result in good outputs if the number
of columns varies between epochs, e.g. if the valid loss is only
present on every other epoch.
Parameters
----------
keys_ignored : str or list of str (default=None)
Key or list of keys that should not be part of the printed
table. Note that in addition to the keys provided by the user,
keys such as those starting with 'event_' or ending on '_best'
are ignored by default.
sink : callable (default=print)
The target that the output string is sent to. By default, the
output is printed to stdout, but the sink could also be a
logger, etc.
tablefmt : str (default='simple')
The format of the table. See the documentation of the ``tabulate``
package for more detail. Can be 'plain', 'grid', 'pipe', 'html',
'latex', among others.
floatfmt : str (default='.4f')
The number formatting. See the documentation of the ``tabulate``
package for more details.
stralign : str (default='right')
The alignment of columns with strings. Can be 'left', 'center',
'right', or ``None`` (disable alignment). Default is 'right' (to
be consistent with numerical columns).
"""
def __init__(
self,
keys_ignored=None,
sink=print,
tablefmt='simple',
floatfmt='.4f',
stralign='right',
):
self.keys_ignored = keys_ignored
self.sink = sink
self.tablefmt = tablefmt
self.floatfmt = floatfmt
self.stralign = stralign
def initialize(self):
self.first_iteration_ = True
keys_ignored = self.keys_ignored
if isinstance(keys_ignored, str):
keys_ignored = [keys_ignored]
self.keys_ignored_ = set(keys_ignored or [])
self.keys_ignored_.add('batches')
return self
def format_row(self, row, key, color):
"""For a given row from the table, format it (i.e. floating
points and color if applicable).
"""
value = row[key]
if isinstance(value, bool) or value is None:
return '+' if value else ''
if not isinstance(value, Number):
return value
# determine if integer value
is_integer = float(value).is_integer()
template = '{}' if is_integer else '{:' + self.floatfmt + '}'
# if numeric, there could be a 'best' key
key_best = key + '_best'
if (key_best in row) and row[key_best]:
template = color + template + Ansi.ENDC.value
return template.format(value)
def _sorted_keys(self, keys):
"""Sort keys, dropping the ones that should be ignored.
The keys that are in ``self.ignored_keys`` or that end on
'_best' are dropped. Among the remaining keys:
* 'epoch' is put first;
* 'dur' is put last;
* keys that start with 'event_' are put just before 'dur';
* all remaining keys are sorted alphabetically.
"""
sorted_keys = []
# make sure 'epoch' comes first
if ('epoch' in keys) and ('epoch' not in self.keys_ignored_):
sorted_keys.append('epoch')
# ignore keys like *_best or event_*
for key in filter_log_keys(sorted(keys), keys_ignored=self.keys_ignored_):
if key != 'dur':
sorted_keys.append(key)
# add event_* keys
for key in sorted(keys):
if key.startswith('event_') and (key not in self.keys_ignored_):
sorted_keys.append(key)
# make sure 'dur' comes last
if ('dur' in keys) and ('dur' not in self.keys_ignored_):
sorted_keys.append('dur')
return sorted_keys
def _yield_keys_formatted(self, row):
colors = cycle([color.value for color in Ansi if color != color.ENDC])
for key, color in zip(self._sorted_keys(row.keys()), colors):
formatted = self.format_row(row, key, color=color)
if key.startswith('event_'):
key = key[6:]
yield key, formatted
def table(self, row):
headers = []
formatted = []
for key, formatted_row in self._yield_keys_formatted(row):
headers.append(key)
formatted.append(formatted_row)
return tabulate(
[formatted],
headers=headers,
tablefmt=self.tablefmt,
floatfmt=self.floatfmt,
stralign=self.stralign,
)
def _sink(self, text, verbose):
if (self.sink is not print) or verbose:
self.sink(text)
# pylint: disable=unused-argument
def on_epoch_end(self, net, **kwargs):
data = net.history[-1]
verbose = net.verbose
tabulated = self.table(data)
if self.first_iteration_:
header, lines = tabulated.split('\n', 2)[:2]
self._sink(header, verbose)
self._sink(lines, verbose)
self.first_iteration_ = False
self._sink(tabulated.rsplit('\n', 1)[-1], verbose)
if self.sink is print:
sys.stdout.flush()
class ProgressBar(Callback):
"""Display a progress bar for each epoch.
The progress bar includes elapsed and estimated remaining time for
the current epoch, the number of batches processed, and other
user-defined metrics. The progress bar is erased once the epoch is
completed.
``ProgressBar`` needs to know the total number of batches per
epoch in order to display a meaningful progress bar. By default,
this number is determined automatically using the dataset length
and the batch size. If this heuristic does not work for some
reason, you may either specify the number of batches explicitly
or let the ``ProgressBar`` count the actual number of batches in
the previous epoch.
For jupyter notebooks a non-ASCII progress bar can be printed
instead. To use this feature, you need to have `ipywidgets
<https://ipywidgets.readthedocs.io/en/stable/user_install.html>`_
installed.
Parameters
----------
batches_per_epoch : int, str (default='auto')
Either a concrete number or a string specifying the method used
to determine the number of batches per epoch automatically.
``'auto'`` means that the number is computed from the length of
the dataset and the batch size. ``'count'`` means that the
number is determined by counting the batches in the previous
epoch. Note that this will leave you without a progress bar at
the first epoch.
detect_notebook : bool (default=True)
If enabled, the progress bar determines if its current environment
is a jupyter notebook and switches to a non-ASCII progress bar.
postfix_keys : list of str (default=['train_loss', 'valid_loss'])
You can use this list to specify additional info displayed in the
progress bar such as metrics and losses. A prerequisite to this is
that these values are residing in the history on batch level already,
i.e. they must be accessible via
>>> net.history[-1, 'batches', -1, key]
"""
def __init__(
self,
batches_per_epoch='auto',
detect_notebook=True,
postfix_keys=None
):
self.batches_per_epoch = batches_per_epoch
self.detect_notebook = detect_notebook
self.postfix_keys = postfix_keys or ['train_loss', 'valid_loss']
def in_ipynb(self):
try:
return get_ipython().__class__.__name__ == 'ZMQInteractiveShell'
except NameError:
return False
def _use_notebook(self):
return self.in_ipynb() if self.detect_notebook else False
def _get_batch_size(self, net, training):
name = 'iterator_train' if training else 'iterator_valid'
net_params = net.get_params()
return net_params.get(name + '__batch_size', net_params['batch_size'])
def _get_batches_per_epoch_phase(self, net, dataset, training):
if dataset is None:
return 0
batch_size = self._get_batch_size(net, training)
return int(np.ceil(get_len(dataset) / batch_size))
def _get_batches_per_epoch(self, net, dataset_train, dataset_valid):
return (self._get_batches_per_epoch_phase(net, dataset_train, True) +
self._get_batches_per_epoch_phase(net, dataset_valid, False))
def _get_postfix_dict(self, net):
postfix = {}
for key in self.postfix_keys:
try:
postfix[key] = net.history[-1, 'batches', -1, key]
except KeyError:
pass
return postfix
# pylint: disable=attribute-defined-outside-init
def on_batch_end(self, net, **kwargs):
self.pbar.set_postfix(self._get_postfix_dict(net), refresh=False)
self.pbar.update()
# pylint: disable=attribute-defined-outside-init, arguments-differ
def on_epoch_begin(self, net, dataset_train=None, dataset_valid=None, **kwargs):
# Assume it is a number until proven otherwise.
batches_per_epoch = self.batches_per_epoch
if self.batches_per_epoch == 'auto':
batches_per_epoch = self._get_batches_per_epoch(
net, dataset_train, dataset_valid
)
elif self.batches_per_epoch == 'count':
if len(net.history) <= 1:
# No limit is known until the end of the first epoch.
batches_per_epoch = None
else:
batches_per_epoch = len(net.history[-2, 'batches'])
if self._use_notebook():
self.pbar = tqdm.tqdm_notebook(total=batches_per_epoch, leave=False)
else:
self.pbar = tqdm.tqdm(total=batches_per_epoch, leave=False)
def on_epoch_end(self, net, **kwargs):
self.pbar.close()
def rename_tensorboard_key(key):
"""Rename keys from history to keys in TensorBoard
Specifically, prefixes all names with "Loss/" if they seem to be
losses.
"""
if key.startswith('train') or key.startswith('valid'):
key = 'Loss/' + key
return key
class TensorBoard(Callback):
"""Logs results from history to TensorBoard
"TensorBoard provides the visualization and tooling needed for
machine learning experimentation" (tensorboard_)
Use this callback to automatically log all interesting values from
your net's history to tensorboard after each epoch.
The best way to log additional information is to subclass this
callback and add your code to one of the ``on_*`` methods.
Examples
--------
>>> # Example to log the bias parameter as a histogram
>>> def extract_bias(module):
... return module.hidden.bias
>>> class MyTensorBoard(TensorBoard):
... def on_epoch_end(self, net, **kwargs):
... bias = extract_bias(net.module_)
... epoch = net.history[-1, 'epoch']
... self.writer.add_histogram('bias', bias, global_step=epoch)
... super().on_epoch_end(net, **kwargs) # call super last
Parameters
----------
writer : torch.utils.tensorboard.writer.SummaryWriter
Instantiated ``SummaryWriter`` class.
close_after_train : bool (default=True)
Whether to close the ``SummaryWriter`` object once training
finishes. Set this parameter to False if you want to continue
logging with the same writer or if you use it as a context
manager.
keys_ignored : str or list of str (default=None)
Key or list of keys that should not be logged to
tensorboard. Note that in addition to the keys provided by the
user, keys such as those starting with 'event_' or ending on
'_best' are ignored by default.
key_mapper : callable or function (default=rename_tensorboard_key)
This function maps a key name from the history to a tag in
tensorboard. This is useful because tensorboard can
automatically group similar tags if their names start with the
same prefix, followed by a forward slash. By default, this
callback will prefix all keys that start with "train" or "valid"
with the "Loss/" prefix.
.. _tensorboard: https://www.tensorflow.org/tensorboard/
"""
def __init__(
self,
writer,
close_after_train=True,
keys_ignored=None,
key_mapper=rename_tensorboard_key,
):
self.writer = writer
self.close_after_train = close_after_train
self.keys_ignored = keys_ignored
self.key_mapper = key_mapper
def initialize(self):
self.first_batch_ = True
keys_ignored = self.keys_ignored
if isinstance(keys_ignored, str):
keys_ignored = [keys_ignored]
self.keys_ignored_ = set(keys_ignored or [])
self.keys_ignored_.add('batches')
return self
def on_batch_end(self, net, **kwargs):
self.first_batch_ = False
def add_scalar_maybe(self, history, key, tag, global_step=None):
"""Add a scalar value from the history to TensorBoard
Will catch errors like missing keys or wrong value types.
Parameters
----------
history : skorch.History
History object saved as attribute on the neural net.
key : str
Key of the desired value in the history.
tag : str
Name of the tag used in TensorBoard.
global_step : int or None
Global step value to record.
"""
hist = history[-1]
val = hist.get(key)
if val is None:
return
global_step = global_step if global_step is not None else hist['epoch']
with suppress(NotImplementedError):
# pytorch raises NotImplementedError on wrong types
self.writer.add_scalar(
tag=tag,
scalar_value=val,
global_step=global_step,
)
def on_epoch_end(self, net, **kwargs):
"""Automatically log values from the last history step."""
history = net.history
hist = history[-1]
epoch = hist['epoch']
for key in filter_log_keys(hist, keys_ignored=self.keys_ignored_):
tag = self.key_mapper(key)
self.add_scalar_maybe(history, key=key, tag=tag, global_step=epoch)
def on_train_end(self, net, **kwargs):
if self.close_after_train:
self.writer.close()