In [5]:
from flash_attn import flash_attn_with_kvcache
import torch

bsz = 4
n_heads = 32
q_len = 5
kv_len = 100
d_model = 128

q = torch.randn(bsz, q_len, n_heads, d_model).half().cuda()
k = torch.randn(bsz, kv_len, n_heads, d_model).half().cuda()
v = torch.randn(bsz, kv_len, n_heads, d_model).half().cuda()

print(f"q: {q.shape}, k: {k.shape}, v: {v.shape}")

q: torch.Size([4, 5, 32, 128]), k: torch.Size([4, 100, 32, 128]), v: torch.Size([4, 100, 32, 128])


In [10]:
data_len = 100
seq_len = torch.full((bsz,), data_len, dtype=torch.int32).cuda()

seq_len

tensor([100, 100, 100, 100], device='cuda:0', dtype=torch.int32)

In [13]:
seq_len = torch.randint(low=10000, high=10240, size=(bsz,), dtype=torch.int32).cuda()

seq_len

tensor([10070, 10068, 10140, 10007], device='cuda:0', dtype=torch.int32)

In [23]:
cache_seqlens = torch.tensor([30, 20, 10, 100], dtype=torch.int32).cuda()
print(f"cache_seqlens: {cache_seqlens}")

import torch

range_tensor = torch.arange(1).cuda()

# Use broadcasting to add the starting index (cache_seqlens) to the range_tensor
# cache_seqlens[:, None] changes the shape to make it compatible for broadcasting
storage_ids_no_loop = cache_seqlens[:, None] + range_tensor

storage_ids_no_loop

cache_seqlens: tensor([ 30,  20,  10, 100], device='cuda:0', dtype=torch.int32)


tensor([[ 30],
        [ 20],
        [ 10],
        [100]], device='cuda:0')

In [7]:
cache_seqlens +1

tensor([ 31,  21,  11, 101], device='cuda:0', dtype=torch.int32)

In [3]:
# compute each sequence's attn_output based on the cache_seqlens

attn_output_test = flash_attn_with_kvcache(q, k_cache=k, v_cache=v, cache_seqlens=cache_seqlens, softmax_scale=1/torch.sqrt(torch.tensor(128, dtype=torch.float16)), causal=True)
attn_output_test.shape

torch.Size([4, 5, 32, 128])

In [24]:
cache_seqlens = torch.full((bsz,), kv_len, dtype=torch.int32).cuda()

In [25]:
cache_seqlens

tensor([100, 100, 100, 100], device='cuda:0', dtype=torch.int32)

In [12]:
attn_output_refs = []
bsz_idx = 0
for i in cache_seqlens:
    k_truct = k[bsz_idx:bsz_idx +1, :i, :, :]
    v_truct = v[bsz_idx:bsz_idx+1, :i, :, :]
    print(f"k_truct: {k_truct.shape}, v_truct: {v_truct.shape}")
    attn_output_ref = flash_attn_with_kvcache(q[bsz_idx:bsz_idx+1], k_cache=k_truct, v_cache=v_truct, softmax_scale=1/torch.sqrt(torch.tensor(128, dtype=torch.float16)), causal=True)
    attn_output_refs.append(attn_output_ref)
    bsz_idx += 1
    assert torch.allclose(attn_output_test[bsz_idx-1:bsz_idx], attn_output_ref, atol=1e-3), f"bsz_idx: {bsz_idx-1}, attn_output_test: {attn_output_test[bsz_idx-1:bsz_idx]}, attn_output_ref: {attn_output_ref}"


k_truct: torch.Size([1, 30, 32, 128]), v_truct: torch.Size([1, 30, 32, 128])
k_truct: torch.Size([1, 20, 32, 128]), v_truct: torch.Size([1, 20, 32, 128])
k_truct: torch.Size([1, 10, 32, 128]), v_truct: torch.Size([1, 10, 32, 128])
k_truct: torch.Size([1, 100, 32, 128]), v_truct: torch.Size([1, 100, 32, 128])


In [16]:
attn_output_refs[1]

tensor([[[[-0.2140, -0.6909,  0.8154,  ...,  0.1670,  0.1050,  0.2377],
          [-0.2330, -0.0182, -0.0674,  ...,  0.4292, -0.3501, -0.2179],
          [-0.1499, -0.4426,  0.2878,  ...,  0.6382,  0.6982,  0.1709],
          ...,
          [ 0.4016,  0.0766,  0.1216,  ...,  0.1700, -0.5640, -0.2235],
          [ 0.0381, -0.0160,  0.5967,  ..., -0.0669,  0.2133,  0.0343],
          [-0.2751,  0.0526, -0.2031,  ..., -0.0446,  0.0818,  0.4768]],

         [[ 0.2859, -0.4028, -0.5942,  ...,  0.5195, -1.3359, -0.7524],
          [-0.1333, -0.0883, -0.0278,  ...,  0.0974, -0.2502, -0.1698],
          [-0.1354, -0.2040, -0.0078,  ..., -0.1407,  0.0199,  0.0045],
          ...,
          [ 0.7915,  0.1787, -0.0656,  ...,  0.3474, -0.4114,  0.0555],
          [-0.4763, -0.0390, -0.0581,  ..., -0.2144,  0.3833,  0.0718],
          [-0.0518,  0.3843, -0.5093,  ..., -0.0288, -0.3535,  0.6528]],

         [[ 0.1794, -0.5269,  0.6548,  ...,  0.3735, -0.1522,  0.2505],
          [-0.0624, -0.4580,  

In [17]:
attn_output_test[1]

tensor([[[-0.2140, -0.6909,  0.8154,  ...,  0.1670,  0.1050,  0.2377],
         [-0.2330, -0.0182, -0.0674,  ...,  0.4292, -0.3501, -0.2179],
         [-0.1499, -0.4426,  0.2878,  ...,  0.6382,  0.6982,  0.1709],
         ...,
         [ 0.4016,  0.0766,  0.1216,  ...,  0.1700, -0.5640, -0.2235],
         [ 0.0381, -0.0160,  0.5967,  ..., -0.0669,  0.2133,  0.0343],
         [-0.2751,  0.0526, -0.2031,  ..., -0.0446,  0.0818,  0.4768]],

        [[ 0.2859, -0.4028, -0.5942,  ...,  0.5195, -1.3359, -0.7524],
         [-0.1333, -0.0883, -0.0278,  ...,  0.0974, -0.2502, -0.1698],
         [-0.1354, -0.2040, -0.0078,  ..., -0.1407,  0.0199,  0.0045],
         ...,
         [ 0.7915,  0.1787, -0.0656,  ...,  0.3474, -0.4114,  0.0555],
         [-0.4763, -0.0390, -0.0581,  ..., -0.2144,  0.3833,  0.0718],
         [-0.0518,  0.3843, -0.5093,  ..., -0.0288, -0.3535,  0.6528]],

        [[ 0.1794, -0.5269,  0.6548,  ...,  0.3735, -0.1522,  0.2505],
         [-0.0624, -0.4580,  0.0086,  ..., -0

In [9]:
import torch

# 定义两个tensor和index
tensor1 = torch.zeros(2, 100, 32, 128).cuda()  # 目标tensor
tensor2 = torch.randn(2, 5, 32, 128).cuda()    # 源tensor
indices = torch.tensor([[9, 10, 11, 12, 13], [29, 30, 31, 32, 33]]).cuda()

# 将indices调整形状以适配scatter的输入要求
# 我们需要一个与tensor1在除了第二维之外的其他维度匹配的indices形状
indices_expanded = indices.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, 32, 128)

# 使用scatter_并行复制tensor2到tensor1的指定位置
# 这将直接修改tensor1的内容
tensor1.scatter_(1, indices_expanded, tensor2)

# 验证操作，例如检查tensor1的某些索引位置是否已更新
# 注意：实际运行中可能不需要这些验证步骤，它们只是为了演示操作成功
print(tensor1.shape)  # 应为 torch.Size([2, 100, 32, 128])
print(tensor1[:, 9, :, :].shape)  # 检查一个被更新的索引位置，应为 torch.Size([2, 32, 128])
print(tensor1[:, 29, :, :].shape)  # 检查另一个被更新的索引位置，同样应为 torch.Size([2, 32, 128])


torch.Size([2, 100, 32, 128])
torch.Size([2, 32, 128])
torch.Size([2, 32, 128])


In [16]:
tensor1[0, 9:13].shape

torch.Size([4, 32, 128])

In [19]:
a = torch.tensor([1,2,3,4])
a[1:2]

tensor([2])

In [17]:
tensor2[0].shape

torch.Size([5, 32, 128])

In [18]:
torch.allclose(tensor1[0, 9:14], tensor2[0], atol=1e-3)
torch.allclose(tensor1[1, 29:34], tensor2[1], atol=1e-3)

True

In [13]:
tensor2

tensor([[[[-8.5044e-01, -1.0663e+00, -1.7845e-01,  ..., -7.5510e-01,
            1.0224e+00, -1.2777e+00],
          [ 7.5481e-01,  6.4249e-02, -3.8490e-01,  ..., -6.6143e-01,
           -1.4507e-01,  8.8743e-01],
          [ 1.4145e+00,  5.8131e-01,  1.8325e+00,  ...,  1.2702e+00,
            1.3220e+00, -1.0395e+00],
          ...,
          [-2.2786e+00, -1.2282e+00, -1.8462e+00,  ..., -1.8943e-01,
           -1.6184e+00,  1.3940e+00],
          [ 4.5204e-01,  1.3632e+00,  3.6513e-01,  ..., -5.1590e-01,
           -5.2662e-01,  1.9107e-01],
          [-5.8666e-02, -1.3744e+00,  1.9036e-01,  ..., -9.2207e-01,
           -1.5555e+00,  1.3249e+00]],

         [[ 1.2556e+00,  7.9786e-01,  5.5016e-02,  ..., -9.2669e-01,
            1.7184e+00,  9.5488e-01],
          [-2.8152e-01,  1.4642e+00, -9.3359e-01,  ..., -1.2772e-01,
           -1.6461e-01,  1.1878e-01],
          [-6.4176e-01, -1.9661e+00,  4.0120e-01,  ...,  1.2389e-01,
           -7.7685e-01, -1.5504e+00],
          ...,
     

In [11]:
tensor1[:, 9, :, :]

tensor([[[-0.8504, -1.0663, -0.1785,  ..., -0.7551,  1.0224, -1.2777],
         [ 0.7548,  0.0642, -0.3849,  ..., -0.6614, -0.1451,  0.8874],
         [ 1.4145,  0.5813,  1.8325,  ...,  1.2702,  1.3220, -1.0395],
         ...,
         [-2.2786, -1.2282, -1.8462,  ..., -0.1894, -1.6184,  1.3940],
         [ 0.4520,  1.3632,  0.3651,  ..., -0.5159, -0.5266,  0.1911],
         [-0.0587, -1.3744,  0.1904,  ..., -0.9221, -1.5555,  1.3249]],

        [[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         ...,
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]]],
       device='cuda:0')