In [None]:
# Taken from https://github.com/jax-ml/jax/blob/5ab714bdaeac7f22e6ae50273b6b227b373763e6/tests/pallas/tpu_pallas_state_test.py#L140

import functools

import jax
import jax.numpy as jnp
import jax.experimental.pallas as pl
import jax.experimental.pallas.tpu as pltpu
import numpy as np


m, k, n = 512, 512, 512
bm, bk, bn = 128, 128, 128

def matmul_kernel(acc_ref, x_ref, y_ref, o_ref):
  @pl.when(pl.program_id(2) == 0)
  def _():
    acc_ref[...] = jnp.zeros_like(acc_ref)

  acc_ref[...] += jnp.dot(
      x_ref[...], y_ref[...], preferred_element_type=jnp.float32
  )

  @pl.when(pl.program_id(2) == pl.num_programs(2) - 1)
  def _():
    o_ref[...] = acc_ref[...].astype(o_ref.dtype)

def matmul(x, y):

  def run_matmul(refs):
    x_ref, y_ref, o_ref = refs

    def matmul_pipeline_kernel(acc_ref):
      pltpu.emit_pipeline(
          functools.partial(matmul_kernel, acc_ref),
          grid=(m // bm, n // bn, k // bk),
          in_specs=[
              pl.BlockSpec((bm, bk), lambda i, j, k: (i, k)),
              pl.BlockSpec((bk, bn), lambda i, j, k: (k, j)),
          ],
          out_specs=pl.BlockSpec((bm, bn), lambda i, j, k: (i, j)),
      )(x_ref, y_ref, o_ref)

    pl.pallas_call(
        matmul_pipeline_kernel,
        out_shape=[],
        scratch_shapes=[pltpu.VMEM((bm, bn), jnp.float32)],
        debug=True,
    )()

  _, _, o = pl.run_state(run_matmul)(
      (x, y, jnp.ones((m, n), dtype=x.dtype))
  )
  return o

x = jax.random.normal(jax.random.key(0), (m, k), jnp.float32)
y = jax.random.normal(jax.random.key(1), (k, n), jnp.float32)
o = matmul(x, y)
atol = 2e-5
np.testing.assert_allclose(o, x @ y, atol=atol)



The kernel jaxpr for pallas_call matmul_pipeline_kernel at /tmp/ipykernel_2227720/2988609796.py:31:
{ [34;1mlambda [39;22m; a[35m:MemRef<any>{float32[512,512]}[39m b[35m:MemRef<any>{float32[512,512]}[39m c[35m:MemRef<any>{float32[512,512]}[39m
    d[35m:MemRef<any>{float32[512,512]}[39m e[35m:MemRef<any>{float32[512,512]}[39m f[35m:MemRef<any>{float32[512,512]}[39m
    g[35m:MemRef<vmem>{float32[128,128]}[39m. [34;1mlet
    [39;22mh[35m:i32[][39m = mul 1:i32[] 4:i32[]
    i[35m:i32[][39m = mul h 4:i32[]
    j[35m:i32[][39m = mul i 4:i32[]
    run_scoped[
      collective_axes=()
      jaxpr={ [34;1mlambda [39;22mk[35m:i32[][39m l[35m:Ref{float32[512,512]}[39m m[35m:Ref{float32[512,512]}[39m n[35m:MemRef<vmem>{float32[128,128]}[39m
          o[35m:Ref{float32[512,512]}[39m; p[35m:MemRef<vmem>{float32[2,128,128]}[39m q[35m:MemRef<smem>{uint32[1]}[39m
          r[35m:MemRef<semaphore_mem>{dma_sem[2]}[39m s[35m:MemRef<vmem>{float32[2,128,128]}[3