diff --git a/ring_flash_attn/utils.py b/ring_flash_attn/utils.py index e305ae8..d150295 100644 --- a/ring_flash_attn/utils.py +++ b/ring_flash_attn/utils.py @@ -88,6 +88,10 @@ def send_recv( send_rank = (self.rank + 1) % self.world_size recv_rank = (self.rank - 1) % self.world_size + if self._process_group is not None: + send_rank = dist.get_global_rank(self._process_group, send_rank) + recv_rank = dist.get_global_rank(self._process_group, recv_rank) + send_op = dist.P2POp(dist.isend, to_send, send_rank, group=self._process_group) recv_op = dist.P2POp(dist.irecv, res, recv_rank, group=self._process_group) self._ops.append(send_op)