Skip to content

Commit

Permalink
Merge pull request #330 from RaulPPelaez/ema_fix
Browse files Browse the repository at this point in the history
Disable EMA if the user inputs 0 for alpha
  • Loading branch information
RaulPPelaez committed Jun 20, 2024
2 parents 440f985 + 59ea4f5 commit e908988
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion torchmdnet/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def _update_loss_with_ema(self, stage, type, loss_name, loss):
# loss_name: name of the loss function
# loss: loss value
alpha = getattr(self.hparams, f"ema_alpha_{type}")
if stage in ["train", "val"] and alpha < 1:
if stage in ["train", "val"] and alpha < 1 and alpha > 0:
ema = (
self.ema[stage][type][loss_name]
if loss_name in self.ema[stage][type]
Expand Down
4 changes: 2 additions & 2 deletions torchmdnet/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ def get_argparse():
parser.add_argument('--early-stopping-patience', type=int, default=30, help='Stop training after this many epochs without improvement')
parser.add_argument('--reset-trainer', type=bool, default=False, help='Reset training metrics (e.g. early stopping, lr) when loading a model checkpoint')
parser.add_argument('--weight-decay', type=float, default=0.0, help='Weight decay strength')
parser.add_argument('--ema-alpha-y', type=float, default=1.0, help='The amount of influence of new losses on the exponential moving average of y')
parser.add_argument('--ema-alpha-neg-dy', type=float, default=1.0, help='The amount of influence of new losses on the exponential moving average of dy')
parser.add_argument('--ema-alpha-y', type=float, default=1.0, help='The amount of influence of new losses on the exponential moving average of y. Must be between 0 and 1.')
parser.add_argument('--ema-alpha-neg-dy', type=float, default=1.0, help='The amount of influence of new losses on the exponential moving average of dy. Must be between 0 and 1.')
parser.add_argument('--ngpus', type=int, default=-1, help='Number of GPUs, -1 use all available. Use CUDA_VISIBLE_DEVICES=1, to decide gpus')
parser.add_argument('--num-nodes', type=int, default=1, help='Number of GPU nodes for distributed training with the Lightning Trainer.')
parser.add_argument('--precision', type=int, default=32, choices=[16, 32, 64], help='Floating point precision')
Expand Down

0 comments on commit e908988

Please sign in to comment.