In [None]:
import torch
import xformers
from xformers.ops import memory_efficient_attention

print("Torch version:", torch.__version__)
print("xFormers version:", xformers.__version__)

# --------- Step 1: Create Q/K/V tensors ---------

batch = 2
n_heads = 4
seq_len = 16
dim = 32

# Important: Use the SAME dtype for all Q/K/V
dtype = torch.bfloat16  # or torch.float32

query = torch.randn(batch, n_heads, seq_len, dim, dtype=dtype, device="cuda")
key = torch.randn(batch, n_heads, seq_len, dim, dtype=dtype, device="cuda")
value = torch.randn(batch, n_heads, seq_len, dim, dtype=dtype, device="cuda")

print(f"query dtype: {query.dtype}, key dtype: {key.dtype}, value dtype: {value.dtype}")

# --------- Step 2: Check dtype compatibility ---------

if not (query.dtype == key.dtype == value.dtype):
    raise ValueError(f"Mismatch in dtypes! Query: {query.dtype}, Key: {key.dtype}, Value: {value.dtype}")

# --------- Step 3: Call xformers memory-efficient attention ---------

output = memory_efficient_attention(query, key, value)
print("Output shape:", output.shape)
