Skip to content

Commit

Permalink
Merge pull request #51 from spokenlanguage/50-constant-adam
Browse files Browse the repository at this point in the history
50 constant adam
  • Loading branch information
cwmeijer committed Dec 16, 2020
2 parents 603b255 + c9e255e commit ebcd61f
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 3 deletions.
2 changes: 2 additions & 0 deletions platalea/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ def val_loss(net):
scheduler = platalea.schedulers.cyclic(optimizer, len(data['train']), max_lr=config['max_lr'], min_lr=config['min_lr'])
elif configured_scheduler == 'noam':
scheduler = platalea.schedulers.noam(optimizer, config['d_model'])
elif configured_scheduler == 'constant':
scheduler = platalea.schedulers.constant(optimizer, config['constant_lr'])
else:
raise Exception("lr_scheduler config value " + configured_scheduler + " is invalid, use cyclic or noam")
optimizer.zero_grad()
Expand Down
4 changes: 3 additions & 1 deletion platalea/experiments/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,14 @@ def __str__(self):
'--downsampling-factor', default=None, type=float,
dest='downsampling_factor',
help='factor by which the dataset should be downsampled')
args.add_argument('--lr_scheduler', default="cyclic", choices=['cyclic', 'noam'],
args.add_argument('--lr_scheduler', default="cyclic", choices=['cyclic', 'noam', 'constant'],
help='The learning rate scheduler to use. WARNING: noam not yet implemented for most experiments!')
args.add_argument('--cyclic_lr_max', default=2 * 1e-4, type=float,
help='Maximum learning rate for cyclic learning rate scheduler')
args.add_argument('--cyclic_lr_min', default=1e-6, type=float,
help='Minimum learning rate for cyclic learning rate scheduler')
args.add_argument('--constant_lr', default=1e-4, type=float,
help='Learning rate for constant learning rate scheduler')
args.add_argument('--device', type=str, default=None,
help="Device to train on. Can be passed on to platalea.hardware.device in experiments.")

Expand Down
5 changes: 3 additions & 2 deletions platalea/experiments/flickr8k/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# Parsing arguments
args.add_argument('--batch_size', default=32, type=int,
help='How many samples per batch to load.')

args.add_argument('--trafo_d_model', default=512, type=int,
help='TRANSFORMER: The dimensionality of the transformer model.')
args.add_argument('--trafo_encoder_layers', default=4, type=int,
Expand Down Expand Up @@ -77,7 +76,9 @@ def __new__(cls, value):
logging.info('Building model')
net = M.SpeechImage(config)
run_config = dict(max_lr=args.cyclic_lr_max, min_lr=args.cyclic_lr_min, epochs=args.epochs, lr_scheduler=args.lr_scheduler,
d_model=args.trafo_d_model, score_on_cpu=args.score_on_cpu, validate_on_cpu=args.validate_on_cpu)
d_model=args.trafo_d_model, score_on_cpu=args.score_on_cpu, validate_on_cpu=args.validate_on_cpu,
constant_lr=args.constant_lr,
)

logged_config = dict(run_config=run_config, encoder_config=config, speech_config=speech_config)
logged_config['encoder_config'].pop('SpeechEncoder') # Object info is redundant in log.
Expand Down
15 changes: 15 additions & 0 deletions platalea/schedulers.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,18 @@ def learning_rate(iteration):

scheduler = lr_scheduler.LambdaLR(optimizer, learning_rate)
return scheduler


def constant(optimizer, lr):
"""
Constant learning rate scheduler. The most trivial kind of scheduler, that keeps the learning rate constant.
:param optimizer:
:param lr:
:return:
"""
logging.info("Using constant learning rate of {}".format(lr))

def learning_rate(_iteration):
return lr

return lr_scheduler.LambdaLR(optimizer, learning_rate, last_epoch=-1)

0 comments on commit ebcd61f

Please sign in to comment.