Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 3 additions & 8 deletions torchtitan/components/float8.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# Note: Performance
# Float8 experimental is intended to be ran under `torch.compile`` for competitive performance

import torch
import torch.nn as nn

from torchtitan.config_manager import JobConfig
Expand All @@ -23,19 +22,15 @@
register_model_converter,
)
from torchtitan.tools.logging import logger


def _is_sm89_or_later():
# Float8 is only supported on SM89 or later (H100+ GPUs)
return torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)
from torchtitan.tools.utils import has_cuda_capability


class Float8Converter(ModelConverter):
def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
self.enabled = False

float8_config = job_config.float8
if not _is_sm89_or_later():
if not has_cuda_capability(8, 9):
logger.warning(
"Failed to swap to Float8Linear because float8 is only supported on SM89 or later",
)
Expand Down Expand Up @@ -73,7 +68,7 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
)

else:
# Mutates the model inplace replacing instances of torch.nn.Linear with Float8Linear
# Mutates the model inplace replacing instances of nn.Linear with Float8Linear
enable_fsdp_float8_all_gather = (
parallel_dims.dp_shard_enabled
and float8_config.enable_fsdp_float8_all_gather
Expand Down
6 changes: 6 additions & 0 deletions torchtitan/experiments/llama4/model/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from torchtitan.protocols.train_spec import BaseModelArgs
from torchtitan.tools.logging import logger
from torchtitan.tools.utils import has_cuda_capability


@dataclass
Expand Down Expand Up @@ -54,6 +55,11 @@ def update_from_config(self, job_config: JobConfig, tokenizer: Tokenizer) -> Non
self.vocab_size = tokenizer.n_words
self.max_seq_len = job_config.training.seq_len
self.use_flex_attn = job_config.model.use_flex_attn
if self.use_grouped_mm and not has_cuda_capability(9, 0):
logger.warning(
"Failed to use grouped mm, which is only supported on SM90 or later",
)
self.use_grouped_mm = False

def get_nparams_and_flops(
self, model: nn.Module, seq_len: int
Expand Down
10 changes: 8 additions & 2 deletions torchtitan/experiments/llama4/model/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ def forward(
# fall back to regular bmm between 3D tensors
assert x.dim() == 3

assert (
x.dtype == self.w1.dtype == self.w2.dtype == self.w3.dtype == torch.bfloat16
), "torch._grouped_mm only supports bf16 dtypes"
h = F.silu(torch._grouped_mm(x, self.w1, offs=offsets))
h = h * torch._grouped_mm(x, self.w3, offs=offsets)
out = torch._grouped_mm(h, self.w2, offs=offsets)
Expand Down Expand Up @@ -246,14 +249,17 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
ALIGN_SIZE_M = 16

with torch.no_grad():
permuted_indices, m_sizes = generate_permute_indices(
(
permuted_indices,
num_local_tokens_per_expert,
_,
) = generate_permute_indices(
num_local_tokens_per_expert,
self.experts.num_experts,
1,
token_indices.shape[0] + self.experts.num_experts * ALIGN_SIZE_M,
ALIGN_SIZE_M,
)
num_local_tokens_per_expert = m_sizes
token_indices = torch.vstack(
(token_indices, token_indices.new_zeros((dim)))
)
Expand Down
7 changes: 7 additions & 0 deletions torchtitan/tools/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,13 @@
from torchtitan.tools.logging import logger


def has_cuda_capability(major: int, minor: int) -> bool:
return torch.cuda.is_available() and torch.cuda.get_device_capability() >= (
major,
minor,
)


def get_device_info():
device_type = _get_available_device_type()
if device_type is None:
Expand Down