diff --git a/torchtitan/experiments/__init__.py b/torchtitan/experiments/__init__.py index 000823803..4596574d9 100644 --- a/torchtitan/experiments/__init__.py +++ b/torchtitan/experiments/__init__.py @@ -5,5 +5,13 @@ # LICENSE file in the root directory of this source tree. _supported_experiments = frozenset( - ["flux", "llama4", "qwen3", "simple_fsdp.llama3", "simple_fsdp.deepseek_v3", "vlm"] + [ + "flux", + "gpt_oss", + "llama4", + "qwen3", + "simple_fsdp.llama3", + "simple_fsdp.deepseek_v3", + "vlm", + ] ) diff --git a/torchtitan/experiments/gpt_oss/README.md b/torchtitan/experiments/gpt_oss/README.md new file mode 100644 index 000000000..613e77003 --- /dev/null +++ b/torchtitan/experiments/gpt_oss/README.md @@ -0,0 +1,19 @@ +# gpt-oss Model in torchtitan + +## Quick Start +```bash +CONFIG_FILE="./torchtitan/experiments/gpt_oss/train_configs/debug_model.toml" ./run_train.sh +``` + +## Supported Features +- FSDP/HSDP, TP, EP, ETP +- Grouped matrix multiplication for efficient computation +- SwiGLU activation +- Multi-head attention with sliding window mask and attention sink + + +## TODO +1. More parallelism support: CP, PP +2. Conversion between HF weights (StateDictAdapter) +3. Forward parity verification +4. CI support diff --git a/torchtitan/experiments/gpt_oss/__init__.py b/torchtitan/experiments/gpt_oss/__init__.py new file mode 100644 index 000000000..4762ffc88 --- /dev/null +++ b/torchtitan/experiments/gpt_oss/__init__.py @@ -0,0 +1,91 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from torchtitan.components.loss import build_cross_entropy_loss +from torchtitan.components.lr_scheduler import build_lr_schedulers +from torchtitan.components.optimizer import build_optimizers_with_moe_load_balancing +from torchtitan.components.tokenizer import build_hf_tokenizer +from torchtitan.datasets.hf_datasets import build_hf_dataloader +from torchtitan.models.moe import MoEArgs + +from torchtitan.protocols.train_spec import TrainSpec + +from .infra.parallelize import parallelize_gptoss +from .model.args import GptOssModelArgs +from .model.model import GptOssModel + +__all__ = [ + "parallelize_gptoss", + "GptOssModelArgs", + "GptOssModel", + "gptoss_configs", +] + + +gptoss_configs = { + "debugmodel": GptOssModelArgs( + dim=256, + n_layers=4, + moe_args=MoEArgs( + num_experts=8, + num_shared_experts=0, + score_func="softmax", + route_norm=False, + route_scale=1.0, + score_before_experts=False, + top_k=4, + use_grouped_mm=True, + load_balance_coeff=1e-3, + ), + attn_mask_type="causal", + ), + "20b": GptOssModelArgs( + n_layers=24, + moe_args=MoEArgs( + num_experts=32, + num_shared_experts=0, + score_func="softmax", + route_norm=False, + route_scale=1.0, + score_before_experts=False, + top_k=4, + use_grouped_mm=True, + load_balance_coeff=1e-3, + ), + ), + "120b": GptOssModelArgs( + n_layers=36, + moe_args=MoEArgs( + num_experts=128, + num_shared_experts=0, + score_func="softmax", + route_norm=False, + route_scale=1.0, + score_before_experts=False, + top_k=4, + use_grouped_mm=True, + load_balance_coeff=1e-3, + ), + ), +} + + +def get_train_spec() -> TrainSpec: + return TrainSpec( + name="gpt_oss", + model_cls=GptOssModel, + model_args=gptoss_configs, + parallelize_fn=parallelize_gptoss, + pipelining_fn=None, + build_optimizers_fn=build_optimizers_with_moe_load_balancing, + build_lr_schedulers_fn=build_lr_schedulers, + build_dataloader_fn=build_hf_dataloader, + build_tokenizer_fn=build_hf_tokenizer, + build_loss_fn=build_cross_entropy_loss, + ) diff --git a/torchtitan/experiments/gpt_oss/infra/expert_parallel.py b/torchtitan/experiments/gpt_oss/infra/expert_parallel.py new file mode 100644 index 000000000..1d1c9e144 --- /dev/null +++ b/torchtitan/experiments/gpt_oss/infra/expert_parallel.py @@ -0,0 +1,194 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Callable + +import torch +import torch.nn as nn +from torch.distributed.tensor import ( + DeviceMesh, + distribute_module, + distribute_tensor, + DTensor, + Replicate, + Shard, +) +from torch.distributed.tensor.parallel import ParallelStyle +from torchtitan.distributed.expert_parallel import ExpertParallel + + +# implementation of Tensor Parallel for the GroupedExperts in MoE +class TensorParallel(ParallelStyle): + def _partition_fn(self, name, module, device_mesh): + module.register_parameter( + "mlp1_weight", + nn.Parameter( + distribute_tensor(module.mlp1_weight, device_mesh, [Shard(2)]) + ), + ) # Column-wise sharding + module.register_parameter( + "mlp1_bias", + nn.Parameter(distribute_tensor(module.mlp1_bias, device_mesh, [Shard(1)])), + ) # Column-wise sharding + module.register_parameter( + "mlp2_weight", + nn.Parameter( + distribute_tensor(module.mlp2_weight, device_mesh, [Shard(1)]) + ), + ) # Row-wise sharding + module.register_parameter( + "mlp2_bias", + nn.Parameter( + distribute_tensor(module.mlp2_bias, device_mesh, [Replicate()]) + ), + ) # Replicate + + def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: + return distribute_module( + module, + device_mesh, + self._partition_fn, + ) + + +# This class is for dp2ep with TP (without TP we can just use ExpertParallel) +class ExpertTensorParallel(ExpertParallel): + def __init__( + self, + tp_mesh: DeviceMesh, + ep_mesh: DeviceMesh, + ): + super().__init__() + # TODO: has to pass in the meshes in addition to the [ep, tp] device_mesh, + # as DeviceMesh doesn't support slicing from a submesh. + self.tp_mesh = tp_mesh + self.ep_mesh = ep_mesh + + def _token_dispatch(self, mod, inputs, device_mesh): + # token dispatch happens on the EP mesh, whereas device_mesh is [ep, tp] mesh + return super()._token_dispatch(mod, inputs, self.ep_mesh) + + def _partition_fn_2d(self, name, mod, ep_tp_mesh): + mod.register_parameter( + "mlp1_weight", + nn.Parameter( + distribute_tensor(mod.mlp1_weight, ep_tp_mesh, [Shard(0), Shard(2)]) + ), + ) # Column-wise sharding + mod.register_parameter( + "mlp1_bias", + nn.Parameter( + distribute_tensor(mod.mlp1_bias, ep_tp_mesh, [Shard(0), Shard(1)]) + ), + ) # Column-wise sharding + mod.register_parameter( + "mlp2_weight", + nn.Parameter( + distribute_tensor(mod.mlp2_weight, ep_tp_mesh, [Shard(0), Shard(1)]) + ), + ) # Row-wise sharding + mod.register_parameter( + "mlp2_bias", + nn.Parameter( + distribute_tensor(mod.mlp2_bias, ep_tp_mesh, [Shard(0), Replicate()]) + ), + ) # Replicate + + def _token_combine(self, mod, routed_output, device_mesh): + # token combine happens on the EP mesh, whereas device_mesh is [ep, tp] mesh + return super()._token_combine(mod, routed_output, self.ep_mesh) + + def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: + return distribute_module( + module, + device_mesh, + partition_fn=self._partition_fn_2d, + input_fn=self._token_dispatch, + output_fn=self._token_combine, + ) + + +# TODO(jianiw): This need to be merged with expert_parallel +def expert_parallel(func: Callable) -> Callable: + """ + This is a wrapper applied to the GroupedExperts computation, serving + the following three purposes: + 1. Convert parameters from DTensors to plain Tensors, to work with + dynamic-shape inputs which cannot be easily expressed as DTensors. + 2. In Expert Parallel, apply the generate_permute_indices kernel to + permute the inputs to be ordered by local experts (see the _token_dispatch + function in ExpertParallel) and permute the outputs back. + 3. In order to use torch._grouped_mm, we need to make sure the number of + tokens each expert gets is a multiple of ALIGN_SIZE_M. The generate_permute_indices + kernel also helps achieve this via padding, without incurring synchronization + between device and host. Note that this will create side effects when wrapping + the for-loop implementation of GroupedExperts, as it does not need padding. + + Among the above: + 1 and 2 are needed only when expert_parallel_degree > 1. + 3 is needed even for single-device computation. + 2 can be moved to ExpertParallel _token_dispatch if not coupled with 3. + """ + + def wrapper( + mlp1_weight: torch.Tensor, + mlp1_bias: torch.Tensor, + mlp2_weight: torch.Tensor, + mlp2_bias: torch.Tensor, + swiglu_limit: float, + x: torch.Tensor, + num_tokens_per_expert: torch.Tensor | None = None, + ) -> torch.Tensor: + if isinstance(mlp1_weight, DTensor): + mlp1_weight = mlp1_weight.to_local() + mlp1_bias = mlp1_bias.to_local() + mlp2_weight = mlp2_weight.to_local() + mlp2_bias = mlp2_bias.to_local() + + if num_tokens_per_expert is not None: + from torchtitan.experiments.kernels.moe.indices import ( + generate_permute_indices, + ) + + experts_per_ep_rank = mlp1_weight.shape[0] + num_ep_ranks = num_tokens_per_expert.shape[0] // experts_per_ep_rank + + ALIGN_SIZE_M = 16 + with torch.no_grad(): + ( + permuted_indices, + num_tokens_per_expert, + _, # offsets, + ) = generate_permute_indices( + num_tokens_per_expert, + experts_per_ep_rank, + num_ep_ranks, + x.shape[0] + experts_per_ep_rank * ALIGN_SIZE_M, + ALIGN_SIZE_M, + ) + + x = torch.vstack((x, x.new_zeros((x.shape[-1])))) + input_shape = x.shape + x = x[permuted_indices, :] + + out = func( + mlp1_weight, + mlp1_bias, + mlp2_weight, + mlp2_bias, + swiglu_limit, + x, + num_tokens_per_expert, + ) + + if num_tokens_per_expert is not None: + out_unpermuted = out.new_empty(input_shape) + out_unpermuted[permuted_indices, :] = out + out = out_unpermuted[:-1] + + return out + + return wrapper diff --git a/torchtitan/experiments/gpt_oss/infra/parallelize.py b/torchtitan/experiments/gpt_oss/infra/parallelize.py new file mode 100644 index 000000000..bbde1a751 --- /dev/null +++ b/torchtitan/experiments/gpt_oss/infra/parallelize.py @@ -0,0 +1,323 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +from torch.distributed.device_mesh import DeviceMesh + +from torch.distributed.tensor import distribute_tensor, Partial, Replicate, Shard +from torch.distributed.tensor.parallel import ( + ColwiseParallel, + parallelize_module, + PrepareModuleInput, + PrepareModuleInputOutput, + PrepareModuleOutput, + RowwiseParallel, + SequenceParallel, +) +from torchtitan.config import TORCH_DTYPE_MAP +from torchtitan.config.job_config import JobConfig +from torchtitan.distributed import NoParallel, ParallelDims +from torchtitan.distributed.expert_parallel import ( + ExpertParallel, + ReordererSequenceParallel, +) +from torchtitan.experiments.llama4.infra.parallelize import apply_fsdp +from torchtitan.models.llama3.infra.parallelize import apply_ac, apply_ddp +from torchtitan.tools.logging import logger + +from .expert_parallel import ExpertTensorParallel, TensorParallel + + +# for selective op activation checkpointing +_op_sac_save_list = { + torch.ops.aten.mm.default, + torch.ops.aten._scaled_dot_product_efficient_attention.default, + torch.ops.aten._scaled_dot_product_flash_attention.default, + torch.ops._c10d_functional.reduce_scatter_tensor.default, + torch.ops._c10d_functional.all_to_all_single.default, + # for low precision training, it's useful to always save + # the result of max, since the absolute maximum is + # used to compute the scaling factor for quantization. + torch.ops.aten.max.default, + torch._higher_order_ops.flex_attention, +} + +# Adapted from llama4/infra/parallelize.py +def parallelize_gptoss( + model: nn.Module, + parallel_dims: ParallelDims, + job_config: JobConfig, +): + world_mesh = parallel_dims.world_mesh + + assert ( + job_config.training.seq_len % parallel_dims.seq_len_divisor == 0 + ), f""" + Sequence length {job_config.training.seq_len} must be divisible by the product of TP degree + ({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}). + """ + + use_flex_attn = getattr(model.model_args, "use_flex_attn", False) + if job_config.parallelism.context_parallel_degree > 1 and use_flex_attn: + raise NotImplementedError("CP support for FlexAttention is still in progress.") + + if parallel_dims.tp_enabled: + if job_config.parallelism.enable_async_tensor_parallel: + raise NotImplementedError( + "Currently, async TP is not tested for gptoss. \ + torch.compile is not supported yet, which is required for async TP." + ) + + enable_float8_linear = "float8" in job_config.model.converters + float8_is_rowwise = job_config.float8.recipe_name in ( + "rowwise", + "rowwise_with_gw_hp", + ) + + enable_float8_tensorwise_tp = enable_float8_linear and not float8_is_rowwise + if enable_float8_tensorwise_tp: + raise NotImplementedError( + "Currently, float8 tensorwise TP is not tested for gptoss" + ) + + apply_non_moe_tp( + model, + world_mesh["tp"], + loss_parallel=not job_config.parallelism.disable_loss_parallel, + enable_float8_tensorwise_tp=False, + enable_async_tp=False, + ) + + if parallel_dims.tp_enabled or parallel_dims.ep_enabled: + apply_moe_ep_tp( + model, + tp_mesh=world_mesh["tp"] if parallel_dims.tp_enabled else None, + ep_mesh=world_mesh["ep"] if parallel_dims.ep_enabled else None, + ep_tp_mesh=( + world_mesh["ep", "tp"] + if parallel_dims.tp_enabled + and parallel_dims.ep_enabled + and parallel_dims.etp_enabled + else None + ), + etp_enabled=parallel_dims.etp_enabled, + ) + + model_compile_enabled = ( + job_config.compile.enable and "model" in job_config.compile.components + ) + + if job_config.activation_checkpoint.mode != "none": + apply_ac( + model, + job_config.activation_checkpoint, + model_compile_enabled=model_compile_enabled, + use_flex_attn=use_flex_attn, + save_list=_op_sac_save_list, + ) + + dp_mesh: DeviceMesh | None = None + if parallel_dims.fsdp_enabled or parallel_dims.ep_enabled: + # apply FSDP or HSDP, potentially with Context Parallel + if parallel_dims.dp_replicate_enabled: + dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp") + else: + dp_mesh_dim_names = ("dp_shard_cp",) + dp_mesh = world_mesh[tuple(dp_mesh_dim_names)] + + # the mesh dim names of which the MoE params are sharded on via FSDP/HSDP + dp_mod_ep_mesh_dim_names = [] + if parallel_dims.ep_enabled: + if parallel_dims.dp_replicate_enabled: + dp_mod_ep_mesh_dim_names.append("dp_replicate") + dp_mod_ep_mesh_dim_names.append("dp_shard_mod_ep") + + apply_fsdp( + model, + dp_mesh, + param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], + reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce], + pp_enabled=parallel_dims.pp_enabled, + cpu_offload=job_config.training.enable_cpu_offload, + reshard_after_forward_policy=job_config.parallelism.fsdp_reshard_after_forward, + ep_degree=parallel_dims.ep, + dp_mod_ep_mesh=( + world_mesh[tuple(dp_mod_ep_mesh_dim_names)] + if parallel_dims.ep_enabled + else None + ), + ) + + if parallel_dims.dp_replicate_enabled: + logger.info("Applied HSDP to the model") + else: + logger.info("Applied FSDP to the model") + + if parallel_dims.cp_enabled: + logger.info("Applied Context Parallel to the model") + + if job_config.training.enable_cpu_offload: + logger.info("Applied CPU Offloading to the model") + elif parallel_dims.dp_replicate_enabled: + if world_mesh.ndim > 1: + raise RuntimeError("DDP has not supported > 1D parallelism") + dp_mesh = world_mesh + apply_ddp( + model, + dp_mesh, + enable_compile=model_compile_enabled, + enable_compiled_autograd=job_config.parallelism.enable_compiled_autograd, + ) + + return model + + +def apply_non_moe_tp( + model: nn.Module, + tp_mesh: DeviceMesh, + loss_parallel: bool, + enable_float8_tensorwise_tp: bool, + enable_async_tp: bool, +): + """Apply tensor parallelism.""" + # 1. Parallelize the embedding and shard its outputs (which are the first + # transformer block's inputs) + # 2. Parallelize the root norm layer over the sequence dim + # 3. Parallelize the final linear output layer + parallelize_module( + model, + tp_mesh, + { + "tok_embeddings": RowwiseParallel( + input_layouts=Replicate(), + output_layouts=Shard(1), + ), + "norm": SequenceParallel(), + "output": ColwiseParallel( + input_layouts=Shard(1), + output_layouts=Shard(-1) if loss_parallel else Replicate(), + use_local_output=not loss_parallel, + ), + }, + ) + + ( + rowwise_parallel, + colwise_parallel, + prepare_module_input, + prepare_module_output, + ) = ( + RowwiseParallel, + ColwiseParallel, + PrepareModuleInput, + PrepareModuleOutput, + ) + + # Apply tensor + sequence parallelism to every transformer block + for transformer_block in model.layers.values(): + layer_plan = { + "attention_norm": SequenceParallel(), + "attention": prepare_module_input( + input_layouts=(Shard(1), None), + desired_input_layouts=(Replicate(), None), + ), + "attention.wq": colwise_parallel(), + "attention.wk": colwise_parallel(), + "attention.wv": colwise_parallel(), + "attention.attn": prepare_module_output( + output_layouts=(Shard(1), Shard(1)), + desired_output_layouts=(Shard(1), Shard(1)), + use_local_output=False, + ), + "attention.wo": rowwise_parallel(output_layouts=Shard(1)), + "ffn_norm": SequenceParallel(), + } + + # shard attention.sinks across heads + attn = transformer_block.attention + attn.register_parameter( + "sinks", + nn.Parameter(distribute_tensor(attn.sinks, tp_mesh, [Shard(0)])), + ) + + parallelize_module( + module=transformer_block, + device_mesh=tp_mesh, + parallelize_plan=layer_plan, + ) + + if enable_async_tp: + from torch.distributed._symmetric_memory import enable_symm_mem_for_group + + torch._inductor.config._micro_pipeline_tp = True + enable_symm_mem_for_group(tp_mesh.get_group().group_name) + + logger.info( + f"Applied {'Float8 tensorwise ' if enable_float8_tensorwise_tp else ''}{'Async ' if enable_async_tp else ''}" + "Tensor Parallelism to the model" + ) + + +# NOTE(jianiw): The function can not be reused now because reimplemented ExpertTensorParallel +def apply_moe_ep_tp( + model: nn.Module, + tp_mesh: DeviceMesh | None, + ep_mesh: DeviceMesh | None, + ep_tp_mesh: DeviceMesh | None, + etp_enabled: bool, +): + for transformer_block in model.layers.values(): + if not transformer_block.moe_enabled: + continue + + if tp_mesh is not None: + moe_layer_plan = { + # input / output sharding on the seqlen dim + # all-gather for input, reduce-scatter for output + "moe": PrepareModuleInputOutput( + input_layouts=(Shard(1),), + desired_input_layouts=(Replicate(),), + use_local_input=True, + output_layouts=(Partial(),), + desired_output_layouts=(Shard(1),), + ), + # replicate computation for the router + "moe.router.gate": NoParallel(), + } + if ep_mesh is not None and not etp_enabled: + # If TP is borrowed for EP, then split the tokens across TP ranks so that + # the reorderer, the all-to-all comms, and routed experts computation + # are effectively running Sequence Parallel (split along the folded bs*slen dim) + moe_layer_plan.update({"moe.reorderer": ReordererSequenceParallel()}) + + parallelize_module( + module=transformer_block, + device_mesh=tp_mesh, + parallelize_plan=moe_layer_plan, + ) + + experts_mesh, experts_plan = None, None + if ep_mesh is None: + experts_mesh = tp_mesh + # input Replicate, output Partial + experts_plan = TensorParallel() + elif tp_mesh is None: + experts_mesh = ep_mesh + # input / output sharding on the batch / tokens dim + experts_plan = ExpertParallel() + elif etp_enabled: + experts_mesh = ep_tp_mesh + experts_plan = ExpertTensorParallel(tp_mesh=tp_mesh, ep_mesh=ep_mesh) + else: + experts_mesh = ep_mesh + experts_plan = ExpertParallel() + + parallelize_module( + module=transformer_block.moe.experts, + device_mesh=experts_mesh, + parallelize_plan=experts_plan, + ) diff --git a/torchtitan/experiments/gpt_oss/model/args.py b/torchtitan/experiments/gpt_oss/model/args.py new file mode 100644 index 000000000..5cb375251 --- /dev/null +++ b/torchtitan/experiments/gpt_oss/model/args.py @@ -0,0 +1,108 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +from dataclasses import dataclass, field +from typing import Literal + +from torch import nn + +from torchtitan.config.job_config import JobConfig +from torchtitan.models.moe import MoEArgs +from torchtitan.models.utils import get_moe_model_nparams_and_flops +from torchtitan.protocols.train_spec import BaseModelArgs +from torchtitan.tools.logging import logger +from torchtitan.tools.utils import has_cuda_capability + + +@dataclass +class GptOssModelArgs(BaseModelArgs): + """ + Data class for defining model arguments and hyperparameters. + + Attributes: + max_batch_size (int): Maximum batch size. + max_seq_len (int): Maximum sequence length. + dtype (Literal["bf16", "fp8"]): Data type for computations. + vocab_size (int): Vocabulary size. + dim (int): Model dimension. + inter_dim (int): Intermediate dimension for MLP layers. + moe_inter_dim (int): Intermediate dimension for MoE layers. + n_layers (int): Number of transformer layers. + n_dense_layers (int): Number of dense layers in the model. + n_heads (int): Number of attention heads. + n_routed_experts (int): Number of routed experts for MoE layers. + n_shared_experts (int): Number of shared experts for MoE layers. + n_activated_experts (int): Number of activated experts in MoE layers. + n_expert_groups (int): Number of expert groups. + n_limited_groups (int): Number of limited groups for MoE routing. + score_func (Literal["softmax", "sigmoid"]): Scoring function for MoE routing. + route_scale (float): Scaling factor for routing scores. + use_grouped_mm (bool): Whether to use grouped matrix multiplication for MoE layers. + load_balance_coeff (float | None): Auxiliary-Loss-Free Load balancing coefficient for MoE layers. + q_lora_rank (int): LoRA rank for query projections. + kv_lora_rank (int): LoRA rank for key-value projections. + qk_nope_head_dim (int): Dimension for query-key projections without positional embeddings. + qk_rope_head_dim (int): Dimension for query-key projections with rotary embeddings. + v_head_dim (int): Dimension for value projections. + original_seq_len (int): Original sequence length. + rope_theta (float): Base for rotary positional encoding. + rope_factor (float): Scaling factor for extended sequence lengths. + beta_fast (int): Fast beta correction factor. + beta_slow (int): Slow beta correction factor. + """ + + max_batch_size: int = 8 + max_seq_len: int = 131072 + dtype: Literal["bf16", "fp8"] = "bf16" + vocab_size: int = 201088 + dim: int = 2880 + moe_inter_dim: int = 2880 + n_layers: int = 24 + norm_eps: float = 1e-5 # eps used for RMSNorm + # MoE + moe_args: MoEArgs = field(default_factory=MoEArgs) + swiglu_limit: float = 7.0 + # Multi-Head Latent Attention (MLA) + head_dim: int = 64 + n_heads: int = 64 + n_kv_heads: int = 8 + sliding_window: int = 128 + use_flex_attn: bool = True + attn_mask_type: str = "causal" + # yarn + original_seq_len: int = 4096 + rope_theta: float = 150000.0 + rope_factor: float = 32 + beta_fast: int = 32 + beta_slow: int = 1 + + def update_from_config(self, job_config: JobConfig, **kwargs) -> None: + seq_len = job_config.training.seq_len + if seq_len > self.max_seq_len: + logger.warning( + f"Sequence length {seq_len} exceeds original maximum {self.max_seq_len}." + ) + self.max_seq_len = seq_len + + if self.moe_args.use_grouped_mm and not has_cuda_capability(9, 0): + logger.warning( + "Failed to use grouped mm, which is only supported on SM90 or later", + ) + self.moe_args.use_grouped_mm = False + + if job_config.parallelism.context_parallel_degree > 1 and self.use_flex_attn: + raise NotImplementedError( + "CP support for FlexAttention is still in progress." + ) + + def get_nparams_and_flops( + self, model: nn.Module, seq_len: int + ) -> tuple[int, float]: + return get_moe_model_nparams_and_flops(self, model, seq_len) diff --git a/torchtitan/experiments/gpt_oss/model/model.py b/torchtitan/experiments/gpt_oss/model/model.py new file mode 100644 index 000000000..883b760aa --- /dev/null +++ b/torchtitan/experiments/gpt_oss/model/model.py @@ -0,0 +1,365 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch import nn +from torchtitan.models.attention import build_attention +from torchtitan.protocols.train_spec import ModelProtocol + +from .args import GptOssModelArgs +from .moe import GptOssMoE + + +def precompute_rope_cache( + dim: int, max_seq_len: int, base: float = 1_000_000.0 +) -> torch.Tensor: + freqs = 1.0 / (base ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + # Create position indexes `[0, 1, ..., max_seq_len - 1]` + t = torch.arange(max_seq_len, dtype=freqs.dtype, device=freqs.device) + + # Outer product of theta and position index; output tensor has + # a shape of [max_seq_len, dim // 2] + idx_theta = torch.outer(t, freqs).float() + + # We cache the cos and sin embeddings instead of the IDs. This helps + # ensure we have correct behavior when training with bf16 + # Size: [max_seq_len, (dim * 2)] + freqs = torch.cat([idx_theta, idx_theta], dim=-1) + rope_cache = torch.cat([freqs.cos(), freqs.sin()], dim=-1) + return rope_cache + + +def rotate_half(x: torch.Tensor) -> torch.Tensor: + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def reshape_for_broadcast(rope_cache: torch.Tensor, x: torch.Tensor) -> torch.Tensor: + """ + Reshape frequency tensor (represented by cos, sin) for broadcasting it with another tensor. + + This function reshapes the frequency tensor to have the same shape as the target tensor 'x' + for the purpose of broadcasting the frequency tensor during element-wise operations. + + The input freqs_cis tensor is assumed to be of shape (max_seqlen, head_dim * 2), + and the first seqlen elements will be sliced, but dim must match x. + + Args: + rope_cache (torch.Tensor): RoPE tensor (cos and sin) to be reshaped. + x (torch.Tensor): Target tensor for broadcasting compatibility. + + Returns: + torch.Tensor: Reshaped frequency tensor. + """ + ndim = x.ndim + assert ndim > 1 + _, seqlen, _, head_dim = x.shape + rope_cache = rope_cache[0:seqlen] + # The shape of rope_cache is (seqlen, head_dim * 2) because we concate cos and sin + assert rope_cache.shape == (seqlen, head_dim * 2) + shape = [-1, seqlen, 1, head_dim * 2] + return rope_cache.view(*shape) + + +def apply_rotary_emb( + xq: torch.Tensor, xk: torch.Tensor, rope_cache: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor]: + # input tensor x has shape [bsz, seq_len, n_heads, head_dim] + head_dim = xq.shape[-1] + + # reshape for broadcast + rope_cache = reshape_for_broadcast(rope_cache, xq) + + # [bsz, seq_len, 1, head_dim] + cos = rope_cache[..., :head_dim].to(dtype=xq.dtype, device=xq.device) + sin = rope_cache[..., head_dim:].to(dtype=xq.dtype, device=xq.device) + + # xq: [bsz, seq_len, n_heads, head_dim] + # xk: [bsz, seq_len, n_kv_heads, head_dim] + xq_out = (xq * cos) + (rotate_half(xq) * sin) + xk_out = (xk * cos) + (rotate_half(xk) * sin) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: + """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" + bs, slen, n_kv_heads, head_dim = x.shape + if n_rep == 1: + return x + return ( + torch.unsqueeze(x, dim=3) + .expand(bs, slen, n_kv_heads, n_rep, head_dim) + .reshape(bs, slen, n_kv_heads * n_rep, head_dim) + ) + + +class Attention(nn.Module): + """ + Multi-head attention (MLA) module. + """ + + def __init__( + self, model_args: GptOssModelArgs, use_sliding_attention: bool = False + ): + super().__init__() + + self.sliding_window = ( + model_args.sliding_window if use_sliding_attention else None + ) + self.head_dim = model_args.head_dim + self.n_heads = model_args.n_heads + self.n_kv_heads = model_args.n_kv_heads + + self.n_rep = self.n_heads // self.n_kv_heads + + self.wq = nn.Linear( + model_args.dim, + model_args.n_heads * model_args.head_dim, + bias=True, + ) + self.wk = nn.Linear( + model_args.dim, + model_args.n_kv_heads * model_args.head_dim, + bias=True, + ) + self.wv = nn.Linear( + model_args.dim, + model_args.n_kv_heads * model_args.head_dim, + bias=True, + ) + self.wo = nn.Linear( + model_args.n_heads * model_args.head_dim, + model_args.dim, + bias=True, + ) + self.sinks = nn.Parameter(torch.empty(model_args.n_heads)) + + self.use_flex_attn = model_args.use_flex_attn + + if not self.use_flex_attn: + raise ValueError("Only support FlexAttention in Gpt-oss model") + + # Only apply sliding window to every other layer + if use_sliding_attention: + self.attn = build_attention( + use_flex_attn=True, + attn_mask_type="sliding_window", + sliding_window=self.sliding_window, + ) + else: + self.attn = build_attention( + use_flex_attn=True, attn_mask_type=model_args.attn_mask_type + ) + + def forward( + self, + x: torch.Tensor, + rope_cache: torch.Tensor, + ): + """ + Forward pass for the Multi-Head Latent Attention (MLA) Layer. + + Args: + x (torch.Tensor): Input tensor of shape (batch_size, seq_len, dim). + rope_cache (torch.Tensor): Precomputed cosine and sine frequencies for rope embedding. + + Returns: + torch.Tensor: Output tensor with the same shape as the input. + """ + bsz, seqlen, _ = x.size() + hidden_shape = (bsz, seqlen, -1, self.head_dim) + + q = self.wq(x).view(hidden_shape) + k = self.wk(x).view(hidden_shape) + v = self.wv(x).view(hidden_shape) + + q, k = apply_rotary_emb(q, k, rope_cache) + + # repeat k/v heads if n_kv_heads < n_heads + keys = repeat_kv(k, self.n_rep) + values = repeat_kv(v, self.n_rep) + + q = q.transpose(1, 2).contiguous() + k = keys.transpose(1, 2).contiguous() + v = values.transpose(1, 2).contiguous() + + # FlexAttention + output, aux_output = self.attn( + q, + k, + v, + scale=None, + return_lse=True, + ) + + # Apply attention sink rescaling: rescale by σ(lse - w[h]) + # This is mathematically equivalent to concatenating learnable sink weights + lse = aux_output.lse + sink_scale = torch.sigmoid(lse - self.sinks.view(1, -1, 1)).unsqueeze(-1) + output = output * sink_scale.to(output.dtype) + + output = output.transpose(1, 2).contiguous() # (B, H, T, D) -> (B, T, H, D) + + # Reshape and project output + output = output.reshape( + bsz, seqlen, -1 + ).contiguous() # (bsz, seqlen, n_heads * v_head_dim) + output = self.wo(output) # (bsz, seqlen, dim) + return output + + def init_weights(self, init_std: float): + linear_list = [ + self.wq, + self.wk, + self.wv, + ] + + nn.init.trunc_normal_(self.sinks, mean=0.0, std=init_std) + for linear in linear_list: + nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std) + nn.init.trunc_normal_(linear.bias, mean=0.0, std=init_std) + nn.init.trunc_normal_(self.wo.weight, mean=0.0, std=init_std) + nn.init.trunc_normal_(self.wo.bias, mean=0.0, std=init_std) + + # TODO: statically init the mask using train.seq_len + def sliding_window_causal(self, seqlen, device): + i = torch.arange(seqlen, device=device) + q_idx = i[:, None] + kv_idx = i[None, :] + + causal_mask = q_idx >= kv_idx + if self.sliding_window is None: + return causal_mask + window_mask = q_idx - kv_idx <= self.sliding_window + return causal_mask & window_mask + + +class TransformerBlock(nn.Module): + """ + Transformer block with attention and feed-forward layers. + """ + + def __init__(self, layer_id: int, model_args: GptOssModelArgs): + + super().__init__() + use_sliding_attention = layer_id % 2 == 0 + self.attention = Attention( + model_args, use_sliding_attention=use_sliding_attention + ) + self.attention_norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps) + self.ffn_norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps) + + self.moe = GptOssMoE( + model_args, dim=model_args.dim, hidden_dim=model_args.moe_inter_dim + ) + self.moe_enabled = True # for composability with load balancing + + self.weight_init_std = 0.02 / (2 * (layer_id + 1)) ** 0.5 + self.layer_id = layer_id + + def forward(self, x: torch.Tensor, rope_cache: torch.Tensor): + """ + Forward pass for the Transformer block. + + Args: + x (torch.Tensor): Input tensor of shape (batch_size, seq_len, dim). + rope_cache (torch.Tensor): Precomputed cosine and sine frequencies. + + Returns: + torch.Tensor: Output tensor with the same shape as the input. + """ + x = x + self.attention(self.attention_norm(x), rope_cache) + x = x + self.moe(self.ffn_norm(x)) + return x + + def init_weights(self, buffer_device: torch.device): + for norm in (self.attention_norm, self.ffn_norm): + norm.reset_parameters() + self.attention.init_weights(self.weight_init_std) + self.moe.init_weights(self.weight_init_std, buffer_device) + + +class GptOssModel(nn.Module, ModelProtocol): + """ + GPT-OSS Transformer model with attention and feed-forward layers. + """ + + def __init__(self, model_args: GptOssModelArgs): + super().__init__() + self.model_args = model_args + self.max_seq_len = model_args.max_seq_len + self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.dim) + self.register_buffer( + "rope_cache", self._precompute_rope_cache(), persistent=False + ) + + self.layers = torch.nn.ModuleDict() + for layer_id in range(model_args.n_layers): + self.layers[str(layer_id)] = TransformerBlock(layer_id, model_args).to( + torch.bfloat16 + ) + + self.norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps) + self.output = nn.Linear( + model_args.dim, + model_args.vocab_size, + dtype=torch.get_default_dtype(), + bias=False, + ) + self.model_args = model_args + self.init_weights() + + def init_weights(self, buffer_device: torch.device | None = None) -> None: + buffer_device = buffer_device or self.rope_cache.device + with torch.device(buffer_device): + self.rope_cache = self._precompute_rope_cache() + if self.tok_embeddings is not None: + nn.init.normal_(self.tok_embeddings.weight) + for layer in self.layers.values(): + if layer is not None: + layer.init_weights(buffer_device=buffer_device) + if self.norm is not None: + self.norm.reset_parameters() + final_out_std = self.model_args.dim**-0.5 + cutoff_factor = 3 + if self.output is not None: + nn.init.trunc_normal_( + self.output.weight, + mean=0.0, + std=final_out_std, + a=-cutoff_factor * final_out_std, + b=cutoff_factor * final_out_std, + ) + + def _precompute_rope_cache(self) -> torch.Tensor: + return precompute_rope_cache( + self.model_args.head_dim, + self.model_args.max_seq_len, + self.model_args.rope_theta, + ) + + def forward(self, tokens: torch.Tensor): + """ + Forward pass for the Transformer model. + + Args: + tokens (torch.Tensor): Input tensor of token IDs with shape (batch_size, seq_len). + + Returns: + torch.Tensor: Logits tensor of shape (batch_size, vocab_size). + """ + h = self.tok_embeddings(tokens) + + for layer in self.layers.values(): + h = layer(h, self.rope_cache) + h = self.norm(h) + output = self.output(h) + return output diff --git a/torchtitan/experiments/gpt_oss/model/moe.py b/torchtitan/experiments/gpt_oss/model/moe.py new file mode 100644 index 000000000..2df093880 --- /dev/null +++ b/torchtitan/experiments/gpt_oss/model/moe.py @@ -0,0 +1,205 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch import nn +from torch.distributed.tensor import DTensor +from torchtitan.experiments.gpt_oss.infra.expert_parallel import expert_parallel +from torchtitan.models.moe import MoE + +from .args import GptOssModelArgs + + +def swiglu(x, alpha: float = 1.702, limit: float = 7.0): + x_glu, x_linear = x[..., ::2], x[..., 1::2] + # Clamp the input values + x_glu = x_glu.clamp(min=None, max=limit) + x_linear = x_linear.clamp(min=-limit, max=limit) + out_glu = x_glu * torch.sigmoid(alpha * x_glu) + # Note we add an extra bias of 1 to the linear layer + return out_glu * (x_linear + 1) + + +class GptOssGroupedExperts(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + num_experts: int, + swiglu_limit: float, + use_grouped_mm: bool, + ): + super().__init__() + self.num_experts = num_experts + self.use_grouped_mm = use_grouped_mm + self.swiglu_limit = swiglu_limit + + self.mlp1_weight = nn.Parameter(torch.empty((num_experts, dim, hidden_dim * 2))) + self.mlp1_bias = nn.Parameter(torch.empty((num_experts, hidden_dim * 2))) + self.mlp2_weight = nn.Parameter(torch.empty((num_experts, hidden_dim, dim))) + self.mlp2_bias = nn.Parameter(torch.empty((num_experts, dim))) + + def forward( + self, + x: torch.Tensor, + num_tokens_per_expert: torch.Tensor | None = None, + ) -> torch.Tensor: + if self.use_grouped_mm: + return GptOssGroupedExperts._run_experts_grouped_mm( + self.mlp1_weight, + self.mlp1_bias, + self.mlp2_weight, + self.mlp2_bias, + self.swiglu_limit, + x, + num_tokens_per_expert, + ) + else: + return GptOssGroupedExperts._run_experts_for_loop( + self.mlp1_weight, + self.mlp1_bias, + self.mlp2_weight, + self.mlp2_bias, + self.swiglu_limit, + x, + num_tokens_per_expert, + ) + + # TODO: keeping this for-loop implementation for comparison + # and readability, may remove later + @expert_parallel + @staticmethod + def _run_experts_for_loop( + mlp1_weight: torch.Tensor, + mlp1_bias: torch.Tensor, + mlp2_weight: torch.Tensor, + mlp2_bias: torch.Tensor, + swiglu_limit: float, + x: torch.Tensor, + num_tokens_per_expert: torch.Tensor | None = None, + ) -> torch.Tensor: + if num_tokens_per_expert is not None: + # NOTE: this would incur a synchronization between device and host + num_tokens_per_expert = num_tokens_per_expert.tolist() + + # side-effect code due to the usage of generate_permute_indices + num_padding = x.shape[0] - sum(num_tokens_per_expert) + + # a tuple of tensors indexed by experts + # each with shape (tokens_per_expert(varying), dim) + x = torch.split( + x[: sum(num_tokens_per_expert)], + split_size_or_sections=num_tokens_per_expert, + dim=0, + ) + out_experts_splits = [] + for expert_idx, x_expert in enumerate(x): + h = ( + torch.matmul(x_expert, mlp1_weight[expert_idx]) + + mlp1_bias[expert_idx] + ) + h = swiglu(h, limit=swiglu_limit) + h = torch.matmul(h, mlp2_weight[expert_idx]) + mlp2_bias[expert_idx] + out_experts_splits.append(h) + out = torch.cat(out_experts_splits, dim=0) + + # side-effect code due to the usage of generate_permute_indices + out = torch.vstack((out, out.new_zeros((num_padding, out.shape[-1])))) + else: + # x shape (num_experts, tokens_per_expert, dim) + h = torch.bmm(x, mlp1_weight) + mlp1_bias.unsqueeze(1) + h = swiglu(h, limit=swiglu_limit) + out = torch.bmm(h, mlp2_weight) + mlp2_bias.unsqueeze(1) + + return out + + @expert_parallel + @staticmethod + def _run_experts_grouped_mm( + mlp1_weight: torch.Tensor, + mlp1_bias: torch.Tensor, + mlp2_weight: torch.Tensor, + mlp2_bias: torch.Tensor, + swiglu_limit: float, + x: torch.Tensor, + num_tokens_per_expert: torch.Tensor | None = None, + ) -> torch.Tensor: + if num_tokens_per_expert is not None: + offsets = torch.cumsum(num_tokens_per_expert, dim=0, dtype=torch.int32) + # grouped mm between a 2D tensor and a 3D tensor + assert x.dim() == 2 + num_tokens_per_expert_long = num_tokens_per_expert.to(torch.long) + else: + offsets = None + # fall back to regular bmm between 3D tensors + assert x.dim() == 3 + + if isinstance(mlp1_weight, DTensor): + mlp1_weight, mlp1_bias, mlp2_weight, mlp2_bias = ( + mlp1_weight.to_local(), + mlp1_bias.to_local(), + mlp2_weight.to_local(), + mlp2_bias.to_local(), + ) + + h = torch._grouped_mm(x.bfloat16(), mlp1_weight.bfloat16(), offs=offsets) + if offsets is not None: + # TODO(jianiw): check what is this doing + b1 = mlp1_bias.repeat_interleave(num_tokens_per_expert_long, dim=0) + tail_slack = x.shape[0] - int(offsets[-1]) + if tail_slack: + b1 = torch.cat([b1, b1.new_zeros((tail_slack, b1.shape[-1]))], dim=0) + h = h + b1.to(h.dtype) + + h = swiglu(h, limit=swiglu_limit) + h = torch._grouped_mm(h, mlp2_weight.bfloat16(), offs=offsets) + if offsets is not None: + b2 = mlp2_bias.repeat_interleave(num_tokens_per_expert_long, dim=0) + tail_slack = x.shape[0] - int(offsets[-1]) + if tail_slack: # padding + b2 = torch.cat([b2, b2.new_zeros((tail_slack, b2.shape[-1]))], dim=0) + h = h + b2.to(h.dtype) + + return h + + def init_weights(self, init_std: float): + nn.init.trunc_normal_(self.mlp1_weight, mean=0.0, std=init_std) + nn.init.trunc_normal_(self.mlp1_bias, mean=0.0, std=init_std) + nn.init.trunc_normal_(self.mlp2_weight, mean=0.0, std=init_std) + nn.init.trunc_normal_(self.mlp2_bias, mean=0.0, std=init_std) + + def extra_repr(self): + return ( + f"num_experts={self.num_experts}, " + f"use_grouped_mm={self.use_grouped_mm}, " + f"mlp1_weight={tuple(self.mlp1_weight.shape)}, " + f"mlp1_bias={tuple(self.mlp1_bias.shape)}, " + f"mlp2_weight={tuple(self.mlp2_weight.shape)}, " + f"mlp2_bias={tuple(self.mlp2_bias.shape)}" + ) + + +class GptOssMoE(MoE): + """GptOss MoE implementation that inherits from the base MoE class.""" + + def __init__(self, model_args: GptOssModelArgs, dim: int, hidden_dim: int): + # Convert GptOssModelArgs to MoEArgs for base class compatibility + moe_args = model_args.moe_args + + # Initialize the base MoE class + super().__init__(moe_args, dim, hidden_dim) + + # Override the base GroupedExperts with GptOssGroupedExperts + self.experts = GptOssGroupedExperts( + dim=dim, + hidden_dim=hidden_dim, + num_experts=moe_args.num_experts, + swiglu_limit=model_args.swiglu_limit, + use_grouped_mm=moe_args.use_grouped_mm, + ) diff --git a/torchtitan/experiments/gpt_oss/train_configs/debug_model.toml b/torchtitan/experiments/gpt_oss/train_configs/debug_model.toml new file mode 100644 index 000000000..22d0d7d8b --- /dev/null +++ b/torchtitan/experiments/gpt_oss/train_configs/debug_model.toml @@ -0,0 +1,82 @@ +[job] +dump_folder = "./outputs" +description = "Gpt-oss debug training" +print_args = false + +[profiling] +enable_profiling = false +save_traces_folder = "profile_trace" +profile_freq = 10 +enable_memory_snapshot = false +save_memory_snapshot_folder = "memory_snapshot" + +[metrics] +log_freq = 1 +disable_color_printing = false +enable_tensorboard = false +save_tb_folder = "tb" +enable_wandb = false + +[model] +name = "gpt_oss" +flavor = "debugmodel" +# test folder with tokenizer.json, for debug purpose only +hf_assets_path = "./tests/assets/tokenizer" +# converters = ["float8"] + +[optimizer] +name = "AdamW" +lr = 8e-4 +eps = 1e-8 + +[lr_scheduler] +warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps +decay_ratio = 0.8 # lr scheduler decay ratio, 80% of the train steps +decay_type = "linear" +min_lr_factor = 0.0 + +[training] +local_batch_size = 8 +seq_len = 2048 +max_norm = 1.0 # grad norm clipping +steps = 10 +dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M) + +[parallelism] +data_parallel_replicate_degree = 1 +data_parallel_shard_degree = -1 +fsdp_reshard_after_forward = "default" # default / never / always +tensor_parallel_degree = 1 +enable_async_tensor_parallel = false +expert_parallel_degree = 1 +expert_tensor_parallel_degree = 1 + +[checkpoint] +enable = false +folder = "checkpoint" +interval = 10 +last_save_model_only = false +export_dtype = "float32" +async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] + +[activation_checkpoint] +mode = "none" # ["none", "selective", "full"] +selective_ac_option = '2' # 'int' = ac every positive int layer or 'op', ac based on ops policy + +[compile] +enable=false +components = ["model", "loss"] + +[quantize.linear.float8] +enable_fsdp_float8_all_gather = false +precompute_float8_dynamic_scale_for_fsdp = false +filter_fqns = ["output", "router.gate"] + +[quantize.grouped_mm.float8] +fqns = ["experts"] + +[validation] +enable = false +dataset = "c4_validation" +freq = 5 +steps = 10 diff --git a/torchtitan/models/attention.py b/torchtitan/models/attention.py index 277d64be1..3b1cd30a0 100644 --- a/torchtitan/models/attention.py +++ b/torchtitan/models/attention.py @@ -13,6 +13,7 @@ from torch.nn.attention import sdpa_kernel, SDPBackend from torch.nn.attention.flex_attention import ( _mask_mod_signature, + AuxRequest, BlockMask, create_block_mask, flex_attention, @@ -23,7 +24,9 @@ # FlexAttention mask type. For each mask type, we initialize it at most once per # batch. To record what it is initialized, FLEX_ATTN_MASK_T is used as the key to # track the initialized mask. -FLEX_ATTN_MASK_T = tuple[str, int | None] +FLEX_ATTN_MASK_T = tuple[ + str, int | None, int | None +] # (mask_type, fixed_block_size, sliding_window) class FlexAttention(torch.nn.Module): @@ -62,19 +65,23 @@ class FlexAttention(torch.nn.Module): attn_mask_type: str def __init__( - self, attn_mask_type: str, fixed_block_size: int | None = None + self, + attn_mask_type: str, + fixed_block_size: int | None = None, + sliding_window_size: int | None = None, ) -> None: super().__init__() - if attn_mask_type not in ["causal", "block_causal"]: + if attn_mask_type not in ["causal", "block_causal", "sliding_window"]: raise ValueError(f"Unrecognized attn_mask_type {attn_mask_type}.") self.attn_mask_type = attn_mask_type self.fixed_block_size = fixed_block_size + self.sliding_window_size = sliding_window_size FlexAttention.used_attn_mask_types.add(self.mask_key) @property def mask_key(self) -> FLEX_ATTN_MASK_T: - return (self.attn_mask_type, self.fixed_block_size) + return (self.attn_mask_type, self.fixed_block_size, self.sliding_window_size) def forward( self, @@ -82,9 +89,30 @@ def forward( k: torch.Tensor, v: torch.Tensor, scale: float | None = None, - ) -> torch.Tensor: + return_lse: bool = False, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + + # Regular path without sink block_mask = FlexAttention.block_masks[self.mask_key] - return FlexAttention.flex_attn(q, k, v, block_mask=block_mask, scale=scale) + return_aux = AuxRequest(lse=True) if return_lse else None + return FlexAttention.flex_attn( + q, k, v, block_mask=block_mask, return_aux=return_aux, scale=scale + ) + + @staticmethod + def _get_sliding_window_mask_mod(window: int): + """ + Returns a mask_mod function that + - only allows kv_idx ≤ q_idx (causal) + - and only if (q_idx - kv_idx) ≤ window + """ + + def sliding_mod(b, h, q_idx, kv_idx): + # causal within window + keep = (kv_idx <= q_idx) & (q_idx - kv_idx <= window) + return keep + + return sliding_mod @staticmethod def _get_causal_mask_mod() -> _mask_mod_signature: @@ -153,7 +181,7 @@ def blocked_mask_mod( def init_attention_mask(batch: torch.Tensor, eos_id: int | None) -> None: # batch is [b, s, h, d] shape for mask_key in FlexAttention.used_attn_mask_types: - attn_mask_type, fixed_block_size = mask_key + attn_mask_type, fixed_block_size, sliding_window = mask_key match attn_mask_type: case "causal": if FlexAttention.block_masks.get(mask_key, None) is not None: @@ -169,6 +197,19 @@ def init_attention_mask(batch: torch.Tensor, eos_id: int | None) -> None: ) batch_dimension = batch.shape[0] mask_mod = FlexAttention._get_block_causal_mask_mod(batch, eos_id) + case "sliding_window": + if sliding_window is None or sliding_window <= 0: + raise RuntimeError( + "sliding_window must be provided and > 0 for sliding_window mask." + ) + if FlexAttention.block_masks.get(mask_key, None) is not None: + continue + # We don't care about batch dimension -- + # all samples have the same sliding window mask. + batch_dimension = 1 + mask_mod = FlexAttention._get_sliding_window_mask_mod( + sliding_window + ) case _: raise RuntimeError(f"Shouldn't reach here. {attn_mask_type}") @@ -223,15 +264,22 @@ def forward( def build_attention( - use_flex_attn: bool, attn_mask_type: str, fixed_block_size: int | None = None + use_flex_attn: bool, + attn_mask_type: str, + fixed_block_size: int | None = None, + sliding_window: int | None = None, ): if use_flex_attn: - return FlexAttention(attn_mask_type, fixed_block_size) + return FlexAttention(attn_mask_type, fixed_block_size, sliding_window) else: if fixed_block_size is not None: raise ValueError( "TorchTitan with SDPA currently does not support fixed_block_size." ) + if sliding_window is not None: + raise ValueError( + "TorchTitan with SDPA currently does not support sliding_window." + ) if attn_mask_type != "causal": raise ValueError( "TorchTitan with SDPA currently only supports causal mask."