diff --git a/stanza/models/depparse/trainer.py b/stanza/models/depparse/trainer.py index 3c05a41d4..2ecc5146a 100644 --- a/stanza/models/depparse/trainer.py +++ b/stanza/models/depparse/trainer.py @@ -58,14 +58,14 @@ def __init__(self, args=None, vocab=None, pretrain=None, model_file=None, wandb.watch(self.model, log_freq=4, log="all", log_graph=True) def __init_optim(self): - if not self.args.get("second_stage", False) and self.args.get('second_optim'): + if not (self.args.get("second_stage", False) and self.args.get('second_optim')): self.optimizer = utils.get_optimizer(self.args['optim'], self.model, self.args['lr'], betas=(0.9, self.args['beta2']), eps=1e-6, bert_learning_rate=self.args.get('bert_learning_rate', 0.0)) else: self.optimizer = utils.get_optimizer(self.args['second_optim'], self.model, - self.args['second_lr'], betas=(0.9, self.args['beta2']), eps=1e-6, - bert_learning_rate=self.args.get('second_bert_learning_rate', 0.0)) + self.args['second_lr'], betas=(0.9, self.args['beta2']), eps=1e-6, + bert_learning_rate=self.args.get('second_bert_learning_rate', 0.0)) def update(self, batch, eval=False):