Skip to content

Commit

Permalink
Added option to update parameters using state_dict in AveragedModel (p…
Browse files Browse the repository at this point in the history
…ytorch#65495)

Summary:
While implementing [EMA](pytorch/vision#4381 extends AveragedModel) in torchvision, update_parameters() from AveragedModel could not be used as it did not handle state_dict(), so a custom update_parameters() needed to be defined in [EMA class](pytorch/vision#4406). This PR aims to handle this scenario removing the need for this custom update_parameters() implementation.

Discussion: pytorch/vision#4406 (review)

Pull Request resolved: pytorch#65495

Reviewed By: datumbox

Differential Revision: D31176742

Pulled By: prabhat00155

fbshipit-source-id: 326d14876018f21cf602bab5eaba344678dbabe2
  • Loading branch information
prabhat00155 authored and facebook-github-bot committed Sep 28, 2021
1 parent 3324bae commit 2ea724b
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 2 deletions.
32 changes: 32 additions & 0 deletions test/test_optim.py
Expand Up @@ -2290,6 +2290,38 @@ def avg_fn(p_avg, p, n_avg):
for p_avg, p_swa in zip(averaged_params, averaged_dnn.parameters()):
self.assertEqual(p_avg, p_swa)

def test_averaged_model_exponential_use_state_dict(self):
# Test AveragedModel with EMA as avg_fn and use_state_dict as True.
dnn = torch.nn.Sequential(
torch.nn.Conv2d(1, 5, kernel_size=3),
torch.nn.BatchNorm2d(5, momentum=0.3),
torch.nn.Linear(5, 10)
)
alpha = 0.9

def avg_fn(p_avg, p, n_avg):
return alpha * p_avg + (1 - alpha) * p
averaged_dnn = AveragedModel(dnn, avg_fn=avg_fn, mode='state_dict')
averaged_params = [torch.zeros_like(param) for param in dnn.state_dict().values()
if param.size() != torch.Size([])]
n_updates = 10
for i in range(n_updates):
updated_averaged_params = []
for p, p_avg in zip(dnn.state_dict().values(), averaged_params):
if p.size() == torch.Size([]):
continue
p.detach().add_(torch.randn_like(p))
if i == 0:
updated_averaged_params.append(p.clone())
else:
updated_averaged_params.append((p_avg * alpha +
p * (1 - alpha)).clone())
averaged_dnn.update_parameters(dnn)
averaged_params = updated_averaged_params

for p_avg, p_swa in zip(averaged_params, averaged_dnn.module.state_dict().values()):
self.assertEqual(p_avg, p_swa)

def _test_update_bn(self, dnn, dl_x, dl_xy, cuda):

preactivation_sum = torch.zeros(dnn.n_features)
Expand Down
9 changes: 7 additions & 2 deletions torch/optim/swa_utils.py
Expand Up @@ -26,6 +26,8 @@ class AveragedModel(Module):
:class:`AveragedModel` parameter, the current value of :attr:`model`
parameter and the number of models already averaged; if None,
equally weighted average is used (default: None)
mode (str, optional): whether to use parameters or state_dict for update
(default: parameters)
Example:
>>> loader, optimizer, model, loss_fn = ...
Expand Down Expand Up @@ -84,7 +86,7 @@ class AveragedModel(Module):
Generalizes Well:
https://arxiv.org/abs/2001.02312
"""
def __init__(self, model, device=None, avg_fn=None):
def __init__(self, model, device=None, avg_fn=None, mode='parameters'):
super(AveragedModel, self).__init__()
self.module = deepcopy(model)
if device is not None:
Expand All @@ -96,12 +98,15 @@ def avg_fn(averaged_model_parameter, model_parameter, num_averaged):
return averaged_model_parameter + \
(model_parameter - averaged_model_parameter) / (num_averaged + 1)
self.avg_fn = avg_fn
self.use_state_dict = mode == 'state_dict'

def forward(self, *args, **kwargs):
return self.module(*args, **kwargs)

def update_parameters(self, model):
for p_swa, p_model in zip(self.parameters(), model.parameters()):
self_param = self.module.state_dict().values() if self.use_state_dict else self.parameters()
model_param = model.state_dict().values() if self.use_state_dict else model.parameters()
for p_swa, p_model in zip(self_param, model_param):
device = p_swa.device
p_model_ = p_model.detach().to(device)
if self.n_averaged == 0:
Expand Down

0 comments on commit 2ea724b

Please sign in to comment.