/
regularization.py
48 lines (35 loc) · 1.46 KB
/
regularization.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
""" Post-process regularization steps such as gradient normalizing. """
from torch.nn.utils import clip_grad_norm_
from skorch.callbacks import Callback
__all__ = ['GradientNormClipping']
class GradientNormClipping(Callback):
"""Clips gradient norm of a module's parameters.
The norm is computed over all gradients together, as if they were
concatenated into a single vector. Gradients are modified
in-place.
See :func:`torch.nn.utils.clip_grad_norm_` for more information.
Parameters
----------
gradient_clip_value : float (default=None)
If not None, clip the norm of all model parameter gradients to this
value. The type of the norm is determined by the
``gradient_clip_norm_type`` parameter and defaults to L2.
gradient_clip_norm_type : float (default=2)
Norm to use when gradient clipping is active. The default is
to use L2-norm. Can be 'inf' for infinity norm.
"""
def __init__(
self,
gradient_clip_value=None,
gradient_clip_norm_type=2,
):
self.gradient_clip_value = gradient_clip_value
self.gradient_clip_norm_type = gradient_clip_norm_type
def on_grad_computed(self, _, named_parameters, **kwargs):
if self.gradient_clip_value is None:
return
clip_grad_norm_(
(p for _, p in named_parameters),
max_norm=self.gradient_clip_value,
norm_type=self.gradient_clip_norm_type,
)