diff --git a/torch/distributed/fsdp/sharded_grad_scaler.py b/torch/distributed/fsdp/sharded_grad_scaler.py index b79a50c636b25..47bfe041cdc22 100644 --- a/torch/distributed/fsdp/sharded_grad_scaler.py +++ b/torch/distributed/fsdp/sharded_grad_scaler.py @@ -15,7 +15,12 @@ def _refresh_per_optimizer_state() -> Dict[str, Any]: def _is_supported_device(tensor: torch.Tensor) -> bool: - return tensor.is_cuda or tensor.device.type in ("xla", "cpu", "hpu") + return tensor.is_cuda or tensor.device.type in ( + "xla", + "cpu", + "hpu", + torch._C._get_privateuse1_backend_name(), + ) class _GeneralMultiDeviceReplicator(_MultiDeviceReplicator): @@ -284,16 +289,16 @@ def unscale_(self, optimizer: torch.optim.Optimizer) -> None: optimizer_state = self._per_optimizer_states[id(optimizer)] works = [] found_inf_on_cpus = [] - found_inf_on_cudas = [] + found_inf_on_devices = [] for found_inf in optimizer_state["found_inf_per_device"].values(): - if self._device == "cuda" and found_inf.device.type == "cpu": + if self._device != "cpu" and found_inf.device.type == "cpu": found_inf_on_cpus.append(found_inf) - found_inf_on_cuda = found_inf.cuda() - found_inf_on_cudas.append(found_inf_on_cuda) + found_inf_on_device = found_inf.to(self._device) + found_inf_on_devices.append(found_inf_on_device) works.append( dist.all_reduce( - found_inf_on_cuda, async_op=True, group=self.process_group + found_inf_on_device, async_op=True, group=self.process_group ) ) else: @@ -303,7 +308,7 @@ def unscale_(self, optimizer: torch.optim.Optimizer) -> None: for work in works: work.wait() if found_inf_on_cpus: - torch._foreach_copy_(found_inf_on_cpus, found_inf_on_cudas) + torch._foreach_copy_(found_inf_on_cpus, found_inf_on_devices) def _amp_update_scale_cpu_(self, found_inf: torch.Tensor) -> None: """