From 70e592073a37fb57c25ce49be82d54c02a477ec9 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Fri, 12 Sep 2025 11:24:55 -0700 Subject: [PATCH 1/2] [CP][RFC] Enable FlexCP for llama3 with parallelize_module Similar to https://github.com/pytorch/torchtitan/pull/1696, but this PR uses parallel_module similar to TP/SP. This PR also requires https://github.com/pytorch/pytorch/pull/162542 --- torchtitan/models/attention.py | 34 +++++++++++++++---- torchtitan/models/llama3/infra/parallelize.py | 15 ++++++-- torchtitan/models/llama3/model/args.py | 7 ---- 3 files changed, 41 insertions(+), 15 deletions(-) diff --git a/torchtitan/models/attention.py b/torchtitan/models/attention.py index f66361a6d2..47bc86fc14 100644 --- a/torchtitan/models/attention.py +++ b/torchtitan/models/attention.py @@ -15,6 +15,7 @@ from torch.nn.attention import sdpa_kernel, SDPBackend from torch.nn.attention.flex_attention import ( _mask_mod_signature, + AuxOutput, BlockMask, create_block_mask, flex_attention, @@ -28,6 +29,26 @@ FLEX_ATTN_MASK_T = tuple[str, int | None] +class FlexAttentionWrapper(torch.nn.Module): + _flex_attn: ClassVar[Callable] = torch.compile( + flex_attention, mode="max-autotune-no-cudagraphs" + ) + + def __init__(self) -> None: + super().__init__() + + def forward(self, *args: object, **kwargs: object) -> [ + torch.Tensor | tuple[torch.Tensor, torch.Tensor], + tuple[torch.Tensor, AuxOutput], + ]: + # 1. _flex_attn has to be a class variable, otherwise there will + # be multiple complied flex_attention, which can be slow. + # 2. `self._flex_attn` is not correct, `self` will be passed in + # as the first argument, which will cause an error. + # `FlexAttentionWrapper._flex_attn` is correct. + return FlexAttentionWrapper._flex_attn(*args, **kwargs) + + class FlexAttention(torch.nn.Module): """FlexAttention module that uses torch.nn.attention.flex_attention. @@ -46,11 +67,6 @@ class FlexAttention(torch.nn.Module): to the keys within the same block. """ - # We registered flex_attention related attributes as class variables as we - # need to amortize the cost of compilation. - flex_attn: ClassVar[Callable] = torch.compile( - flex_attention, mode="max-autotune-no-cudagraphs" - ) compiled_create_block_mask: ClassVar[Callable] = torch.compile(create_block_mask) used_attn_mask_types: ClassVar[set[FLEX_ATTN_MASK_T]] = set() # Attention mask type to the created BlockMask. @@ -71,6 +87,7 @@ def __init__( raise ValueError(f"Unrecognized attn_mask_type {attn_mask_type}.") self.attn_mask_type = attn_mask_type self.fixed_block_size = fixed_block_size + self.attention_fn_wrapper = FlexAttentionWrapper() FlexAttention.used_attn_mask_types.add(self.mask_key) @@ -86,7 +103,7 @@ def forward( scale: float | None = None, ) -> torch.Tensor: block_mask = FlexAttention.block_masks[self.mask_key] - return FlexAttention.flex_attn(q, k, v, block_mask=block_mask, scale=scale) + return self.attention_fn_wrapper(q, k, v, block_mask=block_mask, scale=scale) @staticmethod def _get_causal_mask_mod() -> _mask_mod_signature: @@ -251,6 +268,11 @@ def init_attention_mask( # while we continue debugging accuracy issues. However, we want to evaluate # the user experience with CP enabled. if cp_mesh is not None: + from torch.distributed.tensor.experimental._attention import _DispatchMode + + torch.distributed.tensor.experimental._attention._dispatch_mode = ( + _DispatchMode.MODULE_WRAPPER + ) FlexAttention.compiled_create_block_mask = functools.partial( create_cp_block_mask, device_mesh=cp_mesh ) diff --git a/torchtitan/models/llama3/infra/parallelize.py b/torchtitan/models/llama3/infra/parallelize.py index 7d8aa76f0d..b24ab3e9ef 100644 --- a/torchtitan/models/llama3/infra/parallelize.py +++ b/torchtitan/models/llama3/infra/parallelize.py @@ -14,6 +14,8 @@ from torch.distributed.device_mesh import DeviceMesh from torch.distributed.fsdp import CPUOffloadPolicy, fully_shard, MixedPrecisionPolicy from torch.distributed.tensor import Replicate, Shard + +from torch.distributed.tensor.experimental._attention import _ContextParallel from torch.distributed.tensor.parallel import ( ColwiseParallel, parallelize_module, @@ -67,8 +69,6 @@ def parallelize_llama( """ use_flex_attn = getattr(model.model_args, "use_flex_attn", False) - if job_config.parallelism.context_parallel_degree > 1 and use_flex_attn: - raise NotImplementedError("CP support for FlexAttention is still in progress.") if parallel_dims.tp_enabled: enable_float8_linear = "float8" in job_config.model.converters @@ -90,6 +90,17 @@ def parallelize_llama( ) maybe_enable_async_tp(job_config, world_mesh["tp"]) + if parallel_dims.cp_enabled: + for block in model.layers.values(): + parallelize_module( + module=block.attention.sdpa.attention_fn_wrapper, + device_mesh=world_mesh["cp"], + parallelize_plan=_ContextParallel( + seq_dim=2, + attention_type=_ContextParallel.AttentionType.FLEX, + ), + ) + model_compile_enabled = ( job_config.compile.enable and "model" in job_config.compile.components ) diff --git a/torchtitan/models/llama3/model/args.py b/torchtitan/models/llama3/model/args.py index e2f698f8b1..1723e03462 100644 --- a/torchtitan/models/llama3/model/args.py +++ b/torchtitan/models/llama3/model/args.py @@ -45,13 +45,6 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None: ) self.max_seq_len = seq_len - if job_config.parallelism.context_parallel_degree > 1 and self.use_flex_attn: - raise NotImplementedError( - "CP support for FlexAttention is still in progress." - ) - - self.max_seq_len = seq_len - def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, int]: nparams = sum(p.numel() for p in model.parameters()) nparams_embedding = sum( From a4b4ef1119fcbce9256dbfa619c1edc46666f39f Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Tue, 16 Sep 2025 16:18:05 -0700 Subject: [PATCH 2/2] Enable SDPA for all models --- torchtitan/components/metrics.py | 5 ++++- .../experiments/llama4/infra/parallelize.py | 6 +----- torchtitan/experiments/llama4/model/args.py | 5 ----- torchtitan/models/attention.py | 16 +++++++++++++--- .../models/deepseek_v3/infra/parallelize.py | 6 +----- torchtitan/models/deepseek_v3/model/args.py | 5 ----- torchtitan/models/llama3/infra/parallelize.py | 4 +--- 7 files changed, 20 insertions(+), 27 deletions(-) diff --git a/torchtitan/components/metrics.py b/torchtitan/components/metrics.py index 26c6f2ae2d..878b2b315d 100644 --- a/torchtitan/components/metrics.py +++ b/torchtitan/components/metrics.py @@ -163,10 +163,12 @@ def close(self) -> None: if self.wandb.run is not None: self.wandb.finish() + class LoggerContainer(BaseLogger): """Container to call all loggers enabled in the job config.""" + def __init__(self) -> None: - self._loggers : list[BaseLogger] = [] + self._loggers: list[BaseLogger] = [] def add_logger(self, logger_instance: BaseLogger) -> None: self._loggers.append(logger_instance) @@ -183,6 +185,7 @@ def close(self) -> None: for logger_instance in self._loggers: logger_instance.close() + def ensure_pp_loss_visible( parallel_dims: ParallelDims, job_config: JobConfig, color: Color ) -> None: diff --git a/torchtitan/experiments/llama4/infra/parallelize.py b/torchtitan/experiments/llama4/infra/parallelize.py index 7efc04b784..25dcb308f3 100644 --- a/torchtitan/experiments/llama4/infra/parallelize.py +++ b/torchtitan/experiments/llama4/infra/parallelize.py @@ -70,10 +70,6 @@ def parallelize_llama( ({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}). """ - use_flex_attn = getattr(model.model_args, "use_flex_attn", False) - if job_config.parallelism.context_parallel_degree > 1 and use_flex_attn: - raise NotImplementedError("CP support for FlexAttention is still in progress.") - if parallel_dims.tp_enabled: enable_float8_linear = "float8" in job_config.model.converters float8_is_rowwise = job_config.float8.recipe_name in ( @@ -117,7 +113,7 @@ def parallelize_llama( model, job_config.activation_checkpoint, model_compile_enabled=model_compile_enabled, - use_flex_attn=use_flex_attn, + use_flex_attn=getattr(model.model_args, "use_flex_attn", False), save_list=_save_list, ) diff --git a/torchtitan/experiments/llama4/model/args.py b/torchtitan/experiments/llama4/model/args.py index 272936a153..370b95ef86 100644 --- a/torchtitan/experiments/llama4/model/args.py +++ b/torchtitan/experiments/llama4/model/args.py @@ -66,11 +66,6 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None: ) self.moe_args.use_grouped_mm = False - if job_config.parallelism.context_parallel_degree > 1 and self.use_flex_attn: - raise NotImplementedError( - "CP support for FlexAttention is still in progress." - ) - def get_nparams_and_flops( self, model: nn.Module, seq_len: int ) -> tuple[int, float]: diff --git a/torchtitan/models/attention.py b/torchtitan/models/attention.py index 47bc86fc14..e5e5172f76 100644 --- a/torchtitan/models/attention.py +++ b/torchtitan/models/attention.py @@ -37,7 +37,9 @@ class FlexAttentionWrapper(torch.nn.Module): def __init__(self) -> None: super().__init__() - def forward(self, *args: object, **kwargs: object) -> [ + def forward( + self, *args: object, **kwargs: object + ) -> [ torch.Tensor | tuple[torch.Tensor, torch.Tensor], tuple[torch.Tensor, AuxOutput], ]: @@ -49,6 +51,14 @@ def forward(self, *args: object, **kwargs: object) -> [ return FlexAttentionWrapper._flex_attn(*args, **kwargs) +class SDPAWrapper(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, *args: object, **kwargs: object) -> torch.Tensor: + return F.scaled_dot_product_attention(*args, **kwargs) + + class FlexAttention(torch.nn.Module): """FlexAttention module that uses torch.nn.attention.flex_attention. @@ -214,6 +224,7 @@ def __init__(self, attn_mask_type: str) -> None: ) ScaledDotProductAttention._init_backend() + self.attention_fn_wrapper = SDPAWrapper() @classmethod def _init_backend(cls) -> None: @@ -238,7 +249,7 @@ def forward( ) -> torch.Tensor: assert self.backends, "SDPA Backends should not be empty." with sdpa_kernel(self.backends, set_priority=True): - return F.scaled_dot_product_attention(q, k, v, is_causal=True, scale=scale) + return self.attention_fn_wrapper(q, k, v, is_causal=True, scale=scale) def build_attention( @@ -263,7 +274,6 @@ def init_attention_mask( eos_id: int | None, cp_mesh: torch.distributed.device_mesh.DeviceMesh | None = None, ) -> None: - # This is not functional yet because we currently gate the use of Flex + CP # while we continue debugging accuracy issues. However, we want to evaluate # the user experience with CP enabled. diff --git a/torchtitan/models/deepseek_v3/infra/parallelize.py b/torchtitan/models/deepseek_v3/infra/parallelize.py index 7182b1fca3..aa0944ca55 100644 --- a/torchtitan/models/deepseek_v3/infra/parallelize.py +++ b/torchtitan/models/deepseek_v3/infra/parallelize.py @@ -61,10 +61,6 @@ def parallelize_deepseekv3( ({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}). """ - use_flex_attn = getattr(model.model_args, "use_flex_attn", False) - if job_config.parallelism.context_parallel_degree > 1 and use_flex_attn: - raise NotImplementedError("CP support for FlexAttention is still in progress.") - if parallel_dims.tp_enabled: enable_float8_linear = "float8" in job_config.model.converters float8_is_rowwise = job_config.float8.recipe_name in ( @@ -111,7 +107,7 @@ def parallelize_deepseekv3( model, job_config.activation_checkpoint, model_compile_enabled=model_compile_enabled, - use_flex_attn=use_flex_attn, + use_flex_attn=getattr(model.model_args, "use_flex_attn", False), save_list=_save_list, ) diff --git a/torchtitan/models/deepseek_v3/model/args.py b/torchtitan/models/deepseek_v3/model/args.py index d6afedfa34..bb3dc56bdc 100644 --- a/torchtitan/models/deepseek_v3/model/args.py +++ b/torchtitan/models/deepseek_v3/model/args.py @@ -100,11 +100,6 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None: ) self.moe_args.use_grouped_mm = False - if job_config.parallelism.context_parallel_degree > 1 and self.use_flex_attn: - raise NotImplementedError( - "CP support for FlexAttention is still in progress." - ) - def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, int]: """ Adopted from llama4 implementation. diff --git a/torchtitan/models/llama3/infra/parallelize.py b/torchtitan/models/llama3/infra/parallelize.py index b24ab3e9ef..0b24e4162b 100644 --- a/torchtitan/models/llama3/infra/parallelize.py +++ b/torchtitan/models/llama3/infra/parallelize.py @@ -68,8 +68,6 @@ def parallelize_llama( ({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}). """ - use_flex_attn = getattr(model.model_args, "use_flex_attn", False) - if parallel_dims.tp_enabled: enable_float8_linear = "float8" in job_config.model.converters float8_is_rowwise = job_config.float8.recipe_name in ( @@ -110,7 +108,7 @@ def parallelize_llama( model, job_config.activation_checkpoint, model_compile_enabled=model_compile_enabled, - use_flex_attn=use_flex_attn, + use_flex_attn=getattr(model.model_args, "use_flex_attn", False), save_list=_save_list, )