Skip to content

Commit

Permalink
Merge pull request chainer#3488 from Crissman/opt-setup
Browse files Browse the repository at this point in the history
Enable optimizer model setup with instantiation
  • Loading branch information
mitmul committed Nov 21, 2017
2 parents a054cd6 + 5011639 commit 016cffa
Show file tree
Hide file tree
Showing 10 changed files with 26 additions and 19 deletions.
4 changes: 3 additions & 1 deletion chainer/optimizer.py
Expand Up @@ -484,9 +484,11 @@ class GradientMethod(Optimizer):
"""

def __init__(self):
def __init__(self, link=None):
super(GradientMethod, self).__init__()
self.hyperparam = Hyperparameter()
if isinstance(link, link_module.Link):
self.setup(link)

def setup(self, link):
super(GradientMethod, self).setup(link)
Expand Down
4 changes: 2 additions & 2 deletions chainer/optimizers/ada_delta.py
Expand Up @@ -84,8 +84,8 @@ class AdaDelta(optimizer.GradientMethod):
"""

def __init__(self, rho=_default_hyperparam.rho,
eps=_default_hyperparam.eps):
super(AdaDelta, self).__init__()
eps=_default_hyperparam.eps, model=None):
super(AdaDelta, self).__init__(model)
self.hyperparam.rho = rho
self.hyperparam.eps = eps

Expand Down
5 changes: 3 additions & 2 deletions chainer/optimizers/ada_grad.py
Expand Up @@ -74,8 +74,9 @@ class AdaGrad(optimizer.GradientMethod):
"""

def __init__(self, lr=_default_hyperparam.lr, eps=_default_hyperparam.eps):
super(AdaGrad, self).__init__()
def __init__(self, lr=_default_hyperparam.lr,
eps=_default_hyperparam.eps, model=None):
super(AdaGrad, self).__init__(model)
self.hyperparam.lr = lr
self.hyperparam.eps = eps

Expand Down
5 changes: 3 additions & 2 deletions chainer/optimizers/adam.py
Expand Up @@ -110,8 +110,9 @@ def __init__(self,
alpha=_default_hyperparam.alpha,
beta1=_default_hyperparam.beta1,
beta2=_default_hyperparam.beta2,
eps=_default_hyperparam.eps):
super(Adam, self).__init__()
eps=_default_hyperparam.eps,
model=None):
super(Adam, self).__init__(model)
self.hyperparam.alpha = alpha
self.hyperparam.beta1 = beta1
self.hyperparam.beta2 = beta2
Expand Down
4 changes: 2 additions & 2 deletions chainer/optimizers/momentum_sgd.py
Expand Up @@ -69,8 +69,8 @@ class MomentumSGD(optimizer.GradientMethod):
"""

def __init__(self, lr=_default_hyperparam.lr,
momentum=_default_hyperparam.momentum):
super(MomentumSGD, self).__init__()
momentum=_default_hyperparam.momentum, model=None):
super(MomentumSGD, self).__init__(model)
self.hyperparam.lr = lr
self.hyperparam.momentum = momentum

Expand Down
4 changes: 2 additions & 2 deletions chainer/optimizers/nesterov_ag.py
Expand Up @@ -76,8 +76,8 @@ class NesterovAG(optimizer.GradientMethod):
"""

def __init__(self, lr=_default_hyperparam.lr,
momentum=_default_hyperparam.momentum):
super(NesterovAG, self).__init__()
momentum=_default_hyperparam.momentum, model=None):
super(NesterovAG, self).__init__(model)
self.hyperparam.lr = lr
self.hyperparam.momentum = momentum

Expand Down
5 changes: 3 additions & 2 deletions chainer/optimizers/rmsprop.py
Expand Up @@ -91,8 +91,9 @@ class RMSprop(optimizer.GradientMethod):
"""

def __init__(self, lr=_default_hyperparam.lr,
alpha=_default_hyperparam.alpha, eps=_default_hyperparam.eps):
super(RMSprop, self).__init__()
alpha=_default_hyperparam.alpha, eps=_default_hyperparam.eps,
model=None):
super(RMSprop, self).__init__(model)
self.hyperparam.lr = lr
self.hyperparam.alpha = alpha
self.hyperparam.eps = eps
Expand Down
5 changes: 3 additions & 2 deletions chainer/optimizers/rmsprop_graves.py
Expand Up @@ -102,8 +102,9 @@ class RMSpropGraves(optimizer.GradientMethod):
def __init__(self, lr=_default_hyperparam.lr,
alpha=_default_hyperparam.alpha,
momentum=_default_hyperparam.momentum,
eps=_default_hyperparam.eps):
super(RMSpropGraves, self).__init__()
eps=_default_hyperparam.eps,
model=None):
super(RMSpropGraves, self).__init__(model)
self.hyperparam.lr = lr
self.hyperparam.alpha = alpha
self.hyperparam.momentum = momentum
Expand Down
4 changes: 2 additions & 2 deletions chainer/optimizers/sgd.py
Expand Up @@ -50,8 +50,8 @@ class SGD(optimizer.GradientMethod):
"""

def __init__(self, lr=_default_hyperparam.lr):
super(SGD, self).__init__()
def __init__(self, lr=_default_hyperparam.lr, model=None):
super(SGD, self).__init__(model)
self.hyperparam.lr = lr

lr = optimizer.HyperparameterProxy('lr')
Expand Down
5 changes: 3 additions & 2 deletions chainer/optimizers/smorms3.py
Expand Up @@ -87,8 +87,9 @@ class SMORMS3(optimizer.GradientMethod):
"""

def __init__(self, lr=_default_hyperparam.lr, eps=_default_hyperparam.eps):
super(SMORMS3, self).__init__()
def __init__(self, lr=_default_hyperparam.lr,
eps=_default_hyperparam.eps, model=None):
super(SMORMS3, self).__init__(model)
self.hyperparam.lr = lr
self.hyperparam.eps = eps

Expand Down

0 comments on commit 016cffa

Please sign in to comment.