-
Notifications
You must be signed in to change notification settings - Fork 388
/
net.py
2568 lines (2074 loc) · 92.9 KB
/
net.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""Neural net base class
This is the most flexible class, not making assumptions on the kind of
task being peformed. Subclass this to create more specialized and
sklearn-conforming classes like NeuralNetClassifier.
"""
import fnmatch
from collections.abc import Mapping
from functools import partial
from itertools import chain
from collections import OrderedDict
from contextlib import contextmanager
import tempfile
import warnings
import numpy as np
from sklearn.base import BaseEstimator
import torch
from torch.utils.data import DataLoader
from skorch.callbacks import EpochTimer
from skorch.callbacks import PrintLog
from skorch.callbacks import PassthroughScoring
from skorch.dataset import Dataset
from skorch.dataset import ValidSplit
from skorch.dataset import get_len
from skorch.dataset import unpack_data
from skorch.exceptions import DeviceWarning
from skorch.exceptions import SkorchAttributeError
from skorch.exceptions import SkorchTrainingImpossibleError
from skorch.history import History
from skorch.setter import optimizer_setter
from skorch.utils import _identity
from skorch.utils import _infer_predict_nonlinearity
from skorch.utils import FirstStepAccumulator
from skorch.utils import TeeGenerator
from skorch.utils import _check_f_arguments
from skorch.utils import check_is_fitted
from skorch.utils import duplicate_items
from skorch.utils import get_map_location
from skorch.utils import is_dataset
from skorch.utils import params_for
from skorch.utils import to_device
from skorch.utils import to_numpy
from skorch.utils import to_tensor
# pylint: disable=too-many-instance-attributes
class NeuralNet:
# pylint: disable=anomalous-backslash-in-string
"""NeuralNet base class.
The base class covers more generic cases. Depending on your use
case, you might want to use :class:`.NeuralNetClassifier` or
:class:`.NeuralNetRegressor`.
In addition to the parameters listed below, there are parameters
with specific prefixes that are handled separately. To illustrate
this, here is an example:
>>> net = NeuralNet(
... ...,
... optimizer=torch.optimizer.SGD,
... optimizer__momentum=0.95,
...)
This way, when ``optimizer`` is initialized, :class:`.NeuralNet`
will take care of setting the ``momentum`` parameter to 0.95.
(Note that the double underscore notation in
``optimizer__momentum`` means that the parameter ``momentum``
should be set on the object ``optimizer``. This is the same
semantic as used by sklearn.)
Furthermore, this allows to change those parameters later:
``net.set_params(optimizer__momentum=0.99)``
This can be useful when you want to change certain parameters
using a callback, when using the net in an sklearn grid search,
etc.
By default an :class:`.EpochTimer`, :class:`.BatchScoring` (for
both training and validation datasets), and :class:`.PrintLog`
callbacks are added for convenience.
Parameters
----------
module : torch module (class or instance)
A PyTorch :class:`~torch.nn.Module`. In general, the
uninstantiated class should be passed, although instantiated
modules will also work.
criterion : torch criterion (class)
The uninitialized criterion (loss) used to optimize the
module.
optimizer : torch optim (class, default=torch.optim.SGD)
The uninitialized optimizer (update rule) used to optimize the
module
lr : float (default=0.01)
Learning rate passed to the optimizer. You may use ``lr`` instead
of using ``optimizer__lr``, which would result in the same outcome.
max_epochs : int (default=10)
The number of epochs to train for each ``fit`` call. Note that you
may keyboard-interrupt training at any time.
batch_size : int (default=128)
Mini-batch size. Use this instead of setting
``iterator_train__batch_size`` and ``iterator_test__batch_size``,
which would result in the same outcome. If ``batch_size`` is -1,
a single batch with all the data will be used during training
and validation.
iterator_train : torch DataLoader
The default PyTorch :class:`~torch.utils.data.DataLoader` used for
training data.
iterator_valid : torch DataLoader
The default PyTorch :class:`~torch.utils.data.DataLoader` used for
validation and test data, i.e. during inference.
dataset : torch Dataset (default=skorch.dataset.Dataset)
The dataset is necessary for the incoming data to work with
pytorch's ``DataLoader``. It has to implement the ``__len__`` and
``__getitem__`` methods. The provided dataset should be capable of
dealing with a lot of data types out of the box, so only change
this if your data is not supported. You should generally pass the
uninitialized ``Dataset`` class and define additional arguments to
X and y by prefixing them with ``dataset__``. It is also possible
to pass an initialzed ``Dataset``, in which case no additional
arguments may be passed.
train_split : None or callable (default=skorch.dataset.ValidSplit(5))
If ``None``, there is no train/validation split. Else, ``train_split``
should be a function or callable that is called with X and y
data and should return the tuple ``dataset_train, dataset_valid``.
The validation data may be ``None``.
callbacks : None, "disable", or list of Callback instances (default=None)
Which callbacks to enable. There are three possible values:
If ``callbacks=None``, only use default callbacks,
those returned by ``get_default_callbacks``.
If ``callbacks="disable"``, disable all callbacks, i.e. do not run
any of the callbacks, not even the default callbacks.
If ``callbacks`` is a list of callbacks, use those callbacks in
addition to the default callbacks. Each callback should be an
instance of :class:`.Callback`.
Callback names are inferred from the class
name. Name conflicts are resolved by appending a count suffix
starting with 1, e.g. ``EpochScoring_1``. Alternatively,
a tuple ``(name, callback)`` can be passed, where ``name``
should be unique. Callbacks may or may not be instantiated.
The callback name can be used to set parameters on specific
callbacks (e.g., for the callback with name ``'print_log'``, use
``net.set_params(callbacks__print_log__keys_ignored=['epoch',
'train_loss'])``).
predict_nonlinearity : callable, None, or 'auto' (default='auto')
The nonlinearity to be applied to the prediction. When set to
'auto', infers the correct nonlinearity based on the criterion
(softmax for :class:`~torch.nn.CrossEntropyLoss` and sigmoid for
:class:`~torch.nn.BCEWithLogitsLoss`). If it cannot be inferred
or if the parameter is None, just use the identity
function. Don't pass a lambda function if you want the net to be
pickleable.
In case a callable is passed, it should accept the output of the
module (the first output if there is more than one), which is a
PyTorch tensor, and return the transformed PyTorch tensor.
This can be useful, e.g., when
:func:`~skorch.NeuralNetClassifier.predict_proba`
should return probabilities but a criterion is used that does
not expect probabilities. In that case, the module can return
whatever is required by the criterion and the
``predict_nonlinearity`` transforms this output into
probabilities.
The nonlinearity is applied only when calling
:func:`~skorch.classifier.NeuralNetClassifier.predict` or
:func:`~skorch.classifier.NeuralNetClassifier.predict_proba` but
not anywhere else -- notably, the loss is unaffected by this
nonlinearity.
warm_start : bool (default=False)
Whether each fit call should lead to a re-initialization of the
module (cold start) or whether the module should be trained
further (warm start).
verbose : int (default=1)
This parameter controls how much print output is generated by
the net and its callbacks. By setting this value to 0, e.g. the
summary scores at the end of each epoch are no longer printed.
This can be useful when running a hyperparameter search. The
summary scores are always logged in the history attribute,
regardless of the verbose setting.
device : str, torch.device, or None (default='cpu')
The compute device to be used. If set to 'cuda' in order to use
GPU acceleration, data in torch tensors will be pushed to cuda
tensors before being sent to the module. If set to None, then
all compute devices will be left unmodified.
Attributes
----------
prefixes_ : list of str
Contains the prefixes to special parameters. E.g., since there
is the ``'module'`` prefix, it is possible to set parameters like
so: ``NeuralNet(..., optimizer__momentum=0.95)``.
cuda_dependent_attributes_ : list of str
Contains a list of all attribute prefixes whose values depend on a
CUDA device. If a ``NeuralNet`` trained with a CUDA-enabled device
is unpickled on a machine without CUDA or with CUDA disabled, the
listed attributes are mapped to CPU. Expand this list if you
want to add other cuda-dependent attributes.
initialized_ : bool
Whether the :class:`.NeuralNet` was initialized.
module_ : torch module (instance)
The instantiated module.
criterion_ : torch criterion (instance)
The instantiated criterion.
callbacks_ : list of tuples
The complete (i.e. default and other), initialized callbacks, in
a tuple with unique names.
_modules : list of str
List of names of all modules that are torch modules. This list is
collected dynamically when the net is initialized. Typically, there is no
reason for a user to modify this list.
_criteria : list of str
List of names of all criteria that are torch modules. This list is
collected dynamically when the net is initialized. Typically, there is no
reason for a user to modify this list.
_optimizers : list of str
List of names of all optimizers. This list is collected dynamically when
the net is initialized. Typically, there is no reason for a user to modify
this list.
"""
prefixes_ = ['iterator_train', 'iterator_valid', 'callbacks', 'dataset']
cuda_dependent_attributes_ = []
# This attribute keeps track of which initialization method is being used.
# It should not be changed manually.
init_context_ = None
_modules = []
_criteria = []
_optimizers = []
# pylint: disable=too-many-arguments
def __init__(
self,
module,
criterion,
optimizer=torch.optim.SGD,
lr=0.01,
max_epochs=10,
batch_size=128,
iterator_train=DataLoader,
iterator_valid=DataLoader,
dataset=Dataset,
train_split=ValidSplit(5),
callbacks=None,
predict_nonlinearity='auto',
warm_start=False,
verbose=1,
device='cpu',
**kwargs
):
self.module = module
self.criterion = criterion
self.optimizer = optimizer
self.lr = lr
self.max_epochs = max_epochs
self.batch_size = batch_size
self.iterator_train = iterator_train
self.iterator_valid = iterator_valid
self.dataset = dataset
self.train_split = train_split
self.callbacks = callbacks
self.predict_nonlinearity = predict_nonlinearity
self.warm_start = warm_start
self.verbose = verbose
self.device = device
self._check_deprecated_params(**kwargs)
history = kwargs.pop('history', None)
initialized = kwargs.pop('initialized_', False)
virtual_params = kwargs.pop('virtual_params_', dict())
self._params_to_validate = set(kwargs.keys())
vars(self).update(kwargs)
self.history_ = history
self.initialized_ = initialized
self.virtual_params_ = virtual_params
@property
def history(self):
return self.history_
@history.setter
def history(self, value):
self.history_ = value
@property
def _default_callbacks(self):
return [
('epoch_timer', EpochTimer()),
('train_loss', PassthroughScoring(
name='train_loss',
on_train=True,
)),
('valid_loss', PassthroughScoring(
name='valid_loss',
)),
('print_log', PrintLog()),
]
def get_default_callbacks(self):
return self._default_callbacks
def notify(self, method_name, **cb_kwargs):
"""Call the callback method specified in ``method_name`` with
parameters specified in ``cb_kwargs``.
Method names can be one of:
* on_train_begin
* on_train_end
* on_epoch_begin
* on_epoch_end
* on_batch_begin
* on_batch_end
"""
getattr(self, method_name)(self, **cb_kwargs)
for _, cb in self.callbacks_:
getattr(cb, method_name)(self, **cb_kwargs)
# pylint: disable=unused-argument
def on_train_begin(self, net, X=None, y=None, **kwargs):
pass
# pylint: disable=unused-argument
def on_train_end(self, net, X=None, y=None, **kwargs):
pass
# pylint: disable=unused-argument
def on_epoch_begin(self, net, dataset_train=None, dataset_valid=None, **kwargs):
self.history.new_epoch()
self.history.record('epoch', len(self.history))
# pylint: disable=unused-argument
def on_epoch_end(self, net, dataset_train=None, dataset_valid=None, **kwargs):
pass
# pylint: disable=unused-argument
def on_batch_begin(self, net, batch=None, training=False, **kwargs):
self.history.new_batch()
def on_batch_end(self, net, batch=None, training=False, **kwargs):
pass
def on_grad_computed(
self, net, named_parameters, batch=None, training=False, **kwargs):
pass
def _yield_callbacks(self):
"""Yield all callbacks set on this instance including
a set whether its name was set by the user.
Handles these cases:
* default and user callbacks
* callbacks with and without name
* initialized and uninitialized callbacks
* puts PrintLog(s) last
Yields
------
name : str
Name of the callback.
cb : Callback or Callback instance
The callback itself
named_by_user : bool
Whether the name was given by the user or determined
automatically.
"""
print_logs = []
for item in self.get_default_callbacks() + (self.callbacks or []):
if isinstance(item, (tuple, list)):
named_by_user = True
name, cb = item
else:
named_by_user = False
cb = item
if isinstance(cb, type): # uninitialized:
name = cb.__name__
else:
name = cb.__class__.__name__
if isinstance(cb, PrintLog) or (cb == PrintLog):
print_logs.append((name, cb, named_by_user))
else:
yield name, cb, named_by_user
yield from print_logs
def _callbacks_grouped_by_name(self):
"""Group callbacks by name and collect names set by the user."""
callbacks, names_set_by_user = OrderedDict(), set()
for name, cb, named_by_user in self._yield_callbacks():
if named_by_user:
names_set_by_user.add(name)
callbacks[name] = callbacks.get(name, []) + [cb]
return callbacks, names_set_by_user
def _uniquely_named_callbacks(self):
"""Make sure that the returned dict of named callbacks is unique
w.r.t. to the callback name. User-defined names will not be
renamed on conflict, instead an exception will be raised. The
same goes for the event where renaming leads to a conflict.
"""
grouped_cbs, names_set_by_user = self._callbacks_grouped_by_name()
for name, cbs in grouped_cbs.items():
if len(cbs) > 1 and name in names_set_by_user:
raise ValueError("Found duplicate user-set callback name "
"'{}'. Use unique names to correct this."
.format(name))
for i, cb in enumerate(cbs):
if len(cbs) > 1:
unique_name = '{}_{}'.format(name, i+1)
if unique_name in grouped_cbs:
raise ValueError("Assigning new callback name failed "
"since new name '{}' exists already."
.format(unique_name))
else:
unique_name = name
yield unique_name, cb
def initialize_callbacks(self):
"""Initializes all callbacks and save the result in the
``callbacks_`` attribute.
Both ``default_callbacks`` and ``callbacks`` are used (in that
order). Callbacks may either be initialized or not, and if
they don't have a name, the name is inferred from the class
name. The ``initialize`` method is called on all callbacks.
The final result will be a list of tuples, where each tuple
consists of a name and an initialized callback. If names are
not unique, a ValueError is raised.
"""
callbacks_ = []
class Dummy:
# We cannot use None as dummy value since None is a
# legitimate value to be set.
pass
for name, cb in self._uniquely_named_callbacks():
# check if callback itself is changed
param_callback = getattr(self, 'callbacks__' + name, Dummy)
if param_callback is not Dummy: # callback itself was set
cb = param_callback
# below: check for callback params
# don't set a parameter for non-existing callback
params = self.get_params_for('callbacks__{}'.format(name))
if (cb is None) and params:
raise ValueError("Trying to set a parameter for callback {} "
"which does not exist.".format(name))
if cb is None:
continue
if isinstance(cb, type): # uninitialized:
cb = cb(**params)
else:
cb.set_params(**params)
cb.initialize()
callbacks_.append((name, cb))
# pylint: disable=attribute-defined-outside-init
self.callbacks_ = callbacks_
return self
def initialized_instance(self, instance_or_cls, kwargs):
"""Return an instance initialized with the given parameters
This is a helper method that deals with several possibilities for a
component that might need to be initialized:
* It is already an instance that's good to go
* It is an instance but it needs to be re-initialized
* It's not an instance and needs to be initialized
For the majority of use cases, this comes down to just comes down to
just initializing the class with its arguments.
Parameters
----------
instance_or_cls
The instance or class or callable to be initialized, e.g.
``self.module``.
kwargs : dict
The keyword arguments to initialize the instance or class. Can be an
empty dict.
Returns
-------
instance
The initialized component.
"""
is_init = isinstance(instance_or_cls, torch.nn.Module)
if is_init and not kwargs:
return instance_or_cls
if is_init:
return type(instance_or_cls)(**kwargs)
return instance_or_cls(**kwargs)
def initialize_criterion(self):
"""Initializes the criterion.
If the criterion is already initialized and no parameter was changed, it
will be left as is.
"""
kwargs = self.get_params_for('criterion')
criterion = self.initialized_instance(self.criterion, kwargs)
# pylint: disable=attribute-defined-outside-init
self.criterion_ = criterion
return self
def initialize_module(self):
"""Initializes the module.
If the module is already initialized and no parameter was changed, it
will be left as is.
"""
kwargs = self.get_params_for('module')
module = self.initialized_instance(self.module, kwargs)
# pylint: disable=attribute-defined-outside-init
self.module_ = module
return self
def _is_virtual_param(self, key):
return any(fnmatch.fnmatch(key, pat) for pat in self.virtual_params_)
def _virtual_setattr(self, param, val):
setattr(self, param, val)
def _register_virtual_param(self, param_patterns, fn=_virtual_setattr):
if not isinstance(param_patterns, list):
param_patterns = [param_patterns]
for pattern in param_patterns:
self.virtual_params_[pattern] = fn
def _apply_virtual_params(self, virtual_kwargs):
for pattern, fn in self.virtual_params_.items():
for key, val in virtual_kwargs.items():
if not fnmatch.fnmatch(key, pattern):
continue
fn(self, key, val)
def initialize_virtual_params(self):
self.virtual_params_ = {}
def initialize_optimizer(self, triggered_directly=None):
"""Initialize the model optimizer. If ``self.optimizer__lr``
is not set, use ``self.lr`` instead.
Parameters
----------
triggered_directly
Deprecated, don't use it anymore.
"""
# handle deprecated paramter
if triggered_directly is not None:
warnings.warn(
"The 'triggered_directly' argument to 'initialize_optimizer' is "
"deprecated, please don't use it anymore.", DeprecationWarning)
named_parameters = self.get_all_learnable_params()
args, kwargs = self.get_params_for_optimizer(
'optimizer', named_parameters)
# pylint: disable=attribute-defined-outside-init
self.optimizer_ = self.optimizer(*args, **kwargs)
return self
def initialize_history(self):
"""Initializes the history."""
self.history_ = History()
return self
def _format_reinit_msg(self, name, kwargs=None, triggered_directly=True):
"""Returns a message that informs about re-initializing a compoment.
Sometimes, the module or optimizer need to be
re-initialized. Not only should the user receive a message
about this but also should they be informed about what
parameters, if any, caused it.
"""
msg = "Re-initializing {}".format(name)
if triggered_directly and kwargs:
msg += (" because the following parameters were re-set: {}"
.format(', '.join(sorted(kwargs))))
msg += "."
return msg
@contextmanager
def _current_init_context(self, name):
try:
self.init_context_ = name
yield
finally:
self.init_context_ = None
def _initialize_virtual_params(self):
# this init context is for consistency and not being used at the moment
with self._current_init_context('virtual_params'):
self.initialize_virtual_params()
return self
def _initialize_callbacks(self):
# this init context is for consistency and not being used at the moment
with self._current_init_context('callbacks'):
if self.callbacks == "disable":
self.callbacks_ = []
return self
self.initialize_callbacks()
return self
def _initialize_criterion(self, reason=None):
# _initialize_criterion and _initialize_module share the same logic
with self._current_init_context('criterion'):
kwargs = {}
for criterion_name in self._criteria:
kwargs.update(self.get_params_for(criterion_name))
has_init_criterion = any(
isinstance(getattr(self, criterion_name + '_', None), torch.nn.Module)
for criterion_name in self._criteria)
# check if a re-init message is required
if kwargs or reason or has_init_criterion:
if self.initialized_ and self.verbose:
if reason:
# re-initialization was triggered indirectly
msg = reason
else:
# re-initialization was triggered directly
msg = self._format_reinit_msg("criterion", kwargs)
print(msg)
self.initialize_criterion()
# deal with device
for name in self._criteria:
criterion = getattr(self, name + '_')
if isinstance(criterion, torch.nn.Module):
setattr(self, name + '_', to_device(criterion, self.device))
return self
def _initialize_module(self, reason=None):
# _initialize_criterion and _initialize_module share the same logic
with self._current_init_context('module'):
kwargs = {}
for module_name in self._modules:
kwargs.update(self.get_params_for(module_name))
has_init_module = any(
isinstance(getattr(self, module_name + '_', None), torch.nn.Module)
for module_name in self._modules)
if kwargs or reason or has_init_module:
if self.initialized_ and self.verbose:
if reason:
# re-initialization was triggered indirectly
msg = reason
else:
# re-initialization was triggered directly
msg = self._format_reinit_msg("module", kwargs)
print(msg)
self.initialize_module()
# deal with device
for name in self._modules:
module = getattr(self, name + '_')
if isinstance(module, torch.nn.Module):
setattr(self, name + '_', to_device(module, self.device))
return self
def get_all_learnable_params(self):
"""Yield the learnable parameters of all modules
Typically, this will yield the ``named_parameters`` of the standard
module of the net. However, if you add custom modules or if your
criterion has learnable parameters, these are returned as well.
If you want your optimizer to only update the parameters of some but not
all modules, you should override :meth:`.initialize_module` and match
the corresponding modules and optimizers there:
.. code:: python
class MyNet(NeuralNet):
def initialize_optimizer(self, *args, **kwargs):
# first initialize the normal optimizer
named_params = self.module_.named_parameters()
args, kwargs = self.get_params_for_optimizer('optimizer', named_params)
self.optimizer_ = self.optimizer(*args, **kwargs)
# next add an another optimizer called 'optimizer2_' that is
# only responsible for training 'module2_'
named_params = self.module2_.named_parameters()
args, kwargs = self.get_params_for_optimizer('optimizer2', named_params)
self.optimizer2_ = torch.optim.SGD(*args, **kwargs)
return self
Yields
------
named_parameters : generator of parameter name and parameter
A generator over all module parameters, yielding both the name of the
parameter as well as the parameter itself. Use this, for instance, to
pass the named parameters to :meth:`.get_params_for_optimizer`.
"""
# Note: we have to filter out potential duplicate parameters. This can
# happen when a module references another module (e.g. the criterion
# references the module), thus yielding that module's parameters again.
# The parameter name can be difference, therefore we check only the
# identity of the parameter itself.
seen = set()
for name in self._modules + self._criteria:
module = getattr(self, name + '_')
named_parameters = getattr(module, 'named_parameters', None)
if not named_parameters:
continue
for param_name, param in named_parameters():
if param in seen:
continue
seen.add(param)
yield param_name, param
def _initialize_optimizer(self, reason=None):
with self._current_init_context('optimizer'):
if self.initialized_ and self.verbose:
if reason:
# re-initialization was triggered indirectly
msg = reason
else:
# re-initialization was triggered directly
msg = self._format_reinit_msg("optimizer", triggered_directly=False)
print(msg)
self.initialize_optimizer()
# register the virtual params for all optimizers
for name in self._optimizers:
param_pattern = [name + '__param_groups__*__*', name + '__*']
if name == 'optimizer': # 'lr' is short for optimizer__lr
param_pattern.append('lr')
setter = partial(
optimizer_setter,
optimizer_attr=name + '_',
optimizer_name=name,
)
self._register_virtual_param(param_pattern, setter)
return self
def _initialize_history(self):
# this init context is for consistency and not being used at the moment
with self._current_init_context('history'):
self.initialize_history()
return self
def initialize(self):
"""Initializes all of its components and returns self."""
self.check_training_readiness()
self._initialize_virtual_params()
self._initialize_callbacks()
self._initialize_module()
self._initialize_criterion()
self._initialize_optimizer()
self._initialize_history()
self._validate_params()
self.initialized_ = True
return self
def check_training_readiness(self):
"""Check that the net is ready to train"""
is_trimmed_for_prediction = getattr(self, '_trimmed_for_prediction', False)
if is_trimmed_for_prediction:
msg = (
"The net's attributes were trimmed for prediction, thus it cannot "
"be used for training anymore"
)
raise SkorchTrainingImpossibleError(msg)
def check_data(self, X, y=None):
pass
def _set_training(self, training=True):
"""Set training/evaluation mode on all modules and criteria that are torch
Modules.
Parameters
----------
training : bool (default=True)
Whether to set to training mode (True) or evaluation mode (False).
"""
for module_name in self._modules + self._criteria:
module = getattr(self, module_name + '_')
if isinstance(module, torch.nn.Module):
module.train(training)
def validation_step(self, batch, **fit_params):
"""Perform a forward step using batched data and return the
resulting loss.
The module is set to be in evaluation mode (e.g. dropout is
not applied).
Parameters
----------
batch
A single batch returned by the data loader.
**fit_params : dict
Additional parameters passed to the ``forward`` method of
the module and to the ``self.train_split`` call.
"""
self._set_training(False)
Xi, yi = unpack_data(batch)
with torch.no_grad():
y_pred = self.infer(Xi, **fit_params)
loss = self.get_loss(y_pred, yi, X=Xi, training=False)
return {
'loss': loss,
'y_pred': y_pred,
}
def train_step_single(self, batch, **fit_params):
"""Compute y_pred, loss value, and update net's gradients.
The module is set to be in train mode (e.g. dropout is
applied).
Parameters
----------
batch
A single batch returned by the data loader.
**fit_params : dict
Additional parameters passed to the ``forward`` method of
the module and to the ``self.train_split`` call.
Returns
-------
step : dict
A dictionary ``{'loss': loss, 'y_pred': y_pred}``, where the
float ``loss`` is the result of the loss function and
``y_pred`` the prediction generated by the PyTorch module.
"""
self._set_training(True)
Xi, yi = unpack_data(batch)
y_pred = self.infer(Xi, **fit_params)
loss = self.get_loss(y_pred, yi, X=Xi, training=True)
loss.backward()
return {
'loss': loss,
'y_pred': y_pred,
}
def get_train_step_accumulator(self):
"""Return the train step accumulator.
By default, the accumulator stores and retrieves the first
value from the optimizer call. Most optimizers make only one
call, so first value is at the same time the only value.
In case of some optimizers, e.g. LBFGS,
``train_step_calc_gradient`` is called multiple times, as the
loss function is evaluated multiple times per optimizer
call. If you don't want to return the first value in that
case, override this method to return your custom accumulator.
"""
return FirstStepAccumulator()
def _zero_grad_optimizer(self, set_to_none=None):
"""Zero out the gradient of all optimizers.
Parameters
----------
set_to_none : bool or None (default=None)
Whether to zero out gradients (default) or to set them to None by
passing True. Note that since this option is only available starting
from PyTorch 1.7, it is ignored by default (i.e. when its value is
None). For skorch to pass this value to the ``zero_grad`` call,
override this method and set the value to True or False.
The advantages and disadvantages of setting this value to True are
discussed here:
https://pytorch.org/docs/stable/optim.html#torch.optim.Optimizer.zero_grad
"""
for name in self._optimizers:
optimizer = getattr(self, name + '_')
if set_to_none is None:
optimizer.zero_grad()
else:
optimizer.zero_grad(set_to_none=set_to_none)
def _step_optimizer(self, step_fn):
"""Perform a ``step`` call on all optimizers.
Parameters
----------
step_fn : callable or None
If None, just call ``optimizer.step()`` without arguments. Else, this
will be passed as the training step closure to the optimizer(s). Note
that this could lead to the function being called multiple times. If
more fine-grained control is desired instead, please override the
:meth:`.train_step` method.
"""
for name in self._optimizers:
optimizer = getattr(self, name + '_')
if step_fn is None:
optimizer.step()
else:
optimizer.step(step_fn)
def train_step(self, batch, **fit_params):
"""Prepares a loss function callable and pass it to the optimizer,
hence performing one optimization step.
Loss function callable as required by some optimizers (and accepted by
all of them):
https://pytorch.org/docs/master/optim.html#optimizer-step-closure
The module is set to be in train mode (e.g. dropout is
applied).
Parameters
----------
batch
A single batch returned by the data loader.
**fit_params : dict
Additional parameters passed to the ``forward`` method of
the module and to the train_split call.
Returns
-------
step : dict
A dictionary ``{'loss': loss, 'y_pred': y_pred}``, where the
float ``loss`` is the result of the loss function and
``y_pred`` the prediction generated by the PyTorch module.