diff --git a/torchmdnet/module.py b/torchmdnet/module.py index 8fee82af..d5ea73cf 100644 --- a/torchmdnet/module.py +++ b/torchmdnet/module.py @@ -175,7 +175,7 @@ def _update_loss_with_ema(self, stage, type, loss_name, loss): # loss_name: name of the loss function # loss: loss value alpha = getattr(self.hparams, f"ema_alpha_{type}") - if stage in ["train", "val"] and alpha < 1: + if stage in ["train", "val"] and alpha < 1 and alpha > 0: ema = ( self.ema[stage][type][loss_name] if loss_name in self.ema[stage][type] diff --git a/torchmdnet/scripts/train.py b/torchmdnet/scripts/train.py index 2e69212b..7f2d8e07 100644 --- a/torchmdnet/scripts/train.py +++ b/torchmdnet/scripts/train.py @@ -42,8 +42,8 @@ def get_argparse(): parser.add_argument('--early-stopping-patience', type=int, default=30, help='Stop training after this many epochs without improvement') parser.add_argument('--reset-trainer', type=bool, default=False, help='Reset training metrics (e.g. early stopping, lr) when loading a model checkpoint') parser.add_argument('--weight-decay', type=float, default=0.0, help='Weight decay strength') - parser.add_argument('--ema-alpha-y', type=float, default=1.0, help='The amount of influence of new losses on the exponential moving average of y') - parser.add_argument('--ema-alpha-neg-dy', type=float, default=1.0, help='The amount of influence of new losses on the exponential moving average of dy') + parser.add_argument('--ema-alpha-y', type=float, default=1.0, help='The amount of influence of new losses on the exponential moving average of y. Must be between 0 and 1.') + parser.add_argument('--ema-alpha-neg-dy', type=float, default=1.0, help='The amount of influence of new losses on the exponential moving average of dy. Must be between 0 and 1.') parser.add_argument('--ngpus', type=int, default=-1, help='Number of GPUs, -1 use all available. Use CUDA_VISIBLE_DEVICES=1, to decide gpus') parser.add_argument('--num-nodes', type=int, default=1, help='Number of GPU nodes for distributed training with the Lightning Trainer.') parser.add_argument('--precision', type=int, default=32, choices=[16, 32, 64], help='Floating point precision')