<a href="https://colab.research.google.com/drive/1xpBQSnukiNPka2oQDtLd7sKdQ0gt74PU?usp=sharing" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%%capture
!pip install transformers


In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

class AttnWrapper(torch.nn.Module):
    def __init__(self, attn):
        super().__init__()
        self.attn = attn
        self.activations = None
        self.add_tensor = None

    def forward(self, *args, **kwargs):
        output = self.attn(*args, **kwargs)
        if self.add_tensor is not None:
            output = (output[0] + self.add_tensor,) + output[1:]
        self.activations = output[0]
        return output

    def reset(self):
        self.activations = None
        self.add_tensor = None

class BlockOutputWrapper(torch.nn.Module):
    def __init__(self, block, lm_head, norm):
        super().__init__()
        self.block = block
        self.lm_head = lm_head
        self.norm = norm
        self.attn_mech_output_unembedded = None
        self.intermediate_res_unembedded = None
        self.mlp_output_unembedded = None
        self.block_output_unembedded = None

    def forward(self, x, past_key_value=None, attention_mask=None, position_ids=None, **kwargs):
        output = self.block(x, past_key_value=past_key_value, attention_mask=attention_mask,
                          position_ids=position_ids, **kwargs)

        if isinstance(output, tuple):
            hidden_states = output[0]
        else:
            hidden_states = output

        attention_output = self.get_attn_activations()

        self.attn_mech_output_unembedded = self.lm_head(self.norm(attention_output)) if attention_output is not None else None
        self.intermediate_res_unembedded = self.lm_head(self.norm(hidden_states))
        self.mlp_output_unembedded = self.lm_head(self.norm(hidden_states))
        self.block_output_unembedded = self.lm_head(self.norm(hidden_states))

        return output

    def attn_add_tensor(self, tensor):
        self.block.self_attn.add_tensor = tensor

    def reset(self):
        self.block.self_attn.reset()
        self.attn_mech_output_unembedded = None
        self.intermediate_res_unembedded = None
        self.mlp_output_unembedded = None
        self.block_output_unembedded = None

    def get_attn_activations(self):
        return self.block.self_attn.activations

class Qwen2Helper:
    def __init__(self):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct")
        self.model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-7B-Instruct").to(self.device)

        # Wrap attention modules
        for i, layer in enumerate(self.model.model.layers):
            layer.self_attn = AttnWrapper(layer.self_attn)
            self.model.model.layers[i] = BlockOutputWrapper(layer, self.model.lm_head, self.model.model.norm)

    def generate_text(self, prompt, max_length=100):
        inputs = self.tokenizer(prompt, return_tensors="pt")
        generate_ids = self.model.generate(
            inputs.input_ids.to(self.device),
            max_length=max_length,
            pad_token_id=self.tokenizer.eos_token_id
        )
        return self.tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]

    def get_logits(self, prompt):
        inputs = self.tokenizer(prompt, return_tensors="pt")
        with torch.no_grad():
            logits = self.model(inputs.input_ids.to(self.device)).logits
        return logits

    def set_add_attn_output(self, layer, add_output):
        self.model.model.layers[layer].attn_add_tensor(add_output)

    def get_attn_activations(self, layer):
        return self.model.model.layers[layer].get_attn_activations()

    def reset_all(self):
        for layer in self.model.model.layers:
            layer.reset()

    def print_decoded_activations(self, decoded_activations, label):
        softmaxed = torch.nn.functional.softmax(decoded_activations[0][-1], dim=-1)
        values, indices = torch.topk(softmaxed, 10)
        probs_percent = [int(v * 100) for v in values.tolist()]
        tokens = self.tokenizer.batch_decode(indices.unsqueeze(-1))
        print(label, list(zip(tokens, probs_percent)))

    def decode_all_layers(self, text, topk=10, print_attn_mech=True, print_intermediate_res=True,
                         print_mlp=True, print_block=True):
        self.get_logits(text)
        for i, layer in enumerate(self.model.model.layers):
            print(f'Layer {i}: Decoded intermediate outputs')
            if print_attn_mech:
                self.print_decoded_activations(layer.attn_mech_output_unembedded, 'Attention mechanism')
            if print_intermediate_res:
                self.print_decoded_activations(layer.intermediate_res_unembedded, 'Intermediate residual stream')
            if print_mlp:
                self.print_decoded_activations(layer.mlp_output_unembedded, 'MLP output')
            if print_block:
                self.print_decoded_activations(layer.block_output_unembedded, 'Block output')




In [3]:
model = Qwen2Helper()


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/7.30k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/2.78M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/1.67M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/7.03M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/663 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/27.8k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/4 [00:00<?, ?it/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/3.95G [00:00<?, ?B/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/3.86G [00:00<?, ?B/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/3.86G [00:00<?, ?B/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/3.56G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/243 [00:00<?, ?B/s]

In [7]:
prompt = "The capital of China is "

In [8]:
model.decode_all_layers(prompt,print_attn_mech=True, print_intermediate_res=True, print_mlp=True,print_block=True)

Layer 0: Decoded intermediate outputs
Attention mechanism [(' strugg', 27), ('从根本', 10), ('准', 3), ('냄', 2), ('-strokes', 2), ('icester', 1), ('1', 1), ('ことがあります', 1), (' PartialView', 0), ('ProgressHUD', 0)]
Intermediate residual stream [('从根本', 62), ('嘞', 2), (' strugg', 1), (' PartialView', 1), ('.MouseAdapter', 0), (' Beste', 0), ('냄', 0), ('始建', 0), ('suming', 0), ('icester', 0)]
MLP output [('从根本', 62), ('嘞', 2), (' strugg', 1), (' PartialView', 1), ('.MouseAdapter', 0), (' Beste', 0), ('냄', 0), ('始建', 0), ('suming', 0), ('icester', 0)]
Block output [('从根本', 62), ('嘞', 2), (' strugg', 1), (' PartialView', 1), ('.MouseAdapter', 0), (' Beste', 0), ('냄', 0), ('始建', 0), ('suming', 0), ('icester', 0)]
Layer 1: Decoded intermediate outputs
Attention mechanism [(' 自动生成', 11), ('-svg', 6), (' TORT', 3), (' bureaucr', 2), (' MySqlConnection', 2), (' disadv', 2), ('性价', 2), ('颏', 1), (' uncert', 1), ('瑀', 1)]
Intermediate residual stream [(' strugg', 69), ('从根本', 12), (' bureaucr', 2), ('始

In [None]:
=+model.reset_all()
layer = 14
model.get_logits('bananas')
attn = model.get_attn_activations(layer)
last_token_attn = attn[0][-1]
model.set_add_attn_output(layer, 0.6*last_token_attn)

In [None]:
model.generate_text(prompt, max_length=20)

In [None]:
model.reset_all()