Skip to content

Commit

Permalink
add multi swa support for custom device (#103297)
Browse files Browse the repository at this point in the history
Fixes #ISSUE_NUMBER
add multi swa support for custom device
Pull Request resolved: #103297
Approved by: https://github.com/janeyx99
  • Loading branch information
heidongxianhua authored and pytorchmergebot committed Jun 10, 2023
1 parent daf75c0 commit 900226f
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions torch/optim/swa_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch
from torch.nn import Module
from torch.optim.lr_scheduler import LRScheduler
from torch.utils._foreach_utils import _get_foreach_kernels_supported_devices

__all__ = [
'AveragedModel',
Expand Down Expand Up @@ -184,8 +185,7 @@ def update_parameters(self, model):
self_param_detached = []
model_param_detached = []
for p_averaged, p_model in zip(self_param, model_param):
device = p_averaged.device
p_model_ = p_model.detach().to(device)
p_model_ = p_model.detach().to(p_averaged.device)
self_param_detached.append(p_averaged.detach())
model_param_detached.append(p_model_)
if self.n_averaged == 0:
Expand All @@ -197,7 +197,7 @@ def update_parameters(self, model):
for ((device, _), ([self_params, model_params], _)) in grouped_tensors.items():
if self.multi_avg_fn:

This comment has been minimized.

Copy link
@vadimkantorov

vadimkantorov Jun 10, 2023

Contributor

should this be more explicit if self.multi_avg_fn is not None: (to avoid trying to cast functions to bool)?

self.multi_avg_fn(self_params, model_params, self.n_averaged.to(device))
elif device.type == 'cuda':
elif device.type in _get_foreach_kernels_supported_devices():
multi_avg_fn = get_swa_multi_avg_fn()
multi_avg_fn(self_params, model_params, self.n_averaged.to(device))
else:
Expand Down

0 comments on commit 900226f

Please sign in to comment.