Skip to content

Commit

Permalink
add multi ema support for custom device
Browse files Browse the repository at this point in the history
  • Loading branch information
heidongxianhua committed Jun 10, 2023
1 parent 1eb762c commit 68ec831
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions torch/optim/swa_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch
from torch.nn import Module
from torch.optim.lr_scheduler import LRScheduler
from torch.utils._foreach_utils import _get_foreach_kernels_supported_devices

__all__ = [
'AveragedModel',
Expand Down Expand Up @@ -184,8 +185,7 @@ def update_parameters(self, model):
self_param_detached = []
model_param_detached = []
for p_averaged, p_model in zip(self_param, model_param):
device = p_averaged.device
p_model_ = p_model.detach().to(device)
p_model_ = p_model.detach().to(p_averaged.device)
self_param_detached.append(p_averaged.detach())
model_param_detached.append(p_model_)
if self.n_averaged == 0:
Expand All @@ -197,7 +197,7 @@ def update_parameters(self, model):
for ((device, _), ([self_params, model_params], _)) in grouped_tensors.items():
if self.multi_avg_fn:
self.multi_avg_fn(self_params, model_params, self.n_averaged.to(device))
elif device.type == 'cuda':
elif device.type in _get_foreach_kernels_supported_devices():
multi_avg_fn = get_swa_multi_avg_fn()
multi_avg_fn(self_params, model_params, self.n_averaged.to(device))
else:
Expand Down

0 comments on commit 68ec831

Please sign in to comment.