In [1]:
import torch

def draft_update(k_cache, v_cache, k, v, cache_seqlens, qcache_seqlens):
    # Create indices for scatter operation
    bsz, num_heads, update_len, head_dim = k.shape
    seq_len = k_cache.shape[2]
    
    # Create index tensor for scatter
    indices = torch.arange(update_len, device=k.device)
    indices = indices.view(1, 1, -1, 1)  # [1, 1, update_len, 1]
    indices = indices + qcache_seqlens.view(-1, 1, 1, 1)  # [bsz, 1, update_len, 1]
    indices = indices.expand(bsz, num_heads, update_len, head_dim)
    print(indices)
    
    # Perform scatter operations
    k_cache.scatter_(dim=2, index=indices, src=k)
    v_cache.scatter_(dim=2, index=indices, src=v)
    
    return k_cache, v_cache
    
    

In [None]:

def draft_update(k_cache, v_cache, k, v, cache_seqlens, qcache_seqlens):
    bsz = k_cache.shape[0]
    assert bsz == 1, "Batch size > 1 not supported yet"
    
    # Calculate start index for the current batch
    start_idx = qcache_seqlens[0]  # Using only first batch element since bsz=1
    seq_len = k_cache.shape[2]
    
    # Create a mask for valid positions
    # This avoids data-dependent indexing by operating on the full tensor
    position_indices = torch.arange(self.k_cache.shape[2], device=k_cache.device)
    valid_positions = (position_indices >= start_idx) & (position_indices < start_idx + seq_len)
    
    # Expand mask to match cache dimensions
    mask = valid_positions.view(1, 1, -1, 1).expand_as(self.k_cache[:, :, :, :])
    
    # Create expanded k_cache and v_cache
    expanded_k = k_cache.unsqueeze(2).expand(-1, -1, self.k_cache.shape[2], -1)
    expanded_v = v_cache.unsqueeze(2).expand(-1, -1, self.v_cache.shape[2], -1)
    
    # Use masked operations instead of direct indexing
    self.k_cache = torch.where(mask, expanded_k, self.k_cache)
    self.v_cache = torch.where(mask, expanded_v, self.v_cache)

In [2]:
def test_draft_update():
    # Create test inputs
    batch_size = 2
    num_heads = 4
    seq_len = 8
    head_dim = 16
    
    # Create sample caches
    k_cache = torch.zeros(batch_size, num_heads, seq_len, head_dim)
    v_cache = torch.zeros(batch_size, num_heads, seq_len, head_dim)
    k = torch.randn(batch_size, num_heads, 1, head_dim)
    v = torch.randn(batch_size, num_heads, 1, head_dim)
    
    # Create sample sequence lengths
    cache_seqlens = torch.tensor([6, 4])  # Different lengths for each batch
    qcache_seqlens = torch.tensor([2, 1])  # Different starting points
    
    # Keep copies of original caches
    k_cache_orig = k_cache.clone()
    v_cache_orig = v_cache.clone()
    
    # Run the function
    k_out, v_out = draft_update(k_cache, v_cache,k,v, cache_seqlens, qcache_seqlens)
    
    # Verify shapes haven't changed
    assert k_out.shape == k_cache_orig.shape
    assert v_out.shape == v_cache_orig.shape
    
    # Verify the updates happened at correct positions
    for b in range(batch_size):
        start_idx = qcache_seqlens[b]
        end_idx = qcache_seqlens[b] + k.shape[2]
        
        # Check that values were copied correctly
        assert torch.allclose(k_out[b, :, start_idx:end_idx], k[b])
        assert torch.allclose(v_out[b, :, start_idx:end_idx], v[b])
        
        # Check that values outside update range remain unchanged
        if start_idx > 0:
            assert torch.allclose(k_out[b, :, :start_idx], k_cache_orig[b, :, :start_idx])
            assert torch.allclose(v_out[b, :, :start_idx], v_cache_orig[b, :, :start_idx])
        if end_idx < seq_len:
            assert torch.allclose(k_out[b, :, end_idx:], k_cache_orig[b, :, end_idx:])
            assert torch.allclose(v_out[b, :, end_idx:], v_cache_orig[b, :, end_idx:])
    
    print("All tests passed!")

# Run the test
test_draft_update()


tensor([[[[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]],

         [[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]],

         [[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]],

         [[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]]],


        [[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]],

         [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]],

         [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]],

         [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]]])
All tests passed!


In [39]:
bsz = 2
seq_len = 5

mask = torch.arange(seq_len).unsqueeze(0) >= index_list.unsqueeze(1)
mask &= torch.arange(seq_len).unsqueeze(0) < (index_list + a.shape[1]).unsqueeze(1)
mask = torch.tensor(mask)
mask

  mask = torch.tensor(mask)


tensor([[False,  True,  True, False, False],
        [False, False,  True,  True, False]])

tensor([[ 1.5959],
        [ 0.0472],
        [-0.8919],
        [-0.2357]])

In [40]:
A[mask] = a.view(-1,a.shape[-1])

In [41]:
A

tensor([[[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
           0.0000,  0.0000,  0.0000],
         [ 0.2664, -1.1244,  0.1311,  1.5670, -0.5227, -0.0628, -0.3672,
           0.7776, -0.6443,  0.5735],
         [ 1.2952, -1.7487,  0.5188, -0.9836,  0.3290,  0.1861,  1.1023,
          -1.0282,  0.1419,  0.3721],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
           0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
           0.0000,  0.0000,  0.0000]],

        [[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
           0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
           0.0000,  0.0000,  0.0000],
         [ 2.0867, -1.0878, -1.5198, -0.6441,  1.2848,  0.5716, -0.1964,
          -1.4325,  0.4747,  0.7528],
         [ 0.6980, -0.2796,  0.1953,  0.0589,  1.1897,  0.2222,  0.0208,
          -1.0235,  0.4140, -1.9478],

In [43]:
a.shape

torch.Size([2, 2, 10])

In [None]:
A

In [45]:
A = torch.zeros(2, 5, 10)
# a = torch.randn(2, 2, 10)

index_list = torch.tensor([1, 2]) # 2, 2, 5
A[0, 1:1+a.shape[1], :] = a[0] 
A[1, 2:2+a.shape[1], :] = a[1]
A

tensor([[[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
           0.0000,  0.0000,  0.0000],
         [ 0.2664, -1.1244,  0.1311,  1.5670, -0.5227, -0.0628, -0.3672,
           0.7776, -0.6443,  0.5735],
         [ 1.2952, -1.7487,  0.5188, -0.9836,  0.3290,  0.1861,  1.1023,
          -1.0282,  0.1419,  0.3721],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
           0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
           0.0000,  0.0000,  0.0000]],

        [[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
           0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
           0.0000,  0.0000,  0.0000],
         [ 2.0867, -1.0878, -1.5198, -0.6441,  1.2848,  0.5716, -0.1964,
          -1.4325,  0.4747,  0.7528],
         [ 0.6980, -0.2796,  0.1953,  0.0589,  1.1897,  0.2222,  0.0208,
          -1.0235,  0.4140, -1.9478],

In [16]:
index=torch.tensor([1, 2]).view(1, 2, 1).repeat(2,1,5)
index.shape

torch.Size([2, 2, 5])

In [19]:
A.scatter_(dim=1, index=index, src=a)
print(A)

tensor([[[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
         [-0.2241, -1.1241, -1.8280, -0.4621,  0.3573],
         [-0.3164,  0.9304,  0.2137, -0.8889, -1.4511],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000]],

        [[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
         [-0.4408,  0.6711,  3.1076, -0.6692,  0.4591],
         [-0.6139,  0.2703,  0.3493, -2.3580, -0.8639],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000, 

In [20]:
A[0]

tensor([[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [-0.2241, -1.1241, -1.8280, -0.4621,  0.3573],
        [-0.3164,  0.9304,  0.2137, -0.8889, -1.4511],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000]])

In [55]:
x = torch.zeros(5, 3)
t = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=torch.float)
index = torch.tensor([0, 4, 2])
# x.index_copy_(0, index, t)
# tensor([[ 1.,  2.,  3.],
#         [ 0.,  0.,  0.],
#         [ 7.,  8.,  9.],
#         [ 0.,  0.,  0.],
#         [ 4.,  5.,  6.]])


In [57]:
t, t.shape

(tensor([[1., 2., 3.],
         [4., 5., 6.],
         [7., 8., 9.]]),
 torch.Size([3, 3]))

In [58]:
x, x.shape

(tensor([[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]]),
 torch.Size([5, 3]))