# How self-attention modify representation

I'm interested in viewing how the values of other tokens affect the representation of current token.

In [1]:
from collections import defaultdict
from typing import Optional

import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoTokenizer, AttentionMaskInterface
from transformers.masking_utils import eager_mask
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from transformers.models.qwen3.modeling_qwen3 import eager_attention_forward

from dawnet.inspector import LLMInspector
from dawnet import op
from dawnet.diagnose.vis_attention import extract_top_attention_indices

torch.set_grad_enabled(False)

torch.autograd.grad_mode.set_grad_enabled(mode=False)

In [2]:
model_id = "Qwen/Qwen3-4B-Thinking-2507"
device = torch.device("mps")
insp = LLMInspector.from_hf(model_id).to(device=device)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [3]:
insp.model.config._attn_implementation = "eager"
insp.original_model.config._attn_implementation = "eager"

### Preliminary setup

In [4]:
print(insp.model)

Qwen3ForCausalLM(
  (model): Qwen3Model(
    (embed_tokens): Embedding(151936, 2560)
    (layers): ModuleList(
      (0-35): 36 x Qwen3DecoderLayer(
        (self_attn): Qwen3Attention(
          (q_proj): Linear(in_features=2560, out_features=4096, bias=False)
          (k_proj): Linear(in_features=2560, out_features=1024, bias=False)
          (v_proj): Linear(in_features=2560, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=2560, bias=False)
          (q_norm): Qwen3RMSNorm((128,), eps=1e-06)
          (k_norm): Qwen3RMSNorm((128,), eps=1e-06)
        )
        (mlp): Qwen3MLP(
          (gate_proj): Linear(in_features=2560, out_features=9728, bias=False)
          (up_proj): Linear(in_features=2560, out_features=9728, bias=False)
          (down_proj): Linear(in_features=9728, out_features=2560, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): Qwen3RMSNorm((2560,), eps=1e-06)
        (post_attention_layernorm): Qwe

In [5]:
with insp.ctx(detach_state=True) as state:
    output = insp.generate("""What is (234 + 413) * 89?""", chat=True, max_new_tokens=4096, use_original=True)
print(output)

The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


<|im_start|>user
What is (234 + 413) * 89?<|im_end|>
<|im_start|>assistant
<think>
Okay, let's see. I need to calculate (234 + 413) * 89. Hmm, first I should probably add 234 and 413 together. Let me do that step by step. 

So, 234 plus 413. Let's add the units place first: 4 + 3 is 7. Then the tens place: 3 + 1 is 4. Then the hundreds place: 2 + 4 is 6. So that would be 647? Wait, let me check that again. 234 is 2 hundreds, 3 tens, 4 units. 413 is 4 hundreds, 1 ten, 3 units. Adding them: 2+4=6 hundreds, 3+1=4 tens, 4+3=7 units. Yep, so 647. That seems right. Let me verify with another method. 234 + 400 is 634, then +13 is 647. Yep, that's correct.

Now, the next step is to multiply 647 by 89. Hmm, multiplying by 89. Maybe I can break this down into 647*(90 - 1) which is 647*90 - 647*1. That might be easier than multiplying directly. Let's try that.

First, 647*90. Well, 647*9 is 5823, so 647*90 is 58230. Let me check: 647*10 is 6470, so 647*90 is 6470*9. 6470*9: 6000*9=54000, 400*9=36

In [6]:
output.view()

In [15]:
# insp.add(op.GetInput(), name_regex=r'model.layers.\d+.self_attn.o_proj')
# insp.add(op.GetOutput(), name_regex=r'model.layers.\d+.self_attn.v_proj')
insp.add(op.GetOutput(), name_regex=r'model.layers.\d+.self_attn$')

Added to layer ['model.layers.0.self_attn', 'model.layers.1.self_attn', 'model.layers.2.self_attn', 'model.layers.3.self_attn', 'model.layers.4.self_attn', 'model.layers.5.self_attn', 'model.layers.6.self_attn', 'model.layers.7.self_attn', 'model.layers.8.self_attn', 'model.layers.9.self_attn', 'model.layers.10.self_attn', 'model.layers.11.self_attn', 'model.layers.12.self_attn', 'model.layers.13.self_attn', 'model.layers.14.self_attn', 'model.layers.15.self_attn', 'model.layers.16.self_attn', 'model.layers.17.self_attn', 'model.layers.18.self_attn', 'model.layers.19.self_attn', 'model.layers.20.self_attn', 'model.layers.21.self_attn', 'model.layers.22.self_attn', 'model.layers.23.self_attn', 'model.layers.24.self_attn', 'model.layers.25.self_attn', 'model.layers.26.self_attn', 'model.layers.27.self_attn', 'model.layers.28.self_attn', 'model.layers.29.self_attn', 'model.layers.30.self_attn', 'model.layers.31.self_attn', 'model.layers.32.self_attn', 'model.layers.33.self_attn', 'model.l

<dawnet.op.GetOutput at 0xa84ce2330>

In [8]:
new_input = output.tensor[:,:1529].clone()
insp.tokenizer.decode(new_input[0])

"<|im_start|>user\nWhat is (234 + 413) * 89?<|im_end|>\n<|im_start|>assistant\n<think>\nOkay, let's see. I need to calculate (234 + 413) * 89. Hmm, first I should probably add 234 and 413 together. Let me do that step by step. \n\nSo, 234 plus 413. Let's add the units place first: 4 + 3 is 7. Then the tens place: 3 + 1 is 4. Then the hundreds place: 2 + 4 is 6. So that would be 647? Wait, let me check that again. 234 is 2 hundreds, 3 tens, 4 units. 413 is 4 hundreds, 1 ten, 3 units. Adding them: 2+4=6 hundreds, 3+1=4 tens, 4+3=7 units. Yep, so 647. That seems right. Let me verify with another method. 234 + 400 is 634, then +13 is 647. Yep, that's correct.\n\nNow, the next step is to multiply 647 by 89. Hmm, multiplying by 89. Maybe I can break this down into 647*(90 - 1) which is 647*90 - 647*1. That might be easier than multiplying directly. Let's try that.\n\nFirst, 647*90. Well, 647*9 is 5823, so 647*90 is 58230. Let me check: 647*10 is 6470, so 647*90 is 6470*9. 6470*9: 6000*9=5400

In [17]:
with insp.ctx(detach_state=True) as state:
    exp_output = insp.infer(new_input)

In [19]:
most_important = extract_top_attention_indices(state, 36)
for key in most_important.keys():
    most_important[key] = set(most_important[key])
len(most_important)

1761408

### Find tokens that are forgotten

Tokens that are never recalled by any attention heads of any layer after `look_back` timesteps.

In [40]:
_, token_length = new_input.shape
print(token_length)

1529


In [30]:
look_back = 50
never_call = defaultdict(list)
for layer_idx in range(36):
    for head_idx in range(32):
        for token_idx in range(1529 - look_back * 3):
            for target_idx in range(token_idx+look_back, 1529):
                if token_idx in most_important[(layer_idx, head_idx, target_idx)]:
                    break
            else:
                never_call[(layer_idx, head_idx)].append(token_idx)

for key in never_call.keys():
    print(key, len(never_call[key]))

(0, 0) 245
(0, 1) 161
(0, 2) 1378
(0, 5) 13
(0, 9) 1
(0, 12) 8
(0, 15) 21
(0, 17) 1
(0, 18) 9
(0, 20) 180
(0, 21) 144
(0, 22) 507
(0, 23) 220
(0, 27) 27
(0, 28) 31
(0, 30) 1
(1, 4) 459
(1, 7) 2
(1, 9) 14
(1, 10) 421
(1, 11) 2
(1, 12) 6
(1, 13) 341
(1, 15) 4
(1, 19) 8
(1, 21) 129
(1, 22) 1
(1, 24) 5
(1, 25) 28
(1, 28) 63
(1, 31) 15
(2, 0) 11
(2, 1) 1
(2, 2) 11
(2, 3) 40
(2, 4) 208
(2, 5) 1
(2, 6) 3
(2, 7) 72
(2, 9) 125
(2, 10) 5
(2, 11) 497
(2, 12) 2
(2, 13) 1
(2, 14) 7
(2, 16) 168
(2, 17) 69
(2, 19) 127
(2, 20) 32
(2, 22) 4
(2, 23) 1
(2, 28) 9
(2, 31) 24
(3, 1) 1
(3, 2) 186
(3, 3) 2
(3, 4) 1
(3, 5) 195
(3, 6) 220
(3, 7) 1058
(3, 8) 46
(3, 10) 113
(3, 11) 228
(3, 12) 57
(3, 14) 73
(3, 15) 23
(3, 16) 8
(3, 17) 170
(3, 18) 238
(3, 19) 20
(3, 20) 793
(3, 21) 58
(3, 22) 714
(3, 23) 244
(3, 24) 85
(3, 25) 36
(3, 26) 28
(3, 27) 48
(3, 28) 206
(3, 29) 16
(3, 30) 94
(3, 31) 570
(4, 0) 207
(4, 1) 475
(4, 2) 8
(4, 3) 361
(4, 4) 13
(4, 5) 220
(4, 6) 95
(4, 7) 55
(4, 8) 161
(4, 9) 9
(4, 10) 50
(4, 

In [35]:
look_back = 20
never_call20 = defaultdict(list)
for layer_idx in range(36):
    for head_idx in range(32):
        for token_idx in range(1529 - look_back * 3):
            for target_idx in range(token_idx+look_back, 1529):
                if token_idx in most_important[(layer_idx, head_idx, target_idx)]:
                    break
            else:
                never_call20[(layer_idx, head_idx)].append(token_idx)

for key in never_call20.keys():
    print(key, len(never_call20[key]))

(0, 0) 118
(0, 1) 149
(0, 2) 1468
(0, 5) 13
(0, 9) 1
(0, 12) 6
(0, 15) 21
(0, 18) 4
(0, 20) 149
(0, 21) 140
(0, 22) 542
(0, 23) 167
(0, 27) 5
(0, 28) 19
(1, 4) 391
(1, 7) 1
(1, 9) 8
(1, 10) 456
(1, 11) 2
(1, 12) 5
(1, 13) 309
(1, 15) 2
(1, 19) 4
(1, 21) 91
(1, 24) 4
(1, 25) 2
(1, 28) 80
(1, 31) 5
(2, 0) 5
(2, 1) 2
(2, 2) 13
(2, 3) 22
(2, 4) 106
(2, 6) 3
(2, 7) 56
(2, 9) 103
(2, 10) 2
(2, 11) 372
(2, 12) 1
(2, 13) 1
(2, 14) 1
(2, 16) 128
(2, 17) 54
(2, 19) 56
(2, 20) 24
(2, 22) 6
(2, 23) 4
(2, 28) 8
(2, 30) 1
(2, 31) 17
(3, 2) 87
(3, 4) 1
(3, 5) 185
(3, 6) 193
(3, 7) 567
(3, 8) 23
(3, 10) 58
(3, 11) 180
(3, 12) 55
(3, 14) 30
(3, 15) 10
(3, 16) 3
(3, 17) 88
(3, 18) 155
(3, 19) 13
(3, 20) 800
(3, 21) 33
(3, 22) 753
(3, 23) 196
(3, 24) 64
(3, 25) 51
(3, 26) 17
(3, 27) 52
(3, 28) 185
(3, 29) 14
(3, 30) 102
(3, 31) 571
(4, 0) 111
(4, 1) 465
(4, 2) 4
(4, 3) 308
(4, 4) 9
(4, 5) 122
(4, 6) 29
(4, 7) 40
(4, 8) 129
(4, 9) 5
(4, 10) 20
(4, 11) 1
(4, 12) 96
(4, 13) 32
(4, 14) 57
(4, 15) 180
(4, 16)

In [36]:
look_back = 100
never_call100 = defaultdict(list)
for layer_idx in range(36):
    for head_idx in range(32):
        for token_idx in range(1529 - look_back * 3):
            for target_idx in range(token_idx+look_back, 1529):
                if token_idx in most_important[(layer_idx, head_idx, target_idx)]:
                    break
            else:
                never_call100[(layer_idx, head_idx)].append(token_idx)

for key in never_call100.keys():
    print(key, len(never_call100[key]))

(0, 0) 348
(0, 1) 208
(0, 2) 1229
(0, 4) 2
(0, 5) 23
(0, 9) 1
(0, 12) 10
(0, 15) 11
(0, 18) 6
(0, 20) 233
(0, 21) 168
(0, 22) 463
(0, 23) 261
(0, 24) 1
(0, 27) 17
(0, 28) 46
(0, 30) 1
(1, 4) 512
(1, 9) 37
(1, 10) 398
(1, 11) 3
(1, 12) 10
(1, 13) 351
(1, 15) 10
(1, 19) 14
(1, 21) 169
(1, 22) 1
(1, 24) 12
(1, 25) 93
(1, 28) 25
(1, 31) 30
(2, 0) 18
(2, 1) 1
(2, 2) 16
(2, 3) 56
(2, 4) 235
(2, 5) 12
(2, 6) 2
(2, 7) 90
(2, 9) 121
(2, 10) 12
(2, 11) 586
(2, 12) 3
(2, 13) 4
(2, 14) 25
(2, 16) 225
(2, 17) 87
(2, 19) 171
(2, 20) 25
(2, 21) 2
(2, 22) 5
(2, 23) 1
(2, 28) 2
(2, 31) 21
(3, 1) 15
(3, 2) 244
(3, 4) 3
(3, 5) 178
(3, 6) 227
(3, 7) 1096
(3, 8) 72
(3, 10) 172
(3, 11) 283
(3, 12) 46
(3, 14) 127
(3, 15) 48
(3, 16) 13
(3, 17) 248
(3, 18) 329
(3, 19) 39
(3, 20) 793
(3, 21) 124
(3, 22) 683
(3, 23) 345
(3, 24) 102
(3, 25) 51
(3, 26) 35
(3, 27) 74
(3, 28) 220
(3, 29) 20
(3, 30) 123
(3, 31) 554
(4, 0) 287
(4, 1) 477
(4, 2) 15
(4, 3) 391
(4, 4) 21
(4, 5) 277
(4, 6) 184
(4, 7) 78
(4, 8) 276
(4, 9) 

In [37]:
print(never_call100[(0,0)])

[5, 6, 7, 8, 9, 10, 11, 13, 14, 16, 17, 18, 19, 28, 32, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 48, 49, 50, 51, 53, 59, 60, 61, 62, 64, 65, 66, 67, 69, 77, 80, 81, 82, 83, 84, 86, 87, 88, 89, 90, 98, 99, 100, 102, 103, 105, 106, 107, 113, 114, 115, 116, 119, 120, 121, 127, 128, 129, 130, 131, 133, 134, 135, 140, 141, 142, 143, 146, 152, 153, 154, 155, 156, 158, 159, 161, 162, 163, 166, 167, 170, 171, 173, 175, 176, 178, 179, 182, 183, 184, 189, 190, 191, 192, 193, 197, 198, 199, 200, 201, 203, 206, 207, 208, 209, 210, 211, 213, 217, 218, 219, 233, 234, 235, 236, 238, 239, 243, 245, 246, 249, 251, 253, 254, 255, 272, 277, 285, 296, 297, 298, 299, 300, 301, 304, 309, 310, 311, 312, 314, 317, 319, 320, 321, 323, 339, 340, 341, 342, 344, 349, 353, 356, 363, 371, 382, 383, 384, 390, 397, 398, 399, 400, 401, 402, 405, 410, 411, 413, 418, 419, 421, 422, 426, 427, 435, 436, 439, 440, 441, 442, 443, 447, 448, 450, 451, 452, 458, 459, 460, 464, 465, 466, 467, 471, 477, 479, 484, 499, 503, 506, 5

### How many tokens are revived?

After a loopback period, it is inactive for X timesteps, and then become active later on.

In [42]:
look_back = 50
attended_by = defaultdict(list)
for layer_idx in range(36):
    for head_idx in range(32):
        for token_idx in range(token_length - look_back * 3):
            for target_idx in range(token_idx + look_back, token_length):
                if token_idx in most_important[(layer_idx, head_idx, target_idx)]:
                    attended_by[(layer_idx, head_idx, token_idx)].append(target_idx)

In [60]:
diffs = {}
for layer_idx in range(36):
    for head_idx in range(32):
        max_diff = -1
        for token_idx in range(token_length - look_back * 3):
            if (layer_idx, head_idx, token_idx) not in attended_by:
                continue
            l = attended_by[(layer_idx, head_idx, token_idx)]
            if len(l) < 2:
                continue
            diff = [l[i+1] - l[i] for i in range(len(l)-1)]
            if max(diff) > 1300:
                print("Random very far:", layer_idx, head_idx, token_idx)
            max_diff = max(max_diff, max(diff))
        if max_diff != -1:
            diffs[(layer_idx, head_idx)] = max_diff

Random very far: 1 4 48
Random very far: 1 4 58
Random very far: 1 4 73
Random very far: 4 8 7
Random very far: 4 8 29
Random very far: 4 8 38
Random very far: 4 8 66
Random very far: 4 8 117
Random very far: 4 10 57
Random very far: 4 10 61
Random very far: 4 13 5
Random very far: 4 15 3
Random very far: 4 16 91
Random very far: 4 17 1
Random very far: 4 17 25
Random very far: 4 17 37
Random very far: 4 17 91
Random very far: 4 18 20
Random very far: 4 18 85
Random very far: 5 3 70
Random very far: 5 11 39
Random very far: 5 11 40
Random very far: 5 11 43
Random very far: 5 11 44
Random very far: 5 11 45
Random very far: 5 11 49
Random very far: 5 11 50
Random very far: 5 11 60
Random very far: 5 11 61
Random very far: 5 11 62
Random very far: 5 11 65
Random very far: 5 11 66
Random very far: 5 11 67
Random very far: 5 11 83
Random very far: 5 11 84
Random very far: 5 11 87
Random very far: 5 11 89
Random very far: 5 16 79
Random very far: 5 20 7
Random very far: 5 20 41
Random very f

In [54]:
print(attended_by[(5, 11, 66)])

[121, 122, 135, 1521, 1522, 1524, 1525]


In [56]:
print(attended_by[(5, 11, 50)])

[108, 122, 137, 144, 145, 1497, 1498, 1520, 1521, 1522, 1524, 1525, 1527]


In [59]:
for key, value in sorted(diffs.items(), key=lambda obj: obj[1]):
    print(key, value, f"{value / token_length * 100:.2f}%")

(0, 29) 53 3.47%
(1, 6) 73 4.77%
(0, 25) 99 6.47%
(26, 10) 105 6.87%
(0, 26) 165 10.79%
(1, 18) 170 11.12%
(1, 17) 237 15.50%
(6, 12) 253 16.55%
(1, 22) 261 17.07%
(0, 8) 280 18.31%
(2, 30) 314 20.54%
(2, 29) 318 20.80%
(20, 27) 320 20.93%
(12, 3) 324 21.19%
(0, 16) 331 21.65%
(13, 12) 334 21.84%
(6, 5) 342 22.37%
(21, 20) 342 22.37%
(22, 20) 342 22.37%
(2, 8) 373 24.40%
(3, 4) 374 24.46%
(2, 21) 380 24.85%
(4, 2) 389 25.44%
(5, 13) 390 25.51%
(2, 26) 391 25.57%
(0, 19) 395 25.83%
(3, 9) 404 26.42%
(10, 21) 406 26.55%
(1, 23) 414 27.08%
(2, 27) 424 27.73%
(2, 25) 425 27.80%
(1, 5) 426 27.86%
(1, 20) 427 27.93%
(0, 7) 438 28.65%
(17, 22) 443 28.97%
(6, 9) 448 29.30%
(11, 1) 460 30.09%
(6, 26) 466 30.48%
(2, 24) 475 31.07%
(0, 11) 476 31.13%
(14, 17) 478 31.26%
(0, 3) 480 31.39%
(11, 4) 480 31.39%
(5, 15) 482 31.52%
(1, 27) 485 31.72%
(6, 7) 488 31.92%
(8, 19) 488 31.92%
(33, 3) 489 31.98%
(26, 23) 495 32.37%
(15, 11) 503 32.90%
(30, 29) 503 32.90%
(14, 14) 504 32.96%
(4, 30) 507 33.16%


From this, it shows that for a layer-headidx, a lot of times, the token is dead, and then revived very later in the future.

### Check if a token is attended by

In [67]:
max_attended_by = 3
toks_to_attention = defaultdict(list)
for key, value in attended_by.items():
    if len(value) < max_attended_by:
        toks_to_attention[key[2]].append((key[0], key[1]))

print(f"Total combination: {36*32=}")
for tok, com in sorted(toks_to_attention.items(), key=lambda obj: len(obj[1]), reverse=True):
    print(tok, len(com))

Total combination: 36*32=1152
934 223
944 217
953 200
986 199
993 194
951 192
955 190
911 189
943 188
935 187
971 186
981 182
987 182
995 182
948 179
25 178
982 176
438 176
1178 175
1325 174
940 174
437 172
994 172
980 169
985 168
467 168
1141 167
997 165
1007 165
1326 164
407 164
201 163
947 163
440 160
460 160
1157 160
462 159
1290 159
1033 159
1017 159
445 158
1002 158
1150 158
1304 158
1305 158
494 158
1185 157
1292 157
938 157
744 157
996 156
960 156
1023 156
466 155
715 154
909 154
1296 154
740 153
210 153
260 153
443 153
1019 153
1159 153
451 152
773 152
436 152
469 152
767 152
973 152
1156 152
693 152
866 151
990 151
746 151
1197 151
475 150
703 150
1144 150
1145 150
1151 150
1155 150
1315 150
964 150
1013 150
1031 150
870 149
1297 149
1170 149
1037 149
442 148
444 148
963 148
1262 148
1164 148
900 148
553 147
915 147
1198 147
180 146
439 146
1177 146
1280 146
218 146
461 146
975 146
916 146
765 146
864 146
958 146
991 146
406 145
860 145
1109 145
1209 145
1299 145
1300 145
176

## Dynamic cache, remove unused information

The target is to reduce the amount of memory required for the KV cache. The idea looks like this during generation:

Params:
- Cue window: 100 tokens
- Percentile: 95%
- Threshold

1. After generating "Cue" tokens
2. Update the `last_attended` statistics: for current layer, head, token, get the 95%. Any tokens that are not attended get incremented by 1. Any tokens that are attended reset to 0.
3. Prune: any token that has last timesteps not attended larger than threshold get dropped.
4. Continue generating, following 2 and 3. 

In [3]:
from transformers.cache_utils import Cache, DynamicCache, DynamicLayer, CacheLayerMixin

In [25]:
class RecentlyActivatedLayer(CacheLayerMixin):
    """This is currently only suitable for batch size == 1"""
    def __init__(self, threshold: int=50, contribution_threshold: float=0.95):
        super().__init__()
        self.threshold = threshold
        self.contribution_threshold = contribution_threshold
        self.iterations_since_last_activated = None

    def lazy_initialization(self, key_states):
        self.dtype, self.device = key_states.dtype, key_states.device
        self.keys = torch.tensor([], dtype=self.dtype, device=self.device)
        self.values = torch.tensor([], dtype=self.dtype, device=self.device)
        self.iterations_since_last_activated = torch.tensor([], dtype=torch.int32, device=self.device)
        self.is_initialized = True

    def get_mask_sizes(self, cache_position):
        kv_offset = 0
        query_length = cache_position.shape[0]
        kv_length = self.get_seq_length() + query_length
        return kv_length, kv_offset

    def get_seq_length(self):
        if not self.is_initialized or self.keys.numel() == 0:
            return 0
        return self.keys.shape[-2]

    def get_max_cache_shape(self):
        return -1

    def update(self, key_states, value_states, cache_kwargs=None):
        """Update the key and value caches in-place, and return the necessary keys and value states"""
        if not self.is_initialized:
            self.lazy_initialization(key_states)
        self.keys = torch.cat([self.keys, key_states], dim=-2)    # BxALL_TxD
        self.values = torch.cat([self.values, value_states], dim=-2)   # BxALL_TxD
        self.iterations_since_last_activated = torch.cat(   # ALL_T
            [
                self.iterations_since_last_activated,
                torch.zeros(
                    key_states.shape[2],
                    dtype=self.iterations_since_last_activated.dtype,
                    device=self.iterations_since_last_activated.device,
                )
            ],
            dim=0
        )
        return self.keys, self.values

    def update_attention(self, attn_weights):
        """Update the state based on attn_weights

        Args:
            attn_weights: the weights of the attention, has shape BxHxTxALL_T
        """
        sorted_values, sorted_indices = attn_weights.sort(dim=-1, descending=True)
        cumsum = torch.cumsum(sorted_values, dim=-1)    # BxHxTxALL_T
        exceeds_threshold = cumsum > self.contribution_threshold     # BxHxTxALL_T
        threshold_positions = exceeds_threshold.int().argmax(dim=-1) + 1  # BxHxT

        threshold_positions_cpu = threshold_positions.cpu()
        sorted_indices_cpu = sorted_indices.cpu()
        activated = []
        for head_idx in range(threshold_positions.shape[1]):
            for tok_idx in range(threshold_positions.shape[2]):
                count = threshold_positions_cpu[0,head_idx,tok_idx].item()
                activated += sorted_indices_cpu[0,head_idx,tok_idx,:count].tolist()
        activated = list(set(activated))
        self.iterations_since_last_activated += 1
        self.iterations_since_last_activated[activated] = 0

        to_keep = (self.iterations_since_last_activated <= self.threshold).nonzero().flatten()
        if to_keep.numel() == self.iterations_since_last_activated.numel():
            # Keep all
            return

        print("Will remove")
        to_remove = (self.iterations_since_last_activated > self.threshold).nonzero().flatten()
        print(f"{to_remove.cpu().tolist()=}")
        self.iterations_since_last_activated = self.iterations_since_last_activated[to_keep]
        self.keys = self.keys[:,:,to_keep,:]
        self.values = self.values[:,:,to_keep,:]


class RecentlyActivatedCache(Cache):
    def __init__(
        self, contribution_threshold=0.95, threshold=50, ddp_cache_data=None, config=None, offloading=False, offload_only_non_sliding=False
    ):
        self.contribution_threshold = contribution_threshold
        self.threshold = threshold
        
        layers = []
        # If a config is passed, use it to infer the layer types and initialize accordingly
        if config is not None:
            config = config.get_text_config(decoder=True)
            sliding_window = getattr(config, "sliding_window", None) or getattr(config, "attention_chunk_size", None)
            layer_types = getattr(config, "layer_types", None)
            if layer_types is None:
                layer_types = [
                    "sliding_attention" if sliding_window is not None else "full_attention"
                    for _ in range(config.num_hidden_layers)
                ]
            # Some models have shared layers thus no cache is needed for them (e.g. Gemma3n)
            if hasattr(config, "num_kv_shared_layers"):
                layer_types = layer_types[: -config.num_kv_shared_layers]

            for layer_type in layer_types:
                # From a cache point of view, both sliding and chunked are the same in how they should behave and how many
                # states they should return - only the mask changes to make them different at the end!
                layers.append(RecentlyActivatedLayer())

        # In this case, use the passed data to already fill in the Cache
        if ddp_cache_data is not None:
            # Init all the layers with the data
            for layer_idx, (key_states, value_states) in enumerate(ddp_cache_data):
                # If the config was not passed above, initialize a DynamicLayer for each entry of the ddp_data
                if config is None:
                    layers.append(RecentlyActivatedLayer())
                # Update the layer with the data
                _, _ = layers[layer_idx].update(key_states, value_states)

        # If neither of config nor ddp_data was passed, then simply lazy init a full cache of DynamicLayer
        if len(layers) == 0:
            super().__init__(
                layer_class_to_replicate=RecentlyActivatedLayer,
                offloading=offloading,
                offload_only_non_sliding=offload_only_non_sliding,
            )
        else:
            super().__init__(layers=layers, offloading=offloading, offload_only_non_sliding=offload_only_non_sliding)

    def update_attention(self, attn_weights, layer_idx):
        self.layers[layer_idx].update_attention(attn_weights)

In [26]:
cache = RecentlyActivatedCache()
def temp_attention(self, query_states, key_states, value_states, *args, **kwargs):
    attn_output, attn_weights = eager_attention_forward(self, query_states, key_states, value_states, *args, **kwargs)
    cache.update_attention(attn_weights=attn_weights, layer_idx=self.layer_idx)
    # print("---")
    # print(f" {query_states.shape=}, {key_states.shape=}, {value_states.shape=}")
    # print(f" {attn_weights.shape=}")
    return attn_output, attn_weights


ALL_ATTENTION_FUNCTIONS["temp_attention"] = temp_attention
AttentionMaskInterface.register("temp_attention", eager_mask)

# insp.model.config._attn_implementation = "temp_attention2"
# insp.original_model.config._attn_implementation = "temp_attention2"

In [27]:
# with insp.ctx(detach_state=True) as state:
insp.original_model.set_attn_implementation("temp_attention")
output = insp.original_model.generate(
    insp.tokenizer.apply_chat_template(
        [{"role": "user", "content": """What is the largest city in Europe?"""}], add_generation_prompt=True, return_tensors="pt"
    ).to(device=device),
    max_new_tokens=2048,
    do_sample=False,
    past_key_values=cache,
)
# print(output)
print(insp.tokenizer.decode(output[0]))

Will remove
to_remove.cpu().tolist()=[2, 16]
Will remove
to_remove.cpu().tolist()=[14]
Will remove
to_remove.cpu().tolist()=[14]
Will remove
to_remove.cpu().tolist()=[2]
Will remove
to_remove.cpu().tolist()=[1]
Will remove
to_remove.cpu().tolist()=[12]
Will remove
to_remove.cpu().tolist()=[1]
Will remove
to_remove.cpu().tolist()=[1]
Will remove
to_remove.cpu().tolist()=[14]
Will remove
to_remove.cpu().tolist()=[4]
Will remove
to_remove.cpu().tolist()=[2]
Will remove
to_remove.cpu().tolist()=[13]
Will remove
to_remove.cpu().tolist()=[4]
Will remove
to_remove.cpu().tolist()=[1, 2, 3]
Will remove
to_remove.cpu().tolist()=[4]
Will remove
to_remove.cpu().tolist()=[67]
Will remove
to_remove.cpu().tolist()=[63]
Will remove
to_remove.cpu().tolist()=[39]
Will remove
to_remove.cpu().tolist()=[39]
Will remove
to_remove.cpu().tolist()=[3]
Will remove
to_remove.cpu().tolist()=[2, 14]
Will remove
to_remove.cpu().tolist()=[5]
Will remove
to_remove.cpu().tolist()=[4]
Will remove
to_remove.cpu().tolist

## Calculate k-v similarity of tokens inside the cache

In [116]:
cache_layer = cache.layers[0]

In [117]:
print(f"{cache_layer.keys.shape=}")

kdot = cache_layer.keys @ cache_layer.keys.permute(0,1,3,2)
print(f"{kdot.shape=}")

knorm = cache_layer.keys.norm(p=2, dim=-1, keepdim=True)
print(f"{knorm.shape=}")

ksim = (kdot / (knorm @ knorm.permute(0,1,3,2)))[0]
ksim = ksim.mean(dim=0)
ksim = torch.triu(ksim, diagonal=1)
print(ksim.shape)

kindices = (ksim >= 0.95).nonzero()
print(kindices.shape)

cache_layer.keys.shape=torch.Size([1, 8, 1307, 128])
kdot.shape=torch.Size([1, 8, 1307, 1307])
knorm.shape=torch.Size([1, 8, 1307, 1])
torch.Size([1307, 1307])
torch.Size([853471, 2])


In [118]:
print(f"{cache_layer.values.shape=}")

vdot = cache_layer.values @ cache_layer.values.permute(0,1,3,2)
print(f"{vdot.shape=}")

vnorm = cache_layer.values.norm(p=2, dim=-1, keepdim=True)
print(f"{vnorm.shape=}")

vsim = (vdot / (vnorm @ vnorm.permute(0,1,3,2)))[0]
vsim = vsim.mean(dim=0)
vsim = torch.triu(vsim, diagonal=1)
print(vsim.shape)

vindices = (vsim >= 0.95).nonzero()
print(vindices.shape)

cache_layer.values.shape=torch.Size([1, 8, 1307, 128])
vdot.shape=torch.Size([1, 8, 1307, 1307])
vnorm.shape=torch.Size([1, 8, 1307, 1])
torch.Size([1307, 1307])
torch.Size([9353, 2])


In [119]:
kl = set([tuple(each) for each in kindices.tolist()])
vl = set([tuple(each) for each in vindices.tolist()])
print(len(kl.intersection(vl)))

9353


### Run with all layers

With cosine similarity of 0.9, it seems that we can't significantly remove the values, except in layer 0. Otherwise, we do quite a lot of computation without removing anything.

So we don't want to do this frequently in each generation step. Instead, it's more suitable for a dedicated compression stage.

In [138]:
thresh = 0.80

for cache_idx in range(len(cache.layers)):
    cache_layer = cache.layers[cache_idx]
    kdot = cache_layer.keys @ cache_layer.keys.permute(0,1,3,2)
    knorm = cache_layer.keys.norm(p=2, dim=-1, keepdim=True)
    ksim = (kdot / (knorm @ knorm.permute(0,1,3,2)))[0]
    ksim = ksim.mean(dim=0)
    ksim = torch.triu(ksim, diagonal=1)
    kindices = (ksim >= thresh).nonzero()
    
    vdot = cache_layer.values @ cache_layer.values.permute(0,1,3,2)
    vnorm = cache_layer.values.norm(p=2, dim=-1, keepdim=True)
    vsim = (vdot / (vnorm @ vnorm.permute(0,1,3,2)))[0]
    vsim = vsim.mean(dim=0)
    vsim = torch.triu(vsim, diagonal=1)
    vindices = (vsim >= thresh).nonzero()
    
    kl = set([tuple(each) for each in kindices.tolist()])
    vl = set([tuple(each) for each in vindices.tolist()])
    print(cache_idx, cache_layer.keys.shape[2], len(kl.intersection(vl)))

0 1307 10139
1 1307 5465
2 1307 1321
3 1300 806
4 1297 2775
5 1289 2858
6 1306 2603
7 1205 13
8 1273 1093
9 1289 149
10 1289 968
11 1272 1313
12 1290 771
13 1277 189
14 1281 649
15 1283 146
16 1274 152
17 1231 312
18 1247 398
19 1236 210
20 1208 261
21 1233 311
22 1192 159
23 1205 481
24 1221 544
25 1067 335
26 1144 315
27 1185 328
28 1229 431
29 1142 707
30 698 164
31 1012 226
32 614 104
33 971 379
34 1142 394
35 1293 597
