Skip to content

Commit

Permalink
Add madgrad as a possible optional dependency
Browse files Browse the repository at this point in the history
  • Loading branch information
AngledLuffa committed Nov 13, 2022
1 parent e735ecc commit 2706c4b
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 1 deletion.
6 changes: 6 additions & 0 deletions stanza/models/common/utils.py
Expand Up @@ -154,6 +154,12 @@ def get_optimizer(name, parameters, lr, betas=(0.9, 0.999), eps=1e-8, momentum=0
return torch.optim.Adamax(parameters, **extra_args) # use default lr
elif name == 'adadelta':
return torch.optim.Adadelta(parameters, **extra_args) # use default lr
elif name == 'madgrad':
try:
import madgrad
except ModuleNotFoundError as e:
raise ModuleNotFoundError("Could not create madgrad optimizer. Perhaps the madgrad package is not installed") from e
return madgrad.MADGRAD(parameters, lr=lr, momentum=momentum, **extra_args)
else:
raise ValueError("Unsupported optimizer: {}".format(name))

Expand Down
2 changes: 1 addition & 1 deletion stanza/models/tagger.py
Expand Up @@ -81,7 +81,7 @@ def parse_args(args=None):
parser.set_defaults(share_hid=False)

parser.add_argument('--sample_train', type=float, default=1.0, help='Subsample training data.')
parser.add_argument('--optim', type=str, default='adam', help='sgd, adagrad, adam, adamax, or adadelta.')
parser.add_argument('--optim', type=str, default='adam', help='sgd, adagrad, adam, adamax, or adadelta. madgrad as an optional dependency')
parser.add_argument('--lr', type=float, default=3e-3, help='Learning rate')
parser.add_argument('--initial_weight_decay', type=float, default=None, help='Optimizer weight decay for the first optimizer')
parser.add_argument('--second_weight_decay', type=float, default=None, help='Optimizer weight decay for the second optimizer')
Expand Down

0 comments on commit 2706c4b

Please sign in to comment.