diff --git a/README.md b/README.md index c6b8d0d..c9210ec 100644 --- a/README.md +++ b/README.md @@ -4,6 +4,42 @@ This aims to be a simpler implementation of the [original repo](https://github.com/microsoft/Samba). +## Installation + +> [!TIP] +> While the `pip install` command _should_ install all deps and the package, in practice some of the more CUDA-heavy deps are better installed separately from source. See section below for more details. + +```bash +git clone https://github.com/pszemraj/samba-pytorch.git +cd samba-pytorch +pip install -e . +``` + +### Installing custom kernel packages first + +After installing `torch`, `xformers`, and `flash-attn`, you may want to install `mamba-ssm`, `causal-conv1d`, and `fla` from source: + +```bash +pip install --upgrade pip ninja +pip install git+https://github.com/state-spaces/mamba.git --no-build-isolation +pip install git+https://github.com/Dao-AILab/causal-conv1d.git --no-build-isolation +pip install git+https://github.com/sustcsonglin/flash-linear-attention@98c176e --no-build-isolation +``` + +Then, clone this repo and run commands as above. + +## Usage + +A basic example of creating a random model from a named config: + +```python +from samba_pytorch import Config, GPT +cfg = Config.from_name('Samba_421M_1k_window') +print*(cfg) +model = GPT(cfg) +model +``` + ## repo structure ```text diff --git a/samba_pytorch/config.py b/samba_pytorch/config.py index 79bca96..8159542 100644 --- a/samba_pytorch/config.py +++ b/samba_pytorch/config.py @@ -3,14 +3,13 @@ # Copyright Lightning AI. Licensed under the Apache License 2.0, # see LICENSE file at https://github.com/Lightning-AI/litgpt/blob/main/LICENSE - +import warnings from dataclasses import dataclass from typing import Any, Literal, Optional, Type import torch from typing_extensions import Self -import samba_pytorch.samba from samba_pytorch.utils import find_multiple @@ -101,8 +100,9 @@ def from_name(cls, name: str, **kwargs: Any) -> Self: @property def mlp_class(self) -> Type: + from samba_pytorch import samba # `self._mlp_class` cannot be the type to keep the config json serializable - return getattr(samba_pytorch.samba, self._mlp_class) + return getattr(samba, self._mlp_class) @property def norm_class(self) -> Type: @@ -112,9 +112,12 @@ def norm_class(self) -> Type: return RMSNorm elif self._norm_class == "FusedRMSNorm": - from samba_pytorch.modules.rmsnorm import FusedRMSNorm + warnings.warn( + "FusedRMSNorm has been removed, using standard torch RMSNorm instead" + ) + from samba_pytorch.modules.rmsnorm import RMSNorm - return FusedRMSNorm + return RMSNorm return getattr(torch.nn, self._norm_class) @@ -133,7 +136,7 @@ def norm_class(self) -> Type: rotary_percentage=1.0, parallel_residual=False, bias=False, - _norm_class="FusedRMSNorm", + _norm_class="RMSNorm", norm_eps=1e-5, _mlp_class="LLaMAMLP", intermediate_size=4096, @@ -150,7 +153,7 @@ def norm_class(self) -> Type: rotary_percentage=1.0, parallel_residual=False, bias=False, - _norm_class="FusedRMSNorm", + _norm_class="RMSNorm", norm_eps=1e-5, _mlp_class="LLaMAMLP", intermediate_size=4096, @@ -168,7 +171,7 @@ def norm_class(self) -> Type: rotary_percentage=1.0, parallel_residual=False, bias=False, - _norm_class="FusedRMSNorm", + _norm_class="RMSNorm", norm_eps=1e-5, full_per_layer=2, _mlp_class="LLaMAMLP", @@ -187,7 +190,7 @@ def norm_class(self) -> Type: rotary_percentage=1.0, parallel_residual=False, bias=False, - _norm_class="FusedRMSNorm", + _norm_class="RMSNorm", norm_eps=1e-5, _mlp_class="LLaMAMLP", intermediate_size=4608, @@ -206,7 +209,7 @@ def norm_class(self) -> Type: rotary_percentage=1.0, parallel_residual=False, bias=False, - _norm_class="FusedRMSNorm", + _norm_class="RMSNorm", norm_eps=1e-5, _mlp_class="LLaMAMLP", intermediate_size=4608, @@ -225,7 +228,7 @@ def norm_class(self) -> Type: rotary_percentage=1.0, parallel_residual=False, bias=False, - _norm_class="FusedRMSNorm", + _norm_class="RMSNorm", norm_eps=1e-5, _mlp_class="LLaMAMLP", intermediate_size=4608, @@ -244,7 +247,7 @@ def norm_class(self) -> Type: rotary_percentage=1.0, parallel_residual=False, bias=False, - _norm_class="FusedRMSNorm", + _norm_class="RMSNorm", norm_eps=1e-5, _mlp_class="LLaMAMLP", intermediate_size=4608, @@ -263,7 +266,7 @@ def norm_class(self) -> Type: rotary_percentage=1.0, parallel_residual=False, bias=False, - _norm_class="FusedRMSNorm", + _norm_class="RMSNorm", norm_eps=1e-5, _mlp_class="LLaMAMLP", intermediate_size=4096, @@ -280,7 +283,7 @@ def norm_class(self) -> Type: rotary_percentage=1.0, parallel_residual=False, bias=False, - _norm_class="FusedRMSNorm", + _norm_class="RMSNorm", norm_eps=1e-5, _mlp_class="LLaMAMLP", intermediate_size=4096, @@ -298,7 +301,7 @@ def norm_class(self) -> Type: rotary_percentage=1.0, parallel_residual=False, bias=False, - _norm_class="FusedRMSNorm", + _norm_class="RMSNorm", norm_eps=1e-5, _mlp_class="LLaMAMLP", intermediate_size=4096, @@ -316,7 +319,7 @@ def norm_class(self) -> Type: rotary_percentage=1.0, parallel_residual=False, bias=False, - _norm_class="FusedRMSNorm", + _norm_class="RMSNorm", norm_eps=1e-5, _mlp_class="LLaMAMLP", intermediate_size=4096, @@ -335,7 +338,7 @@ def norm_class(self) -> Type: parallel_residual=True, shared_attention_norm=True, bias=False, - _norm_class="FusedRMSNorm", + _norm_class="RMSNorm", norm_eps=1e-5, _mlp_class="LLaMAMLP", intermediate_size=4096, @@ -354,7 +357,7 @@ def norm_class(self) -> Type: rotary_percentage=1.0, parallel_residual=False, bias=False, - _norm_class="FusedRMSNorm", + _norm_class="RMSNorm", norm_eps=1e-5, _mlp_class="LLaMAMLP", intermediate_size=4096, @@ -373,7 +376,7 @@ def norm_class(self) -> Type: rotary_percentage=1.0, parallel_residual=False, bias=False, - _norm_class="FusedRMSNorm", + _norm_class="RMSNorm", norm_eps=1e-5, _mlp_class="LLaMAMLP", intermediate_size=4096, @@ -393,7 +396,7 @@ def norm_class(self) -> Type: rotary_percentage=1.0, parallel_residual=False, bias=False, - _norm_class="FusedRMSNorm", + _norm_class="RMSNorm", norm_eps=1e-5, _mlp_class="LLaMAMLP", intermediate_size=4096, @@ -412,7 +415,7 @@ def norm_class(self) -> Type: rotary_percentage=1.0, parallel_residual=False, bias=False, - _norm_class="FusedRMSNorm", + _norm_class="RMSNorm", norm_eps=1e-5, _mlp_class="LLaMAMLP", intermediate_size=4096, @@ -431,7 +434,7 @@ def norm_class(self) -> Type: rotary_percentage=1.0, parallel_residual=False, bias=False, - _norm_class="FusedRMSNorm", + _norm_class="RMSNorm", norm_eps=1e-5, _mlp_class="LLaMAMLP", intermediate_size=4096, @@ -450,7 +453,7 @@ def norm_class(self) -> Type: rotary_percentage=1.0, parallel_residual=False, bias=False, - _norm_class="FusedRMSNorm", + _norm_class="RMSNorm", norm_eps=1e-5, _mlp_class="LLaMAMLP", intermediate_size=4096, @@ -469,7 +472,7 @@ def norm_class(self) -> Type: rotary_percentage=1.0, parallel_residual=False, bias=False, - _norm_class="FusedRMSNorm", + _norm_class="RMSNorm", norm_eps=1e-5, _mlp_class="LLaMAMLP", intermediate_size=4096, @@ -489,7 +492,7 @@ def norm_class(self) -> Type: rotary_percentage=1.0, parallel_residual=False, bias=False, - _norm_class="FusedRMSNorm", + _norm_class="RMSNorm", norm_eps=1e-5, _mlp_class="LLaMAMLP", intermediate_size=4608, @@ -510,7 +513,7 @@ def norm_class(self) -> Type: rotary_percentage=1.0, parallel_residual=False, bias=False, - _norm_class="FusedRMSNorm", + _norm_class="RMSNorm", norm_eps=1e-5, _mlp_class="LLaMAMLP", intermediate_size=4608, @@ -531,7 +534,7 @@ def norm_class(self) -> Type: rotary_percentage=1.0, parallel_residual=False, bias=False, - _norm_class="FusedRMSNorm", + _norm_class="RMSNorm", norm_eps=1e-5, _mlp_class="LLaMAMLP", intermediate_size=4608, @@ -552,7 +555,7 @@ def norm_class(self) -> Type: rotary_percentage=1.0, parallel_residual=False, bias=False, - _norm_class="FusedRMSNorm", + _norm_class="RMSNorm", norm_eps=1e-5, _mlp_class="LLaMAMLP", intermediate_size=4608, @@ -573,7 +576,7 @@ def norm_class(self) -> Type: rotary_percentage=1.0, parallel_residual=False, bias=False, - _norm_class="FusedRMSNorm", + _norm_class="RMSNorm", norm_eps=1e-5, _mlp_class="LLaMAMLP", intermediate_size=4096, @@ -592,7 +595,7 @@ def norm_class(self) -> Type: rotary_percentage=1.0, parallel_residual=False, bias=False, - _norm_class="FusedRMSNorm", + _norm_class="RMSNorm", norm_eps=1e-5, _mlp_class="LLaMAMLP", intermediate_size=4096, @@ -612,7 +615,7 @@ def norm_class(self) -> Type: rotary_percentage=1.0, parallel_residual=False, bias=False, - _norm_class="FusedRMSNorm", + _norm_class="RMSNorm", norm_eps=1e-5, _mlp_class="LLaMAMLP", intermediate_size=4096, @@ -632,7 +635,7 @@ def norm_class(self) -> Type: rotary_percentage=1.0, parallel_residual=False, bias=False, - _norm_class="FusedRMSNorm", + _norm_class="RMSNorm", norm_eps=1e-5, _mlp_class="LLaMAMLP", intermediate_size=4096, @@ -653,7 +656,7 @@ def norm_class(self) -> Type: rotary_percentage=1.0, parallel_residual=False, bias=False, - _norm_class="FusedRMSNorm", + _norm_class="RMSNorm", norm_eps=1e-5, _mlp_class="LLaMAMLP", intermediate_size=6144, @@ -673,7 +676,7 @@ def norm_class(self) -> Type: rotary_percentage=1.0, parallel_residual=False, bias=False, - _norm_class="FusedRMSNorm", + _norm_class="RMSNorm", norm_eps=1e-5, _mlp_class="LLaMAMLP", intermediate_size=6144, @@ -693,7 +696,7 @@ def norm_class(self) -> Type: rotary_percentage=1.0, parallel_residual=False, bias=False, - _norm_class="FusedRMSNorm", + _norm_class="RMSNorm", norm_eps=1e-5, _mlp_class="LLaMAMLP", intermediate_size=6144, @@ -712,7 +715,7 @@ def norm_class(self) -> Type: rotary_percentage=1.0, parallel_residual=False, bias=False, - _norm_class="FusedRMSNorm", + _norm_class="RMSNorm", norm_eps=1e-5, _mlp_class="LLaMAMLP", intermediate_size=6144, @@ -731,7 +734,7 @@ def norm_class(self) -> Type: rotary_percentage=1.0, parallel_residual=False, bias=False, - _norm_class="FusedRMSNorm", + _norm_class="RMSNorm", norm_eps=1e-5, _mlp_class="LLaMAMLP", intermediate_size=6144, @@ -750,7 +753,7 @@ def norm_class(self) -> Type: rotary_percentage=1.0, parallel_residual=False, bias=False, - _norm_class="FusedRMSNorm", + _norm_class="RMSNorm", norm_eps=1e-5, _mlp_class="LLaMAMLP", intermediate_size=6144, @@ -769,7 +772,7 @@ def norm_class(self) -> Type: rotary_percentage=1.0, parallel_residual=False, bias=False, - _norm_class="FusedRMSNorm", + _norm_class="RMSNorm", norm_eps=1e-5, _mlp_class="LLaMAMLP", intermediate_size=6144, @@ -787,7 +790,7 @@ def norm_class(self) -> Type: rotary_percentage=1.0, parallel_residual=False, bias=False, - _norm_class="FusedRMSNorm", + _norm_class="RMSNorm", norm_eps=1e-5, _mlp_class="LLaMAMLP", intermediate_size=8192, diff --git a/samba_pytorch/modules/__init__.py b/samba_pytorch/modules/__init__.py index f02d299..bb23c11 100644 --- a/samba_pytorch/modules/__init__.py +++ b/samba_pytorch/modules/__init__.py @@ -7,7 +7,7 @@ from samba_pytorch.modules.gla import GatedLinearAttention from samba_pytorch.modules.mamba_simple import Mamba from samba_pytorch.modules.multiscale_retention import MultiScaleRetention -from samba_pytorch.modules.rmsnorm import FusedRMSNorm, RMSNorm, rms_norm +from samba_pytorch.modules.rmsnorm import RMSNorm, rms_norm from samba_pytorch.modules.rotary import RotaryEmbedding, apply_rotary_emb __all__ = [ @@ -16,7 +16,6 @@ "GatedLinearAttention", "Mamba", "MultiScaleRetention", - "FusedRMSNorm", "RMSNorm", "rms_norm", "apply_rotary_emb", diff --git a/samba_pytorch/modules/rmsnorm.py b/samba_pytorch/modules/rmsnorm.py index 01350d0..0cfbea9 100644 --- a/samba_pytorch/modules/rmsnorm.py +++ b/samba_pytorch/modules/rmsnorm.py @@ -1,836 +1,108 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -# Copyright (c) 2022, Tri Dao. -# Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/layer_norm/layer_norm.py AND https://github.com/Dao-AILab/flash-attention/blob/7a983df74215e035e566e37125b0a71e3618f39d/flash_attn/ops/layer_norm.py#L16 - -import dropout_layer_norm import torch -from torch.nn import init - +from torch import nn +from torch.nn import functional as F +from einops import rearrange +from typing import Optional, Tuple, Union -def maybe_align(x, alignment_in_bytes=16): - """Assume that x already has last dim divisible by alignment_in_bytes""" - # TD [2023-07-04] I'm not 100% sure that clone will align the memory - # https://discuss.pytorch.org/t/how-to-ensure-that-tensor-data-ptr-is-aligned-to-16-bytes/183440 +def maybe_align(x: torch.Tensor, alignment_in_bytes: int = 16) -> torch.Tensor: + """Ensures memory alignment by cloning if necessary.""" return x if x.data_ptr() % alignment_in_bytes == 0 else x.clone() - -def _dropout_add_layer_norm_forward( - x0, - residual, - gamma, - beta, - rowscale, - colscale, - dropout_p, - epsilon, - residual_in_fp32=False, - is_rms_norm=False, -): - """Assume that arguments are contiguous and aligned to 16 bytes""" - hidden_size = gamma.numel() - x0mat = x0.view((-1, hidden_size)) - residualmat = residual.view((-1, hidden_size)) if residual is not None else None - rowscale = rowscale.view(-1) if rowscale is not None else None - zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd( - x0mat, - residualmat, - gamma, - beta, - rowscale, - colscale, - None, - None, - dropout_p, - epsilon, - 1.0, - 0, - None, - residual_in_fp32, - is_rms_norm, - ) - # dmask is None if dropout_p == 0.0 - # xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype - return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma - - -def _dropout_add_layer_norm_backward( - dz, - dx, - x, - x0, - dmask, - mu, - rsigma, - gamma, - rowscale, - colscale, - dropout_p, - has_residual, - is_rms_norm=False, -): - """Assume that arguments are contiguous and aligned to 16 bytes - dx == None means that it was a post-norm architecture - (x = drop(x0) + residual was not returned in the fwd). - x0 must not be None if we have colscale. - """ - hidden_size = gamma.numel() - xmat = x.view((-1, hidden_size)) - dzmat = dz.view(xmat.shape) - dxmat = dx.view(xmat.shape) if dx is not None else None - x0mat = x0.view((-1, hidden_size)) if x0 is not None else None - rowscale = rowscale.view(-1) if rowscale is not None else None - if colscale is not None: - assert x0 is not None, "x0 is required to compute the gradient of colscale" - dx0mat, dresidualmat, dgamma, dbeta, _, _, *rest = ( - dropout_layer_norm.dropout_add_ln_bwd( - dzmat, - dxmat, - xmat, - x0mat, - dmask, - mu, - rsigma, - gamma, - rowscale, - colscale, - None, - None, - dropout_p, - 1.0, - 0, - has_residual, - is_rms_norm, - ) - ) - # dresidualmat is None if not has_residual - if colscale is None: - return dx0mat, dresidualmat, dgamma, dbeta - else: - dcolscale = rest[0] - return dx0mat, dresidualmat, dgamma, dbeta, dcolscale - - -def _dropout_add_layer_norm_subset_forward( - x0, - residual, - gamma, - beta, - colscale, - x0_subset, - out_subset, - dropout_p, - epsilon, - rowscale_const, - out_numrows, - residual_in_fp32=False, - is_rms_norm=False, -): - """Assume that arguments are contiguous and aligned to 16 bytes""" - hidden_size = gamma.numel() - x0mat = x0.view((-1, hidden_size)) - residualmat = residual.view((-1, hidden_size)) if residual is not None else None - x0_subset = x0_subset.view(-1) if x0_subset is not None else None - out_subset = out_subset.view(-1) if out_subset is not None else None - zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd( - x0mat, - residualmat, - gamma, - beta, - None, - colscale, - x0_subset, - out_subset, - dropout_p, - epsilon, - rowscale_const, - out_numrows, - None, - residual_in_fp32, - is_rms_norm, - ) - # dmask is None if dropout_p == 0.0 - # xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype - return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma - - -def _dropout_add_layer_norm_subset_backward( - dz, - dx, - x, - x0, - dmask, - mu, - rsigma, - gamma, - colscale, - x0_subset, - out_subset, - dropout_p, - rowscale_const, - x0_numrows, - has_residual, - is_rms_norm=False, -): - """Assume that arguments are contiguous and aligned to 16 bytes - dx == None means that it was a post-norm architecture - (x = drop(x0) + residual was not returned in the fwd). - x0 must not be None if we have colscale. +def dropout_add_layer_norm( + x0: torch.Tensor, + residual: Optional[torch.Tensor], + weight: torch.Tensor, + bias: Optional[torch.Tensor], + dropout_p: float, + epsilon: float, + rowscale: Optional[torch.Tensor] = None, + prenorm: bool = False, + residual_in_fp32: bool = False, + is_rms_norm: bool = False, + return_dropout_mask: bool = False, +) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """ - hidden_size = gamma.numel() - xmat = x.view((-1, hidden_size)) - dzmat = dz.view(-1, hidden_size) - dxmat = dx.view(xmat.shape) if dx is not None else None - x0mat = x0.view((-1, hidden_size)) if x0 is not None else None - x0_subset = x0_subset.view(-1) if x0_subset is not None else None - out_subset = out_subset.view(-1) if out_subset is not None else None - if colscale is not None: - assert x0 is not None, "x0 is required to compute the gradient of colscale" - dx0mat, dresidualmat, dgamma, dbeta, _, _, *rest = ( - dropout_layer_norm.dropout_add_ln_bwd( - dzmat, - dxmat, - xmat, - x0mat, - dmask, - mu, - rsigma, - gamma, - None, - colscale, - x0_subset, - out_subset, - dropout_p, - rowscale_const, - x0_numrows, - has_residual, - is_rms_norm, - ) - ) - # dresidualmat is None if not has_residual - if colscale is None: - return dx0mat, dresidualmat, dgamma, dbeta - else: - dcolscale = rest[0] - return dx0mat, dresidualmat, dgamma, dbeta, dcolscale - - -def _dropout_add_layer_norm_parallel_residual_forward( - x0, - x1, - residual, - gamma0, - beta0, - gamma1, - beta1, - dropout_p, - epsilon, - residual_in_fp32=False, - is_rms_norm=False, -): - """Assume that arguments are contiguous and aligned to 16 bytes""" - hidden_size = gamma0.numel() - x0mat = x0.view((-1, hidden_size)) - x1mat = x1.view((-1, hidden_size)) if x1 is not None else None - residualmat = residual.view((-1, hidden_size)) if residual is not None else None - ( - z0mat, - z1mat, - xmat, - dmask0, - dmask1, - mu, - rsigma, - ) = dropout_layer_norm.dropout_add_ln_parallel_residual_fwd( - x0mat, - x1mat, - residualmat, - gamma0, - beta0, - gamma1, - beta1, - dropout_p, - epsilon, - None, - residual_in_fp32, - is_rms_norm, - ) - # dmask0 and dmask1 are None if dropout_p == 0.0 - # xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype - return z0mat, z1mat, xmat if xmat is not None else x0mat, dmask0, dmask1, mu, rsigma - - -def _dropout_add_layer_norm_parallel_residual_backward( - dz0, - dz1, - dx, - x, - dmask0, - dmask1, - mu, - rsigma, - gamma0, - gamma1, - dropout_p, - has_x1, - has_residual, - is_rms_norm=False, -): - """Assume that arguments are contiguous and aligned to 16 bytes - dx == None means that it was a post-norm architecture - (x = drop(x0) + residual was not returned in the fwd). + Fused dropout + residual add + layer norm implementation. + + Args: + x0: Input tensor + residual: Optional residual tensor to add + weight: Layer norm weight parameter + bias: Optional layer norm bias parameter + dropout_p: Dropout probability + epsilon: Small constant for numerical stability + rowscale: Optional row-wise scaling factor + prenorm: Whether to return pre-normalization results + residual_in_fp32: Whether to cast residual to fp32 during addition + is_rms_norm: Whether to use RMS normalization instead of layer norm + return_dropout_mask: Whether to return the dropout mask """ - hidden_size = gamma0.numel() - xmat = x.view((-1, hidden_size)) - dz0mat = dz0.view(xmat.shape) - dz1mat = dz1.view(xmat.shape) if dz1 is not None else None - dxmat = dx.view(xmat.shape) if dx is not None else None - ( - dx0mat, - dx1mat, - dresidualmat, - dgamma0, - dbeta0, - dgamma1, - dbeta1, - *rest, - ) = dropout_layer_norm.dropout_add_ln_parallel_residual_bwd( - dz0mat, - dz1mat, - dxmat, - xmat, - dmask0, - dmask1, - mu, - rsigma, - gamma0, - gamma1, - dropout_p, - has_x1, - has_residual, - is_rms_norm, - ) - # dresidualmat is None if not has_residual - return dx0mat, dx1mat, dresidualmat, dgamma0, dbeta0, dgamma1, dbeta1 - - -class DropoutAddLayerNormFn(torch.autograd.Function): - @staticmethod - def forward( - ctx, - x0, - residual, - gamma, - beta, - rowscale, - colscale, - dropout_p, - epsilon, - residual_in_fp32=False, - prenorm=False, - is_rms_norm=False, - return_dmask=False, - ): - x0 = maybe_align(x0.contiguous(), 16) - residual = ( - maybe_align(residual.contiguous(), 16) if residual is not None else None - ) - gamma = maybe_align(gamma.contiguous(), 16) - beta = maybe_align(beta.contiguous(), 16) if beta is not None else None - rowscale = ( - maybe_align(rowscale.contiguous(), 16) if rowscale is not None else None - ) - colscale = ( - maybe_align(colscale.contiguous(), 16) if colscale is not None else None - ) - zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_forward( - x0, - residual, - gamma, - beta, - rowscale, - colscale, - dropout_p, - epsilon, - residual_in_fp32, - is_rms_norm, - ) - # Only need to save x0 if we need to compute gradient wrt colscale - x0_saved = x0 if colscale is not None else None - ctx.save_for_backward( - xmat.view(x0.shape), x0_saved, dmask, gamma, mu, rsigma, rowscale, colscale - ) - ctx.prenorm = prenorm - ctx.dropout_p = dropout_p - ctx.has_residual = residual is not None - ctx.is_rms_norm = is_rms_norm - ctx.has_beta = beta is not None - if not return_dmask: - return ( - zmat.view(x0.shape) - if not prenorm - else (zmat.view(x0.shape), xmat.view(x0.shape)) - ) - else: - dmask = ( - dmask.view(x0.shape) - if dropout_p > 0.0 - else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device) - ) - ctx.mark_non_differentiable(dmask) - return ( - (zmat.view(x0.shape), dmask) - if not prenorm - else (zmat.view(x0.shape), xmat.view(x0.shape), dmask) - ) - - @staticmethod - def backward(ctx, dz, *args): - # assert dz.is_contiguous() - dz = maybe_align(dz.contiguous(), 16) # this happens! - dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None - x, x0, dmask, gamma, mu, rsigma, rowscale, colscale = ctx.saved_tensors - # x0 is None if colscale is None - dropout_p = ctx.dropout_p - has_residual = ctx.has_residual - dx0mat, dresidualmat, dgamma, dbeta, *rest = _dropout_add_layer_norm_backward( - dz, - dx, - x, - x0, - dmask, - mu, - rsigma, - gamma, - rowscale, - colscale, - dropout_p, - has_residual, - ctx.is_rms_norm, - ) - dx0 = dx0mat.view(x.shape) - dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None - dcolscale = rest[0] if colscale is not None else None - return ( - dx0, - dresidual, - dgamma, - dbeta if ctx.has_beta else None, - None, - dcolscale, - None, - None, - None, - None, - None, - None, - ) - - -class DropoutAddLayerNormSubsetFn(torch.autograd.Function): - @staticmethod - def forward( - ctx, - x0, - residual, - gamma, - beta, - colscale, - x0_subset, - out_subset, - dropout_p, - epsilon, - rowscale_const, - out_numrows, - residual_in_fp32=False, - prenorm=False, - is_rms_norm=False, - return_dmask=False, - ): - x0 = maybe_align(x0.contiguous(), 16) - residual = ( - maybe_align(residual.contiguous(), 16) if residual is not None else None - ) - gamma = maybe_align(gamma.contiguous(), 16) - beta = maybe_align(beta.contiguous(), 16) if beta is not None else None - colscale = ( - maybe_align(colscale.contiguous(), 16) if colscale is not None else None - ) - zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_subset_forward( - x0, - residual, - gamma, - beta, - colscale, - x0_subset, - out_subset, - dropout_p, - epsilon, - rowscale_const, - out_numrows, - residual_in_fp32, - is_rms_norm, - ) - # Only need to save x0 if we need to compute gradient wrt colscale - x0_saved = x0 if colscale is not None else None - x_shape = (-1, *x0.shape[1:]) - ctx.save_for_backward( - xmat.view(x_shape), - x0_saved, - dmask, - gamma, - mu, - rsigma, - colscale, - x0_subset, - out_subset, - ) - ctx.prenorm = prenorm - ctx.dropout_p = dropout_p - ctx.rowscale_const = rowscale_const - ctx.x0_numrows = x0.shape[:-1].numel() - ctx.has_residual = residual is not None - ctx.is_rms_norm = is_rms_norm - ctx.has_beta = beta is not None - z_shape = (-1, *x0.shape[1:]) - if not return_dmask: - return ( - zmat.view(z_shape) - if not prenorm - else (zmat.view(z_shape), xmat.view(x0.shape)) - ) + # Initialize mask + mask = None + + # Apply dropout + if dropout_p > 0.0: + mask = torch.bernoulli(torch.full_like(x0, 1 - dropout_p)) + x0 = x0 * mask / (1 - dropout_p) + + # Add residual if provided + if residual is not None: + if residual_in_fp32: + x0 = x0 + residual.float().to(x0.dtype) else: - z = zmat.view(z_shape) - dmask = ( - dmask.view(x0.shape) - if dropout_p > 0.0 - else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device) - ) - ctx.mark_non_differentiable(dmask) - return (z, dmask) if not prenorm else (z, xmat.view(x_shape), dmask) + x0 = x0 + residual - @staticmethod - def backward(ctx, dz, *args): - # assert dz.is_contiguous() - dz = maybe_align(dz.contiguous(), 16) # this happens! - dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None - x, x0, dmask, gamma, mu, rsigma, colscale, x0_subset, out_subset = ( - ctx.saved_tensors - ) - # x0 is None if colscale is None - dropout_p = ctx.dropout_p - has_residual = ctx.has_residual - dx0mat, dresidualmat, dgamma, dbeta, *rest = ( - _dropout_add_layer_norm_subset_backward( - dz, - dx, - x, - x0, - dmask, - mu, - rsigma, - gamma, - colscale, - x0_subset, - out_subset, - dropout_p, - ctx.rowscale_const, - ctx.x0_numrows, - has_residual, - ctx.is_rms_norm, - ) - ) - dx0 = dx0mat.view(-1, *x.shape[1:]) - dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None - dcolscale = rest[0] if colscale is not None else None - return ( - dx0, - dresidual, - dgamma, - dbeta if ctx.has_beta else None, - dcolscale, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - ) - - -class DropoutAddLayerNormParallelResidualFn(torch.autograd.Function): - @staticmethod - def forward( - ctx, - x0, - x1, - residual, - gamma0, - beta0, - gamma1, - beta1, - dropout_p, - epsilon, - residual_in_fp32=False, - prenorm=False, - is_rms_norm=False, - return_dmask=False, - ): - x0 = maybe_align(x0.contiguous(), 16) - x1 = maybe_align(x1.contiguous(), 16) if x1 is not None else None - residual = ( - maybe_align(residual.contiguous(), 16) if residual is not None else None - ) - gamma0 = maybe_align(gamma0.contiguous(), 16) - beta0 = maybe_align(beta0.contiguous(), 16) if beta0 is not None else None - gamma1 = maybe_align(gamma1.contiguous(), 16) if gamma1 is not None else None - beta1 = maybe_align(beta1.contiguous(), 16) if beta1 is not None else None - ( - z0mat, - z1mat, - xmat, - dmask0, - dmask1, - mu, - rsigma, - ) = _dropout_add_layer_norm_parallel_residual_forward( - x0, - x1, - residual, - gamma0, - beta0, - gamma1, - beta1, - dropout_p, - epsilon, - residual_in_fp32, - is_rms_norm, - ) - ctx.save_for_backward( - xmat.view(x0.shape), dmask0, dmask1, gamma0, gamma1, mu, rsigma - ) - ctx.prenorm = prenorm - ctx.dropout_p = dropout_p - ctx.has_x1 = x1 is not None - ctx.has_residual = residual is not None - ctx.is_rms_norm = is_rms_norm - ctx.has_beta = beta0 is not None - z = (z0mat.view(x0.shape), z1mat.view(x0.shape) if z1mat is not None else None) - if not return_dmask: - return z if not prenorm else (*z, xmat.view(x0.shape)) - else: - dmask0 = ( - dmask0.view(x0.shape) - if dropout_p > 0.0 - else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device) - ) - dmask1 = ( - dmask1.view(x0.shape) - if dropout_p > 0.0 and x1 is not None - else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device) - ) - ctx.mark_non_differentiable(dmask0) - ctx.mark_non_differentiable(dmask1) - return ( - (*z, dmask0, dmask1) - if not prenorm - else (*z, xmat.view(x0.shape), dmask0, dmask1) - ) - - @staticmethod - def backward(ctx, dz0, dz1, *args): - dz0 = maybe_align(dz0.contiguous(), 16) # this happens! - dz1 = maybe_align(dz1.contiguous(), 16) if dz1 is not None else None - dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None - x, dmask0, dmask1, gamma0, gamma1, mu, rsigma = ctx.saved_tensors - dropout_p = ctx.dropout_p - has_x1 = ctx.has_x1 - has_residual = ctx.has_residual - ( - dx0mat, - dx1mat, - dresidualmat, - dgamma0, - dbeta0, - dgamma1, - dbeta1, - ) = _dropout_add_layer_norm_parallel_residual_backward( - dz0, - dz1, - dx, - x, - dmask0, - dmask1, - mu, - rsigma, - gamma0, - gamma1, - dropout_p, - has_x1, - has_residual, - ctx.is_rms_norm, - ) - dx0 = dx0mat.view(x.shape) - dx1 = dx1mat.view(x.shape) if dx1mat is not None else None - dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None - return ( - dx0, - dx1, - dresidual, - dgamma0, - dbeta0 if ctx.has_beta else None, - dgamma1, - dbeta1 if ctx.has_beta else None, - None, - None, - None, - None, - None, - None, - ) - - -def layer_norm(x, weight, bias, epsilon): - return DropoutAddLayerNormFn.apply( - x, None, weight, bias, None, None, 0.0, epsilon, False - ) + # Apply row scaling if provided + if rowscale is not None: + x0 = x0 * rearrange(rowscale, 'b -> b 1') + # Compute normalization (either LayerNorm or RMSNorm) + if is_rms_norm: + norm_x = torch.mean(x0 * x0, dim=-1, keepdim=True) + x_normed = x0 * torch.rsqrt(norm_x + epsilon) + else: + mean = x0.mean(dim=-1, keepdim=True) + var = x0.var(dim=-1, unbiased=False, keepdim=True) + x_normed = (x0 - mean) / torch.sqrt(var + epsilon) -def dropout_add_layer_norm( - x0, - residual, - weight, - bias, - dropout_p, - epsilon, - rowscale=None, - layerscale=None, - prenorm=False, - residual_in_fp32=False, - return_dropout_mask=False, -): - """residual_in_fp32 only has an effect if residual is None. - Otherwise residual dtype is residual.dtype. - """ - return DropoutAddLayerNormFn.apply( - x0, - residual, - weight, - bias, - rowscale, - layerscale, - dropout_p, - epsilon, - residual_in_fp32, - prenorm, - False, - return_dropout_mask, - ) + # Apply weight and optional bias + output = x_normed * weight + (bias if bias is not None else 0.0) + if return_dropout_mask: + if mask is None: + mask = torch.ones_like(x0, dtype=torch.uint8) + return output, mask + return output -def dropout_add_layer_norm_subset( - x0, - residual, - weight, - bias, - dropout_p, - epsilon, - layerscale=None, - x0_subset=None, - out_subset=None, - rowscale_const=1.0, - out_numrows=0, - prenorm=False, - residual_in_fp32=False, - return_dropout_mask=False, -): - """residual_in_fp32 only has an effect if residual is None. - Otherwise residual dtype is residual.dtype. +class DropoutAddLayerNorm(nn.Module): """ - return DropoutAddLayerNormSubsetFn.apply( - x0, - residual, - weight, - bias, - layerscale, - x0_subset, - out_subset, - dropout_p, - epsilon, - rowscale_const, - out_numrows, - residual_in_fp32, - prenorm, - False, - return_dropout_mask, - ) - - -def dropout_add_layer_norm_parallel_residual( - x0, - x1, - residual, - weight0, - bias0, - weight1, - bias1, - dropout_p, - epsilon, - prenorm=False, - residual_in_fp32=False, - return_dropout_mask=False, -): - """residual_in_fp32 only has an effect if residual is None. - Otherwise residual dtype is residual.dtype. + Module that combines dropout, residual connection, and layer normalization. """ - return DropoutAddLayerNormParallelResidualFn.apply( - x0, - x1, - residual, - weight0, - bias0, - weight1, - bias1, - dropout_p, - epsilon, - residual_in_fp32, - prenorm, - False, - return_dropout_mask, - ) - - -class DropoutAddLayerNorm(torch.nn.Module): def __init__( self, - hidden_size, - prenorm=False, - p=0.0, - eps=1e-5, - residual_in_fp32=False, - device=None, - dtype=None, + hidden_size: int, + prenorm: bool = False, + p: float = 0.0, + eps: float = 1e-5, + residual_in_fp32: bool = False, + device = None, + dtype = None, ): - factory_kwargs = {"device": device, "dtype": dtype} + factory_kwargs = {'device': device, 'dtype': dtype} super().__init__() self.prenorm = prenorm self.p = p self.eps = eps self.residual_in_fp32 = residual_in_fp32 - self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) - self.bias = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) - self.reset_parameters() - - def reset_parameters(self): - init.ones_(self.weight) - init.zeros_(self.bias) + self.weight = nn.Parameter(torch.ones(hidden_size, **factory_kwargs)) + self.bias = nn.Parameter(torch.zeros(hidden_size, **factory_kwargs)) - def forward(self, x0, residual=None): + def forward( + self, + x0: torch.Tensor, + residual: Optional[torch.Tensor] = None, + rowscale: Optional[torch.Tensor] = None + ) -> torch.Tensor: return dropout_add_layer_norm( x0, residual, @@ -838,50 +110,49 @@ def forward(self, x0, residual=None): self.bias, self.p if self.training else 0.0, self.eps, + rowscale=rowscale, prenorm=self.prenorm, - residual_in_fp32=self.residual_in_fp32, + residual_in_fp32=self.residual_in_fp32 ) - -def rms_norm(x, weight, epsilon): - return DropoutAddLayerNormFn.apply( - x, None, weight, None, None, None, 0.0, epsilon, False, False, True - ) - - -class FusedRMSNorm(torch.nn.Module): - def __init__(self, size: int, dim: int = -1, eps: float = 1e-5): - super().__init__() - self.eps = eps - self.weight = torch.nn.Parameter(torch.ones(size)) - self.dim = dim - self.reset_parameters() - def reset_parameters(self): - init.ones_(self.weight) + """Reset parameters to default initialization.""" + nn.init.ones_(self.weight) + nn.init.zeros_(self.bias) - def forward(self, x): - return rms_norm(x, self.weight, self.eps) - - -class RMSNorm(torch.nn.Module): - """Root Mean Square Layer Normalization. - - Derived from https://github.com/bzhangGo/rmsnorm/blob/master/rmsnorm_torch.py. BSD 3-Clause License: - https://github.com/bzhangGo/rmsnorm/blob/master/LICENSE. +class RMSNorm(nn.Module): """ + Root Mean Square Layer Normalization. - def __init__(self, size: int, dim: int = -1, eps: float = 1e-5) -> None: + Implementation follows the paper: https://arxiv.org/abs/1910.07467 + """ + def __init__( + self, + hidden_size: int, + eps: float = 1e-5, + device = None, + dtype = None + ): + factory_kwargs = {'device': device, 'dtype': dtype} super().__init__() - self.weight = torch.nn.Parameter(torch.ones(size)) + self.weight = nn.Parameter(torch.ones(hidden_size, **factory_kwargs)) self.eps = eps - self.dim = dim def forward(self, x: torch.Tensor) -> torch.Tensor: - # NOTE: the original RMSNorm paper implementation is not equivalent - norm_x = torch.mean(x * x, dim=self.dim, keepdim=True) - x_normed = x * torch.rsqrt(norm_x + self.eps) - return self.weight * x_normed + return rms_norm(x, self.weight, self.eps) def reset_parameters(self): - torch.nn.init.ones_(self.weight) + """Reset parameters to default initialization.""" + nn.init.ones_(self.weight) + +def rms_norm( + x: torch.Tensor, + weight: torch.Tensor, + epsilon: float +) -> torch.Tensor: + """ + Applies RMS normalization to the input tensor. + """ + norm_x = torch.mean(x * x, dim=-1, keepdim=True) + x_normed = x * torch.rsqrt(norm_x + epsilon) + return x_normed * weight \ No newline at end of file