diff --git a/torchtitan/experiments/compiler_toolkit/common_utils.py b/torchtitan/experiments/compiler_toolkit/common_utils.py index df997aabe9..d08dcdd111 100644 --- a/torchtitan/experiments/compiler_toolkit/common_utils.py +++ b/torchtitan/experiments/compiler_toolkit/common_utils.py @@ -6,6 +6,10 @@ from contextlib import contextmanager +import torch +from torch.distributed.tensor import DTensor, Replicate +from torch.utils._pytree import tree_map + from torchtitan.config import JobConfig @@ -18,3 +22,21 @@ def disable_compile(job_config: JobConfig): yield finally: job_config.compile.enable = original_value + + +def parallelize_inputs(world_mesh, args, kwargs): + def to_dtensor(tensor): + if isinstance(tensor, torch.Tensor): + return DTensor.from_local(tensor, world_mesh["tp"], [Replicate()]) + return tensor + + dt_args = tree_map(to_dtensor, args) + + # TODO: When using flex_attention, BlockMask would show up in kwargs, + # and it's unclear how to convert it to DTensor. If I use to_dtensor, + # it would fail with Dynamo Error: P2011360347 + # dt_kwargs = tree_map(to_dtensor, kwargs) + + dt_kwargs = kwargs + + return dt_args, dt_kwargs diff --git a/torchtitan/experiments/compiler_toolkit/deepseek_v3/parallelize.py b/torchtitan/experiments/compiler_toolkit/deepseek_v3/parallelize.py index a859415c1c..0253267567 100644 --- a/torchtitan/experiments/compiler_toolkit/deepseek_v3/parallelize.py +++ b/torchtitan/experiments/compiler_toolkit/deepseek_v3/parallelize.py @@ -11,14 +11,16 @@ from torch._functorch.aot_autograd import aot_compile_joint_with_descriptors from torch._guards import tracing -from torch.distributed.tensor import DTensor, Replicate +from torch.distributed.tensor import DTensor from torch.fx.traceback import annotate_fn -from torch.utils._pytree import tree_map from torchtitan.config import JobConfig from torchtitan.distributed import ParallelDims from torchtitan.distributed.expert_parallel import ExpertParallel -from torchtitan.experiments.compiler_toolkit.common_utils import disable_compile +from torchtitan.experiments.compiler_toolkit.common_utils import ( + disable_compile, + parallelize_inputs, +) from torchtitan.experiments.compiler_toolkit.graph_utils import ( CompiledModule, @@ -75,18 +77,6 @@ def wrapper_fn(args, kwargs): return wrapper_fn -def parallelize_inputs(world_mesh, args, kwargs): - def to_dtensor(tensor): - if isinstance(tensor, torch.Tensor): - return DTensor.from_local(tensor, world_mesh["tp"], [Replicate()]) - return tensor - - dt_args = tree_map(to_dtensor, args) - dt_kwargs = tree_map(to_dtensor, kwargs) - - return dt_args, dt_kwargs - - def annotate_model() -> None: # annotate the MoE with dispatch, compute and combine ExpertParallel._token_dispatch = annotate_fn({"EP": "dispatch"})( diff --git a/torchtitan/experiments/compiler_toolkit/llama3/parallelize.py b/torchtitan/experiments/compiler_toolkit/llama3/parallelize.py index b8d00db39e..61633c3cc9 100644 --- a/torchtitan/experiments/compiler_toolkit/llama3/parallelize.py +++ b/torchtitan/experiments/compiler_toolkit/llama3/parallelize.py @@ -9,13 +9,15 @@ from torch._functorch.aot_autograd import aot_compile_joint_with_descriptors from torch._guards import tracing -from torch.distributed.tensor import DTensor, Replicate +from torch.distributed.tensor import DTensor from torch.fx.passes.regional_inductor import regional_inductor -from torch.utils._pytree import tree_map from torchtitan.config import JobConfig from torchtitan.distributed import ParallelDims -from torchtitan.experiments.compiler_toolkit.common_utils import disable_compile +from torchtitan.experiments.compiler_toolkit.common_utils import ( + disable_compile, + parallelize_inputs, +) from torchtitan.experiments.compiler_toolkit.graph_utils import ( CompiledModule, @@ -78,18 +80,6 @@ def wrapper_fn(args, kwargs): return wrapper_fn -def parallelize_inputs(world_mesh, args, kwargs): - def to_dtensor(tensor): - if isinstance(tensor, torch.Tensor): - return DTensor.from_local(tensor, world_mesh["tp"], [Replicate()]) - return tensor - - dt_args = tree_map(to_dtensor, args) - dt_kwargs = tree_map(to_dtensor, kwargs) - - return dt_args, dt_kwargs - - def annotate_model() -> None: from torch.fx.traceback import annotate_fn from torchtitan.models.attention import FlexAttentionWrapper