Skip to content

Commit

Permalink
Introduce use_sync_collectives configuration (#2010)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2010

Introduce use_sync_collectives global config similar to gradient_division.

Replace is_torchdynamo_compiling to use from pt2/checks

test_pt2_multiprocess needs this setting for compilation.

Reviewed By: dstaay-fb, gnahzg

Differential Revision: D57442917

fbshipit-source-id: a66c2bd6ea36a5f4eec76ce272e2b9bf484b782f
  • Loading branch information
Ivan Kobzarev authored and facebook-github-bot committed May 16, 2024
1 parent 85d6cfe commit 6e8d374
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 17 deletions.
48 changes: 31 additions & 17 deletions torchrec/distributed/comm_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

# pyre-strict

from contextlib import contextmanager
from dataclasses import dataclass, field
from typing import Any, List, Optional, Tuple, TypeVar

Expand All @@ -19,6 +20,7 @@
from torch.autograd.profiler import record_function
from torchrec.distributed.types import Awaitable, NoWait, QuantizedCommCodecs
from torchrec.distributed.utils import none_throws
from torchrec.pt2.checks import is_torchdynamo_compiling

try:
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops")
Expand All @@ -34,18 +36,11 @@
pass


try:
from torch.compiler import is_dynamo_compiling as is_torchdynamo_compiling
except Exception:

def is_torchdynamo_compiling() -> bool: # type: ignore[misc]
return False


W = TypeVar("W")

# TODO: T96382816, NE Parity Backward compatibility
GRADIENT_DIVISION: bool = True
USE_SYNC_COLLECTIVES: bool = False


def set_gradient_division(val: bool) -> None:
Expand All @@ -58,6 +53,25 @@ def get_gradient_division() -> bool:
return GRADIENT_DIVISION


def set_use_sync_collectives(val: bool) -> None:
global USE_SYNC_COLLECTIVES
USE_SYNC_COLLECTIVES = val


def get_use_sync_collectives() -> bool:
global USE_SYNC_COLLECTIVES
return USE_SYNC_COLLECTIVES or is_torchdynamo_compiling()


@contextmanager
# pyre-ignore
def torchrec_use_sync_collectives():
original_use_sync_collectives: bool = get_use_sync_collectives()
set_use_sync_collectives(True)
yield
set_use_sync_collectives(original_use_sync_collectives)


"""
Some commonly used notations for comm ops:
B - batch size
Expand Down Expand Up @@ -368,7 +382,7 @@ def alltoall_pooled(
codecs=codecs,
)

if is_torchdynamo_compiling():
if get_use_sync_collectives():
return NoWait(all2all_pooled_sync(group, a2ai, a2a_pooled_embs_tensor))

myreq = Request(group, device=a2a_pooled_embs_tensor.device)
Expand Down Expand Up @@ -458,7 +472,7 @@ def variable_batch_alltoall_pooled(
codecs=codecs,
)

if is_torchdynamo_compiling():
if get_use_sync_collectives():
return NoWait(
variable_batch_all2all_pooled_sync(group, a2ai, a2a_pooled_embs_tensor)
)
Expand Down Expand Up @@ -602,7 +616,7 @@ def alltoall_sequence(
)
# sequence of embeddings, bags are definitely non-uniform

if is_torchdynamo_compiling():
if get_use_sync_collectives():
return NoWait(all2all_sequence_sync(group, a2ai, a2a_sequence_embs_tensor))

myreq = Request(group, device=a2a_sequence_embs_tensor.device)
Expand Down Expand Up @@ -752,7 +766,7 @@ def alltoallv(
codecs=codecs,
)

if is_torchdynamo_compiling():
if get_use_sync_collectives():
return NoWait(all2allv_sync(group, a2ai, inputs))

myreq = Request(group, device=inputs[0].device)
Expand Down Expand Up @@ -829,7 +843,7 @@ def reduce_scatter_pooled(
input_sizes=[tensor.size() for tensor in inputs], codecs=codecs
)

if is_torchdynamo_compiling():
if get_use_sync_collectives():
return NoWait(reduce_scatter_sync(group, rsi, *inputs))

myreq = Request(group, device=inputs[0].device)
Expand Down Expand Up @@ -889,7 +903,7 @@ def reduce_scatter_base_pooled(

rsi = ReduceScatterBaseInfo(input_sizes=input.size(), codecs=codecs)

if is_torchdynamo_compiling():
if get_use_sync_collectives():
return NoWait(reduce_scatter_base_sync(group, rsi, input))

myreq = Request(group, device=input.device)
Expand Down Expand Up @@ -948,7 +962,7 @@ def all_gather_base_pooled(
if dist.get_world_size(group) <= 1:
return NoWait(input)

if is_torchdynamo_compiling():
if get_use_sync_collectives():
return NoWait(all_gather_base_sync(group, agi, input))

myreq = Request(group, device=input.device)
Expand Down Expand Up @@ -1024,7 +1038,7 @@ def reduce_scatter_v_pooled(
codecs=codecs,
)

if is_torchdynamo_compiling():
if get_use_sync_collectives():
return NoWait(reduce_scatter_v_sync(group, rsvi, input))

myreq = Request(group, device=input.device)
Expand Down Expand Up @@ -1127,7 +1141,7 @@ def reduce_scatter_v_per_feature_pooled(
codecs=codecs,
)

if is_torchdynamo_compiling():
if get_use_sync_collectives():
return NoWait(reduce_scatter_v_sync(group, rsvi, input))

myreq = Request(group, device=input.device)
Expand Down
3 changes: 3 additions & 0 deletions torchrec/distributed/tests/test_pt2_multiprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import fbgemm_gpu.sparse_ops # noqa: F401, E402

import torch
import torchrec
from hypothesis import given, settings, strategies as st, Verbosity
from torchrec.distributed.embedding import EmbeddingCollectionSharder
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
Expand Down Expand Up @@ -331,7 +332,9 @@ def _test_compile_rank_fn(
kjt = local_model_input.idlist_features
kjt_ft = kjt_for_pt2_tracing(kjt, convert_to_vb=convert_to_vb)

torchrec.distributed.comm_ops.set_use_sync_collectives(True)
dmp.train(True)

eager_out = dmp(kjt_ft)

if torch_compile_backend is None:
Expand Down

0 comments on commit 6e8d374

Please sign in to comment.