-
Notifications
You must be signed in to change notification settings - Fork 0
Splash attention #67
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Splash attention #67
Changes from all commits
46e68cf
6fc5bf0
fccc0ee
9f77aff
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,122 @@ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Splash Attention integration for TPU-optimized attention.""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import importlib.util | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from dataclasses import dataclass | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from typing import Callable, Literal | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import jax | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from jaxtyping import Array, Float | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| _TOKAMAX_AVAILABLE = importlib.util.find_spec("tokamax") is not None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if _TOKAMAX_AVAILABLE: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from tokamax._src.ops.experimental.tpu.splash_attention import ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| splash_attention_kernel as splash, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from tokamax._src.ops.experimental.tpu.splash_attention import ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| splash_attention_mask as mask_lib, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| @dataclass | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| class SplashAttentionConfig: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Configuration for splash attention. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Attributes: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| enabled (bool): Whether to enable splash attention. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| mask_type (Literal["full", "causal"]): Type of attention mask. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| block_q (int): Block size for query sequence tiling. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| block_kv (int): Block size for key/value sequence tiling. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| enabled: bool = False | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| mask_type: Literal["full", "causal"] = "full" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| block_q: int = 128 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| block_kv: int = 128 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| _kernel_cache: dict[tuple[int, int, int, str, int, int], Callable] = {} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def _create_splash_kernel( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| seq_len: int, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| num_heads: int, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| head_dim: int, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| config: SplashAttentionConfig, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) -> Callable[ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| [Float[Array, "heads seq head_dim"], Float[Array, "heads seq head_dim"], Float[Array, "heads seq head_dim"]], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Float[Array, "heads seq head_dim"], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ]: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Create a cached splash attention kernel. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Args: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| seq_len (int): Sequence length. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| num_heads (int): Number of attention heads. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| head_dim (int): Dimension of each attention head. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| config (SplashAttentionConfig): Splash attention configuration. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Returns: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Callable: A splash attention kernel function. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| cache_key = (seq_len, num_heads, head_dim, config.mask_type, config.block_q, config.block_kv) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if cache_key in _kernel_cache: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return _kernel_cache[cache_key] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| mask_shape = (seq_len, seq_len) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| mask = mask_lib.CausalMask(mask_shape) if config.mask_type == "causal" else mask_lib.FullMask(mask_shape) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| splash_config = splash.SplashConfig( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| block_q=config.block_q, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| block_kv=config.block_kv, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| block_kv_compute=config.block_kv, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| block_q_dkv=config.block_q, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| block_kv_dkv=config.block_kv, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| block_kv_dkv_compute=config.block_kv, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| kernel = splash.make_splash_mha_single_device(mask=mask, config=splash_config) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| _kernel_cache[cache_key] = kernel | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return kernel | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def create_splash_attention_fn( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| config: SplashAttentionConfig, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| num_heads: int, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| head_dim: int, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) -> Callable[..., Float[Array, "batch heads seq head_dim"]]: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Create a splash attention function compatible with nnx.MultiHeadAttention. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Args: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| config (SplashAttentionConfig): Splash attention configuration. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| num_heads (int): Number of attention heads. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| head_dim (int): Dimension of each attention head. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Returns: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Callable: An attention function. Returns splash attention if enabled and available, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| otherwise returns the default dot_product_attention. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if not _TOKAMAX_AVAILABLE or not config.enabled: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from flax.nnx.nn.attention import dot_product_attention | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return dot_product_attention | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+86
to
+101
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The function I suggest changing the implementation to return
Suggested change
Comment on lines
+98
to
+101
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The function should return
Suggested change
Comment on lines
+98
to
+101
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The import of
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def splash_attention_fn( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| query: Float[Array, "batch heads seq head_dim"], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| key: Float[Array, "batch heads seq head_dim"], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| value: Float[Array, "batch heads seq head_dim"], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) -> Float[Array, "batch heads seq head_dim"]: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Splash attention function. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Args: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| query (Float[Array, "batch heads seq head_dim"]): Query tensor. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| key (Float[Array, "batch heads seq head_dim"]): Key tensor. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| value (Float[Array, "batch heads seq head_dim"]): Value tensor. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Returns: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Float[Array, "batch heads seq head_dim"]: Output tensor. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| seq_len = query.shape[2] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| kernel = _create_splash_kernel(seq_len, num_heads, head_dim, config) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return jax.vmap(kernel)(query, key, value) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return splash_attention_fn | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -6,6 +6,7 @@ | |
| from jax.typing import DTypeLike | ||
| from jaxtyping import Array, Float | ||
|
|
||
| from jimm.common.splash_attention import SplashAttentionConfig, create_splash_attention_fn | ||
| from jimm.common.utils import DEFAULT_SHARDING, MeshRules | ||
|
|
||
|
|
||
|
|
@@ -39,6 +40,7 @@ def __init__( | |
| attn_mask: Float[Array, "seq seq"] | None = None, | ||
| use_quick_gelu: bool = False, | ||
| use_gradient_checkpointing: bool = False, | ||
| splash_attention_config: SplashAttentionConfig | None = None, | ||
| rngs: rnglib.Rngs | None = None, | ||
| dtype: DTypeLike = jnp.float32, | ||
| param_dtype: DTypeLike = jnp.float32, | ||
|
|
@@ -56,6 +58,7 @@ def __init__( | |
| attn_mask (Float[Array, "seq seq"] | None, optional): Optional attention mask. Defaults to None. | ||
| use_quick_gelu (bool, optional): Whether to use quickgelu instead of gelu. Defaults to False. | ||
| use_gradient_checkpointing (bool, optional): Whether to use gradient checkpointing. Defaults to False. | ||
| splash_attention_config (SplashAttentionConfig | None, optional): Configuration for TPU splash attention. Defaults to None. | ||
| rngs (rnglib.Rngs | None, optional): Random number generator keys. If None, initializes to nnx.Rngs(0). | ||
| dtype (DTypeLike, optional): Data type for computations. Defaults to jnp.float32. | ||
| param_dtype (DTypeLike, optional): Data type for parameters. Defaults to jnp.float32. | ||
|
|
@@ -66,6 +69,16 @@ def __init__( | |
| rngs = nnx.Rngs(0) | ||
| self.attn_mask = attn_mask | ||
| self.use_gradient_checkpointing = use_gradient_checkpointing | ||
|
|
||
| attention_fn = ( | ||
| create_splash_attention_fn( | ||
| splash_attention_config, | ||
| num_heads=num_heads, | ||
| head_dim=hidden_size // num_heads, | ||
| ) | ||
| if splash_attention_config is not None | ||
| else nnx.dot_product_attention | ||
| ) | ||
|
Comment on lines
+73
to
+81
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The logic for determining This change depends on attention_fn = None
if splash_attention_config is not None:
attention_fn = create_splash_attention_fn(
splash_attention_config,
num_heads=num_heads,
head_dim=hidden_size // num_heads,
) |
||
| self.norm1 = nnx.LayerNorm( | ||
| hidden_size, | ||
| epsilon=layernorm_epsilon, | ||
|
|
@@ -102,6 +115,7 @@ def __init__( | |
| "qkv_out", | ||
| ), | ||
| ), | ||
| attention_fn=attention_fn, | ||
| ) | ||
| self.norm2 = nnx.LayerNorm( | ||
| hidden_size, | ||
|
|
@@ -198,6 +212,7 @@ def __init__( | |
| attn_mask: Float[Array, "seq seq"] | None = None, | ||
| use_quick_gelu: bool = False, | ||
| use_gradient_checkpointing: bool = False, | ||
| splash_attention_config: SplashAttentionConfig | None = None, | ||
| rngs: rnglib.Rngs | None = None, | ||
| dtype: DTypeLike = jnp.float32, | ||
| param_dtype: DTypeLike = jnp.float32, | ||
|
|
@@ -216,6 +231,7 @@ def __init__( | |
| attn_mask (Float[Array, "seq seq"] | None, optional): Optional attention mask. Defaults to None. | ||
| use_quick_gelu (bool, optional): Whether to use quickgelu instead of gelu. Defaults to False. | ||
| use_gradient_checkpointing (bool, optional): Whether to use gradient checkpointing. Defaults to False. | ||
| splash_attention_config (SplashAttentionConfig | None, optional): Configuration for TPU splash attention. Defaults to None. | ||
| rngs (rnglib.Rngs | None, optional): Random number generator keys. If None, initializes to nnx.Rngs(0). | ||
| dtype (DTypeLike, optional): The data type for computations. Defaults to jnp.float32. | ||
| param_dtype (DTypeLike, optional): The data type for parameters. Defaults to jnp.float32. | ||
|
|
@@ -240,6 +256,7 @@ def __init__( | |
| attn_mask=attn_mask, | ||
| use_quick_gelu=use_quick_gelu, | ||
| use_gradient_checkpointing=use_gradient_checkpointing, | ||
| splash_attention_config=splash_attention_config, | ||
| dtype=dtype, | ||
| param_dtype=param_dtype, | ||
| rngs=rngs, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The return type hint and docstring for
create_splash_attention_fnshould be updated to reflect that it can returnNonewhen splash attention is disabled. This makes the function signature more accurate and consistent with the proposed implementation change.