Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added option to update parameters using state_dict in AveragedModel #65495

28 changes: 28 additions & 0 deletions test/test_optim.py
Expand Up @@ -2290,6 +2290,34 @@ def avg_fn(p_avg, p, n_avg):
for p_avg, p_swa in zip(averaged_params, averaged_dnn.parameters()):
self.assertEqual(p_avg, p_swa)

def test_averaged_model_exponential_use_state_dict(self):
# Test AveragedModel with EMA as avg_fn and use_state_dict as True.
dnn = torch.nn.Sequential(
torch.nn.Conv2d(1, 5, kernel_size=3),
prabhat00155 marked this conversation as resolved.
Show resolved Hide resolved
torch.nn.Linear(5, 10)
)
alpha = 0.9

def avg_fn(p_avg, p, n_avg):
return alpha * p_avg + (1 - alpha) * p
averaged_dnn = AveragedModel(dnn, avg_fn=avg_fn, use_state_dict=True)
averaged_params = [torch.zeros_like(param) for param in dnn.state_dict().values()]
n_updates = 10
for i in range(n_updates):
updated_averaged_params = []
for p, p_avg in zip(dnn.state_dict().values(), averaged_params):
p.detach().add_(torch.randn_like(p))
if i == 0:
updated_averaged_params.append(p.clone())
else:
updated_averaged_params.append((p_avg * alpha +
p * (1 - alpha)).clone())
averaged_dnn.update_parameters(dnn)
averaged_params = updated_averaged_params

for p_avg, p_swa in zip(averaged_params, averaged_dnn.module.state_dict().values()):
self.assertEqual(p_avg, p_swa)

def _test_update_bn(self, dnn, dl_x, dl_xy, cuda):

preactivation_sum = torch.zeros(dnn.n_features)
Expand Down
7 changes: 5 additions & 2 deletions torch/optim/swa_utils.py
Expand Up @@ -84,7 +84,7 @@ class AveragedModel(Module):
Generalizes Well:
https://arxiv.org/abs/2001.02312
"""
def __init__(self, model, device=None, avg_fn=None):
def __init__(self, model, device=None, avg_fn=None, use_state_dict=False):
prabhat00155 marked this conversation as resolved.
Show resolved Hide resolved
super(AveragedModel, self).__init__()
self.module = deepcopy(model)
if device is not None:
Expand All @@ -96,12 +96,15 @@ def avg_fn(averaged_model_parameter, model_parameter, num_averaged):
return averaged_model_parameter + \
(model_parameter - averaged_model_parameter) / (num_averaged + 1)
self.avg_fn = avg_fn
self.use_state_dict = use_state_dict

def forward(self, *args, **kwargs):
return self.module(*args, **kwargs)

def update_parameters(self, model):
for p_swa, p_model in zip(self.parameters(), model.parameters()):
self_param = self.module.state_dict().values() if self.use_state_dict else self.parameters()
model_param = model.state_dict().values() if self.use_state_dict else model.parameters()
for p_swa, p_model in zip(self_param, model_param):
device = p_swa.device
p_model_ = p_model.detach().to(device)
if self.n_averaged == 0:
Expand Down