Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Exponential moving averages of model parameters #321

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,5 @@ dependencies:
- pytest
- psutil
- gxx<12
- pip:
RaulPPelaez marked this conversation as resolved.
Show resolved Hide resolved
- torch-ema
11 changes: 11 additions & 0 deletions torchmdnet/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from torchmdnet.models.model import create_model, load_model
from torchmdnet.models.utils import dtype_mapping
import torch_geometric.transforms as T
from torch_ema import ExponentialMovingAverage


class FloatCastDatasetWrapper(T.BaseTransform):
Expand Down Expand Up @@ -73,6 +74,11 @@ def __init__(self, hparams, prior_model=None, mean=None, std=None):
else:
self.model = create_model(self.hparams, prior_model, mean, std)

self.ema_prmtrs = None
if self.hparams.ema_prmtrs_decay is not None:
RaulPPelaez marked this conversation as resolved.
Show resolved Hide resolved
# initialize EMA for the model paremeters
self.ema_prmtrs = ExponentialMovingAverage(self.model.parameters(), decay=self.hparams.ema_prmtrs_decay)

# initialize exponential smoothing
self.ema = None
self._reset_ema_dict()
Expand Down Expand Up @@ -251,6 +257,11 @@ def optimizer_step(self, *args, **kwargs):
pg["lr"] = lr_scale * self.hparams.lr
super().optimizer_step(*args, **kwargs)
optimizer.zero_grad()

def on_before_zero_grad(self, *args, **kwargs):
if self.ema_prmtrs is not None:
self.ema_prmtrs.to(self.device)
self.ema_prmtrs.update(self.model.parameters())

def _get_mean_loss_dict_for_type(self, type):
# Returns a list with the mean loss for each loss_fn for each stage (train, val, test)
Expand Down
1 change: 1 addition & 0 deletions torchmdnet/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def get_argparse():
parser.add_argument('--redirect', type=bool, default=False, help='Redirect stdout and stderr to log_dir/log')
parser.add_argument('--gradient-clipping', type=float, default=0.0, help='Gradient clipping norm')
parser.add_argument('--remove-ref-energy', action='store_true', help='If true, remove the reference energy from the dataset for delta-learning. Total energy can still be predicted by the model during inference by turning this flag off when loading. The dataset must be compatible with Atomref for this to be used.')
parser.add_argument('--ema-prmtrs-decay', type=float, default=None, help='Exponential moving average decay for model parameters (None to disable)')
# dataset specific
parser.add_argument('--dataset', default=None, type=str, choices=datasets.__all__, help='Name of the torch_geometric dataset')
parser.add_argument('--dataset-root', default='~/data', type=str, help='Data storage directory (not used if dataset is "CG")')
Expand Down
Loading