Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added option to update parameters using state_dict in AveragedModel #65495

32 changes: 32 additions & 0 deletions test/test_optim.py
Expand Up @@ -2290,6 +2290,38 @@ def avg_fn(p_avg, p, n_avg):
for p_avg, p_swa in zip(averaged_params, averaged_dnn.parameters()):
self.assertEqual(p_avg, p_swa)

def test_averaged_model_exponential_use_state_dict(self):
# Test AveragedModel with EMA as avg_fn and use_state_dict as True.
dnn = torch.nn.Sequential(
torch.nn.Conv2d(1, 5, kernel_size=3),
prabhat00155 marked this conversation as resolved.
Show resolved Hide resolved
torch.nn.BatchNorm2d(5, momentum=0.3),
torch.nn.Linear(5, 10)
)
alpha = 0.9

def avg_fn(p_avg, p, n_avg):
return alpha * p_avg + (1 - alpha) * p
averaged_dnn = AveragedModel(dnn, avg_fn=avg_fn, mode='state_dict')
averaged_params = [torch.zeros_like(param) for param in dnn.state_dict().values()
if param.size() != torch.Size([])]
n_updates = 10
for i in range(n_updates):
updated_averaged_params = []
for p, p_avg in zip(dnn.state_dict().values(), averaged_params):
if p.size() == torch.Size([]):
continue
p.detach().add_(torch.randn_like(p))
if i == 0:
updated_averaged_params.append(p.clone())
else:
updated_averaged_params.append((p_avg * alpha +
p * (1 - alpha)).clone())
averaged_dnn.update_parameters(dnn)
averaged_params = updated_averaged_params

for p_avg, p_swa in zip(averaged_params, averaged_dnn.module.state_dict().values()):
self.assertEqual(p_avg, p_swa)

def _test_update_bn(self, dnn, dl_x, dl_xy, cuda):

preactivation_sum = torch.zeros(dnn.n_features)
Expand Down
9 changes: 7 additions & 2 deletions torch/optim/swa_utils.py
Expand Up @@ -26,6 +26,8 @@ class AveragedModel(Module):
:class:`AveragedModel` parameter, the current value of :attr:`model`
parameter and the number of models already averaged; if None,
equally weighted average is used (default: None)
mode (str, optional): whether to use parameters or state_dict for update
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should document what are the valid values for this argument.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

parameters and state_dict are valid values here as stated in line 29. Although, anything passed other than state_dict defaults to parameters.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ho yes, what I meant is that this is missing " to make it clear that these are strings.
It's a nit though.

(default: parameters)

Example:
>>> loader, optimizer, model, loss_fn = ...
Expand Down Expand Up @@ -84,7 +86,7 @@ class AveragedModel(Module):
Generalizes Well:
https://arxiv.org/abs/2001.02312
"""
def __init__(self, model, device=None, avg_fn=None):
def __init__(self, model, device=None, avg_fn=None, mode='parameters'):
super(AveragedModel, self).__init__()
self.module = deepcopy(model)
if device is not None:
Expand All @@ -96,12 +98,15 @@ def avg_fn(averaged_model_parameter, model_parameter, num_averaged):
return averaged_model_parameter + \
(model_parameter - averaged_model_parameter) / (num_averaged + 1)
self.avg_fn = avg_fn
self.use_state_dict = mode == 'state_dict'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That looks very britle if the user pass another string as input.


def forward(self, *args, **kwargs):
return self.module(*args, **kwargs)

def update_parameters(self, model):
for p_swa, p_model in zip(self.parameters(), model.parameters()):
self_param = self.module.state_dict().values() if self.use_state_dict else self.parameters()
model_param = model.state_dict().values() if self.use_state_dict else model.parameters()
for p_swa, p_model in zip(self_param, model_param):
device = p_swa.device
p_model_ = p_model.detach().to(device)
if self.n_averaged == 0:
Expand Down