In this notebook we test an AOTAutograd-friendly flash attention kernel with
sharded inputs.

In [1]:
%env TPU_LIBRARY_PATH=/workspaces/torch/_libtpu.so
%env XLA_SAVE_TENSORS_FILE=ir_dumps/aot-sharded-flash-attention.txt
%env XLA_SAVE_TENSORS_FMT=text
%env XLA_DEBUG_HLO=1
%env XLA_DEBUG_IR=1
%env XLA_FLAGS=--xla_dump_to=xla_dumps/aot-sharded-flash-attention
%env PT_XLA_DEBUG_LEVEL=2

env: TPU_LIBRARY_PATH=/workspaces/torch/_libtpu.so
env: XLA_SAVE_TENSORS_FILE=ir_dumps/aot-sharded-flash-attention.txt
env: XLA_SAVE_TENSORS_FMT=text
env: XLA_DEBUG_HLO=1
env: XLA_DEBUG_IR=1
env: XLA_FLAGS=--xla_dump_to=xla_dumps/aot-sharded-flash-attention
env: PT_XLA_DEBUG_LEVEL=2


In [2]:
import torch
import torch_xla
import torch_xla.runtime
import torch_xla.distributed.spmd as xs

torch_xla.runtime.use_spmd()

num_devices = torch_xla.runtime.global_runtime_device_count()
assert num_devices > 4

mesh_shape = (num_devices // 2, 2)
spmd_mesh = xs.Mesh(list(range(num_devices)), mesh_shape, ('fsdp', 'tensor'))
xs.set_global_mesh(spmd_mesh)

torch.set_default_dtype(torch.float32)
torch.manual_seed(42)
torch_xla._XLAC._xla_set_mat_mul_precision('highest')



## AOTAutograd-friendly flash attention

In [3]:
AVOID_AS_STRIDED = False

In [4]:
import torch
import torch_xla
from typing import List
import functools
import os
import warnings

import numpy as np
import torch
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.spmd as xs
import torch_xla.debug.metrics as met

from typing import Any, List, Callable, Optional, Tuple, Dict
from torch.library import impl, custom_op
from torch_xla.core.xla_model import XLA_LIB
from torch_xla.experimental.custom_kernel import FlashAttention
import torch_xla.debug.profiler as xp

_XLA_USE_BF16 = os.environ.get("XLA_USE_BF16", "0") == "1"

_DEBUG = False


def describe_value(v):
  if v is not None and isinstance(v, torch.Tensor):
    print(f"{type(v)}({v.shape}, dtype={v.dtype}, device={v.device})")
  elif isinstance(v, list):
    print(f"list({len(v)})")
  elif v is None:
    print("None")
  else:
    print(type(v))


def _extract_backend_config(
    module: "jaxlib.mlir._mlir_libs._mlir.ir.Module") -> Optional[str]:
  """
  This algorithm intends to extract the backend config from the compiler IR like the following,
  and it is not designed to traverse any generic MLIR module.

  module @jit_add_vectors attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
    func.func public @main(%arg0: tensor<8xi32> {mhlo.layout_mode = "default", mhlo.sharding = "{replicated}"}, %arg1: tensor<8xi32> {mhlo.layout_mode = "default", mhlo.sharding = "{replicated}"}) -> (tensor<8xi32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
      %0 = call @add_vectors(%arg0, %arg1) : (tensor<8xi32>, tensor<8xi32>) -> tensor<8xi32>
      return %0 : tensor<8xi32>
    }
    func.func private @add_vectors(%arg0: tensor<8xi32>, %arg1: tensor<8xi32>) -> tensor<8xi32> {
      %0 = call @wrapped(%arg0, %arg1) : (tensor<8xi32>, tensor<8xi32>) -> tensor<8xi32>
      return %0 : tensor<8xi32>
    }
    func.func private @wrapped(%arg0: tensor<8xi32>, %arg1: tensor<8xi32>) -> tensor<8xi32> {
      %0 = call @apply_kernel(%arg0, %arg1) : (tensor<8xi32>, tensor<8xi32>) -> tensor<8xi32>
      return %0 : tensor<8xi32>
    }
    func.func private @apply_kernel(%arg0: tensor<8xi32>, %arg1: tensor<8xi32>) -> tensor<8xi32> {
      %0 = stablehlo.custom_call @tpu_custom_call(%arg0, %arg1) {backend_config = "{\22custom_call_config\22: {\22body\22: \22TUzvUgFNTElSMTkuMC4wZ2l0AAErCwEDBQcJAQMLAwUDDQcFDxEJBRMVA3lZDQFVBwsPEw8PCw8PMwsLCwtlCwsLCwsPCw8PFw8LFw8PCxcPCxcTCw8LDxcLBQNhBwNZAQ0bBxMPGw8CagMfBRcdKy0DAycpHVMREQsBBRkVMzkVTw8DCxUXGRsfCyELIyUFGwEBBR0NCWFmZmluZV9tYXA8KGQwKSAtPiAoZDApPgAFHwUhBSMFJQUnEQMBBSkVLw8dDTEXA8IfAR01NwUrFwPWHwEVO0EdPT8FLRcD9h8BHUNFBS8XA3InAQMDSVcFMR1NEQUzHQ1RFwPGHwEFNSN0cHUubWVtb3J5X3NwYWNlPHZtZW0+ACNhcml0aC5vdmVyZmxvdzxub25lPgAXVQMhBx0DJwMhBwECAgUHAQEBAQECBASpBQEQAQcDAQUDEQETBwMVJwcBAQEBAQEHAwUHAwMLBgUDBQUBBwcDBQcDAwsGBQMFBQMLCQdLRwMFBQkNBwMJBwMDCwYJAwUFBRENBAkHDwURBQABBgMBBQEAxgg32wsdE2EZ2Q0LEyMhHSknaw0LCxMPDw8NCQsRYnVpbHRpbgBmdW5jAHRwdQBhcml0aAB2ZWN0b3IAbW9kdWxlAHJldHVybgBjb25zdGFudABhZGRpAGxvYWQAc3RvcmUAL3dvcmtzcGFjZXMvd29yay9weXRvcmNoL3hsYS90ZXN0L3Rlc3Rfb3BlcmF0aW9ucy5weQBhZGRfdmVjdG9yc19rZXJuZWwAZGltZW5zaW9uX3NlbWFudGljcwBmdW5jdGlvbl90eXBlAHNjYWxhcl9wcmVmZXRjaABzY3JhdGNoX29wZXJhbmRzAHN5bV9uYW1lAG1haW4AdmFsdWUAL2dldFt0cmVlPVB5VHJlZURlZigoQ3VzdG9tTm9kZShOREluZGV4ZXJbKFB5VHJlZURlZigoQ3VzdG9tTm9kZShTbGljZVsoMCwgOCldLCBbXSksKSksICg4LCksICgpKV0sIFtdKSwpKV0AYWRkX3ZlY3RvcnMAdGVzdF90cHVfY3VzdG9tX2NhbGxfcGFsbGFzX2V4dHJhY3RfYWRkX3BheWxvYWQAPG1vZHVsZT4Ab3ZlcmZsb3dGbGFncwAvYWRkAC9zd2FwW3RyZWU9UHlUcmVlRGVmKChDdXN0b21Ob2RlKE5ESW5kZXhlclsoUHlUcmVlRGVmKChDdXN0b21Ob2RlKFNsaWNlWygwLCA4KV0sIFtdKSwpKSwgKDgsKSwgKCkpXSwgW10pLCkpXQA=\22, \22needs_layout_passes\22: true}}", kernel_name = "add_vectors_kernel", operand_layouts = [dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>], result_layouts = [dense<0> : tensor<1xindex>]} : (tensor<8xi32>, tensor<8xi32>) -> tensor<8xi32>
      return %0 : tensor<8xi32>
    }
  }

  Basically, what we are looking for is a two level of operations, and the tpu_custom_call operation in the inner level. It will return None if the payload is not found.
  """
  for operation in module.body.operations:
    assert len(
        operation.body.blocks) == 1, "The passing module is not compatible."
    for op in operation.body.blocks[0].operations:
      if op.name == "stablehlo.custom_call":
        return op.backend_config.value
  return None


def jax_import_guard():
  # Somehow, we need to grab the TPU before JAX locks it. Otherwise, any pt-xla TPU operations will hang.
  torch_xla._XLAC._init_computation_client()


def convert_torch_dtype_to_jax(dtype: torch.dtype) -> "jnp.dtype":
  # Import JAX within the function such that we don't need to call the jax_import_guard()
  # in the global scope which could cause problems for xmp.spawn.
  jax_import_guard()
  import jax.numpy as jnp
  if _XLA_USE_BF16:
    raise RuntimeError(
        "Pallas kernel does not support XLA_USE_BF16, please unset the env var")
  if dtype == torch.float32:
    return jnp.float32
  elif dtype == torch.float64:
    return jnp.float64
  elif dtype == torch.float16:
    return jnp.float16
  elif dtype == torch.bfloat16:
    return jnp.bfloat16
  elif dtype == torch.int32:
    return jnp.int32
  elif dtype == torch.int64:
    return jnp.int64
  elif dtype == torch.int16:
    return jnp.int16
  elif dtype == torch.int8:
    return jnp.int8
  elif dtype == torch.uint8:
    return jnp.uint8
  else:
    raise ValueError(f"Unsupported dtype: {dtype}")


def to_jax_shape_dtype_struct(tensor: torch.Tensor) -> "jax.ShapeDtypeStruct":
  # Import JAX within the function such that we don't need to call the jax_import_guard()
  # in the global scope which could cause problems for xmp.spawn.
  jax_import_guard()
  import jax

  return jax.ShapeDtypeStruct(tensor.shape,
                              convert_torch_dtype_to_jax(tensor.dtype))


trace_pallas_arg_to_payload: Dict[Tuple[Any], str] = {}


def trace_pallas(kernel: Callable,
                 *args,
                 static_argnums=None,
                 static_argnames=None,
                 use_cache=False,
                 **kwargs):
  # Import JAX within the function such that we don't need to call the jax_import_guard()
  # in the global scope which could cause problems for xmp.spawn.
  jax_import_guard()
  import jax
  import jax._src.pallas.mosaic.pallas_call_registration

  jax_args = []  # for tracing
  tensor_args = []  # for execution
  for i, arg in enumerate(args):
    # TODO: Could the args be a tuple of tensors or a list of tensors? Flattern them?
    if torch.is_tensor(arg):
      # ShapeDtypeStruct doesn't have any storage and thus is very suitable for generating the payload.
      jax_meta_tensor = to_jax_shape_dtype_struct(arg)
      jax_args.append(jax_meta_tensor)
      tensor_args.append(arg)
    else:
      jax_args.append(arg)

  hash_key = ()
  if use_cache:
    global trace_pallas_arg_to_payload
    # implcit assumption here that everything in kwargs is hashable and not a tensor,
    # which is true for the gmm and tgmm.
    hash_key = (jax.config.jax_default_matmul_precision, kernel, static_argnums,
                tuple(static_argnames)
                if static_argnames is not None else static_argnames,
                tuple(jax_args), repr(sorted(kwargs.items())).encode())
    if hash_key in trace_pallas_arg_to_payload:
      torch_xla._XLAC._xla_increment_counter('trace_pallas_cache_hit', 1)
      return trace_pallas_arg_to_payload[hash_key], tensor_args

  # Here we ignore the kwargs for execution as most of the time, the kwargs is only used in traced code.
  ir = jax.jit(
      kernel, static_argnums=static_argnums,
      static_argnames=static_argnames).lower(*jax_args, **kwargs).compiler_ir()
  payload = _extract_backend_config(ir)

  if use_cache:
    # if we reach here it means we have a cache miss.
    trace_pallas_arg_to_payload[hash_key] = payload

  return payload, tensor_args


def make_kernel_from_pallas(kernel: Callable, output_shape_dtype_fn: Callable):
  # TODO: Maybe we can cache the payload for the same input.
  def wrapped_kernel(kernel: Callable,
                     output_shape_dtype_fn: Callable,
                     *args,
                     static_argnums=None,
                     static_argnames=None,
                     **kwargs) -> Callable:
    payload, tensor_args = trace_pallas(
        kernel,
        *args,
        static_argnums=static_argnums,
        static_argnames=static_argnames,
        **kwargs)
    output_shape_dtype = output_shape_dtype_fn(*args)
    assert isinstance(output_shape_dtype,
                      list), "The output_shape_dtype_fn should return a list."
    output_shapes = [shape for shape, _ in output_shape_dtype]
    output_dtypes = [dtype for _, dtype in output_shape_dtype]
    outputs = torch_xla._XLAC._xla_tpu_custom_call(tensor_args, payload,
                                                   output_shapes, output_dtypes)

    # Make the output easier to use.
    if len(outputs) == 1:
      return outputs[0]
    return tuple(outputs)

  return functools.partial(wrapped_kernel, kernel, output_shape_dtype_fn)


def defeat_alias(v):
  return v * 1


# Note: the alias inference and mutation removal in PyTorch doesn't work. So we
#
# - Explicitly clone all inputs.
# - Clone outputs if the output aliases an input.
#
@custom_op("xla::fa_custom_forward", mutates_args=())
def fa_custom_forward(
    q: torch.Tensor, k: torch.Tensor, v: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
           torch.Tensor]:
  partition_spec = ('fsdp', 'tensor', None, None)
  mesh = xs.get_global_mesh()
  assert mesh is not None

  if _DEBUG:
    print("Inside fa_custom_forward")
    for t in [q, k, v]:
      describe_value(t)

  q_segment_ids = kv_segment_ids = ab = None
  sm_scale = 1.0
  causal = False 

  # Import JAX within the function such that we don't need to call the jax_import_guard()
  # in the global scope which could cause problems for xmp.spawn.
  jax_import_guard()
  import jax
  from jax.experimental.pallas.ops.tpu.flash_attention import _flash_attention_impl

  q_full_shape = None
  kv_full_shape = None
  save_residuals = True
  
  q = defeat_alias(q)
  k = defeat_alias(k)
  v = defeat_alias(v)

  # SPMD integration.
  # mark_sharding is in-placed, and therefore save the full q, k, v for the backward.
  # PyTorch tell us clone is necessary:
  #
  # RuntimeError: Found a custom (non-ATen) operator whose output has alias
  # annotations: xla::fa_custom_forward(Tensor(a0!) q, Tensor(a1!) k,
  # Tensor(a2!) v) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor). We only
  # support functionalizing operators whose outputs do not have alias
  # annotations (e.g. 'Tensor(a)' is a Tensor with an alias annotation whereas
  # 'Tensor' is a Tensor without. The '(a)' is the alias annotation). The alias
  # annotation specifies that the output Tensor shares storage with an input
  # that has the same annotation. Please check if (1) the output needs to be an
  # output (if not, don't return it), (2) if the output doesn't share storage
  # with any inputs, then delete the alias annotation. (3) if the output indeed
  # shares storage with an input, then add a .clone() before returning it to
  # prevent storage sharing and then delete the alias annotation. Otherwise,
  # please file an issue on GitHub.
  #
  with xp.Trace('shard'):
    full_q = q.clone()
    full_k = k.clone()
    full_v = v.clone()
    full_ab = ab
    if partition_spec is not None:
      q_full_shape = q.shape
      kv_full_shape = k.shape
      q = xs.enable_manual_sharding(q, partition_spec, mesh=mesh).global_tensor
      k = xs.enable_manual_sharding(k, partition_spec, mesh=mesh).global_tensor
      v = xs.enable_manual_sharding(v, partition_spec, mesh=mesh).global_tensor
      if ab:
        ab = xs.enable_manual_sharding(
            ab, partition_spec, mesh=mesh).global_tensor

  # It computes the shape and type of o, l, m.
  shapes = [q.shape]
  dtypes = [q.dtype]
  if save_residuals:
    res_shape = list(q.shape)
    res_shape[-1] = FlashAttention.MIN_BLOCK_SIZE
    for _ in range(2):
      shapes.append(res_shape)
      dtypes.append(torch.float32)

  with torch.no_grad():
    if partition_spec is not None and q_segment_ids is not None and kv_segment_ids is not None:
      # partition_spec is for q,k,v with shape [batch, num_head, seq_len, head_dim], segment id
      # is of shape [batch, seq_len], hence we need to tweak it a bit
      segment_id_partition_spec = (partition_spec[0], partition_spec[2])
      q_segment_ids = xs.enable_manual_sharding(
          q_segment_ids, segment_id_partition_spec, mesh=mesh).global_tensor
      kv_segment_ids = xs.enable_manual_sharding(
          kv_segment_ids, segment_id_partition_spec, mesh=mesh).global_tensor
    segment_ids, q_segment_ids_fa, kv_segment_ids_fa = FlashAttention.prepare_segment_ids(
        q_segment_ids, kv_segment_ids)

    with xp.Trace('pallas'):
      # We can't directly use flash_attention as we need to override the save_residuals flag which returns
      # l and m that is needed for the backward. Then we lose all the shape checks.
      # TODO: replicate the shape checks on flash_attention.
      # Here we seperate the tracing and execution part just to support SegmentIds.
      payload, _ = trace_pallas(
          _flash_attention_impl,
          q,
          k,
          v,
          ab,
          segment_ids,
          save_residuals,
          causal,
          sm_scale,
          min(FlashAttention.DEFAULT_BLOCK_SIZES["block_b"], q.shape[0]),
          min(FlashAttention.DEFAULT_BLOCK_SIZES["block_q"], q.shape[2]),
          min(FlashAttention.DEFAULT_BLOCK_SIZES["block_k_major"], k.shape[2]),
          min(FlashAttention.DEFAULT_BLOCK_SIZES["block_k"], k.shape[2]),
          False,
          static_argnums=range(5, 13),
          use_cache=True,
      )

    with xp.Trace('custom_call'):
      args = [q, k, v]
      if ab is not None:
        args += [ab]
      if segment_ids is not None:
        args += [q_segment_ids_fa, kv_segment_ids_fa]
      o = torch_xla._XLAC._xla_tpu_custom_call(args, payload, shapes, dtypes)

    if not save_residuals:
      o = o[0]
      # SPMD integration
      if partition_spec is not None:
        o = xs.disable_manual_sharding(
            o, partition_spec, q_full_shape, mesh=mesh).global_tensor
      return o

    assert isinstance(o, list)
    o, *aux = o
    
    print("About to index into aux to get l, m")
    import torch_xla.debug.metrics as met
    met.clear_all()

    if AVOID_AS_STRIDED:
      l = aux[-2]
      m = aux[-1]

      print("Done indexing into aux. Metrics:")
      print(met.metrics_report())
      
      # SPMD integration
      with xp.Trace('index_lm'):
        if partition_spec is not None:
          o = xs.disable_manual_sharding(
              o, partition_spec, q_full_shape, mesh=mesh).global_tensor
          l = xs.disable_manual_sharding(
              l, partition_spec, q_full_shape[0:3] + (l.shape[-1], ),
              mesh=mesh).global_tensor
          m = xs.disable_manual_sharding(
              m, partition_spec, q_full_shape[0:3] + (m.shape[-1], ),
              mesh=mesh).global_tensor
          
          l = l.permute(3, 0, 1, 2)[0]
          m = m.permute(3, 0, 1, 2)[0]
    else:
      l, m = (v[..., 0] for v in aux[-2:])

      print("Done indexing into aux. Metrics:")
      print(met.metrics_report())

      # SPMD integration
      with xp.Trace('index_lm'):
        if partition_spec is not None:
          o = xs.disable_manual_sharding(
              o, partition_spec, q_full_shape, mesh=mesh).global_tensor
          l = xs.disable_manual_sharding(
              l, partition_spec[0:3], q_full_shape[0:3],
              mesh=mesh).global_tensor
          m = xs.disable_manual_sharding(
              m, partition_spec[0:3], q_full_shape[0:3],
              mesh=mesh).global_tensor

  assert partition_spec is not None

  # q_segment_ids and kv_segment_ids are sharded here if partition_spec is provided
  # but it should be OK as the backward will use the same partition_spec
  outs = [o] + [full_q, full_k, full_v, l, m]
  if _DEBUG:
    print("Outs")
    for t in outs:
      describe_value(t)
  return tuple(outs)


@fa_custom_forward.register_fake
def fa_custom_forward_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
  if _DEBUG:
    print("Inside fake fa_custom_forward")

  assert q.shape == k.shape
  assert k.shape == v.shape

  # full_q, full_k, full_v, o, l, m
  full_q = torch.empty_like(q)
  full_k = torch.empty_like(k)
  full_v = torch.empty_like(v)
  o = torch.empty_like(v)
  l = torch.empty_like(v, dtype=torch.float32)[..., 0]
  m = torch.empty_like(v, dtype=torch.float32)[..., 0]

  return tuple([torch.empty_like(o)] +
               [torch.empty_like(t) for t in (
                   full_q,
                   full_k,
                   full_v,
                   l,
                   m,
               )])


@custom_op("xla::fa_custom_backward", mutates_args=())
def fa_custom_backward(
    grad_output: torch.Tensor, q: torch.Tensor, k: torch.Tensor,
    v: torch.Tensor, o: torch.Tensor, l: torch.Tensor, m: torch.Tensor,
    q_shape: List[int],
    k_shape: List[int]) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  q_segment_ids_fa = kv_segment_ids_fa = ab = None

  partition_spec = ('fsdp', 'tensor', None, None)
  mesh = xs.get_global_mesh()
  assert mesh is not None

  if _DEBUG:
    print("Inside fa_custom_backward")

  from jax.experimental.pallas.ops.tpu.flash_attention import _flash_attention_bwd_dq, _flash_attention_bwd_dkv

  grad_output = defeat_alias(grad_output)
  saved_tensors = (q, k, v, o, l, m)
  q, k, v, o, l, m = (defeat_alias(t) for t in saved_tensors)

  causal = False
  sm_scale = 1.0
  q_full_shape = torch.Size(q_shape)
  kv_full_shape = torch.Size(k_shape)
  # this segment_ids only reflects the local shape of segment_ids
  segment_ids = None
  grad_q = grad_k = grad_v = grad_ab = None
  needs_input_grad = [True, True, True]
  grad_i = torch.sum(
      o.to(torch.float32) * grad_output.to(torch.float32),
      axis=-1)  # [batch_size, num_heads, q_seq_len]

  expanded_l = l.unsqueeze(-1).expand([-1 for _ in l.shape] +
                                      [FlashAttention.MIN_BLOCK_SIZE])
  expanded_m = m.unsqueeze(-1).expand([-1 for _ in m.shape] +
                                      [FlashAttention.MIN_BLOCK_SIZE])
  expanded_grad_i = grad_i.unsqueeze(-1).expand([-1 for _ in grad_i.shape] +
                                                [FlashAttention.MIN_BLOCK_SIZE])

  # SPMD integration
  if partition_spec is not None:
    q = xs.enable_manual_sharding(q, partition_spec, mesh=mesh).global_tensor
    k = xs.enable_manual_sharding(k, partition_spec, mesh=mesh).global_tensor
    v = xs.enable_manual_sharding(v, partition_spec, mesh=mesh).global_tensor
    expanded_l = xs.enable_manual_sharding(
        expanded_l, partition_spec, mesh=mesh).global_tensor
    expanded_m = xs.enable_manual_sharding(
        expanded_m, partition_spec, mesh=mesh).global_tensor
    grad_output = xs.enable_manual_sharding(
        grad_output, partition_spec, mesh=mesh).global_tensor
    expanded_grad_i = xs.enable_manual_sharding(
        expanded_grad_i, partition_spec, mesh=mesh).global_tensor
    if ab:
      ab = xs.enable_manual_sharding(
          ab, partition_spec, mesh=mesh).global_tensor

  if needs_input_grad[0]:
    payload, _ = trace_pallas(
        _flash_attention_bwd_dq,
        q,
        k,
        v,
        ab,
        segment_ids,
        l,
        m,
        grad_output,
        grad_i,
        block_q_major=min(FlashAttention.DEFAULT_BLOCK_SIZES["block_q_dq"],
                          q.shape[2]),
        block_k_major=min(
            FlashAttention.DEFAULT_BLOCK_SIZES["block_k_major_dq"], k.shape[2]),
        block_k=min(FlashAttention.DEFAULT_BLOCK_SIZES["block_k_dq"],
                    k.shape[2]),
        sm_scale=sm_scale,
        causal=causal,
        mask_value=FlashAttention.DEFAULT_MASK_VALUE,
        debug=False,
        static_argnames=[
            "block_q_major", "block_k_major", "block_k", "sm_scale", "causal",
            "mask_value", "debug"
        ],
        use_cache=True,
    )

    args = [q, k, v]
    if ab is not None:
      args += [ab]
    if segment_ids is not None:
      args += [q_segment_ids_fa, kv_segment_ids_fa]
    args += [expanded_l, expanded_m, grad_output, expanded_grad_i]

    outputs = [q]
    if ab is not None:
      outputs += [ab]
    grads = torch_xla._XLAC._xla_tpu_custom_call(args, payload,
                                                 [i.shape for i in outputs],
                                                 [i.dtype for i in outputs])
    if needs_input_grad[0]:
      grad_q = grads[0]

  if needs_input_grad[1] or needs_input_grad[2]:
    payload, _ = trace_pallas(
        _flash_attention_bwd_dkv,
        q,
        k,
        v,
        ab,
        segment_ids,
        l,
        m,
        grad_output,
        grad_i,
        block_q_major=min(
            FlashAttention.DEFAULT_BLOCK_SIZES["block_q_major_dkv"],
            q.shape[2]),
        block_k_major=min(
            FlashAttention.DEFAULT_BLOCK_SIZES["block_k_major_dkv"],
            k.shape[2]),
        block_k=min(FlashAttention.DEFAULT_BLOCK_SIZES["block_k_dkv"],
                    k.shape[2]),
        block_q=min(FlashAttention.DEFAULT_BLOCK_SIZES["block_q_dkv"],
                    q.shape[2]),
        sm_scale=sm_scale,
        causal=causal,
        mask_value=FlashAttention.DEFAULT_MASK_VALUE,
        debug=False,
        static_argnames=[
            "block_q_major", "block_k_major", "block_k", "block_q", "sm_scale",
            "causal", "mask_value", "debug"
        ],
        use_cache=True)

    grads = torch_xla._XLAC._xla_tpu_custom_call(args, payload,
                                                 [k.shape, v.shape],
                                                 [k.dtype, v.dtype])

  if needs_input_grad[1]:
    grad_k = grads[0]
  if needs_input_grad[2]:
    grad_v = grads[1]

  # SPMD integration
  if partition_spec is not None:
    grad_q = xs.disable_manual_sharding(
        grad_q, partition_spec, q_full_shape, mesh=mesh).global_tensor
    grad_k = xs.disable_manual_sharding(
        grad_k, partition_spec, kv_full_shape, mesh=mesh).global_tensor
    grad_v = xs.disable_manual_sharding(
        grad_v, partition_spec, kv_full_shape, mesh=mesh).global_tensor

  assert partition_spec is not None

  return grad_q, grad_k, grad_v


@fa_custom_backward.register_fake
def fa_custom_backward_fake(grad_output, q, k, v, o, l, m, q_shape, k_shape):
  if _DEBUG:
    print("Inside fake fa_custom_backward")
  return torch.empty_like(grad_output), torch.empty_like(
      grad_output), torch.empty_like(grad_output)


class FlashAttention2(torch.autograd.Function):

  @staticmethod
  def forward(ctx, q, k, v):
    with torch.no_grad():
      ctx.q_shape = q.shape
      ctx.k_shape = k.shape

      outs = torch.ops.xla.fa_custom_forward(q, k, v)
      if _DEBUG:
        print("forward done with fa_custom_forward")

      o = outs[0]
      full_q, full_k, full_v, l, m = [x for x in outs[1:]]

      # q_segment_ids and kv_segment_ids are sharded here if partition_spec is provided
      # but it should be OK as the backward will use the same partition_spec
      ctx.save_for_backward(full_q, full_k, full_v, o, l, m)
      return o

  @staticmethod
  def backward(ctx, grad_output):
    with torch.no_grad():
      grad_ab = None
      if _DEBUG:
        print("Inside backward")

      saved = [v for v in ctx.saved_tensors]
      if _DEBUG:
        for t in [grad_output] + saved:
          describe_value(t)

      return torch.ops.xla.fa_custom_backward(grad_output, *saved,
                                              list(ctx.q_shape),
                                              list(ctx.k_shape))


def flash_attention_2(
    q,  # [batch_size, num_heads, q_seq_len, d_model]
    k,  # [batch_size, num_heads, kv_seq_len, d_model]
    v,  # [batch_size, num_heads, kv_seq_len, d_model]
    causal=False,
    q_segment_ids=None,  # [batch_size, q_seq_len]
    kv_segment_ids=None,
    sm_scale=1.0,
    *,
    ab=None,  # [batch_size, num_heads, q_seq_len, kv_seq_len]
    partition_spec=None,
    mesh=None,
):
  assert not causal, "causal must be False"
  assert partition_spec == ('fsdp', 'tensor', None, None)
  return FlashAttention2.apply(q, k, v)


## Run a test with the kernel

In [5]:
jax_import_guard()
import jax
import jax.numpy as jnp
from jax.experimental import pallas as pl
import torch.nn as nn


def _attention(q, k, v, *, attn_mask=None, ab=None):
  attn_weight = q @ k.transpose(-2, -1)
  if attn_mask is not None:
    # Masked out the unrelevant parts.
    attn_weight = attn_weight.masked_fill(attn_mask,
                                          torch.finfo(attn_weight.dtype).min)
  if ab is not None:
    attn_weight = attn_weight + ab
  attn_weight = nn.functional.softmax(attn_weight, dim=-1)
  attn_output = attn_weight @ v
  return attn_output


def do_test(attn_fn, aot_autograd: bool):
  from functorch.compile import aot_function, make_boxed_func  # type: ignore

  def flash_attention_wrapper(q, k, v):
    return attn_fn(q, k, v, partition_spec=('fsdp', 'tensor', None, None))

  q = torch.randn(16, 2, 128, 8).to("xla").clone().detach().requires_grad_(True)
  k = torch.randn(16, 2, 128, 8).to("xla").clone().detach().requires_grad_(True)
  v = torch.randn(16, 2, 128, 8).to("xla").clone().detach().requires_grad_(True)
  
  xs.mark_sharding(q, xs.get_global_mesh(), ('fsdp', 'tensor', None, None))
  xs.mark_sharding(k, xs.get_global_mesh(), ('fsdp', 'tensor', None, None))
  xs.mark_sharding(v, xs.get_global_mesh(), ('fsdp', 'tensor', None, None))

  q_clone = q.clone().detach().requires_grad_(True)
  k_clone = k.clone().detach().requires_grad_(True)
  v_clone = v.clone().detach().requires_grad_(True)
  
  def compiler(gm, _):
    print("Got graph:")
    print(gm.code)
    return make_boxed_func(gm)

  if aot_autograd:
    compiled_flash_attention = aot_function(
        flash_attention_wrapper, fw_compiler=compiler)
    o_actual = compiled_flash_attention(q, k, v)
  else:
    o_actual = flash_attention_wrapper(q, k, v)
  o_actual.sum().backward()

  expected_o = _attention(q_clone, k_clone, v_clone)
  expected_o.sum().backward()
  
  print("Executing graph", flush=True)
  import time
  time.sleep(0.1)
  torch_xla.sync(wait=True)

  torch.testing.assert_close(o_actual.cpu(), expected_o.cpu())
  assert q.grad is not None and q_clone.grad is not None, f"{q.grad}, {q_clone.grad}"
  torch.testing.assert_close(q.grad.cpu(), q_clone.grad.cpu())
  assert k.grad is not None and k_clone.grad is not None, f"{k.grad}, {k_clone.grad}"
  torch.testing.assert_close(k.grad.cpu(), k_clone.grad.cpu())
  assert v.grad is not None and v_clone.grad is not None, f"{v.grad}, {v_clone.grad}"
  torch.testing.assert_close(v.grad.cpu(), v_clone.grad.cpu())


def test_flash_attention_wrapper(attn_fn):
  jax.config.update("jax_default_matmul_precision", "highest")
  try:
    do_test(attn_fn, aot_autograd=False)
  finally:
    jax.config.update("jax_default_matmul_precision", "default")


def test_flash_attention_wrapper_with_aot_autograd(attn_fn):
  jax.config.update("jax_default_matmul_precision", "highest")
  try:
    do_test(attn_fn, aot_autograd=True)
  finally:
    jax.config.update("jax_default_matmul_precision", "default")


In [6]:
test_flash_attention_wrapper_with_aot_autograd(flash_attention_2)

Got graph:



def forward(self, primals_1, primals_2, primals_3):
    fa_custom_forward = torch.ops.xla.fa_custom_forward.default(primals_1, primals_2, primals_3);  primals_1 = primals_2 = primals_3 = None
    getitem = fa_custom_forward[0]
    getitem_1 = fa_custom_forward[1]
    getitem_2 = fa_custom_forward[2]
    getitem_3 = fa_custom_forward[3]
    getitem_4 = fa_custom_forward[4]
    getitem_5 = fa_custom_forward[5];  fa_custom_forward = None
    return (getitem, getitem, getitem_1, getitem_2, getitem_3, getitem_4, getitem_5)
    
About to index into aux to get l, m
Done indexing into aux. Metrics:
Metric: IrValueTensorToXlaData
  TotalSamples: 2
  Accumulator: 279.130us
  ValueRate: 01s194ms036.874us / second
  Rate: 8555.42 / second
  Percentiles: 1%=113.920us; 5%=113.920us; 10%=113.920us; 20%=113.920us; 50%=165.210us; 80%=165.210us; 90%=165.210us; 95%=165.210us; 99%=165.210us
Metric: LazyTracing
  TotalSamples: 10
  Accumulator: 001ms310.940us
  ValueRate: 03s756ms160.121us / 


Compilation Analysis: Compilation Cause
Compilation Analysis:   most likely user code trying to access tensor value before mark_step
Compilation Analysis: Graph Info: 
Compilation Analysis:   Graph Hash: 3d9ebaf4585054aaca5ace517e577f8a
Compilation Analysis:   Number of Graph Inputs: 1
Compilation Analysis:   Number of Graph Outputs: 1
Compilation Analysis: Python Frame Triggered Execution: 
Compilation Analysis:   mark_sharding (/workspaces/torch/pytorch/xla/torch_xla/distributed/spmd/xla_sharding.py:627)
Compilation Analysis:   do_test (/tmp/ipykernel_2190827/803061339.py:31)
Compilation Analysis:   test_flash_attention_wrapper_with_aot_autograd (/tmp/ipykernel_2190827/803061339.py:80)
Compilation Analysis:   <module> (/tmp/ipykernel_2190827/1076877315.py:1)
Compilation Analysis:   run_code (/root/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3577)
Compilation Analysis:   run_ast_nodes (/root/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py

OK, that works great, although it transfers two 32 KiB indexing tensors.
What if the flash attention is wrapped in a scan?

In [7]:
# Clear the dumps before running the offending graph
import os
from pathlib import Path

def clear_dumps():
  for f in os.listdir('xla_dumps/aot-sharded-flash-attention'):
    p = Path(f"xla_dumps/aot-sharded-flash-attention/{f}")
    if p.is_file():
      os.remove(p)

clear_dumps()

In [8]:
# This combine_fn indexes into q, k, v, and outputs o.
def combine_fn(carry, x):
  q, k, v = [defeat_alias(v) for v in x]
  y = flash_attention_2(q, k, v, causal=False, partition_spec=('fsdp', 'tensor', None, None))
  assert isinstance(y, torch.Tensor)
  return defeat_alias(carry), defeat_alias(y)

def flash_attention_in_scan(q, k, v, partition_spec=None):
  # Create a leading dim of length 2.
  q = q.reshape(2, q.shape[0] // 2, q.shape[1], q.shape[2], q.shape[3])
  k = k.reshape(2, k.shape[0] // 2, k.shape[1], k.shape[2], k.shape[3])
  v = v.reshape(2, v.shape[0] // 2, v.shape[1], v.shape[2], v.shape[3])
  from torch_xla.experimental.scan import scan
  init = torch.zeros_like(q, requires_grad=True)
  carry, ys = scan(combine_fn, init, (q, k, v))
  assert isinstance(ys, torch.Tensor)
  return ys.reshape(ys.shape[0] * ys.shape[1], ys.shape[2], ys.shape[3], ys.shape[4])

In [None]:
test_flash_attention_wrapper(flash_attention_in_scan)

About to index into aux to get l, m
Done indexing into aux. Metrics:
Metric: IrValueTensorToXlaData
  TotalSamples: 2
  Accumulator: 348.930us
  ValueRate: 01s250ms704.523us / second
  Rate: 7163.07 / second
  Percentiles: 1%=170.150us; 5%=170.150us; 10%=170.150us; 20%=170.150us; 50%=178.780us; 80%=178.780us; 90%=178.780us; 95%=178.780us; 99%=178.780us
Metric: LazyTracing
  TotalSamples: 10
  Accumulator: 001ms179.870us
  ValueRate: 02s459ms882.127us / second
  Rate: 20840.3 / second
  Percentiles: 1%=006.410us; 5%=006.410us; 10%=014.950us; 20%=019.400us; 50%=079.030us; 80%=244.810us; 90%=351.350us; 95%=351.350us; 99%=351.350us
Metric: TensorToData
  TotalSamples: 2
  Accumulator: 346.200us
  ValueRate: 01s239ms216.809us / second
  Rate: 7158.96 / second
  Percentiles: 1%=168.460us; 5%=168.460us; 10%=168.460us; 20%=168.460us; 50%=177.740us; 80%=177.740us; 90%=177.740us; 95%=177.740us; 99%=177.740us
Counter: CreateXlaTensor
  Value: 4
Counter: DestroyLtcTensor
  Value: 2
Counter: Destro


Execution Analysis: Execution Cause
Execution Analysis:   most likely user code trying to access tensor value before mark_step
Execution Analysis: Graph Info: 
Execution Analysis:   Graph Hash: 3d9ebaf4585054aaca5ace517e577f8a
Execution Analysis:   Number of Graph Inputs: 1
Execution Analysis:   Number of Graph Outputs: 1
Execution Analysis: Python Frame Triggered Execution: 
Execution Analysis:   mark_sharding (/workspaces/torch/pytorch/xla/torch_xla/distributed/spmd/xla_sharding.py:627)
Execution Analysis:   do_test (/tmp/ipykernel_2140329/803061339.py:31)
Execution Analysis:   test_flash_attention_wrapper (/tmp/ipykernel_2140329/803061339.py:72)
Execution Analysis:   <module> (/tmp/ipykernel_2140329/1813840263.py:1)
Execution Analysis:   run_code (/root/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3577)
Execution Analysis:   run_ast_nodes (/root/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3517)
Execution Analysis:   run_cell_async (/

Here we encounter a problem: for some reason we get an `!IsManual()` assertion
in the GSPMD partitioner.

The relevant IR and HLO are rendered below.

In [1]:
from pathlib import Path

text = Path("ir_dumps/aot-sharded-flash-attention.txt.0").read_text()
print(text.split("[ScheduleSyncTensorsGraph]")[-1])


TensorsGraphInfo:
  sync (/workspaces/torch/pytorch/xla/torch_xla/torch_xla.py:69)
  do_test (/tmp/ipykernel_2140329/803061339.py:58)
  test_flash_attention_wrapper (/tmp/ipykernel_2140329/803061339.py:72)
  <module> (/tmp/ipykernel_2140329/1813840263.py:1)
  run_code (/root/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3577)
  run_ast_nodes (/root/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3517)
  run_cell_async (/root/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3334)
  _pseudo_sync_runner (/root/.local/lib/python3.10/site-packages/IPython/core/async_helpers.py:128)
  _run_cell (/root/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3130)
  run_cell (/root/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3075)
  run_cell (/root/.local/lib/python3.10/site-packages/ipykernel/zmqshell.py:549)
  do_execute (/root/.local/lib/python3.10/site-packages/ipykernel/ipkernel.py:449)


In [3]:
import os
from pathlib import Path

for f in os.listdir('xla_dumps/aot-sharded-flash-attention'):
  p = Path(f"xla_dumps/aot-sharded-flash-attention/{f}")
  if p.is_file():
    if "before_optimizations" in str(p.stem):
      print(p.read_text())
      break

HloModule SyncTensorsGraph.644, entry_computation_layout={(f32[16,2,128,8]{2,3,1,0:T(8,128)}, f32[16,2,128,8]{2,3,1,0:T(8,128)}, f32[16,2,128,8]{2,3,1,0:T(8,128)}, s64[2,1,128]{2,1,0:T(1,128)}, s64[2,1,128]{2,1,0:T(1,128)})->(f32[16,2,128,8]{2,3,1,0:T(8,128)}, f32[16,2,128,8]{2,3,1,0:T(8,128)}, f32[16,2,128,8]{2,3,1,0:T(8,128)}, f32[2,8,2,128,8]{3,4,2,1,0:T(8,128)}, f32[2,8,2,128,8]{3,4,2,1,0:T(8,128)}, /*index=5*/f32[16,2,128,8]{2,3,1,0:T(8,128)}, f32[2,8,2,128,8]{3,4,2,1,0:T(8,128)}, f32[2,8,2,128,8]{3,4,2,1,0:T(8,128)}, f32[2,8,2,128,8]{3,4,2,1,0:T(8,128)}, f32[2,8,2,128,8]{3,4,2,1,0:T(8,128)}, /*index=10*/f32[16,2,128,8]{2,3,1,0:T(8,128)}, f32[16,2,128,8]{2,3,1,0:T(8,128)}, f32[16,2,128,8]{2,3,1,0:T(8,128)}, f32[32,128,8]{1,2,0:T(8,128)}, f32[16,2,128,8]{2,3,1,0:T(8,128)}, /*index=15*/f32[32,128,8]{1,2,0:T(8,128)}, f32[16,2,128,8]{2,3,1,0:T(8,128)}, f32[32,8,128]{2,1,0:T(8,128)}, f32[32,128,8]{1,2,0:T(8,128)}, f32[16,2,128,8]{2,3,1,0:T(8,128)}, /*index=20*/f32[16,2,128,8]{2,3,1,0:T

Rewriting `as_strided` into some other operation that doesn't eagerly send
tensors (`permute` in this notebook) avoids this crash.

The numbers are still wrong, possibly due to another bug, but at least GSPMD
sharding annotations are propagated successfully.

In [9]:
AVOID_AS_STRIDED = True

clear_dumps()

test_flash_attention_wrapper(flash_attention_in_scan)

About to index into aux to get l, m
Done indexing into aux. Metrics:

About to index into aux to get l, m
Done indexing into aux. Metrics:

Executing graph



Execution Analysis: Execution Cause
Execution Analysis:   most likely user code trying to access tensor value before mark_step
Execution Analysis: Graph Info: 
Execution Analysis:   Graph Hash: 3d9ebaf4585054aaca5ace517e577f8a
Execution Analysis:   Number of Graph Inputs: 1
Execution Analysis:   Number of Graph Outputs: 1
Execution Analysis: Python Frame Triggered Execution: 
Execution Analysis:   mark_sharding (/workspaces/torch/pytorch/xla/torch_xla/distributed/spmd/xla_sharding.py:627)
Execution Analysis:   do_test (/tmp/ipykernel_2190827/803061339.py:31)
Execution Analysis:   test_flash_attention_wrapper (/tmp/ipykernel_2190827/803061339.py:72)
Execution Analysis:   <module> (/tmp/ipykernel_2190827/3180461044.py:5)
Execution Analysis:   run_code (/root/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3577)
Execution Analysis:   run_ast_nodes (/root/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3517)
Execution Analysis:   run_cell_async (/


Compilation Analysis: Compilation Cause
Compilation Analysis:   user mark_step
Compilation Analysis: Graph Info: 
Compilation Analysis:   Graph Hash: 4b3ed403f54b79f0c84b4ac4c47bbdad
Compilation Analysis:   Number of Graph Inputs: 3
Compilation Analysis:   Number of Graph Outputs: 21
Compilation Analysis: Python Frame Triggered Execution: 
Compilation Analysis:   sync (/workspaces/torch/pytorch/xla/torch_xla/torch_xla.py:69)
Compilation Analysis:   do_test (/tmp/ipykernel_2190827/803061339.py:58)
Compilation Analysis:   test_flash_attention_wrapper (/tmp/ipykernel_2190827/803061339.py:72)
Compilation Analysis:   <module> (/tmp/ipykernel_2190827/3180461044.py:5)
Compilation Analysis:   run_code (/root/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3577)
Compilation Analysis:   run_ast_nodes (/root/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3517)
Compilation Analysis:   run_cell_async (/root/.local/lib/python3.10/site-packages/IPython/core

AssertionError: Tensor-likes are not close!

Mismatched elements: 32752 / 32768 (100.0%)
Greatest absolute difference: 12691.0400390625 at index (5, 0, 58, 4) (up to 1e-05 allowed)
Greatest relative difference: 80902.328125 at index (15, 0, 6, 5) (up to 1.3e-06 allowed)

In [10]:
from pathlib import Path

text = Path("ir_dumps/aot-sharded-flash-attention.txt.0").read_text()
print(text.split("[ScheduleSyncTensorsGraph]")[-1])


TensorsGraphInfo:
  sync (/workspaces/torch/pytorch/xla/torch_xla/torch_xla.py:69)
  do_test (/tmp/ipykernel_2190827/803061339.py:58)
  test_flash_attention_wrapper (/tmp/ipykernel_2190827/803061339.py:72)
  <module> (/tmp/ipykernel_2190827/3180461044.py:5)
  run_code (/root/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3577)
  run_ast_nodes (/root/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3517)
  run_cell_async (/root/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3334)
  _pseudo_sync_runner (/root/.local/lib/python3.10/site-packages/IPython/core/async_helpers.py:128)
  _run_cell (/root/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3130)
  run_cell (/root/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3075)
  run_cell (/root/.local/lib/python3.10/site-packages/ipykernel/zmqshell.py:549)
  do_execute (/root/.local/lib/python3.10/site-packages/ipykernel/ipkernel.py:449)


In [11]:
import os
from pathlib import Path

for f in os.listdir('xla_dumps/aot-sharded-flash-attention'):
  p = Path(f"xla_dumps/aot-sharded-flash-attention/{f}")
  if p.is_file():
    if "before_optimizations" in str(p.stem):
      print(p.read_text())
      break

HloModule SyncTensorsGraph.614, entry_computation_layout={(f32[16,2,128,8]{2,3,1,0:T(8,128)}, f32[16,2,128,8]{2,3,1,0:T(8,128)}, f32[16,2,128,8]{2,3,1,0:T(8,128)})->(f32[16,2,128,8]{2,3,1,0:T(8,128)}, f32[16,2,128,8]{2,3,1,0:T(8,128)}, f32[16,2,128,8]{2,3,1,0:T(8,128)}, f32[2,8,2,128,8]{3,4,2,1,0:T(8,128)}, f32[2,8,2,128,8]{3,4,2,1,0:T(8,128)}, /*index=5*/f32[16,2,128,8]{2,3,1,0:T(8,128)}, f32[2,8,2,128,8]{3,4,2,1,0:T(8,128)}, f32[2,8,2,128,8]{3,4,2,1,0:T(8,128)}, f32[2,8,2,128,8]{3,4,2,1,0:T(8,128)}, f32[2,8,2,128,8]{3,4,2,1,0:T(8,128)}, /*index=10*/f32[16,2,128,8]{2,3,1,0:T(8,128)}, f32[16,2,128,8]{2,3,1,0:T(8,128)}, f32[16,2,128,8]{2,3,1,0:T(8,128)}, f32[32,128,8]{1,2,0:T(8,128)}, f32[16,2,128,8]{2,3,1,0:T(8,128)}, /*index=15*/f32[32,128,8]{1,2,0:T(8,128)}, f32[16,2,128,8]{2,3,1,0:T(8,128)}, f32[32,8,128]{2,1,0:T(8,128)}, f32[32,128,8]{1,2,0:T(8,128)}, f32[16,2,128,8]{2,3,1,0:T(8,128)}, /*index=20*/f32[16,2,128,8]{2,3,1,0:T(8,128)})}, allow_spmd_sharding_propagation_to_output={true}