Skip to content

tr7200/Adam_to_SGD

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

17 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Adam2SGD

This is a modified version of Keras' EarlyStopping callback that switches to the SGD optimizer from Adam following arXiv 1712.07628:

Keskar, N. S., & Socher, R. (2017).
Improving generalization performance by switching from adam to sgd.
arXiv preprint arXiv:1712.07628.
  • The callback monitors learning rate according to (4) from arXiv 1712.07628
  • If condition (4) from that paper is satisfied, the callback stops training early and starts training using separate SWATS function (Switching from Adam To SGD) with SGD optimizer that uses the learning rate that satisfied (4).

Usage:

model = Sequential()
...
model.compile(...)    


AdamToSGD_ = [AdamToSGD(after_training_with_Adam=SWATS(x=train_x, 
                                                       y=train_y,
                                                       ...))]
                                                       
                                                       
def SWATS(momentum=0.0,    # SGD optimizer arguments
          nesterov=False
          ...,
          loss='mse',      # compile arguments
          ...,
          x=None,          # model.fit statements
          y=None,
          **kwargs):
    """
    This user-defined function restarts training if condition 4 from 
    1712.07628 is satisfied in the callback.
    Define optimizer, compile it, and fit model again in one function.
    """
    lr = float(K.get_value(model.optimizer.lr))
    bias_corrected_exponential_avg = lr / (1. - K.get_value(model.optimizer.beta_2))
 
    if (K.abs(bias_corrected_exponential_avg - lr) < 1e-9) is not None:
        return
    else:
        SGD_optimizer = SGD(lr=bia_corrected_exponential_avg,
                            ...)
     
        model.compile(optimizer=SGD_optimizer,
                      ...)
                   
        print('\nNow switching to SGD...\n')
     
        model.fit(x=x,
                  y=y,
                  ...)
  
 
result = model.fit(train_x,
                   train_y,
                   callbacks=[AdamToSGD_, ...],
                   ...)

If condition (4) from arXiv 1712.07628 is satisfied, training will end early and restart with the user-defined SWATS function using the SGD optimizer with the last learning rate value from Adam before that condition.

Tensorflow < 2.0, Keras 2.3.1 or lower.

This callback is more suitable for training with image or text data for hundreds of epochs.

python setup.py install to install.

Update 2021-05-21

If you're having difficulty running this or implementing it in TF 2.0, just train with the Adam optimizer and change the early stopping callback to monitor the learning rate (stop at the LR value from paper). Then manually restart training using the SGD optimizer. All this callback does is automate that whole process.

MIT License

Releases

No releases published

Packages

No packages published

Languages