Skip to content

Commit

Permalink
Added validation of mode parameter in AveragedModel (pytorch#65921)
Browse files Browse the repository at this point in the history
Summary:
Discussion: pytorch#65495 (comment)

Pull Request resolved: pytorch#65921

Reviewed By: albanD

Differential Revision: D31310105

Pulled By: prabhat00155

fbshipit-source-id: 417691832a7c793744830c11e0ce53e3972d21a3
  • Loading branch information
prabhat00155 authored and facebook-github-bot committed Oct 3, 2021
1 parent 0fc6bd2 commit c7748fc
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions torch/optim/swa_utils.py
Expand Up @@ -26,8 +26,8 @@ class AveragedModel(Module):
:class:`AveragedModel` parameter, the current value of :attr:`model`
parameter and the number of models already averaged; if None,
equally weighted average is used (default: None)
mode (str, optional): whether to use parameters or state_dict for update
(default: parameters)
mode (str, optional): whether to use ``'parameters'`` or ``'state_dict'`` for update
(default: ``'parameters'``)
Example:
>>> loader, optimizer, model, loss_fn = ...
Expand Down Expand Up @@ -98,6 +98,9 @@ def avg_fn(averaged_model_parameter, model_parameter, num_averaged):
return averaged_model_parameter + \
(model_parameter - averaged_model_parameter) / (num_averaged + 1)
self.avg_fn = avg_fn
modes = ['parameters', 'state_dict']
if mode not in modes:
raise ValueError(f'Invalid mode passed, valid values are {", ".join(modes)}.')
self.use_state_dict = mode == 'state_dict'

def forward(self, *args, **kwargs):
Expand Down

0 comments on commit c7748fc

Please sign in to comment.