The purpose of this notebook is to prototype a version of FlashAttention kernel
that can be captured by AOTAutograd and Dynamo.

In [1]:
import torch
import torch_xla
from typing import List

torch.set_default_dtype(torch.float32)
torch.manual_seed(42)
torch_xla._XLAC._xla_set_mat_mul_precision('highest')


def serialize(t):
  if t is None:
    # TODO: this code path causes FunctionalStorageImpl assertion error at https://github.com/pytorch/pytorch/blob/9933e59c2b16c1b0475189eef63b0ff405dfd091/aten/src/ATen/FunctionalStorageImpl.cpp#L111
    return torch.tensor(False, dtype=torch.bool, device='xla')
  else:
    return t.clone()
  

def deserialize(t):
  if t.dtype == torch.bool:
    return None
  else:
    return t.clone()


class ProxyCtx:

  def __init__(self) -> None:
    self.to_save = []

  def save_for_backward(self, *tensors: torch.Tensor):
    self.to_save = tensors

  @property
  def saved_tensors(self):
    return self.to_save

  def serialize(self) -> List[torch.Tensor]:
    lst = []
    for t in self.to_save:
      lst.append(serialize(t))
    return lst

  @staticmethod
  def deserialize(saved: List[torch.Tensor]) -> "ProxyCtx":
    ctx = ProxyCtx()
    lst = []
    for t in saved:
      lst.append(deserialize(t))
    ctx.to_save = lst
    return ctx
  

def describe_value(v):
    if v is not None and isinstance(v, torch.Tensor):
      print(f"{type(v)}({v.shape}, dtype={v.dtype}, device={v.device})")
    elif isinstance(v, list):
      print(f"list({len(v)})")
    elif v is None:
      print("None")
    else:
      print(type(v))





In [2]:
import functools
import os
import warnings

import numpy as np
import torch
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.spmd as xs
import torch_xla.debug.metrics as met

from typing import Any, List, Callable, Optional, Tuple, Dict
from torch.library import impl
from torch_xla.core.xla_model import XLA_LIB

_XLA_USE_BF16 = os.environ.get("XLA_USE_BF16", "0") == "1"


def _extract_backend_config(
    module: "jaxlib.mlir._mlir_libs._mlir.ir.Module") -> Optional[str]:
  """
  This algorithm intends to extract the backend config from the compiler IR like the following,
  and it is not designed to traverse any generic MLIR module.

  module @jit_add_vectors attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
    func.func public @main(%arg0: tensor<8xi32> {mhlo.layout_mode = "default", mhlo.sharding = "{replicated}"}, %arg1: tensor<8xi32> {mhlo.layout_mode = "default", mhlo.sharding = "{replicated}"}) -> (tensor<8xi32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
      %0 = call @add_vectors(%arg0, %arg1) : (tensor<8xi32>, tensor<8xi32>) -> tensor<8xi32>
      return %0 : tensor<8xi32>
    }
    func.func private @add_vectors(%arg0: tensor<8xi32>, %arg1: tensor<8xi32>) -> tensor<8xi32> {
      %0 = call @wrapped(%arg0, %arg1) : (tensor<8xi32>, tensor<8xi32>) -> tensor<8xi32>
      return %0 : tensor<8xi32>
    }
    func.func private @wrapped(%arg0: tensor<8xi32>, %arg1: tensor<8xi32>) -> tensor<8xi32> {
      %0 = call @apply_kernel(%arg0, %arg1) : (tensor<8xi32>, tensor<8xi32>) -> tensor<8xi32>
      return %0 : tensor<8xi32>
    }
    func.func private @apply_kernel(%arg0: tensor<8xi32>, %arg1: tensor<8xi32>) -> tensor<8xi32> {
      %0 = stablehlo.custom_call @tpu_custom_call(%arg0, %arg1) {backend_config = "{\22custom_call_config\22: {\22body\22: \22TUzvUgFNTElSMTkuMC4wZ2l0AAErCwEDBQcJAQMLAwUDDQcFDxEJBRMVA3lZDQFVBwsPEw8PCw8PMwsLCwtlCwsLCwsPCw8PFw8LFw8PCxcPCxcTCw8LDxcLBQNhBwNZAQ0bBxMPGw8CagMfBRcdKy0DAycpHVMREQsBBRkVMzkVTw8DCxUXGRsfCyELIyUFGwEBBR0NCWFmZmluZV9tYXA8KGQwKSAtPiAoZDApPgAFHwUhBSMFJQUnEQMBBSkVLw8dDTEXA8IfAR01NwUrFwPWHwEVO0EdPT8FLRcD9h8BHUNFBS8XA3InAQMDSVcFMR1NEQUzHQ1RFwPGHwEFNSN0cHUubWVtb3J5X3NwYWNlPHZtZW0+ACNhcml0aC5vdmVyZmxvdzxub25lPgAXVQMhBx0DJwMhBwECAgUHAQEBAQECBASpBQEQAQcDAQUDEQETBwMVJwcBAQEBAQEHAwUHAwMLBgUDBQUBBwcDBQcDAwsGBQMFBQMLCQdLRwMFBQkNBwMJBwMDCwYJAwUFBRENBAkHDwURBQABBgMBBQEAxgg32wsdE2EZ2Q0LEyMhHSknaw0LCxMPDw8NCQsRYnVpbHRpbgBmdW5jAHRwdQBhcml0aAB2ZWN0b3IAbW9kdWxlAHJldHVybgBjb25zdGFudABhZGRpAGxvYWQAc3RvcmUAL3dvcmtzcGFjZXMvd29yay9weXRvcmNoL3hsYS90ZXN0L3Rlc3Rfb3BlcmF0aW9ucy5weQBhZGRfdmVjdG9yc19rZXJuZWwAZGltZW5zaW9uX3NlbWFudGljcwBmdW5jdGlvbl90eXBlAHNjYWxhcl9wcmVmZXRjaABzY3JhdGNoX29wZXJhbmRzAHN5bV9uYW1lAG1haW4AdmFsdWUAL2dldFt0cmVlPVB5VHJlZURlZigoQ3VzdG9tTm9kZShOREluZGV4ZXJbKFB5VHJlZURlZigoQ3VzdG9tTm9kZShTbGljZVsoMCwgOCldLCBbXSksKSksICg4LCksICgpKV0sIFtdKSwpKV0AYWRkX3ZlY3RvcnMAdGVzdF90cHVfY3VzdG9tX2NhbGxfcGFsbGFzX2V4dHJhY3RfYWRkX3BheWxvYWQAPG1vZHVsZT4Ab3ZlcmZsb3dGbGFncwAvYWRkAC9zd2FwW3RyZWU9UHlUcmVlRGVmKChDdXN0b21Ob2RlKE5ESW5kZXhlclsoUHlUcmVlRGVmKChDdXN0b21Ob2RlKFNsaWNlWygwLCA4KV0sIFtdKSwpKSwgKDgsKSwgKCkpXSwgW10pLCkpXQA=\22, \22needs_layout_passes\22: true}}", kernel_name = "add_vectors_kernel", operand_layouts = [dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>], result_layouts = [dense<0> : tensor<1xindex>]} : (tensor<8xi32>, tensor<8xi32>) -> tensor<8xi32>
      return %0 : tensor<8xi32>
    }
  }

  Basically, what we are looking for is a two level of operations, and the tpu_custom_call operation in the inner level. It will return None if the payload is not found.
  """
  for operation in module.body.operations:
    assert len(
        operation.body.blocks) == 1, "The passing module is not compatible."
    for op in operation.body.blocks[0].operations:
      if op.name == "stablehlo.custom_call":
        return op.backend_config.value
  return None


def jax_import_guard():
  # Somehow, we need to grab the TPU before JAX locks it. Otherwise, any pt-xla TPU operations will hang.
  torch_xla._XLAC._init_computation_client()


def convert_torch_dtype_to_jax(dtype: torch.dtype) -> "jnp.dtype":
  # Import JAX within the function such that we don't need to call the jax_import_guard()
  # in the global scope which could cause problems for xmp.spawn.
  jax_import_guard()
  import jax.numpy as jnp
  if _XLA_USE_BF16:
    raise RuntimeError(
        "Pallas kernel does not support XLA_USE_BF16, please unset the env var")
  if dtype == torch.float32:
    return jnp.float32
  elif dtype == torch.float64:
    return jnp.float64
  elif dtype == torch.float16:
    return jnp.float16
  elif dtype == torch.bfloat16:
    return jnp.bfloat16
  elif dtype == torch.int32:
    return jnp.int32
  elif dtype == torch.int64:
    return jnp.int64
  elif dtype == torch.int16:
    return jnp.int16
  elif dtype == torch.int8:
    return jnp.int8
  elif dtype == torch.uint8:
    return jnp.uint8
  else:
    raise ValueError(f"Unsupported dtype: {dtype}")


def to_jax_shape_dtype_struct(tensor: torch.Tensor) -> "jax.ShapeDtypeStruct":
  # Import JAX within the function such that we don't need to call the jax_import_guard()
  # in the global scope which could cause problems for xmp.spawn.
  jax_import_guard()
  import jax

  return jax.ShapeDtypeStruct(tensor.shape,
                              convert_torch_dtype_to_jax(tensor.dtype))


trace_pallas_arg_to_payload: Dict[Tuple[Any], str] = {}


def trace_pallas(kernel: Callable,
                 *args,
                 static_argnums=None,
                 static_argnames=None,
                 use_cache=False,
                 **kwargs):
  # Import JAX within the function such that we don't need to call the jax_import_guard()
  # in the global scope which could cause problems for xmp.spawn.
  jax_import_guard()
  import jax
  import jax._src.pallas.mosaic.pallas_call_registration

  jax_args = []  # for tracing
  tensor_args = []  # for execution
  for i, arg in enumerate(args):
    # TODO: Could the args be a tuple of tensors or a list of tensors? Flattern them?
    if torch.is_tensor(arg):
      # ShapeDtypeStruct doesn't have any storage and thus is very suitable for generating the payload.
      jax_meta_tensor = to_jax_shape_dtype_struct(arg)
      jax_args.append(jax_meta_tensor)
      tensor_args.append(arg)
    else:
      jax_args.append(arg)

  hash_key = ()
  if use_cache:
    global trace_pallas_arg_to_payload
    # implcit assumption here that everything in kwargs is hashable and not a tensor,
    # which is true for the gmm and tgmm.
    hash_key = (jax.config.jax_default_matmul_precision, kernel, static_argnums,
                tuple(static_argnames)
                if static_argnames is not None else static_argnames,
                tuple(jax_args), repr(sorted(kwargs.items())).encode())
    if hash_key in trace_pallas_arg_to_payload:
      torch_xla._XLAC._xla_increment_counter('trace_pallas_cache_hit', 1)
      return trace_pallas_arg_to_payload[hash_key], tensor_args

  # Here we ignore the kwargs for execution as most of the time, the kwargs is only used in traced code.
  ir = jax.jit(
      kernel, static_argnums=static_argnums,
      static_argnames=static_argnames).lower(*jax_args, **kwargs).compiler_ir()
  payload = _extract_backend_config(ir)

  if use_cache:
    # if we reach here it means we have a cache miss.
    trace_pallas_arg_to_payload[hash_key] = payload

  return payload, tensor_args


def make_kernel_from_pallas(kernel: Callable, output_shape_dtype_fn: Callable):
  # TODO: Maybe we can cache the payload for the same input.
  def wrapped_kernel(kernel: Callable,
                     output_shape_dtype_fn: Callable,
                     *args,
                     static_argnums=None,
                     static_argnames=None,
                     **kwargs) -> Callable:
    payload, tensor_args = trace_pallas(
        kernel,
        *args,
        static_argnums=static_argnums,
        static_argnames=static_argnames,
        **kwargs)
    output_shape_dtype = output_shape_dtype_fn(*args)
    assert isinstance(output_shape_dtype,
                      list), "The output_shape_dtype_fn should return a list."
    output_shapes = [shape for shape, _ in output_shape_dtype]
    output_dtypes = [dtype for _, dtype in output_shape_dtype]
    outputs = torch_xla._XLAC._xla_tpu_custom_call(tensor_args, payload,
                                                   output_shapes, output_dtypes)

    # Make the output easier to use.
    if len(outputs) == 1:
      return outputs[0]
    return tuple(outputs)

  return functools.partial(wrapped_kernel, kernel, output_shape_dtype_fn)



In [3]:
class FlashAttention(torch.autograd.Function):
  """
  This is a simplified wrapper on top of https://github.com/google/jax/blob/b2058d72b7e1693a41303d5411572aabf99b7981/jax/experimental/pallas/ops/tpu/flash_attention.py#L139
  where we only takes q, k, v and causal as input and set block_sizes for the users.
  """

  MIN_BLOCK_SIZE = 128
  DEFAULT_MASK_VALUE = -0.7 * float(torch.finfo(torch.float32).max)
  # The block_sizes configuration is copied from https://github.com/google/maxtext/blob/0fee320451738166c8e596dc63a57a4673671576/MaxText/layers/attentions.py#L215-L240
  # It yields much better performance than the default block_sizes.
  DEFAULT_BLOCK_SIZES = {
      "block_q": 512,
      "block_k_major": 512,
      "block_k": 512,
      "block_b": 2,
      "block_q_major_dkv": 512,
      "block_k_major_dkv": 512,
      "block_q_dkv": 512,
      "block_k_dkv": 512,
      "block_q_dq": 1024,
      "block_k_dq": 256,
      "block_k_major_dq": 512,
  }
  NUM_LANES = 128
  NUM_SUBLANES = 8

  @staticmethod
  def prepare_segment_ids(q_segment_ids, kv_segment_ids):
    from jax.experimental.pallas.ops.tpu.flash_attention import SegmentIds
    if q_segment_ids is None or kv_segment_ids is None:
      return None, None, None

    assert q_segment_ids is not None and kv_segment_ids is not None, "Both q_segment_ids and kv_segment_ids should be provided."
    segment_ids = SegmentIds(
        to_jax_shape_dtype_struct(q_segment_ids),
        to_jax_shape_dtype_struct(kv_segment_ids))
    q_segment_ids = q_segment_ids.unsqueeze(-1).expand(
        [-1 for _ in q_segment_ids.shape] + [FlashAttention.NUM_LANES])
    kv_segment_ids = kv_segment_ids.unsqueeze(1).expand([
        kv_segment_ids.shape[0], FlashAttention.NUM_SUBLANES,
        kv_segment_ids.shape[1]
    ])
    return segment_ids, q_segment_ids, kv_segment_ids

  @staticmethod
  def forward(ctx, q, k, v, causal, q_segment_ids, kv_segment_ids, sm_scale, ab,
              partition_spec, mesh):
    # Import JAX within the function such that we don't need to call the jax_import_guard()
    # in the global scope which could cause problems for xmp.spawn.
    jax_import_guard()
    import jax
    from jax.experimental.pallas.ops.tpu.flash_attention import _flash_attention_impl

    ctx.causal = causal
    ctx.sm_scale = sm_scale
    ctx.partition_spec = partition_spec
    ctx.mesh = mesh
    ctx.q_full_shape = None
    ctx.kv_full_shape = None
    save_residuals = True

    # SPMD integration.
    # mark_sharding is in-placed, and therefore save the full q, k, v for the backward.
    full_q = q
    full_k = k
    full_v = v
    full_ab = ab
    if partition_spec is not None:
      ctx.q_full_shape = q.shape
      ctx.kv_full_shape = k.shape
      q = xs.enable_manual_sharding(q, partition_spec, mesh=mesh).global_tensor
      k = xs.enable_manual_sharding(k, partition_spec, mesh=mesh).global_tensor
      v = xs.enable_manual_sharding(v, partition_spec, mesh=mesh).global_tensor
      if ab:
        ab = xs.enable_manual_sharding(
            ab, partition_spec, mesh=mesh).global_tensor

    # It computes the shape and type of o, l, m.
    shapes = [q.shape]
    dtypes = [q.dtype]
    if save_residuals:
      res_shape = list(q.shape)
      res_shape[-1] = FlashAttention.MIN_BLOCK_SIZE
      for _ in range(2):
        shapes.append(res_shape)
        dtypes.append(torch.float32)

    with torch.no_grad():
      if partition_spec is not None and q_segment_ids is not None and kv_segment_ids is not None:
        # partition_spec is for q,k,v with shape [batch, num_head, seq_len, head_dim], segment id
        # is of shape [batch, seq_len], hence we need to tweak it a bit
        segment_id_partition_spec = (partition_spec[0], partition_spec[2])
        q_segment_ids = xs.enable_manual_sharding(
            q_segment_ids, segment_id_partition_spec, mesh=mesh).global_tensor
        kv_segment_ids = xs.enable_manual_sharding(
            kv_segment_ids, segment_id_partition_spec, mesh=mesh).global_tensor
      segment_ids, q_segment_ids_fa, kv_segment_ids_fa = FlashAttention.prepare_segment_ids(
          q_segment_ids, kv_segment_ids)
      ctx.segment_ids = segment_ids

      # We can't directly use flash_attention as we need to override the save_residuals flag which returns
      # l and m that is needed for the backward. Then we lose all the shape checks.
      # TODO: replicate the shape checks on flash_attention.
      # Here we seperate the tracing and execution part just to support SegmentIds.
      payload, _ = trace_pallas(
          _flash_attention_impl,
          q,
          k,
          v,
          ab,
          segment_ids,
          save_residuals,
          causal,
          sm_scale,
          min(FlashAttention.DEFAULT_BLOCK_SIZES["block_b"], q.shape[0]),
          min(FlashAttention.DEFAULT_BLOCK_SIZES["block_q"], q.shape[2]),
          min(FlashAttention.DEFAULT_BLOCK_SIZES["block_k_major"], k.shape[2]),
          min(FlashAttention.DEFAULT_BLOCK_SIZES["block_k"], k.shape[2]),
          False,
          static_argnums=range(5, 13),
          use_cache=True,
      )

      args = [q, k, v]
      if ab is not None:
        args += [ab]
      if segment_ids is not None:
        args += [q_segment_ids_fa, kv_segment_ids_fa]
      o = torch_xla._XLAC._xla_tpu_custom_call(args, payload, shapes, dtypes)

      if not save_residuals:
        o = o[0]
        # SPMD integration
        if partition_spec is not None:
          o = xs.disable_manual_sharding(
              o, partition_spec, ctx.q_full_shape, mesh=mesh).global_tensor
        return o
      o, *aux = o
      l, m = (v[..., 0] for v in aux[-2:])

    # SPMD integration
    if partition_spec is not None:
      o = xs.disable_manual_sharding(
          o, partition_spec, ctx.q_full_shape, mesh=mesh).global_tensor
      l = xs.disable_manual_sharding(
          l, partition_spec[0:3], ctx.q_full_shape[0:3],
          mesh=mesh).global_tensor
      m = xs.disable_manual_sharding(
          m, partition_spec[0:3], ctx.q_full_shape[0:3],
          mesh=mesh).global_tensor

    # q_segment_ids and kv_segment_ids are sharded here if partition_spec is provided
    # but it should be OK as the backward will use the same partition_spec
    ctx.save_for_backward(full_q, full_k, full_v, o, l, m, q_segment_ids_fa,
                          kv_segment_ids_fa, full_ab)
    return o

  @staticmethod
  def backward(ctx, grad_output):
    from jax.experimental.pallas.ops.tpu.flash_attention import _flash_attention_bwd_dq, _flash_attention_bwd_dkv

    q, k, v, o, l, m, q_segment_ids_fa, kv_segment_ids_fa, ab = ctx.saved_tensors
    causal = ctx.causal
    sm_scale = ctx.sm_scale
    partition_spec = ctx.partition_spec
    mesh = ctx.mesh
    q_full_shape = ctx.q_full_shape
    kv_full_shape = ctx.kv_full_shape
    # this segment_ids only reflects the local shape of segment_ids
    segment_ids = ctx.segment_ids
    grad_q = grad_k = grad_v = grad_ab = None

    grad_i = torch.sum(
        o.to(torch.float32) * grad_output.to(torch.float32),
        axis=-1)  # [batch_size, num_heads, q_seq_len]

    expanded_l = l.unsqueeze(-1).expand([-1 for _ in l.shape] +
                                        [FlashAttention.MIN_BLOCK_SIZE])
    expanded_m = m.unsqueeze(-1).expand([-1 for _ in m.shape] +
                                        [FlashAttention.MIN_BLOCK_SIZE])
    expanded_grad_i = grad_i.unsqueeze(-1).expand(
        [-1 for _ in grad_i.shape] + [FlashAttention.MIN_BLOCK_SIZE])

    # SPMD integration
    if partition_spec is not None:
      q = xs.enable_manual_sharding(q, partition_spec, mesh=mesh).global_tensor
      k = xs.enable_manual_sharding(k, partition_spec, mesh=mesh).global_tensor
      v = xs.enable_manual_sharding(v, partition_spec, mesh=mesh).global_tensor
      expanded_l = xs.enable_manual_sharding(
          expanded_l, partition_spec, mesh=mesh).global_tensor
      expanded_m = xs.enable_manual_sharding(
          expanded_m, partition_spec, mesh=mesh).global_tensor
      grad_output = xs.enable_manual_sharding(
          grad_output, partition_spec, mesh=mesh).global_tensor
      expanded_grad_i = xs.enable_manual_sharding(
          expanded_grad_i, partition_spec, mesh=mesh).global_tensor
      if ab:
        ab = xs.enable_manual_sharding(
            ab, partition_spec, mesh=mesh).global_tensor

    if ctx.needs_input_grad[0]:
      payload, _ = trace_pallas(
          _flash_attention_bwd_dq,
          q,
          k,
          v,
          ab,
          segment_ids,
          l,
          m,
          grad_output,
          grad_i,
          block_q_major=min(FlashAttention.DEFAULT_BLOCK_SIZES["block_q_dq"],
                            q.shape[2]),
          block_k_major=min(
              FlashAttention.DEFAULT_BLOCK_SIZES["block_k_major_dq"],
              k.shape[2]),
          block_k=min(FlashAttention.DEFAULT_BLOCK_SIZES["block_k_dq"],
                      k.shape[2]),
          sm_scale=sm_scale,
          causal=causal,
          mask_value=FlashAttention.DEFAULT_MASK_VALUE,
          debug=False,
          static_argnames=[
              "block_q_major", "block_k_major", "block_k", "sm_scale", "causal",
              "mask_value", "debug"
          ],
          use_cache=True,
      )

      args = [q, k, v]
      if ab is not None:
        args += [ab]
      if segment_ids is not None:
        args += [q_segment_ids_fa, kv_segment_ids_fa]
      args += [expanded_l, expanded_m, grad_output, expanded_grad_i]

      outputs = [q]
      if ab is not None:
        outputs += [ab]
      grads = torch_xla._XLAC._xla_tpu_custom_call(args, payload,
                                                   [i.shape for i in outputs],
                                                   [i.dtype for i in outputs])
      if ctx.needs_input_grad[0]:
        grad_q = grads[0]
      if ctx.needs_input_grad[-3]:
        grad_ab = grads[1]

    if ctx.needs_input_grad[1] or ctx.needs_input_grad[2]:
      payload, _ = trace_pallas(
          _flash_attention_bwd_dkv,
          q,
          k,
          v,
          ab,
          segment_ids,
          l,
          m,
          grad_output,
          grad_i,
          block_q_major=min(
              FlashAttention.DEFAULT_BLOCK_SIZES["block_q_major_dkv"],
              q.shape[2]),
          block_k_major=min(
              FlashAttention.DEFAULT_BLOCK_SIZES["block_k_major_dkv"],
              k.shape[2]),
          block_k=min(FlashAttention.DEFAULT_BLOCK_SIZES["block_k_dkv"],
                      k.shape[2]),
          block_q=min(FlashAttention.DEFAULT_BLOCK_SIZES["block_q_dkv"],
                      q.shape[2]),
          sm_scale=sm_scale,
          causal=causal,
          mask_value=FlashAttention.DEFAULT_MASK_VALUE,
          debug=False,
          static_argnames=[
              "block_q_major", "block_k_major", "block_k", "block_q",
              "sm_scale", "causal", "mask_value", "debug"
          ],
          use_cache=True)

      grads = torch_xla._XLAC._xla_tpu_custom_call(args, payload,
                                                   [k.shape, v.shape],
                                                   [k.dtype, v.dtype])

    if ctx.needs_input_grad[1]:
      grad_k = grads[0]
    if ctx.needs_input_grad[2]:
      grad_v = grads[1]

    # SPMD integration
    if partition_spec is not None:
      grad_q = xs.disable_manual_sharding(
          grad_q, partition_spec, q_full_shape, mesh=mesh).global_tensor
      grad_k = xs.disable_manual_sharding(
          grad_k, partition_spec, kv_full_shape, mesh=mesh).global_tensor
      grad_v = xs.disable_manual_sharding(
          grad_v, partition_spec, kv_full_shape, mesh=mesh).global_tensor

    return grad_q, grad_k, grad_v, None, None, None, None, grad_ab, None, None


def flash_attention(
    q,  # [batch_size, num_heads, q_seq_len, d_model]
    k,  # [batch_size, num_heads, kv_seq_len, d_model]
    v,  # [batch_size, num_heads, kv_seq_len, d_model]
    causal=False,
    q_segment_ids=None,  # [batch_size, q_seq_len]
    kv_segment_ids=None,  # [batch_size, kv_seq_len]
    sm_scale=1.0,
    *,
    ab=None,  # [batch_size, num_heads, q_seq_len, kv_seq_len]
    partition_spec=None,
    mesh=None,
):
  # TODO: support SPMD and Dynamo with segment_ids.
  return FlashAttention.apply(q, k, v, causal, q_segment_ids, kv_segment_ids,
                              sm_scale, ab, partition_spec, mesh)


In [4]:
jax_import_guard()
import jax
import jax.numpy as jnp
from jax.experimental import pallas as pl
import torch.nn as nn


def _attention(q, k, v, *, attn_mask=None, ab=None):
  attn_weight = q @ k.transpose(-2, -1)
  if attn_mask is not None:
    # Masked out the unrelevant parts.
    attn_weight = attn_weight.masked_fill(attn_mask,
                                          torch.finfo(attn_weight.dtype).min)
  if ab is not None:
    attn_weight = attn_weight + ab
  attn_weight = nn.functional.softmax(attn_weight, dim=-1)
  attn_output = attn_weight @ v
  return attn_output


def do_test(attn_fn):
  from functorch.compile import aot_function
  from torch_xla.experimental.custom_kernel import flash_attention

  def flash_attention_wrapper(q, k, v):
    return attn_fn(q, k, v)

  q = torch.randn(3, 2, 128, 4).to("xla").clone().detach().requires_grad_(True)
  k = torch.randn(3, 2, 128, 4).to("xla").clone().detach().requires_grad_(True)
  v = torch.randn(3, 2, 128, 4).to("xla").clone().detach().requires_grad_(True)

  q_clone = q.clone().detach().requires_grad_(True)
  k_clone = k.clone().detach().requires_grad_(True)
  v_clone = v.clone().detach().requires_grad_(True)
  
  def compiler(gm, _):
    print("Got graph:")
    print(gm.code)
    return gm

  compiled_flash_attention = aot_function(
      flash_attention_wrapper, fw_compiler=compiler)
  o_actual = compiled_flash_attention(q, k, v)
  o_actual.sum().backward()

  expected_o = _attention(q_clone, k_clone, v_clone)
  expected_o.sum().backward()

  torch.testing.assert_close(o_actual.cpu(), expected_o.cpu())
  assert q.grad is not None and q_clone.grad is not None, f"{q.grad}, {q_clone.grad}"
  torch.testing.assert_close(q.grad.cpu(), q_clone.grad.cpu())
  assert k.grad is not None and k_clone.grad is not None, f"{k.grad}, {k_clone.grad}"
  torch.testing.assert_close(k.grad.cpu(), k_clone.grad.cpu())
  assert v.grad is not None and v_clone.grad is not None, f"{v.grad}, {v_clone.grad}"
  torch.testing.assert_close(v.grad.cpu(), v_clone.grad.cpu())


def test_flash_attention_wrapper_with_aot_autograd(attn_fn):
  jax.config.update("jax_default_matmul_precision", "highest")
  try:
    do_test(attn_fn)
  finally:
    jax.config.update("jax_default_matmul_precision", "default")


Step 1: without registering a backward pass, the custom op can't be used during
backward propagation.

In [5]:
from torch.library import custom_op


@custom_op("xla::flash_attention_xla_v2", mutates_args=())
def flash_attention_xla_v2(q: torch.Tensor,
                           k: torch.Tensor,
                           v: torch.Tensor,
                           causal: bool = False) -> torch.Tensor:
  return flash_attention(q, k, v, causal=causal)


@flash_attention_xla_v2.register_fake
def _(q, k, v, causal=False):
  assert q.shape == k.shape
  assert k.shape == v.shape
  return torch.empty_like(v)

In [6]:
try:
  test_flash_attention_wrapper_with_aot_autograd(
      torch.ops.xla.flash_attention_xla_v2)
except RuntimeError as e:
  print(e)

Trying to backward through xla.flash_attention_xla_v2.default but no autograd formula was registered. Please use register_autograd to add one.


 (Triggered internally at /workspaces/torch/pytorch/torch/csrc/autograd/python_anomaly_mode.cpp:122.)
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


Step 2: replace the forward/backward of FlashAttention with custom ops that
don't have autograd themselves.

In [7]:
@custom_op("xla::fa_custom_forward", mutates_args=())
def fa_custom_forward(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
  q = q.clone()
  k = k.clone()
  v = v.clone()

  print("Inside fa_custom_forward")
  for t in [q, k, v]:
    describe_value(t)

  ctx = ProxyCtx()
  q_segment_ids = kv_segment_ids = ab = partition_spec = mesh = None
  sm_scale = 1.0
  causal = False

  # Import JAX within the function such that we don't need to call the jax_import_guard()
  # in the global scope which could cause problems for xmp.spawn.
  jax_import_guard()
  import jax
  from jax.experimental.pallas.ops.tpu.flash_attention import _flash_attention_impl

  ctx.causal = causal
  ctx.sm_scale = sm_scale
  ctx.partition_spec = partition_spec
  ctx.mesh = mesh
  ctx.q_full_shape = None
  ctx.kv_full_shape = None
  save_residuals = True

  # SPMD integration.
  # mark_sharding is in-placed, and therefore save the full q, k, v for the backward.
  full_q = q
  full_k = k
  full_v = v
  full_ab = ab
  if partition_spec is not None:
    ctx.q_full_shape = q.shape
    ctx.kv_full_shape = k.shape
    q = xs.enable_manual_sharding(q, partition_spec, mesh=mesh).global_tensor
    k = xs.enable_manual_sharding(k, partition_spec, mesh=mesh).global_tensor
    v = xs.enable_manual_sharding(v, partition_spec, mesh=mesh).global_tensor
    if ab:
      ab = xs.enable_manual_sharding(
          ab, partition_spec, mesh=mesh).global_tensor

  # It computes the shape and type of o, l, m.
  shapes = [q.shape]
  dtypes = [q.dtype]
  if save_residuals:
    res_shape = list(q.shape)
    res_shape[-1] = FlashAttention.MIN_BLOCK_SIZE
    for _ in range(2):
      shapes.append(res_shape)
      dtypes.append(torch.float32)

  with torch.no_grad():
    if partition_spec is not None and q_segment_ids is not None and kv_segment_ids is not None:
      # partition_spec is for q,k,v with shape [batch, num_head, seq_len, head_dim], segment id
      # is of shape [batch, seq_len], hence we need to tweak it a bit
      segment_id_partition_spec = (partition_spec[0], partition_spec[2])
      q_segment_ids = xs.enable_manual_sharding(
          q_segment_ids, segment_id_partition_spec, mesh=mesh).global_tensor
      kv_segment_ids = xs.enable_manual_sharding(
          kv_segment_ids, segment_id_partition_spec, mesh=mesh).global_tensor
    segment_ids, q_segment_ids_fa, kv_segment_ids_fa = FlashAttention.prepare_segment_ids(
        q_segment_ids, kv_segment_ids)
    ctx.segment_ids = segment_ids

    # We can't directly use flash_attention as we need to override the save_residuals flag which returns
    # l and m that is needed for the backward. Then we lose all the shape checks.
    # TODO: replicate the shape checks on flash_attention.
    # Here we seperate the tracing and execution part just to support SegmentIds.
    payload, _ = trace_pallas(
        _flash_attention_impl,
        q,
        k,
        v,
        ab,
        segment_ids,
        save_residuals,
        causal,
        sm_scale,
        min(FlashAttention.DEFAULT_BLOCK_SIZES["block_b"], q.shape[0]),
        min(FlashAttention.DEFAULT_BLOCK_SIZES["block_q"], q.shape[2]),
        min(FlashAttention.DEFAULT_BLOCK_SIZES["block_k_major"], k.shape[2]),
        min(FlashAttention.DEFAULT_BLOCK_SIZES["block_k"], k.shape[2]),
        False,
        static_argnums=range(5, 13),
        use_cache=True,
    )

    args = [q, k, v]
    if ab is not None:
      args += [ab]
    if segment_ids is not None:
      args += [q_segment_ids_fa, kv_segment_ids_fa]
    o = torch_xla._XLAC._xla_tpu_custom_call(args, payload, shapes, dtypes)

    if not save_residuals:
      o = o[0]
      # SPMD integration
      if partition_spec is not None:
        o = xs.disable_manual_sharding(
            o, partition_spec, ctx.q_full_shape, mesh=mesh).global_tensor
      return o
    o, *aux = o
    l, m = (v[..., 0] for v in aux[-2:])

  # SPMD integration
  if partition_spec is not None:
    o = xs.disable_manual_sharding(
        o, partition_spec, ctx.q_full_shape, mesh=mesh).global_tensor
    l = xs.disable_manual_sharding(
        l, partition_spec[0:3], ctx.q_full_shape[0:3],
        mesh=mesh).global_tensor
    m = xs.disable_manual_sharding(
        m, partition_spec[0:3], ctx.q_full_shape[0:3],
        mesh=mesh).global_tensor

  # q_segment_ids and kv_segment_ids are sharded here if partition_spec is provided
  # but it should be OK as the backward will use the same partition_spec
  ctx.save_for_backward(full_q, full_k, full_v, o, l, m)

  outs = [o] + ctx.serialize()
  print("Outs")
  for t in outs:
    describe_value(t)
  return tuple(outs)


@fa_custom_forward.register_fake
def _(q, k, v):
  print("Inside fake fa_custom_forward")

  assert q.shape == k.shape
  assert k.shape == v.shape
  ctx = ProxyCtx()

  # full_q, full_k, full_v, o, l, m, q_segment_ids_fa, kv_segment_ids_fa, full_ab
  full_q = torch.empty_like(q)
  full_k = torch.empty_like(k)
  full_v = torch.empty_like(v)
  o = torch.empty_like(v)
  l = torch.empty_like(v)[..., 0]
  m = torch.empty_like(v)[..., 0]
  q_segment_ids_fa = None
  kv_segment_ids_fa = None
  full_ab = None
  ctx.save_for_backward(
      full_q,
      full_k,
      full_v,
      o,
      l,
      m,
  )

  return tuple([torch.empty_like(o)] + [torch.empty_like(t) for t in ctx.serialize()])


@custom_op("xla::fa_custom_backward", mutates_args=())
def fa_custom_backward(grad_output: torch.Tensor, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, o: torch.Tensor, l: torch.Tensor, m: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  q_segment_ids_fa = kv_segment_ids_fa = ab = None
  grad_output = grad_output.clone()

  print("Inside fa_custom_backward")

  from jax.experimental.pallas.ops.tpu.flash_attention import _flash_attention_bwd_dq, _flash_attention_bwd_dkv

  saved_tensors = (q, k, v, o, l, m)
  q, k, v, o, l, m = (deserialize(t.clone()) for t in saved_tensors)
  causal = False
  sm_scale = 1.0
  partition_spec = None
  mesh = None
  q_full_shape = None
  kv_full_shape = None
  # this segment_ids only reflects the local shape of segment_ids
  segment_ids = None
  grad_q = grad_k = grad_v = grad_ab = None
  needs_input_grad = [True, True, True]

  grad_i = torch.sum(
      o.to(torch.float32) * grad_output.to(torch.float32),
      axis=-1)  # [batch_size, num_heads, q_seq_len]

  expanded_l = l.unsqueeze(-1).expand([-1 for _ in l.shape] +
                                      [FlashAttention.MIN_BLOCK_SIZE])
  expanded_m = m.unsqueeze(-1).expand([-1 for _ in m.shape] +
                                      [FlashAttention.MIN_BLOCK_SIZE])
  expanded_grad_i = grad_i.unsqueeze(-1).expand(
      [-1 for _ in grad_i.shape] + [FlashAttention.MIN_BLOCK_SIZE])

  # SPMD integration
  if partition_spec is not None:
    q = xs.enable_manual_sharding(q, partition_spec, mesh=mesh).global_tensor
    k = xs.enable_manual_sharding(k, partition_spec, mesh=mesh).global_tensor
    v = xs.enable_manual_sharding(v, partition_spec, mesh=mesh).global_tensor
    expanded_l = xs.enable_manual_sharding(
        expanded_l, partition_spec, mesh=mesh).global_tensor
    expanded_m = xs.enable_manual_sharding(
        expanded_m, partition_spec, mesh=mesh).global_tensor
    grad_output = xs.enable_manual_sharding(
        grad_output, partition_spec, mesh=mesh).global_tensor
    expanded_grad_i = xs.enable_manual_sharding(
        expanded_grad_i, partition_spec, mesh=mesh).global_tensor
    if ab:
      ab = xs.enable_manual_sharding(
          ab, partition_spec, mesh=mesh).global_tensor

  if needs_input_grad[0]:
    payload, _ = trace_pallas(
        _flash_attention_bwd_dq,
        q,
        k,
        v,
        ab,
        segment_ids,
        l,
        m,
        grad_output,
        grad_i,
        block_q_major=min(FlashAttention.DEFAULT_BLOCK_SIZES["block_q_dq"],
                          q.shape[2]),
        block_k_major=min(
            FlashAttention.DEFAULT_BLOCK_SIZES["block_k_major_dq"],
            k.shape[2]),
        block_k=min(FlashAttention.DEFAULT_BLOCK_SIZES["block_k_dq"],
                    k.shape[2]),
        sm_scale=sm_scale,
        causal=causal,
        mask_value=FlashAttention.DEFAULT_MASK_VALUE,
        debug=False,
        static_argnames=[
            "block_q_major", "block_k_major", "block_k", "sm_scale", "causal",
            "mask_value", "debug"
        ],
        use_cache=True,
    )

    args = [q, k, v]
    if ab is not None:
      args += [ab]
    if segment_ids is not None:
      args += [q_segment_ids_fa, kv_segment_ids_fa]
    args += [expanded_l, expanded_m, grad_output, expanded_grad_i]

    outputs = [q]
    if ab is not None:
      outputs += [ab]
    grads = torch_xla._XLAC._xla_tpu_custom_call(args, payload,
                                                 [i.shape for i in outputs],
                                                 [i.dtype for i in outputs])
    if needs_input_grad[0]:
      grad_q = grads[0]

  if needs_input_grad[1] or needs_input_grad[2]:
    payload, _ = trace_pallas(
        _flash_attention_bwd_dkv,
        q,
        k,
        v,
        ab,
        segment_ids,
        l,
        m,
        grad_output,
        grad_i,
        block_q_major=min(
            FlashAttention.DEFAULT_BLOCK_SIZES["block_q_major_dkv"],
            q.shape[2]),
        block_k_major=min(
            FlashAttention.DEFAULT_BLOCK_SIZES["block_k_major_dkv"],
            k.shape[2]),
        block_k=min(FlashAttention.DEFAULT_BLOCK_SIZES["block_k_dkv"],
                    k.shape[2]),
        block_q=min(FlashAttention.DEFAULT_BLOCK_SIZES["block_q_dkv"],
                    q.shape[2]),
        sm_scale=sm_scale,
        causal=causal,
        mask_value=FlashAttention.DEFAULT_MASK_VALUE,
        debug=False,
        static_argnames=[
            "block_q_major", "block_k_major", "block_k", "block_q",
            "sm_scale", "causal", "mask_value", "debug"
        ],
        use_cache=True)

    grads = torch_xla._XLAC._xla_tpu_custom_call(args, payload,
                                                 [k.shape, v.shape],
                                                 [k.dtype, v.dtype])

  if needs_input_grad[1]:
    grad_k = grads[0]
  if needs_input_grad[2]:
    grad_v = grads[1]

  # SPMD integration
  if partition_spec is not None:
    grad_q = xs.disable_manual_sharding(
        grad_q, partition_spec, q_full_shape, mesh=mesh).global_tensor
    grad_k = xs.disable_manual_sharding(
        grad_k, partition_spec, kv_full_shape, mesh=mesh).global_tensor
    grad_v = xs.disable_manual_sharding(
        grad_v, partition_spec, kv_full_shape, mesh=mesh).global_tensor

  return grad_q, grad_k, grad_v


@fa_custom_backward.register_fake
def _(grad_o, q, k, v, o, l, m):
  print("Inside fake fa_custom_backward")
  return torch.empty_like(grad_o), torch.empty_like(grad_o), torch.empty_like(grad_o)


class FlashAttention2(torch.autograd.Function):
  @staticmethod
  def forward(ctx, q, k, v, causal, q_segment_ids, kv_segment_ids, sm_scale, ab,
              partition_spec, mesh):
    outs = fa_custom_forward(q, k, v)
    print("forward done with fa_custom_forward")
    
    o = outs[0]
    saved = outs[1:]
    proxy_ctx = ProxyCtx.deserialize(saved)
    full_q, full_k, full_v, o2, l, m = proxy_ctx.saved_tensors

    # q_segment_ids and kv_segment_ids are sharded here if partition_spec is provided
    # but it should be OK as the backward will use the same partition_spec
    ctx.save_for_backward(full_q, full_k, full_v, o2, l, m)
    return o

  @staticmethod
  def backward(ctx, grad_output):
    grad_ab = None
    print("Inside backward")
    
    saved = [serialize(v) for v in ctx.saved_tensors]
    for t in [grad_output] + saved:
      describe_value(t)

    return fa_custom_backward(grad_output, *saved) + (None, None, None, None, grad_ab, None, None)


def flash_attention_2(
    q,  # [batch_size, num_heads, q_seq_len, d_model]
    k,  # [batch_size, num_heads, kv_seq_len, d_model]
    v,  # [batch_size, num_heads, kv_seq_len, d_model]
    causal=False,
    q_segment_ids=None,  # [batch_size, q_seq_len]
    kv_segment_ids=None,  # [batch_size, kv_seq_len]
    sm_scale=1.0,
    *,
    ab=None,  # [batch_size, num_heads, q_seq_len, kv_seq_len]
    partition_spec=None,
    mesh=None,
):
  # TODO: support SPMD and Dynamo with segment_ids.
  return FlashAttention2.apply(q, k, v, causal, q_segment_ids, kv_segment_ids,
                              sm_scale, ab, partition_spec, mesh)


In [8]:
test_flash_attention_wrapper_with_aot_autograd(flash_attention_2)

Inside fake fa_custom_forward
forward done with fa_custom_forward
Inside fake fa_custom_forward
forward done with fa_custom_forward
Inside backward
<class 'torch._subclasses.functional_tensor.FunctionalTensor'>(torch.Size([3, 2, 128, 4]), dtype=torch.float32, device=xla:0)
<class 'torch._subclasses.functional_tensor.FunctionalTensor'>(torch.Size([3, 2, 128, 4]), dtype=torch.float32, device=xla:0)
<class 'torch._subclasses.functional_tensor.FunctionalTensor'>(torch.Size([3, 2, 128, 4]), dtype=torch.float32, device=xla:0)
<class 'torch._subclasses.functional_tensor.FunctionalTensor'>(torch.Size([3, 2, 128, 4]), dtype=torch.float32, device=xla:0)
<class 'torch._subclasses.functional_tensor.FunctionalTensor'>(torch.Size([3, 2, 128, 4]), dtype=torch.float32, device=xla:0)
<class 'torch._subclasses.functional_tensor.FunctionalTensor'>(torch.Size([3, 2, 128]), dtype=torch.float32, device=xla:0)
<class 'torch._subclasses.functional_tensor.FunctionalTensor'>(torch.Size([3, 2, 128]), dtype=torch

