Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Exponential moving averages of model parameters #321

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 33 additions & 10 deletions torchmdnet/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch
from torch.optim import AdamW
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torch.optim.swa_utils
from torch.nn.functional import local_response_norm, mse_loss, l1_loss
from torch import Tensor
from typing import Optional, Dict, Tuple
Expand Down Expand Up @@ -73,9 +74,26 @@ def __init__(self, hparams, prior_model=None, mean=None, std=None):
else:
self.model = create_model(self.hparams, prior_model, mean, std)

# initialize exponential smoothing
self.ema = None
self._reset_ema_dict()
self.ema_model = None
if (
"ema_parameters_decay" in self.hparams
and self.hparams.ema_parameters_decay is not None
):
self.ema_model = torch.optim.swa_utils.AveragedModel(
self.model,
multi_avg_fn=torch.optim.swa_utils.get_ema_multi_avg_fn(
self.hparams.ema_parameters_decay
),
)
self.ema_parameters_start = (
self.hparams.ema_parameters_start
if "ema_parameters_start" in self.hparams
else 0
)

# initialize exponential smoothing for the losses
self.ema_loss = None
self._reset_ema_loss_dict()

# initialize loss collection
self.losses = None
Expand Down Expand Up @@ -177,12 +195,12 @@ def _update_loss_with_ema(self, stage, type, loss_name, loss):
alpha = getattr(self.hparams, f"ema_alpha_{type}")
if stage in ["train", "val"] and alpha < 1 and alpha > 0:
ema = (
self.ema[stage][type][loss_name]
if loss_name in self.ema[stage][type]
self.ema_loss[stage][type][loss_name]
if loss_name in self.ema_loss[stage][type]
else loss.detach()
)
loss = alpha * loss + (1 - alpha) * ema
self.ema[stage][type][loss_name] = loss.detach()
self.ema_loss[stage][type][loss_name] = loss.detach()
return loss

def step(self, batch, loss_fn_list, stage):
Expand Down Expand Up @@ -250,6 +268,11 @@ def optimizer_step(self, *args, **kwargs):
for pg in optimizer.param_groups:
pg["lr"] = lr_scale * self.hparams.lr
super().optimizer_step(*args, **kwargs)
if (
self.trainer.current_epoch >= self.ema_parameters_start
and self.ema_model is not None
):
self.ema_model.update_parameters(self.model)
optimizer.zero_grad()

def _get_mean_loss_dict_for_type(self, type):
Expand Down Expand Up @@ -304,9 +327,9 @@ def _reset_losses_dict(self):
for loss_type in ["total", "y", "neg_dy"]:
self.losses[stage][loss_type] = defaultdict(list)

def _reset_ema_dict(self):
self.ema = {}
def _reset_ema_loss_dict(self):
self.ema_loss = {}
for stage in ["train", "val"]:
self.ema[stage] = {}
self.ema_loss[stage] = {}
for loss_type in ["y", "neg_dy"]:
self.ema[stage][loss_type] = {}
self.ema_loss[stage][loss_type] = {}
2 changes: 2 additions & 0 deletions torchmdnet/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ def get_argparse():
parser.add_argument('--redirect', type=bool, default=False, help='Redirect stdout and stderr to log_dir/log')
parser.add_argument('--gradient-clipping', type=float, default=0.0, help='Gradient clipping norm')
parser.add_argument('--remove-ref-energy', action='store_true', help='If true, remove the reference energy from the dataset for delta-learning. Total energy can still be predicted by the model during inference by turning this flag off when loading. The dataset must be compatible with Atomref for this to be used.')
parser.add_argument('--ema-parameters-decay', type=float, default=None, help='Exponential moving average decay for model parameters (defaults to None, meaning disable). The decay is a parameter between 0 and 1 that controls how fast the averaged parameters are decayed.')
parser.add_argument('--ema-parameters-start', type=int, default=0, help='Epoch to start averaging the parameters.')
# dataset specific
parser.add_argument('--dataset', default=None, type=str, choices=datasets.__all__, help='Name of the torch_geometric dataset')
parser.add_argument('--dataset-root', default='~/data', type=str, help='Data storage directory (not used if dataset is "CG")')
Expand Down
Loading