diff --git a/test/test_optim.py b/test/test_optim.py index 2d88d6f4bdabc..4db1a4997bcab 100644 --- a/test/test_optim.py +++ b/test/test_optim.py @@ -2290,6 +2290,38 @@ def avg_fn(p_avg, p, n_avg): for p_avg, p_swa in zip(averaged_params, averaged_dnn.parameters()): self.assertEqual(p_avg, p_swa) + def test_averaged_model_exponential_use_state_dict(self): + # Test AveragedModel with EMA as avg_fn and use_state_dict as True. + dnn = torch.nn.Sequential( + torch.nn.Conv2d(1, 5, kernel_size=3), + torch.nn.BatchNorm2d(5, momentum=0.3), + torch.nn.Linear(5, 10) + ) + alpha = 0.9 + + def avg_fn(p_avg, p, n_avg): + return alpha * p_avg + (1 - alpha) * p + averaged_dnn = AveragedModel(dnn, avg_fn=avg_fn, mode='state_dict') + averaged_params = [torch.zeros_like(param) for param in dnn.state_dict().values() + if param.size() != torch.Size([])] + n_updates = 10 + for i in range(n_updates): + updated_averaged_params = [] + for p, p_avg in zip(dnn.state_dict().values(), averaged_params): + if p.size() == torch.Size([]): + continue + p.detach().add_(torch.randn_like(p)) + if i == 0: + updated_averaged_params.append(p.clone()) + else: + updated_averaged_params.append((p_avg * alpha + + p * (1 - alpha)).clone()) + averaged_dnn.update_parameters(dnn) + averaged_params = updated_averaged_params + + for p_avg, p_swa in zip(averaged_params, averaged_dnn.module.state_dict().values()): + self.assertEqual(p_avg, p_swa) + def _test_update_bn(self, dnn, dl_x, dl_xy, cuda): preactivation_sum = torch.zeros(dnn.n_features) diff --git a/torch/optim/swa_utils.py b/torch/optim/swa_utils.py index a143ffd13df69..a186e4327979e 100644 --- a/torch/optim/swa_utils.py +++ b/torch/optim/swa_utils.py @@ -26,6 +26,8 @@ class AveragedModel(Module): :class:`AveragedModel` parameter, the current value of :attr:`model` parameter and the number of models already averaged; if None, equally weighted average is used (default: None) + mode (str, optional): whether to use parameters or state_dict for update + (default: parameters) Example: >>> loader, optimizer, model, loss_fn = ... @@ -84,7 +86,7 @@ class AveragedModel(Module): Generalizes Well: https://arxiv.org/abs/2001.02312 """ - def __init__(self, model, device=None, avg_fn=None): + def __init__(self, model, device=None, avg_fn=None, mode='parameters'): super(AveragedModel, self).__init__() self.module = deepcopy(model) if device is not None: @@ -96,12 +98,15 @@ def avg_fn(averaged_model_parameter, model_parameter, num_averaged): return averaged_model_parameter + \ (model_parameter - averaged_model_parameter) / (num_averaged + 1) self.avg_fn = avg_fn + self.use_state_dict = mode == 'state_dict' def forward(self, *args, **kwargs): return self.module(*args, **kwargs) def update_parameters(self, model): - for p_swa, p_model in zip(self.parameters(), model.parameters()): + self_param = self.module.state_dict().values() if self.use_state_dict else self.parameters() + model_param = model.state_dict().values() if self.use_state_dict else model.parameters() + for p_swa, p_model in zip(self_param, model_param): device = p_swa.device p_model_ = p_model.detach().to(device) if self.n_averaged == 0: