Skip to content

Commit

Permalink
update wrappers for pytorch 1.7.0 and after.
Browse files Browse the repository at this point in the history
  • Loading branch information
TsumiNa committed Jun 3, 2021
1 parent 860adf5 commit e36ca0e
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 21 deletions.
10 changes: 6 additions & 4 deletions xenonpy/model/training/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
# Use of this source code is governed by a BSD-style
# license that can be found in the LICENSE file.

__all__ = ['NLLLoss', 'NLLLoss2d', 'L1Loss', 'MSELoss', 'CrossEntropyLoss', 'CTCLoss', 'PoissonNLLLoss', 'KLDivLoss',
'BCELoss', 'BCEWithLogitsLoss', 'MarginRankingLoss', 'HingeEmbeddingLoss', 'MultiLabelMarginLoss',
'SmoothL1Loss', 'SoftMarginLoss', 'MultiLabelSoftMarginLoss', 'CosineEmbeddingLoss', 'MultiMarginLoss',
'TripletMarginLoss']
__all__ = [
'NLLLoss', 'NLLLoss2d', 'L1Loss', 'MSELoss', 'CrossEntropyLoss', 'CTCLoss', 'PoissonNLLLoss', 'KLDivLoss',
'BCELoss', 'BCEWithLogitsLoss', 'MarginRankingLoss', 'HingeEmbeddingLoss', 'MultiLabelMarginLoss', 'SmoothL1Loss',
'SoftMarginLoss', 'MultiLabelSoftMarginLoss', 'CosineEmbeddingLoss', 'MultiMarginLoss', 'TripletMarginLoss',
'GaussianNLLLoss', 'TripletMarginWithDistanceLoss', 'PairwiseDistance'
]

from torch.nn.modules.loss import *
34 changes: 26 additions & 8 deletions xenonpy/model/training/lr_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,9 +147,17 @@ def __init__(self, *, T_max, eta_min=0, last_epoch=-1):

class ReduceLROnPlateau(BaseLRScheduler):

def __init__(self, *, mode='min', factor=0.1, patience=10,
verbose=False, threshold=1e-4, threshold_mode='rel',
cooldown=0, min_lr=0, eps=1e-8):
def __init__(self,
*,
mode='min',
factor=0.1,
patience=10,
verbose=False,
threshold=1e-4,
threshold_mode='rel',
cooldown=0,
min_lr=0,
eps=1e-8):
"""Reduce learning rate when a metric has stopped improving.
Models often benefit from reducing the learning rate by a factor
of 2-10 once learning stagnates. This scheduler reads a metrics
Expand Down Expand Up @@ -188,14 +196,23 @@ def __init__(self, *, mode='min', factor=0.1, patience=10,
ignored. Default: 1e-8.
"""
super().__init__(lr_scheduler.ReduceLROnPlateau, mode=mode, factor=factor, patience=patience,
verbose=verbose, threshold=threshold, threshold_mode=threshold_mode,
cooldown=cooldown, min_lr=min_lr, eps=eps)
super().__init__(lr_scheduler.ReduceLROnPlateau,
mode=mode,
factor=factor,
patience=patience,
verbose=verbose,
threshold=threshold,
threshold_mode=threshold_mode,
cooldown=cooldown,
min_lr=min_lr,
eps=eps)


class CyclicLR(BaseLRScheduler):

def __init__(self, *, base_lr,
def __init__(self,
*,
base_lr,
max_lr,
step_size_up=2000,
step_size_down=None,
Expand Down Expand Up @@ -285,7 +302,8 @@ def __init__(self, *, base_lr,
.. _Cyclical Learning Rates for Training Neural Networks: https://arxiv.org/abs/1506.01186
.. _bckenstler/CLR: https://github.com/bckenstler/CLR
"""
super().__init__(lr_scheduler.CyclicLR, base_lr=base_lr,
super().__init__(lr_scheduler.CyclicLR,
base_lr=base_lr,
max_lr=max_lr,
step_size_up=step_size_up,
step_size_down=step_size_down,
Expand Down
73 changes: 64 additions & 9 deletions xenonpy/model/training/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,17 +46,20 @@ def __init__(self, *, lr=0.01, lr_decay=0, weight_decay=0, initial_accumulator_v
.. _Adaptive Subgradient Methods for Online Learning and Stochastic
Optimization: http://jmlr.org/papers/v12/duchi11a.html
"""
super().__init__(optim.Adagrad, lr=lr, lr_decay=lr_decay, weight_decay=weight_decay,
super().__init__(optim.Adagrad,
lr=lr,
lr_decay=lr_decay,
weight_decay=weight_decay,
initial_accumulator_value=initial_accumulator_value)


class Adam(BaseOptimizer):

def __init__(self, *, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
weight_decay=0, amsgrad=False):
def __init__(self, *, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False):
r"""Implements Adam algorithm.
It has been proposed in `Adam: A Method for Stochastic Optimization`_.
The implementation of the L2 penalty follows changes proposed in `Decoupled Weight Decay Regularization`_.
Arguments:
lr (float, optional): learning rate (default: 1e-3)
Expand All @@ -71,13 +74,45 @@ def __init__(self, *, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
.. _Adam\: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
.. _Decoupled Weight Decay Regularization:
https://arxiv.org/abs/1711.05101
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
"""

super().__init__(optim.Adam, lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad)


class AdamW(BaseOptimizer):

def __init__(self, *, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False):
r"""Implements AdamW algorithm.
The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_.
The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_.
Arguments:
lr (float, optional): learning rate (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
(default: False)
.. _Adam\: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
.. _Decoupled Weight Decay Regularization:
https://arxiv.org/abs/1711.05101
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
"""

super().__init__(optim.AdamW, lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad)


class SparseAdam(BaseOptimizer):

def __init__(self, *, lr=0.001, betas=(0.9, 0.999), eps=1e-08):
Expand Down Expand Up @@ -145,8 +180,14 @@ def __init__(self, *, lr=0.002, lambd=0.0001, alpha=0.75, t0=1000000.0, weight_d

class LBFGS(BaseOptimizer):

def __init__(self, *, lr=1, max_iter=20, max_eval=None,
tolerance_grad=1e-5, tolerance_change=1e-9, history_size=100,
def __init__(self,
*,
lr=1,
max_iter=20,
max_eval=None,
tolerance_grad=1e-5,
tolerance_change=1e-9,
history_size=100,
line_search_fn=None):
"""Implements L-BFGS algorithm.
Expand Down Expand Up @@ -176,8 +217,13 @@ def __init__(self, *, lr=1, max_iter=20, max_eval=None,
history_size (int): update history size (default: 100).
"""

super().__init__(optim.LBFGS, lr=lr, max_iter=max_iter, max_eval=max_eval,
tolerance_grad=tolerance_grad, tolerance_change=tolerance_change, history_size=history_size,
super().__init__(optim.LBFGS,
lr=lr,
max_iter=max_iter,
max_eval=max_eval,
tolerance_grad=tolerance_grad,
tolerance_change=tolerance_change,
history_size=history_size,
line_search_fn=line_search_fn)


Expand All @@ -204,7 +250,12 @@ def __init__(self, *, lr=0.01, alpha=0.99, eps=1e-08, weight_decay=0, momentum=0
"""

super().__init__(optim.RMSprop, lr=lr, alpha=alpha, eps=eps, weight_decay=weight_decay, momentum=momentum,
super().__init__(optim.RMSprop,
lr=lr,
alpha=alpha,
eps=eps,
weight_decay=weight_decay,
momentum=momentum,
centered=centered)


Expand Down Expand Up @@ -272,5 +323,9 @@ def __init__(self, *, lr=0.01, momentum=0, dampening=0, weight_decay=0, nesterov
The Nesterov version is analogously modified.
"""

super().__init__(optim.SGD, lr=lr, momentum=momentum, dampening=dampening, weight_decay=weight_decay,
super().__init__(optim.SGD,
lr=lr,
momentum=momentum,
dampening=dampening,
weight_decay=weight_decay,
nesterov=nesterov)

0 comments on commit e36ca0e

Please sign in to comment.