Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,26 @@ or if you prefer to not add as a direct dependency:
`uv pip install git+https://github.com/pythoncrazy/jimm.git`
### Using pip/conda
`pip install git+https://github.com/pythoncrazy/jimm.git`

## TPU Splash Attention (Experimental)

JIMM supports TPU-optimized splash attention via the tokamax library.

### Usage
```python
from flax import nnx
from jimm import CLIP, SplashAttentionConfig

splash_config = SplashAttentionConfig(
enabled=True,
mask_type="full", # "full" for vision, "causal" for text
)

model = CLIP.from_pretrained(
"openai/clip-vit-large-patch14",
splash_attention_config=splash_config,
rngs=nnx.Rngs(0),
)
```

**Note**: When `enabled=True` and `tokamax` is installed, splash attention will be used regardless of device type. There is currently no automatic fallback based on hardware; ensure you only enable splash attention on supported devices (e.g., TPU) to avoid runtime errors or performance issues.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@ name = "jimm"
version = "0.1.0"
description = "Jax Image Modeling of Models"
readme = "README.md"
requires-python = ">=3.10"
requires-python = ">=3.11"
dependencies = [
"flax>=0.10.6",
"jax>=0.6.2",
"jaxtyping>=0.3.2",
"safetensors>=0.5.3",
"tokamax>=0.0.9",
]

[[tool.uv.index]]
Expand Down
2 changes: 2 additions & 0 deletions src/jimm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .common.splash_attention import SplashAttentionConfig
from .models import (
CLIP,
CLIPTextModel,
Expand All @@ -16,4 +17,5 @@
"SigLIP",
"SigLIPTextModel",
"SigLIPVisionModel",
"SplashAttentionConfig",
]
122 changes: 122 additions & 0 deletions src/jimm/common/splash_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
"""Splash Attention integration for TPU-optimized attention."""

import importlib.util
from dataclasses import dataclass
from typing import Callable, Literal

import jax
from jaxtyping import Array, Float

_TOKAMAX_AVAILABLE = importlib.util.find_spec("tokamax") is not None

if _TOKAMAX_AVAILABLE:
from tokamax._src.ops.experimental.tpu.splash_attention import (
splash_attention_kernel as splash,
)
from tokamax._src.ops.experimental.tpu.splash_attention import (
splash_attention_mask as mask_lib,
)


@dataclass
class SplashAttentionConfig:
"""Configuration for splash attention.

Attributes:
enabled (bool): Whether to enable splash attention.
mask_type (Literal["full", "causal"]): Type of attention mask.
block_q (int): Block size for query sequence tiling.
block_kv (int): Block size for key/value sequence tiling.
"""

enabled: bool = False
mask_type: Literal["full", "causal"] = "full"
block_q: int = 128
block_kv: int = 128


_kernel_cache: dict[tuple[int, int, int, str, int, int], Callable] = {}


def _create_splash_kernel(
seq_len: int,
num_heads: int,
head_dim: int,
config: SplashAttentionConfig,
) -> Callable[
[Float[Array, "heads seq head_dim"], Float[Array, "heads seq head_dim"], Float[Array, "heads seq head_dim"]],
Float[Array, "heads seq head_dim"],
]:
"""Create a cached splash attention kernel.

Args:
seq_len (int): Sequence length.
num_heads (int): Number of attention heads.
head_dim (int): Dimension of each attention head.
config (SplashAttentionConfig): Splash attention configuration.

Returns:
Callable: A splash attention kernel function.
"""
cache_key = (seq_len, num_heads, head_dim, config.mask_type, config.block_q, config.block_kv)
if cache_key in _kernel_cache:
return _kernel_cache[cache_key]

mask_shape = (seq_len, seq_len)
mask = mask_lib.CausalMask(mask_shape) if config.mask_type == "causal" else mask_lib.FullMask(mask_shape)

splash_config = splash.SplashConfig(
block_q=config.block_q,
block_kv=config.block_kv,
block_kv_compute=config.block_kv,
block_q_dkv=config.block_q,
block_kv_dkv=config.block_kv,
block_kv_dkv_compute=config.block_kv,
)

kernel = splash.make_splash_mha_single_device(mask=mask, config=splash_config)
_kernel_cache[cache_key] = kernel
return kernel


def create_splash_attention_fn(
config: SplashAttentionConfig,
num_heads: int,
head_dim: int,
) -> Callable[..., Float[Array, "batch heads seq head_dim"]]:
"""Create a splash attention function compatible with nnx.MultiHeadAttention.

Args:
config (SplashAttentionConfig): Splash attention configuration.
num_heads (int): Number of attention heads.
head_dim (int): Dimension of each attention head.

Returns:
Callable: An attention function. Returns splash attention if enabled and available,
otherwise returns the default dot_product_attention.
"""
Comment on lines +86 to +97
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The return type hint and docstring for create_splash_attention_fn should be updated to reflect that it can return None when splash attention is disabled. This makes the function signature more accurate and consistent with the proposed implementation change.

Suggested change
) -> Callable[..., Float[Array, "batch heads seq head_dim"]]:
"""Create a splash attention function compatible with nnx.MultiHeadAttention.
Args:
config (SplashAttentionConfig): Splash attention configuration.
num_heads (int): Number of attention heads.
head_dim (int): Dimension of each attention head.
Returns:
Callable: An attention function. Returns splash attention if enabled and available,
otherwise returns the default dot_product_attention.
"""
) -> Callable[..., Float[Array, "batch heads seq head_dim"]] | None:
"""Create a splash attention function compatible with nnx.MultiHeadAttention.
Args:
config (SplashAttentionConfig): Splash attention configuration.
num_heads (int): Number of attention heads.
head_dim (int): Dimension of each attention head.
Returns:
Callable | None: An attention function if splash attention is enabled and available,
otherwise `None`.
"""

if not _TOKAMAX_AVAILABLE or not config.enabled:
from flax.nnx.nn.attention import dot_product_attention

return dot_product_attention
Comment on lines +86 to +101
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The function create_splash_attention_fn currently returns dot_product_attention when Splash Attention is disabled or unavailable. This is inconsistent with the new test test_create_fn_returns_none_when_disabled in tests/test_splash_attention.py, which expects None. Returning None would make the function's behavior more explicit, align with the test's expectation, and simplify the logic at call sites.

I suggest changing the implementation to return None when splash attention is not used. You'll also need to update the function's return type hint and docstring. This change will make the test test_create_fn_returns_none_when_disabled pass as written.

Suggested change
) -> Callable[..., Float[Array, "batch heads seq head_dim"]]:
"""Create a splash attention function compatible with nnx.MultiHeadAttention.
Args:
config (SplashAttentionConfig): Splash attention configuration.
num_heads (int): Number of attention heads.
head_dim (int): Dimension of each attention head.
Returns:
Callable: An attention function. Returns splash attention if enabled and available,
otherwise returns the default dot_product_attention.
"""
if not _TOKAMAX_AVAILABLE or not config.enabled:
from flax.nnx.nn.attention import dot_product_attention
return dot_product_attention
) -> Callable[..., Float[Array, "batch heads seq head_dim"]] | None:
"""Create a splash attention function compatible with nnx.MultiHeadAttention.
Args:
config (SplashAttentionConfig): Splash attention configuration.
num_heads (int): Number of attention heads.
head_dim (int): Dimension of each attention head.
Returns:
Callable | None: An attention function. Returns splash attention if enabled and available,
otherwise returns None.
"""
if not _TOKAMAX_AVAILABLE or not config.enabled:
return None

Comment on lines +98 to +101
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The function should return None when Splash Attention is disabled or unavailable. This allows nnx.MultiHeadAttention to fall back to its default dot_product_attention. The current implementation returns dot_product_attention directly, which is inconsistent with the new test test_create_fn_returns_none_when_disabled and can lead to less clear control flow in the model definitions.

Suggested change
if not _TOKAMAX_AVAILABLE or not config.enabled:
from flax.nnx.nn.attention import dot_product_attention
return dot_product_attention
if not _TOKAMAX_AVAILABLE or not config.enabled:
return None

Comment on lines +98 to +101
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The import of dot_product_attention is performed inside this function. It's generally better practice to place all imports at the top of the file for improved readability and to avoid repeated import operations. Since flax is a core dependency of the project, you can add from flax.nnx.nn.attention import dot_product_attention at the top of the file and simplify this block.

Suggested change
if not _TOKAMAX_AVAILABLE or not config.enabled:
from flax.nnx.nn.attention import dot_product_attention
return dot_product_attention
if not _TOKAMAX_AVAILABLE or not config.enabled:
return dot_product_attention


def splash_attention_fn(
query: Float[Array, "batch heads seq head_dim"],
key: Float[Array, "batch heads seq head_dim"],
value: Float[Array, "batch heads seq head_dim"],
) -> Float[Array, "batch heads seq head_dim"]:
"""Splash attention function.

Args:
query (Float[Array, "batch heads seq head_dim"]): Query tensor.
key (Float[Array, "batch heads seq head_dim"]): Key tensor.
value (Float[Array, "batch heads seq head_dim"]): Value tensor.

Returns:
Float[Array, "batch heads seq head_dim"]: Output tensor.
"""
seq_len = query.shape[2]
kernel = _create_splash_kernel(seq_len, num_heads, head_dim, config)
return jax.vmap(kernel)(query, key, value)

return splash_attention_fn
17 changes: 17 additions & 0 deletions src/jimm/common/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from jax.typing import DTypeLike
from jaxtyping import Array, Float

from jimm.common.splash_attention import SplashAttentionConfig, create_splash_attention_fn
from jimm.common.utils import DEFAULT_SHARDING, MeshRules


Expand Down Expand Up @@ -39,6 +40,7 @@ def __init__(
attn_mask: Float[Array, "seq seq"] | None = None,
use_quick_gelu: bool = False,
use_gradient_checkpointing: bool = False,
splash_attention_config: SplashAttentionConfig | None = None,
rngs: rnglib.Rngs | None = None,
dtype: DTypeLike = jnp.float32,
param_dtype: DTypeLike = jnp.float32,
Expand All @@ -56,6 +58,7 @@ def __init__(
attn_mask (Float[Array, "seq seq"] | None, optional): Optional attention mask. Defaults to None.
use_quick_gelu (bool, optional): Whether to use quickgelu instead of gelu. Defaults to False.
use_gradient_checkpointing (bool, optional): Whether to use gradient checkpointing. Defaults to False.
splash_attention_config (SplashAttentionConfig | None, optional): Configuration for TPU splash attention. Defaults to None.
rngs (rnglib.Rngs | None, optional): Random number generator keys. If None, initializes to nnx.Rngs(0).
dtype (DTypeLike, optional): Data type for computations. Defaults to jnp.float32.
param_dtype (DTypeLike, optional): Data type for parameters. Defaults to jnp.float32.
Expand All @@ -66,6 +69,16 @@ def __init__(
rngs = nnx.Rngs(0)
self.attn_mask = attn_mask
self.use_gradient_checkpointing = use_gradient_checkpointing

attention_fn = (
create_splash_attention_fn(
splash_attention_config,
num_heads=num_heads,
head_dim=hidden_size // num_heads,
)
if splash_attention_config is not None
else nnx.dot_product_attention
)
Comment on lines +73 to +81
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic for determining attention_fn can be simplified and made more consistent with the pattern used in src/jimm/common/vit.py. By initializing attention_fn to None and only creating the splash function if splash_attention_config is provided, the code becomes more explicit and readable. nnx.MultiHeadAttention will correctly use its default when attention_fn is None.

This change depends on create_splash_attention_fn returning None when disabled, as suggested in another comment.

        attention_fn = None
        if splash_attention_config is not None:
            attention_fn = create_splash_attention_fn(
                splash_attention_config,
                num_heads=num_heads,
                head_dim=hidden_size // num_heads,
            )

self.norm1 = nnx.LayerNorm(
hidden_size,
epsilon=layernorm_epsilon,
Expand Down Expand Up @@ -102,6 +115,7 @@ def __init__(
"qkv_out",
),
),
attention_fn=attention_fn,
)
self.norm2 = nnx.LayerNorm(
hidden_size,
Expand Down Expand Up @@ -198,6 +212,7 @@ def __init__(
attn_mask: Float[Array, "seq seq"] | None = None,
use_quick_gelu: bool = False,
use_gradient_checkpointing: bool = False,
splash_attention_config: SplashAttentionConfig | None = None,
rngs: rnglib.Rngs | None = None,
dtype: DTypeLike = jnp.float32,
param_dtype: DTypeLike = jnp.float32,
Expand All @@ -216,6 +231,7 @@ def __init__(
attn_mask (Float[Array, "seq seq"] | None, optional): Optional attention mask. Defaults to None.
use_quick_gelu (bool, optional): Whether to use quickgelu instead of gelu. Defaults to False.
use_gradient_checkpointing (bool, optional): Whether to use gradient checkpointing. Defaults to False.
splash_attention_config (SplashAttentionConfig | None, optional): Configuration for TPU splash attention. Defaults to None.
rngs (rnglib.Rngs | None, optional): Random number generator keys. If None, initializes to nnx.Rngs(0).
dtype (DTypeLike, optional): The data type for computations. Defaults to jnp.float32.
param_dtype (DTypeLike, optional): The data type for parameters. Defaults to jnp.float32.
Expand All @@ -240,6 +256,7 @@ def __init__(
attn_mask=attn_mask,
use_quick_gelu=use_quick_gelu,
use_gradient_checkpointing=use_gradient_checkpointing,
splash_attention_config=splash_attention_config,
dtype=dtype,
param_dtype=param_dtype,
rngs=rngs,
Expand Down
22 changes: 20 additions & 2 deletions src/jimm/common/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from jax.sharding import Mesh
from jaxtyping import Array, DTypeLike, Float

from jimm.common.splash_attention import SplashAttentionConfig, create_splash_attention_fn
from jimm.common.transformer import Transformer
from jimm.common.utils import DEFAULT_SHARDING, MeshRules

Expand All @@ -17,6 +18,7 @@ def __init__(
intermediate_size: int,
num_heads: int,
layernorm_epsilon: float = 1e-6,
splash_attention_config: SplashAttentionConfig | None = None,
rngs: rnglib.Rngs | None = None,
dtype: DTypeLike = jnp.float32,
param_dtype: DTypeLike = jnp.float32,
Expand All @@ -30,6 +32,7 @@ def __init__(
intermediate_size (int): The dimension of the intermediate MLP at the end of the MAP head.
num_heads (int): The number of attention heads.
layernorm_epsilon (float, optional): The epsilon used in the layernorm. Defaults to 1e-6.
splash_attention_config (SplashAttentionConfig | None, optional): Configuration for TPU splash attention. Defaults to None.
rngs (rnglib.Rngs | None, optional): The flax nnx rng to use for initialization. If None, initializes to nnx.Rngs(0).
dtype (DTypeLike, optional): The data type for computations. Defaults to jnp.float32.
param_dtype (DTypeLike, optional): The data type for parameters. Defaults to jnp.float32.
Expand All @@ -41,9 +44,19 @@ def __init__(
probe_value: Float[Array, "1 1 hidden_size"] = nnx.initializers.zeros_init()(rngs.params(), (1, 1, hidden_size))
self.probe = nnx.Param(probe_value, sharding_names=mesh_rules("probe_token_batch", "probe_token_seq", "probe_token_hidden"))

attention_fn = (
create_splash_attention_fn(
splash_attention_config,
num_heads=num_heads,
head_dim=hidden_size // num_heads,
)
if splash_attention_config is not None
else nnx.dot_product_attention
)

self.attn = nnx.MultiHeadAttention(
num_heads,
hidden_size,
num_heads=num_heads,
in_features=hidden_size,
broadcast_dropout=False,
decode=False,
deterministic=False,
Expand All @@ -57,6 +70,7 @@ def __init__(
"map_attn_out",
),
),
attention_fn=attention_fn,
)

self.layernorm = nnx.LayerNorm(
Expand Down Expand Up @@ -147,6 +161,7 @@ def __init__(
use_patch_bias: bool = True,
use_gradient_checkpointing: bool = False,
layernorm_epsilon: float = 1e-5,
splash_attention_config: SplashAttentionConfig | None = None,
rngs: rnglib.Rngs | None = None,
dtype: DTypeLike = jnp.float32,
param_dtype: DTypeLike = jnp.float32,
Expand All @@ -171,6 +186,7 @@ def __init__(
use_patch_bias (bool, optional): Whether to use bias in the patch embedding convolution. Defaults to True.
use_gradient_checkpointing (bool, optional): Whether to use gradient checkpointing. Defaults to False.
layernorm_epsilon (float, optional): Epsilon for LayerNorm. Defaults to 1e-5.
splash_attention_config (SplashAttentionConfig | None, optional): Configuration for TPU splash attention. Defaults to None.
rngs (rnglib.Rngs | None, optional): The random number generator state. If None, initializes to nnx.Rngs(0).
dtype (DTypeLike, optional): The data type for computations. Defaults to jnp.float32.
param_dtype (DTypeLike, optional): The data type for parameters. Defaults to jnp.float32.
Expand Down Expand Up @@ -212,6 +228,7 @@ def __init__(
intermediate_size=4 * hidden_size,
num_heads=num_heads,
layernorm_epsilon=layernorm_epsilon,
splash_attention_config=splash_attention_config,
dtype=dtype,
param_dtype=param_dtype,
rngs=rngs,
Expand Down Expand Up @@ -254,6 +271,7 @@ def __init__(
dropout_rate=dropout_rate,
use_quick_gelu=use_quick_gelu,
use_gradient_checkpointing=use_gradient_checkpointing,
splash_attention_config=splash_attention_config,
rngs=rngs,
dtype=dtype,
param_dtype=param_dtype,
Expand Down
Loading
Loading