Skip to content

Commit

Permalink
Add privateuse1 in FSDP's sharded grad scaler (pytorch#126971)
Browse files Browse the repository at this point in the history
1. add privateuse1 in FSDP's sharded grad scaler
2. support found_inf copy for more devices

Pull Request resolved: pytorch#126971
Approved by: https://github.com/awgu, https://github.com/weifengpy
  • Loading branch information
accelerate321 authored and titaiwangms committed May 28, 2024
1 parent 78a2b80 commit 415b28f
Showing 1 changed file with 12 additions and 7 deletions.
19 changes: 12 additions & 7 deletions torch/distributed/fsdp/sharded_grad_scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,12 @@ def _refresh_per_optimizer_state() -> Dict[str, Any]:


def _is_supported_device(tensor: torch.Tensor) -> bool:
return tensor.is_cuda or tensor.device.type in ("xla", "cpu", "hpu")
return tensor.is_cuda or tensor.device.type in (
"xla",
"cpu",
"hpu",
torch._C._get_privateuse1_backend_name(),
)


class _GeneralMultiDeviceReplicator(_MultiDeviceReplicator):
Expand Down Expand Up @@ -284,16 +289,16 @@ def unscale_(self, optimizer: torch.optim.Optimizer) -> None:
optimizer_state = self._per_optimizer_states[id(optimizer)]
works = []
found_inf_on_cpus = []
found_inf_on_cudas = []
found_inf_on_devices = []

for found_inf in optimizer_state["found_inf_per_device"].values():
if self._device == "cuda" and found_inf.device.type == "cpu":
if self._device != "cpu" and found_inf.device.type == "cpu":
found_inf_on_cpus.append(found_inf)
found_inf_on_cuda = found_inf.cuda()
found_inf_on_cudas.append(found_inf_on_cuda)
found_inf_on_device = found_inf.to(self._device)
found_inf_on_devices.append(found_inf_on_device)
works.append(
dist.all_reduce(
found_inf_on_cuda, async_op=True, group=self.process_group
found_inf_on_device, async_op=True, group=self.process_group
)
)
else:
Expand All @@ -303,7 +308,7 @@ def unscale_(self, optimizer: torch.optim.Optimizer) -> None:
for work in works:
work.wait()
if found_inf_on_cpus:
torch._foreach_copy_(found_inf_on_cpus, found_inf_on_cudas)
torch._foreach_copy_(found_inf_on_cpus, found_inf_on_devices)

def _amp_update_scale_cpu_(self, found_inf: torch.Tensor) -> None:
"""
Expand Down

0 comments on commit 415b28f

Please sign in to comment.