Skip to content

Commit

Permalink
bugfix
Browse files Browse the repository at this point in the history
  • Loading branch information
zhuzilin committed Feb 29, 2024
1 parent 129ef58 commit 7895974
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions ring_flash_attn/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,10 @@ def send_recv(
send_rank = (self.rank + 1) % self.world_size
recv_rank = (self.rank - 1) % self.world_size

if self._process_group is not None:
send_rank = dist.get_global_rank(self._process_group, send_rank)
recv_rank = dist.get_global_rank(self._process_group, recv_rank)

This comment has been minimized.

Copy link
@fmmoret

fmmoret May 1, 2024

What bug was this fixing?
Wouldn't it already be correct since send_op = dist.P2POp(dist.isend, to_send, send_rank, group=self._process_group) specifies group=self._process_group?

This seems like it's creating a bug the way I'm reading it

This comment has been minimized.

Copy link
@fmmoret

fmmoret May 1, 2024

Okay -- I can totally see how this bug happened in the first place. The torch API is very bad here:
https://pytorch.org/docs/stable/distributed.html#torch.distributed.isend

Destination rank on global process group (regardless of group argument)


No need to fix anything -- your implementation looks correct. Leaving my documentation here for other people to learn from ^


send_op = dist.P2POp(dist.isend, to_send, send_rank, group=self._process_group)
recv_op = dist.P2POp(dist.irecv, res, recv_rank, group=self._process_group)
self._ops.append(send_op)
Expand Down

0 comments on commit 7895974

Please sign in to comment.