In [1]:
import torch

### A Concrete Example

In [9]:
torch.manual_seed(11)

# Parameters
batch_size = 1
num_kv_heads = 2
seq_len_cache = 3
seq_len_curr = 1
head_dim = 4

# --- 1. Simulate the Cache (Past K values) ---
# We have K vectors for 3 past tokens (p=0, 1, 2) for 2 KV heads.
# Shape: [1, 2, 3, 4]
# Let's fill it with distinct values to see where they go.
# Head 0: values 0-11 represent features for tokens 0, 1, 2
# Head 1: values 100-111 represent features for tokens 0, 1, 2
cache_k = torch.zeros(batch_size, num_kv_heads, seq_len_cache, head_dim)
# Fill Head 0
cache_k[0, 0, 0, :] = torch.arange(0, 4)    # Token 0, Head 0 Features
print(f"cache_k after filling Head 0 with torch.arange(0, 4):\n{cache_k}")
cache_k[0, 0, 1, :] = torch.arange(4, 8)    # Token 1, Head 0 Features
print(f"cache_k after filling Head 0 with torch.arange(4, 8):\n{cache_k}")
cache_k[0, 0, 2, :] = torch.arange(8, 12)   # Token 2, Head 0 Features
print(f"cache_k after filling Head 0 with torch.arange(8, 12):\n{cache_k}")
# Fill Head 1
cache_k[0, 1, 0, :] = torch.arange(100, 104)  # Token 0, Head 1 Features
print(f"cache_k after filling Head 1 with torch.arange(100, 104):\n{cache_k}")
cache_k[0, 1, 1, :] = torch.arange(104, 108)  # Token 1, Head 1 Features
print(f"cache_k after filling Head 1 with torch.arange(104, 108):\n{cache_k}")
cache_k[0, 1, 2, :] = torch.arange(108, 112) # Token 2, Head 1 Features
print(f"cache_k after filling Head 1 with torch.arange(108, 112):\n{cache_k}")

cache_k after filling Head 0 with torch.arange(0, 4):
tensor([[[[0., 1., 2., 3.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.]],

         [[0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.]]]])
cache_k after filling Head 0 with torch.arange(4, 8):
tensor([[[[0., 1., 2., 3.],
          [4., 5., 6., 7.],
          [0., 0., 0., 0.]],

         [[0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.]]]])
cache_k after filling Head 0 with torch.arange(8, 12):
tensor([[[[ 0.,  1.,  2.,  3.],
          [ 4.,  5.,  6.,  7.],
          [ 8.,  9., 10., 11.]],

         [[ 0.,  0.,  0.,  0.],
          [ 0.,  0.,  0.,  0.],
          [ 0.,  0.,  0.,  0.]]]])
cache_k after filling Head 1 with torch.arange(100, 104):
tensor([[[[  0.,   1.,   2.,   3.],
          [  4.,   5.,   6.,   7.],
          [  8.,   9.,  10.,  11.]],

         [[100., 101., 102., 103.],
          [  0.,   0.,   0.,   0.],
          [  0.,   0.,   0.,   0.]]]])
cache_k after f

In [10]:
print("--- Cache K ---")
print(f"Shape: {cache_k.shape}")
print(cache_k)
print("-" * 20)

--- Cache K ---
Shape: torch.Size([1, 2, 3, 4])
tensor([[[[  0.,   1.,   2.,   3.],
          [  4.,   5.,   6.,   7.],
          [  8.,   9.,  10.,  11.]],

         [[100., 101., 102., 103.],
          [104., 105., 106., 107.],
          [108., 109., 110., 111.]]]])
--------------------


In [11]:
# --- 2. Simulate the Current K value ---
# We calculate the K vector for the *new* token (at position p=3).
# Shape: [1, 2, 1, 4]
# Head 0: values 50-53 represent features for token 3
# Head 1: values 150-153 represent features for token 3
current_k = torch.zeros(batch_size, num_kv_heads, seq_len_curr, head_dim)
# Fill Head 0
current_k[0, 0, 0, :] = torch.arange(50, 54) # Token 3, Head 0 Features
print(f"current_k after filling Head 0 with torch.arange(50, 54):\n{current_k}")
# Fill Head 1
current_k[0, 1, 0, :] = torch.arange(150, 154) # Token 3, Head 1 Features
print(f"current_k after filling Head 1 with torch.arange(150, 154):\n{current_k}")


current_k after filling Head 0 with torch.arange(50, 54):
tensor([[[[50., 51., 52., 53.]],

         [[ 0.,  0.,  0.,  0.]]]])
current_k after filling Head 1 with torch.arange(150, 154):
tensor([[[[ 50.,  51.,  52.,  53.]],

         [[150., 151., 152., 153.]]]])


In [12]:
print("--- Current K ---")
print(f"Shape: {current_k.shape}")
print(current_k)
print("-" * 20)

--- Current K ---
Shape: torch.Size([1, 2, 1, 4])
tensor([[[[ 50.,  51.,  52.,  53.]],

         [[150., 151., 152., 153.]]]])
--------------------


In [13]:
# --- 3. Perform Concatenation along dim=2 (Sequence Length) ---
k_total = torch.cat([cache_k, current_k], dim=2)

print("--- K Total (After Concatenation) ---")
print(f"Shape: {k_total.shape}")
print(k_total)
print("-" * 20)

--- K Total (After Concatenation) ---
Shape: torch.Size([1, 2, 4, 4])
tensor([[[[  0.,   1.,   2.,   3.],
          [  4.,   5.,   6.,   7.],
          [  8.,   9.,  10.,  11.],
          [ 50.,  51.,  52.,  53.]],

         [[100., 101., 102., 103.],
          [104., 105., 106., 107.],
          [108., 109., 110., 111.],
          [150., 151., 152., 153.]]]])
--------------------


### So, what happened here?

```python
# Cache K
torch.Size([1, 2, 3, 4])

tensor([[[[ 0., 1., 2., 3.], # head 0, token 0, 1, 2
          [ 4., 5., 6., 7.],
          [ 8., 9., 10., 11.]],
     
        [[100., 101., 102., 103.], # head 1, token 0, 1, 2
        [104., 105., 106., 107.],
        [108., 109., 110., 111.]]]])

# Current K
torch.Size([1, 2, 1, 4])

tensor([[[[ 50., 51., 52., 53.]], # head 0, token 3
     [[150., 151., 152., 153.]]]]) # head 1, token 3

# --- K Total (After Concatenation)
torch.Size([1, 2, 4, 4])

tensor([[[[ 0., 1., 2., 3.], # head 0, token 0, 1, 2, 3
          [ 4., 5., 6., 7.],
          [ 8., 9., 10., 11.],
          [ 50., 51., 52., 53.]],

          [[100., 101., 102., 103.], # head 1, token 0, 1, 2, 3  
          [104., 105., 106., 107.],
          [108., 109., 110., 111.],
          [150., 151., 152., 153.]]]])
```

In other words, when we did concatenation alongside the `seq_len` dim and thus increased it from 3 to 4, we just added the last slice of `head_dim` to the respective head! So the `head_dim=4` is still the same (since the 4th 'row' has same number of elements in it), but since we added the 4th 'row', it increased the `seq_len`.

Important bits:
- What was added: the data corresponding to the new token (position `p=3`). This data wasn't just a "slice of `head_dim`"; it was a complete slice along the sequence dimension (dim=2) with shape `[1, 2, 1, 4]`. This slice contains the full `head_dim=4` feature vector for each of the `num_kv_heads=2`.
- Preservation of `head_dim`: the `head_dim` (the last dimension, size 4) remained unchanged. The vectors representing the rich features for each token (e.g., [0., 1., 2., 3.] for token 0/head 0, or [50., 51., 52., 53.] for token 3/head 0) were kept intact.
- Increase in `seq_len`: The only dimension that changed was `dim=2` (the sequence length dimension), increasing from 3 (cache) to 4 (total) because we effectively added one position's worth of data to the sequence history.

I asked Gemini 2.5 whether this understanding is correct, and this was the response:

> Your refined understanding is spot on. Concatenating along `dim=2` acts like appending the complete Key/Value representation (spanning all heads and the full `head_dim` feature vector) of the new token(s) to the end of the list of representations for the previously processed tokens. It extends the timeline (`seq_len`) while preserving the richness (`head_dim`) of the information stored for each point in time. 
> You've nailed the key insight.
