Skip to content

Commit

Permalink
adding complex support for distributed functions and . fix #45760
Browse files Browse the repository at this point in the history
ghstack-source-id: ff291d0f75451c20b1cdfd7d93738fb252a60fc9
Pull Request resolved: #45879
  • Loading branch information
bdhirsh committed Oct 6, 2020
1 parent f65ab89 commit 075ff23
Show file tree
Hide file tree
Showing 2 changed files with 148 additions and 51 deletions.
15 changes: 15 additions & 0 deletions torch/distributed/distributed_c10d.py
Expand Up @@ -892,6 +892,8 @@ def all_reduce_multigpu(tensor_list,
if _rank_not_in_group(group):
return

tensor_list = [t if not t.is_complex() else torch.view_as_real(t) for t in tensor_list]

opts = AllreduceOptions()
opts.reduceOp = op
if group == GroupMember.WORLD:
Expand Down Expand Up @@ -934,6 +936,8 @@ def all_reduce(tensor,
if _rank_not_in_group(group):
return

tensor = tensor if not tensor.is_complex() else torch.view_as_real(tensor)

opts = AllreduceOptions()
opts.reduceOp = op
if group == GroupMember.WORLD:
Expand Down Expand Up @@ -985,6 +989,8 @@ def all_reduce_coalesced(tensors,
if _rank_not_in_group(group):
return

tensors = [t if not t.is_complex() else torch.view_as_real(t) for t in tensors]

opts = AllreduceCoalescedOptions()
opts.reduceOp = op
if group == GroupMember.WORLD:
Expand Down Expand Up @@ -1149,6 +1155,9 @@ def all_gather_multigpu(output_tensor_lists,
if _rank_not_in_group(group):
return

output_tensor_lists = [[t if not t.is_complex() else torch.view_as_real(t) for t in l] for l in output_tensor_lists]
input_tensor_list = [t if not t.is_complex() else torch.view_as_real(t) for t in input_tensor_list]

if group == GroupMember.WORLD:
_check_default_pg()
work = _default_pg.allgather(output_tensor_lists, input_tensor_list)
Expand Down Expand Up @@ -1414,6 +1423,9 @@ def all_gather(tensor_list,
if _rank_not_in_group(group):
return

tensor_list = [t if not t.is_complex() else torch.view_as_real(t) for t in tensor_list]
tensor = tensor if not tensor.is_complex() else torch.view_as_real(tensor)

if group == GroupMember.WORLD:
_check_default_pg()
work = _default_pg.allgather([tensor_list], [tensor])
Expand Down Expand Up @@ -1480,6 +1492,9 @@ def all_gather_coalesced(output_tensor_lists,
for output_tensor_list in output_tensor_lists:
_check_tensor_list(output_tensor_list, "output_tensor_lists")

output_tensor_lists = [[t if not t.is_complex() else torch.view_as_real(t) for t in l] for l in output_tensor_lists]
input_tensor_list = [t if not t.is_complex() else torch.view_as_real(t) for t in input_tensor_list]

if group == GroupMember.WORLD:
_check_default_pg()
work = _default_pg.allgather_coalesced(
Expand Down

0 comments on commit 075ff23

Please sign in to comment.