diff --git a/torchtitan/components/float8.py b/torchtitan/components/float8.py index b01c5063bc..a853256a55 100644 --- a/torchtitan/components/float8.py +++ b/torchtitan/components/float8.py @@ -13,7 +13,6 @@ # Note: Performance # Float8 experimental is intended to be ran under `torch.compile`` for competitive performance -import torch import torch.nn as nn from torchtitan.config_manager import JobConfig @@ -23,11 +22,7 @@ register_model_converter, ) from torchtitan.tools.logging import logger - - -def _is_sm89_or_later(): - # Float8 is only supported on SM89 or later (H100+ GPUs) - return torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9) +from torchtitan.tools.utils import has_cuda_capability class Float8Converter(ModelConverter): @@ -35,7 +30,7 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): self.enabled = False float8_config = job_config.float8 - if not _is_sm89_or_later(): + if not has_cuda_capability(8, 9): logger.warning( "Failed to swap to Float8Linear because float8 is only supported on SM89 or later", ) @@ -73,7 +68,7 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): ) else: - # Mutates the model inplace replacing instances of torch.nn.Linear with Float8Linear + # Mutates the model inplace replacing instances of nn.Linear with Float8Linear enable_fsdp_float8_all_gather = ( parallel_dims.dp_shard_enabled and float8_config.enable_fsdp_float8_all_gather diff --git a/torchtitan/experiments/llama4/model/args.py b/torchtitan/experiments/llama4/model/args.py index 7253b387cb..d87215db59 100644 --- a/torchtitan/experiments/llama4/model/args.py +++ b/torchtitan/experiments/llama4/model/args.py @@ -14,6 +14,7 @@ from torchtitan.protocols.train_spec import BaseModelArgs from torchtitan.tools.logging import logger +from torchtitan.tools.utils import has_cuda_capability @dataclass @@ -54,6 +55,11 @@ def update_from_config(self, job_config: JobConfig, tokenizer: Tokenizer) -> Non self.vocab_size = tokenizer.n_words self.max_seq_len = job_config.training.seq_len self.use_flex_attn = job_config.model.use_flex_attn + if self.use_grouped_mm and not has_cuda_capability(9, 0): + logger.warning( + "Failed to use grouped mm, which is only supported on SM90 or later", + ) + self.use_grouped_mm = False def get_nparams_and_flops( self, model: nn.Module, seq_len: int diff --git a/torchtitan/experiments/llama4/model/moe.py b/torchtitan/experiments/llama4/model/moe.py index 5c55422b8f..4e1758d973 100644 --- a/torchtitan/experiments/llama4/model/moe.py +++ b/torchtitan/experiments/llama4/model/moe.py @@ -79,6 +79,9 @@ def forward( # fall back to regular bmm between 3D tensors assert x.dim() == 3 + assert ( + x.dtype == self.w1.dtype == self.w2.dtype == self.w3.dtype == torch.bfloat16 + ), "torch._grouped_mm only supports bf16 dtypes" h = F.silu(torch._grouped_mm(x, self.w1, offs=offsets)) h = h * torch._grouped_mm(x, self.w3, offs=offsets) out = torch._grouped_mm(h, self.w2, offs=offsets) @@ -246,14 +249,17 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ALIGN_SIZE_M = 16 with torch.no_grad(): - permuted_indices, m_sizes = generate_permute_indices( + ( + permuted_indices, + num_local_tokens_per_expert, + _, + ) = generate_permute_indices( num_local_tokens_per_expert, self.experts.num_experts, 1, token_indices.shape[0] + self.experts.num_experts * ALIGN_SIZE_M, ALIGN_SIZE_M, ) - num_local_tokens_per_expert = m_sizes token_indices = torch.vstack( (token_indices, token_indices.new_zeros((dim))) ) diff --git a/torchtitan/tools/utils.py b/torchtitan/tools/utils.py index bc6a570d1d..dc82cf2657 100644 --- a/torchtitan/tools/utils.py +++ b/torchtitan/tools/utils.py @@ -16,6 +16,13 @@ from torchtitan.tools.logging import logger +def has_cuda_capability(major: int, minor: int) -> bool: + return torch.cuda.is_available() and torch.cuda.get_device_capability() >= ( + major, + minor, + ) + + def get_device_info(): device_type = _get_available_device_type() if device_type is None: