/
base.py
65 lines (44 loc) · 2.03 KB
/
base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
""" Basic callback definition. """
import warnings
from sklearn.base import BaseEstimator
from skorch.exceptions import SkorchWarning
__all__ = ['Callback']
class Callback:
"""Base class for callbacks.
All custom callbacks should inherit from this class. The subclass
may override any of the ``on_...`` methods. It is, however, not
necessary to override all of them, since it's okay if they don't
have any effect.
Classes that inherit from this also gain the ``get_params`` and
``set_params`` method.
"""
def initialize(self):
"""(Re-)Set the initial state of the callback. Use this
e.g. if the callback tracks some state that should be reset
when the model is re-initialized.
This method should return self.
"""
return self
def on_train_begin(self, net, X=None, y=None, **kwargs):
"""Called at the beginning of training."""
def on_train_end(self, net, X=None, y=None, **kwargs):
"""Called at the end of training."""
def on_epoch_begin(self, net, dataset_train=None, dataset_valid=None, **kwargs):
"""Called at the beginning of each epoch."""
def on_epoch_end(self, net, dataset_train=None, dataset_valid=None, **kwargs):
"""Called at the end of each epoch."""
def on_batch_begin(self, net, batch=None, training=None, **kwargs):
"""Called at the beginning of each batch."""
def on_batch_end(self, net, batch=None, training=None, **kwargs):
"""Called at the end of each batch."""
def on_grad_computed(
self, net, named_parameters, X=None, y=None, training=None, **kwargs):
"""Called once per batch after gradients have been computed but before
an update step was performed.
"""
def _get_param_names(self):
return [key for key in self.__dict__ if not key.endswith('_')]
def get_params(self, deep=True):
return BaseEstimator.get_params(self, deep=deep)
def set_params(self, **params):
BaseEstimator.set_params(self, **params)