<a href="https://colab.research.google.com/github/p-nordmann/pallas-puzzles-2025/blob/main/notebooks/Pallas-Puzzles.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>


# Let's solve some puzzles!

Programming for accelerators such as GPUs is critical for modern AI systems. Pallas is a JAX-integrated [Triton](https://github.com/openai/triton/)-like language. Triton is an alternative open-source language that allows you to code at a higher-level and compile to accelerators like GPU.

<p align="left"><img alt="illustration" src="../assets/triton_load.png" width="50%"/></p>

Coding for Pallas is very similar to Numpy and JAX in both syntax and semantics. However, as a lower-level language there are a lot of details that you need to keep track of. In particular, one area that learners have trouble with is memory loading and storage which is critical for speed on low-level devices.

This set is puzzles is meant to teach you how to use Pallas from first principles in an interactive fashion. You will start with trivial examples and build your way up to real algorithms like Flash Attention and Quantized neural networks. These puzzles **do not** need to run on GPU since Pallas has an interpreter, but be careful as the interperter might produce somewhat different results to the GPU with masked loads and stores -- this'll hopefully be fixed soon.


In [None]:
%%capture

# Only need to run the first time.
!pip install jaxtyping
!pip install "jax[cuda12]"

In [None]:
import jax
from jax import Array
from jax import numpy as jnp
from jax.experimental import pallas as pl
from jaxtyping import Float32, Int32

from utils import test

In [None]:
import os

# We set an INTERPRET environment variable.
# If you have a GPU available and want to run the puzzles on it, set it to True.
interpret = True

os.environ["INTERPRET"] = str(interpret)

## Introduction

To begin with, we will only use `pl.load` and `pl.store` in order to build simple programs.

Here's an example of load. It takes an `arange` over the memory. By default the indexing of JAX Arrays with column, rows, depths or right-to-left. It also takes in a mask as an optional argument. The load/store mask is critically important because all shapes in Triton need to be powers of two.

We'll start with a program that both reads (`pl.load`) and writes (`pl.store`) to memory. The reason we cannot just read from memory is that JAX will skip calling a function that does not have a side effect. This is by design.


In [None]:
def pallas_demo(x_ref, y_ref):
    # a mask is used to not read/write data out of bounds
    # x_mask = jnp.arange(x_ref.size) < 5
    x_mask = jnp.arange(x_ref.size) < 5
    # let's read only 5 elements, the rest will be 0
    x_index = (pl.dslice(0, x_ref.shape[0]),)  # slice dimensions,
    # `[pl.dslice(None)]` is equivalent to Python's `:`
    x = pl.load(x_ref, x_index, mask=x_mask, other=0)

    # when executing in a debugging mode (`interpret=True`) you can use
    # `jax.debug.print` to print values.
    # jax.debug.print("x = {}", x)

    # pl.store takes (reference/pointer, slice dimensions, data)
    y_mask = jnp.arange(y_ref.size) < 6  # let's write only 6 elements

    # we can also specify a slice as `pl.dslice(start, size)`
    # however, note that size has to be available at compile time, start does not
    # [pl.dslice(0, 6)] is equivalent to Python's y_ref[0:]
    y_index = (pl.dslice(0, y_ref.shape[0]),)
    pl.store(y_ref, y_index, (x + 1), mask=y_mask)


# example data
x = jnp.arange(8).astype(jnp.float32)

# input spec, how to split input on the grid
# BlockSpec takes two args:
# - grid dimension lambda that outputs start idx
# - size of data reference to pass to each program (this is copy-free)
in_spec = pl.BlockSpec(x.shape, lambda i: (0,))

y_shape = x.shape
out_spec = pl.BlockSpec(y_shape, lambda i: (0,))

out = pl.pallas_call(
    pallas_demo,
    grid=(1,),
    in_specs=[in_spec],  # input spec must be a list/tuple
    out_specs=out_spec,  # output spec for Pallas, how to split output
    out_shape=jax.ShapeDtypeStruct(y_shape, jnp.float32),  # output shape for JIT
    interpret=True,  # execute on CPU, debug statements do not work on GPU
)(x)
print(f"out = {out}")

You can also use this trick to read in a 2d array.


In [None]:
def pallas_demo(x_ref, y_ref):
    # index for 2 axis/dims `[:, :]`
    x_index = (pl.dslice(None), pl.dslice(None))
    x = pl.load(x_ref, x_index)
    jax.debug.print("x =\n{}", x)

    y_index = (pl.dslice(None), pl.dslice(None))
    # two dimensional mask using numpy-like broadcasting
    # with `&` (logical and) operator
    y_mask = (jnp.arange(y_ref.shape[0]) < 2)[:, None]
    y_mask &= (jnp.arange(y_ref.shape[1]) < 3)[None, :]
    pl.store(y_ref, y_index, (x + 1) ** 2, mask=y_mask)


# example data
x = jnp.arange(4**2).reshape((4, 4)).astype(jnp.float32)

out = pl.pallas_call(
    pallas_demo,
    grid=(1,),
    in_specs=[pl.BlockSpec(x.shape, lambda i: (0, 0))],
    out_specs=pl.BlockSpec(x.shape, lambda i: (0, 0)),
    out_shape=jax.ShapeDtypeStruct(x.shape, jnp.float32),
    interpret=True,
)(x)
print(f"out =\n{out}")

You can only load in relatively small `blocks` at a time in Pallas. to work with larger tensors you need to use a program id axis to run multiple blocks in parallel. Here is an example with one program axis with 3 blocks.


In [None]:
def pallas_demo(x_ref, y_ref):
    pid = pl.program_id(0)

    # index for 2 axis/dims `[2*i:2*i+2, 2*i:2*i+1]`
    x_index = (pl.dslice(2 * pid, 2), pl.dslice(2 * pid, 2))
    x = pl.load(x_ref, x_index)
    jax.debug.print("x =\n{}", x)

    # index for 2 axis/dims `[2*i:2*i+2, 2*i:2*i+1]`
    y_index = (pl.dslice(2 * pid, 2), pl.dslice(2 * pid, 2))
    pl.store(y_ref, y_index, (x + 1) ** 2)


# example data
x = jnp.arange(4**2).reshape((4, 4)).astype(jnp.float32)

# this kernel call is not great, we're only writing to block diagonal elements
# of y; the rest is uninitialized memory
out = pl.pallas_call(
    pallas_demo,
    grid=(2,),  # NOTE: we changed the grid size
    in_specs=[pl.BlockSpec(x.shape, lambda i: (0, 0))],
    out_specs=pl.BlockSpec(x.shape, lambda i: (0, 0)),
    out_shape=jax.ShapeDtypeStruct(x.shape, jnp.float32),
    interpret=True,
)(x)
print(f"out =\n{out}")

See the [Pallas Docs](https://jax.readthedocs.io/en/latest/pallas/index.html) for further information.


## Prerequisites: grid and BlockSpec

**TODO**


## Puzzle 1: Constant Add

Add a constant to a vector. Uses one program id axis. Block size `B0` is always the same as vector `x` with length `N0`.

$$z_i = 10 + x_i \text{ for } i = 1\ldots N_0$$

<p align="center"><img alt="constant add illustration" width="50%" src="../assets/1_constant_add.png" /></p>


In [None]:
N0 = 32


def add_spec(x: Float32[Array, f"{N0}"]) -> Float32[Array, f"{N0}"]:
    "This is the spec that you should implement. Uses typing to define sizes."
    return x + 10.0


def add_kernel(x_ref, z_ref, B0: int):
    pass
    # Finish me!


test(add_kernel, add_spec, nelem={"N0": N0})

## Puzzle 2: Constant Add Block

Add a constant to a vector. Uses one program block axis (no `for` loops yet). Block size `B0` is now smaller than the shape vector `x` which is `N0`.

$$z_i = 10 + x_i \text{ for } i = 1\ldots N_0$$

<p align="center"><img alt="constant add illustration" width="50%" src="../assets/2_constant_add_block.png" /></p>


In [None]:
N0 = 100


def add2_spec(x: Float32[Array, f"{N0}"]) -> Float32[Array, f"{N0}"]:
    return x + 10.0


def add_mask2_kernel(x_ref, z_ref, B0: int):
    pid = pl.program_id(0)
    pass
    # finish me!


test(add_mask2_kernel, add2_spec, nelem={"N0": N0})

## Puzzle 3: Outer Vector Add

Add two vectors.

Uses one program block axis. Block size `B0` is always the same as vector `x` length `N0`.
Block size `B1` is always the same as vector `y` length `N1`.

$$z_{j, i} = x_i + y_j\text{ for } i = 1\ldots B_0,\ j = 1\ldots B_1$$

<p align="center"><img alt="constant add illustration" width="50%" src="../assets/3_outer_vector_add.png" /></p>


In [None]:
N0, N1 = 32, 32


def add_vec_spec(
    x: Float32[Array, f"{N0}"], y: Float32[Array, f"{N1}"]
) -> Float32[Array, f"{N1} {N0}"]:
    return x[None, :] + y[:, None]


def add_vec_kernel(x_ref, y_ref, z_ref, B0: int, B1: int):
    pass
    # finish me!


test(add_vec_kernel, add_vec_spec, nelem={"N0": N0, "N1": N1})

## Puzzle 4: Outer Vector Add Block

Add a row vector to a column vector.

Uses two program block axes. Block size `B0` is always less than the vector `x` length `N0`.
Block size `B1` is always less than vector `y` length `N1`.

$$z_{j, i} = x_i + y_j\text{ for } i = 1\ldots N_0,\ j = 1\ldots N_1$$

<p align="center"><img alt="constant add illustration" width="50%" src="../assets/4_outer_vector_add_block.png" /></p>


In [None]:
N0, N1 = 100, 90


def add_vec_block_spec(
    x: Float32[Array, f"{N0}"], y: Float32[Array, f"{N1}"]
) -> Float32[Array, f"{N0} {N1}"]:
    return x[None, :] + y[:, None]


def add_vec_block_kernel(x_ref, y_ref, z_ref, B0: int, B1: int):
    pid_i, pid_j = pl.program_id(0), pl.program_id(1)
    pass
    # finish me!


test(add_vec_block_kernel, add_vec_block_spec, nelem={"N0": N0, "N1": N1})

## Puzzle 5: Fused Outer Multiplication

Multiply a row vector to a column vector and take a relu.

Uses two program block axes. Block size `B0` is always less than the vector `x` length `N0`.
Block size `B1` is always less than vector `y` length `N1`.

$$z_{j, i} = \text{relu}(x_i \times y_j)\text{ for } i = 1\ldots N_0,\ j = 1\ldots N_1$$

<p align="center"><img alt="constant add illustration" width="50%" src="../assets/5_fused_outer_multiplication.png" /></p>


In [None]:
N0, N1 = 100, 90


def mul_relu_block_spec(
    x: Float32[Array, f"{N0}"], y: Float32[Array, f"{N1}"]
) -> Float32[Array, f"{N1} {N0}"]:
    return jax.nn.relu(x[None, :] * y[:, None])


def mul_relu_block_kernel(x_ref, y_ref, z_ref, B0: int, B1: int):
    pid_i, pid_j = pl.program_id(0), pl.program_id(1)
    pass
    # finish me!


test(mul_relu_block_kernel, mul_relu_block_spec, nelem={"N0": N0, "N1": N1})

## Puzzle 6: Fused Outer Multiplication - Backwards

Backwards of a function that multiplies a matrix with a row vector and take a relu.

Uses two program blocks. Block size `B0` is always less than the vector `x` length `N0`.
Block size `B1` is always less than vector `y` length `N1`. Chain rule backward `dz`
is of shape `N1` by `N0`

$$f(x, y) = \text{relu}(x_i \times y_j)\text{ for } i = 1\ldots N_0,\ j = 1\ldots N_1$$

$$dx_{i, j} = f_x'(x, y)_{i, j} \times dz_{i,j}$$

<p align="center"><img alt="constant add illustration" width="50%" src="../assets/6_fused_outer_multiplication_backwards.png" /></p>


In [None]:
N0, N1 = 100, 90


def mul_relu_block_back_spec(
    x: Float32[Array, f"{N1} {N0}"],
    y: Float32[Array, f"{N1}"],
    dz: Float32[Array, f"{N1} {N0}"],
) -> Float32[Array, f"{N1} {N0}"]:
    return jax.grad(lambda x, y: jnp.sum(jax.nn.relu(x * y[:, None]) * dz))(x, y)


def mul_relu_block_back_kernel(x_ref, y_ref, dz_ref, dx_ref, B0: int, B1: int):
    pid_i, pid_j = pl.program_id(0), pl.program_id(1)
    pass
    # finish me!


test(mul_relu_block_back_kernel, mul_relu_block_back_spec, nelem={"N0": N0, "N1": N1})

## Puzzle 7: Long Sum

Sum of a batch of numbers.

Uses one program blocks. Block size `B0` represents a range of batches of `x` of length `N0`.
Each element is of length `T`. Process it `B1 < T` elements at a time.

$$z_{i} = \sum^{T}_j x_{i,j} =  \text{ for } i = 1\ldots N_0$$

<p align="center"><img alt="constant add illustration" width="50%" src="../assets/7_long_sum.png" /></p>

Hint: You will need a for loop for this problem. These work and look the same as in Python.


In [None]:
N0, N1 = 4, 200


def sum_spec(x: Float32[Array, f"{N0} {N1}"]) -> Float32[Array, f"{N0}"]:
    return jnp.sum(x, -1)


def sum_kernel(x_ref, z_ref, B0: int, B1: int):
    pid_i = pl.program_id(0)
    T = x_ref.shape[-1]
    pass
    # finish me!


test(sum_kernel, sum_spec, B={"B0": 1, "B1": 32}, nelem={"N0": N0, "N1": N1, "T": 200})

## Puzzle 8: Long Softmax

Softmax of a batch of logits.

Uses one program block axis. Block size `B0` represents the batch of `x` of length `N0`.
Block logit length `T`. Process it `B1 < T` elements at a time.

$$z_{i, j} = \text{softmax}(x_{i,1} \ldots x_{i, T}) \text{ for } i = 1\ldots N_0$$

Note softmax needs to be computed in numerically stable form as in Python. In addition in Triton they recommend not using `exp` but instead using `exp2`. You need the identity

$$\exp(x) = 2^{\log_2(e) x}$$

Advanced: there one way to do this with 3 loops. You can also do it with 2 loops if you are clever. Hint: you will find this identity useful:

$$\exp(x_i - m) =  \exp(x_i - m/2 - m/2) = \exp(x_i - m/ 2) /  \exp(m/2) $$

<p align="center"><img alt="constant add illustration" width="50%" src="../assets/8_long_softmax.png" /></p>


In [None]:
N0, N1 = 4, 200


def softmax_spec(x: Float32[Array, f"{N0} {N1}"]) -> Float32[Array, f"{N0} {N1}"]:
    x_max = jnp.max(x, -1)[..., None]
    x = x - x_max
    x_exp = jnp.exp(x)
    return x_exp / jnp.sum(x_exp, -1)[..., None]


def softmax_kernel(x_ref, z_ref, B0: int, B1: int) -> None:
    pid_i = pl.program_id(0)
    T = x_ref.shape[-1]
    log2_e = 1.44269504
    exp = lambda x: jnp.exp2(x * log2_e)
    pass
    # finish me!


test(
    softmax_kernel,
    softmax_spec,
    B={"B0": 1, "B1": 32},
    nelem={"N0": N0, "N1": N1, "T": 200},
)

## Puzzle 9: Simple FlashAttention

A scalar version of FlashAttention.

Uses zero programs. Block size `B0` represents `k` of length `N0`.
Block size `B0` represents `q` of length `N0`. Block size `B0` represents `v` of length `N0`.
Sequence length is `T`. Process it `B1 < T` elements at a time.

$$z_{i} = \sum_{j} \text{softmax}(q_1 k_1, \ldots, q_T k_T)_j v_{j} \text{ for } i = 1\ldots N_0$$

<p align="center"><img alt="constant add illustration" width="50%" src="../assets/9_simple_flash_attention.png" /></p>

This can be done in 1 loop using a similar trick from the last puzzle.


In [None]:
N0, T = 100, 200


def flashatt_spec(
    q: Float32[Array, f"{N0}"], k: Float32[Array, f"{T}"], v: Float32[Array, f"{T}"]
) -> Float32[Array, "100"]:
    x = q[:, None] * k[None, :]
    x_max = jnp.max(x, -1)[..., None]
    x = x - x_max
    x_exp = jnp.exp(x)
    soft = x_exp / jnp.sum(x_exp, -1)[..., None]
    return jnp.sum(v[None, :] * soft, -1)


def flashatt_kernel(q_ref, k_ref, v_ref, z_ref, B0: int):
    T = k_ref.shape[-1]
    pid_i = pl.program_id(0)
    log2_e = 1.44269504
    exp = lambda x: jnp.exp2(x * log2_e)
    pass
    # finish me!


test(flashatt_kernel, flashatt_spec, B={"B0": 32}, nelem={"N0": N0, "T": T})

## Puzzle 10: Two Dimensional Convolution

A batched 2D convolution.

Uses one program id axis. Block size `B0` represent the batches to process out of `N0`.
Image `x` is size is `H` by `W` with only 1 channel, and kernel `k` is size `KH` by `KW`.

$$z_{i, j, k} = \sum_{oj, ok} k_{oj,ok} \times x_{i,j + oj, k + ok} \text{ for } i = 1\ldots N_0$$

<p align="center"><img alt="constant add illustration" width="50%" src="../assets/10_two_dimensional_convolution.png" /></p>

_A comment on Pallas: We highly recommend performing the loop using
`jax.lax.fori_loop` so that you can **avoid** the static analysis of `pl.load`
complaining about retrieving values from beyond the bounds of `x` (which is
required for dynamic padding). You can then safely `pl.load` values from beyond
the bounds of `x` with a correct `mask`._


In [None]:
N0, T, K = 4, 8, 4


def conv2d_spec(
    x: Float32[Array, f"{N0} {T} {T}"], k: Float32[Array, f"{K} {K}"]
) -> Float32[Array, f"{N0} {T} {T}"]:
    z = [[None for _ in range(x.shape[2])] for _ in range(x.shape[1])]
    x_pad = jnp.pad(x, ((0, 0), (0, k.shape[-2]), (0, k.shape[-1])))
    for i in range(x.shape[1]):
        for j in range(x.shape[2]):
            z[i][j] = jnp.sum(
                k[None, :, :] * x_pad[:, i : i + 4, j : j + 4], axis=(-1, -2)
            )
    z = jnp.stack(
        [
            jnp.stack([z[i][j] for i in range(x.shape[1])], -1)
            for j in range(x.shape[2])
        ],
        -1,
    )
    return z


def conv2d_kernel(x_ref, k_ref, z_ref, B0: int):
    pid_i = pl.program_id(0)
    pass
    # finish me!


test(
    conv2d_kernel,
    conv2d_spec,
    B={"B0": 2},
    nelem={"N0": N0, "H": T, "W": T, "KH": K, "KW": K},
)

## Puzzle 11: Matrix Multiplication

A blocked matrix multiplication.

Uses three program id axes. Block size `B2` represent the batches to process out of `N2`.
Block size `B0` represent the rows of `x` to process out of `N0`. Block size `B1` represent the cols of `y` to process out of `N1`. The middle shape is `MID`.

$$z_{i, j, k} = \sum_{l} x_{i,j, l} \times y_{i, l, k} \text{ for } i = 1\ldots N_2, j = 1\ldots N_0, k = 1\ldots N_1$$

You are allowed to use `pl.dot` which computes a smaller mat mul.

Hint: the main trick is that you can split a matmul into smaller parts.

$$z_{i, j, k} = \sum_{l=1}^{L/2} x_{i,j, l} \times y_{i, l, k} +  \sum_{l=L/2}^{L} x_{i,j, l} \times y_{i, l, k} $$

<p align="center"><img alt="constant add illustration" width="50%" src="../assets/11_matrix_multiplication.png" /></p>


In [None]:
N0, N1, N2, MID = 32, 32, 4, 64


def dot_spec(
    x: Float32[Array, f"{N2} {N0} {MID}"], y: Float32[Array, f"{N2} {MID} {N1}"]
) -> Float32[Array, f"{N2} {N0} {N1}"]:
    return x @ y


def dot_kernel(x_ref, y_ref, z_ref, B0: int, B1: int, B2: int, B_MID: int):
    pid_i, pid_j, pid_k = pl.program_id(0), pl.program_id(1), pl.program_id(2)
    N2, N0, MID = x_ref.shape
    _, MID, N1 = y_ref.shape
    pass
    # finish me!


test(
    dot_kernel,
    dot_spec,
    B={"B0": 16, "B1": 16, "B2": 1, "B_MID": 16},
    nelem={"N0": N0, "N1": N1, "N2": N2, "MID": MID},
    rtol=3e-3,
    atol=3e-3,
)

## Puzzle 12: Quantized Matrix Multiplication

When doing matrix multiplication with quantized neural networks a common strategy is to store the weight matrix in lower precision, with a shift and scale term.

For this problem our `weight` will be stored in 4 bits. We can store `FPINT` of these in a 32 bit integer. In addition for every `group` weights in order we will store 1 `scale` float value and 1 `shift` 4 bit value. We store these for the column of weight. The `activation`s are stored separately in standard floats.

Mathematically it looks like.

$$z_{j, k} = \sum_{l} sc_{j, \frac{l}{g}} (w_{j, l} - sh_{j, \frac{l}{g}}) \times y_{l, k} \text{ for } i = 1\ldots N_2, j = 1\ldots N_0, k = 1\ldots N_1$$

<p align="center"><img alt="constant add illustration" width="50%" src="../assets/12_quantized_matrix_multiplication.png" /></p>

However, it is a bit more complex since we need to also extract the 4-bit values into floats to begin.


In [None]:
FPINT = 32 // 4
GROUP = 8
N0, N1, MID = 32, 32, 64


def extract_4bit(x: Int32) -> Int32:
    """Unpack 4-bit integers from 32-bit integers."""
    over = jnp.arange(FPINT, dtype=jnp.int32) * 4
    mask = 2**4 - 1
    return ((x[..., None] >> over) & mask).reshape(x.shape[:-1] + (-1,))


def broadcast_group(z: Float32 | Int32, group: int):
    """Broadcast each scalar to `group` elements contiguously in the last dimension."""
    return jnp.broadcast_to(z[..., None], z.shape + (group,)).reshape(
        z.shape[:-1] + (-1,)
    )


def quant_dot_spec(
    scale: Float32[Array, f"{N0} {MID // GROUP}"],
    offset: Int32[Array, f"{N0} {MID // FPINT // GROUP}"],
    weight: Int32[Array, f"{N0} {MID // FPINT}"],
    activation: Float32[Array, f"{MID} {N1}"],
) -> Float32[Array, f"{N0} {N1}"]:
    scale = broadcast_group(scale, GROUP)
    weight_float = scale * (
        extract_4bit(weight) - broadcast_group(extract_4bit(offset), GROUP)
    )
    return weight_float @ activation


def quant_dot_kernel(
    scale_ref: Float32[Array, f"{N0} {MID // GROUP}"],
    offset_ref: Int32[Array, f"{N0} {MID // FPINT // GROUP}"],
    weight_ref: Int32[Array, f"{N0} {MID // FPINT}"],
    activation_ref: Float32[Array, f"{MID} {N1}"],
    z_ref: Float32[Array, f"{N0} {N1}"],
    B0: int,
    B1: int,
    B_MID: int,
):
    pid_i, pid_j = pl.program_id(0), pl.program_id(1)
    B_weight = B_MID // FPINT
    assert B_weight * FPINT == B_MID, "B_MID must be divisible by FPINT"
    B_offset = B_MID // FPINT // GROUP
    assert B_offset * GROUP * FPINT == B_MID, "B_MID must be divisible by FPINT * GROUP"
    B_scale = B_MID // GROUP
    assert B_scale * GROUP == B_MID, "B_MID must be divisible by GROUP"
    pass
    # finish me!

    row_mask = ((pid_i * B0 + jnp.arange(B0)) < weight_ref.shape[0])[:, None]
    col_mask = ((pid_j * B1 + jnp.arange(B1)) < activation_ref.shape[1])[None, :]

    def body_fn(k, acc):
        weight_idx = (pl.dslice(B0 * pid_i, B0), pl.dslice(B_weight * k, B_weight))
        weight_mask = (
            row_mask
            & (B_weight * k + jnp.arange(B_weight) < weight_ref.shape[1])[None, :]
        )
        weight_4bit = pl.load(weight_ref, weight_idx, mask=weight_mask)

        scale_idx = (pl.dslice(B0 * pid_i, B0), pl.dslice(B_scale * k, B_scale))
        scale_mask = (
            row_mask & (B_scale * k + jnp.arange(B_scale) < scale_ref.shape[1])[None, :]
        )
        scale = pl.load(scale_ref, scale_idx, mask=scale_mask)

        offset_idx = (pl.dslice(B0 * pid_i, B0), pl.dslice(B_offset * k, B_offset))
        offset_mask = (
            row_mask
            & (B_offset * k + jnp.arange(B_offset) < offset_ref.shape[1])[None, :]
        )
        offset = pl.load(offset_ref, offset_idx, mask=offset_mask)

        scale = broadcast_group(scale, GROUP)
        offset = broadcast_group(extract_4bit(offset), GROUP)
        weight_float = scale * (extract_4bit(weight_4bit) - offset)
        act_idx = (pl.dslice(B_MID * k, B_MID), pl.dslice(B1 * pid_j, B1))
        act_mask = (
            col_mask
            & (B_MID * k + jnp.arange(B_MID) < activation_ref.shape[0])[:, None]
        )
        act = pl.load(activation_ref, act_idx, mask=act_mask)
        w_act = pl.dot(weight_float, act)
        return acc + w_act

    MID = activation_ref.shape[0]
    acc = jnp.zeros((B0, B1), dtype=z_ref.dtype)
    acc = jax.lax.fori_loop(0, pl.cdiv(MID, B_MID), body_fn, acc)
    z_mask = row_mask & col_mask
    pl.store(
        z_ref, (pl.dslice(B0 * pid_i, B0), pl.dslice(B1 * pid_j, B1)), acc, mask=z_mask
    )


test(
    quant_dot_kernel,
    quant_dot_spec,
    B={"B0": 16, "B1": 16, "B_MID": 64},
    nelem={"N0": 32, "N1": 32, "MID": 64},
    rtol=1e-2,
    atol=1e-2,
)