<a href="https://colab.research.google.com/drive/1OcYRKAsVI-me1zmtnhwoQfdyg6y_SVd3?usp=sharing" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%%capture
!pip install transformers


In [None]:
token = input("Enter your HuggingFace token:")

In [3]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

class AttnWrapper(torch.nn.Module):
    def __init__(self, attn):
        super().__init__()
        self.attn = attn
        self.activations = None
        self.add_tensor = None

    def forward(self, *args, **kwargs):
        output = self.attn(*args, **kwargs)
        if self.add_tensor is not None:
            output = (output[0] + self.add_tensor,) + output[1:]
        self.activations = output[0]
        return output

    def reset(self):
        self.activations = None
        self.add_tensor = None

class BlockOutputWrapper(torch.nn.Module):
    def __init__(self, block, lm_head, norm):
        super().__init__()
        self.block = block
        self.lm_head = lm_head
        self.norm = norm
        self.attn_mech_output_unembedded = None
        self.intermediate_res_unembedded = None
        self.mlp_output_unembedded = None
        self.block_output_unembedded = None

    def forward(self, x, past_key_value=None, attention_mask=None, position_ids=None, **kwargs):
        output = self.block(x, past_key_value=past_key_value, attention_mask=attention_mask, position_ids=position_ids, **kwargs)

        # Store intermediate outputs
        if isinstance(output, tuple):
            hidden_states = output[0]
        else:
            hidden_states = output

        attention_output = self.get_attn_activations()

        # Calculate and store unembedded outputs
        self.attn_mech_output_unembedded = self.lm_head(self.norm(attention_output)) if attention_output is not None else None
        self.intermediate_res_unembedded = self.lm_head(self.norm(hidden_states))
        self.mlp_output_unembedded = self.lm_head(self.norm(hidden_states))
        self.block_output_unembedded = self.lm_head(self.norm(hidden_states))

        return output

    def attn_add_tensor(self, tensor):
        self.block.self_attn.add_tensor = tensor

    def reset(self):
        self.block.self_attn.reset()
        self.attn_mech_output_unembedded = None
        self.intermediate_res_unembedded = None
        self.mlp_output_unembedded = None
        self.block_output_unembedded = None

    def get_attn_activations(self):
        return self.block.self_attn.activations

class Llama3_1_8BHelper:
    def __init__(self, token):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B", use_auth_token=token)
        self.model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B", use_auth_token=token).to(self.device)
        # Wrap attention modules
        for i, layer in enumerate(self.model.model.layers):
            layer.self_attn = AttnWrapper(layer.self_attn)
            self.model.model.layers[i] = BlockOutputWrapper(layer, self.model.lm_head, self.model.model.norm)


    def generate_text(self, prompt, max_length=100):
        inputs = self.tokenizer(prompt, return_tensors="pt")
        generate_ids = self.model.generate(
            inputs.input_ids.to(self.device),
            max_length=max_length,
            pad_token_id=self.tokenizer.eos_token_id
        )
        return self.tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]


    def get_logits(self, prompt):
        inputs = self.tokenizer(prompt, return_tensors="pt")
        with torch.no_grad():
          logits = self.model(inputs.input_ids.to(self.device)).logits
          return logits

    def set_add_attn_output(self, layer, add_output):
        self.model.model.layers[layer].attn_add_tensor(add_output)

    def get_attn_activations(self, layer):
        return self.model.model.layers[layer].get_attn_activations()

    def reset_all(self):
        for layer in self.model.model.layers:
            layer.reset()

    def print_decoded_activations(self, decoded_activations, label):
        softmaxed = torch.nn.functional.softmax(decoded_activations[0][-1], dim=-1)
        values, indices = torch.topk(softmaxed, 10)
        probs_percent = [int(v * 100) for v in values.tolist()]
        tokens = self.tokenizer.batch_decode(indices.unsqueeze(-1))
        print(label, list(zip(tokens, probs_percent)))


    def decode_all_layers(self, text, topk=10, print_attn_mech=True, print_intermediate_res=True, print_mlp=True, print_block=True):
        self.get_logits(text)
        for i, layer in enumerate(self.model.model.layers):
            print(f'Layer {i}: Decoded intermediate outputs')
            if print_attn_mech:
                self.print_decoded_activations(layer.attn_mech_output_unembedded, 'Attention mechanism')
            if print_intermediate_res:
                self.print_decoded_activations(layer.intermediate_res_unembedded, 'Intermediate residual stream')
            if print_mlp:
                self.print_decoded_activations(layer.mlp_output_unembedded, 'MLP output')
            if print_block:
                self.print_decoded_activations(layer.block_output_unembedded, 'Block output')



In [4]:
model = Llama3_1_8BHelper(token)



tokenizer_config.json:   0%|          | 0.00/50.5k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/73.0 [00:00<?, ?B/s]



config.json:   0%|          | 0.00/826 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/23.9k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/4 [00:00<?, ?it/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/4.92G [00:00<?, ?B/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/1.17G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/185 [00:00<?, ?B/s]

In [5]:
prompt = "The capital of China is "

In [6]:
model.decode_all_layers(prompt,print_attn_mech=True, print_intermediate_res=True, print_mlp=True,print_block=True)

Layer 0: Decoded intermediate outputs
Attention mechanism [(':].', 0), (".']", 0), ("'gc", 0), ('./(', 0), ('혼', 0), ('.dataTables', 0), ('\xa0ro', 0), (' nackte', 0), ('internet', 0), ('ickness', 0)]
Intermediate residual stream [('Disposition', 1), ('RYPTO', 1), (' removeFrom', 0), (' Tanner', 0), ('енту', 0), ('.updateDynamic', 0), ('opher', 0), ('ISON', 0), ('ColumnType', 0), ('allen', 0)]
MLP output [('Disposition', 1), ('RYPTO', 1), (' removeFrom', 0), (' Tanner', 0), ('енту', 0), ('.updateDynamic', 0), ('opher', 0), ('ISON', 0), ('ColumnType', 0), ('allen', 0)]
Block output [('Disposition', 1), ('RYPTO', 1), (' removeFrom', 0), (' Tanner', 0), ('енту', 0), ('.updateDynamic', 0), ('opher', 0), ('ISON', 0), ('ColumnType', 0), ('allen', 0)]
Layer 1: Decoded intermediate outputs
Attention mechanism [(' المتحدة', 2), (' dap', 0), (' securely', 0), (' �', 0), ('olut', 0), ('autop', 0), ('��', 0), ('emand', 0), ('ewood', 0), ('ď', 0)]
Intermediate residual stream [('.updateDynamic', 4)

In [7]:
model.decode_all_layers('My favorite dish to eat is',print_attn_mech=False, print_intermediate_res=False, print_mlp=False,print_block=True)

Layer 0: Decoded intermediate outputs
Block output [('otope', 0), ('OSP', 0), ('/is', 0), ('gın', 0), ('/w', 0), ('ridor', 0), (' disag', 0), ('uo', 0), ('alien', 0), ('consin', 0)]
Layer 1: Decoded intermediate outputs
Block output [('��', 3), ('μι', 1), ("'gc", 0), ('gın', 0), ('ジア', 0), ('يث', 0), (' gì', 0), ('psc', 0), ('uali', 0), ('otope', 0)]
Layer 2: Decoded intermediate outputs
Block output [('.Areas', 1), ("'gc", 0), ('@mail', 0), ('оск', 0), (' Bair', 0), ('.Criteria', 0), ('embr', 0), ('NECT', 0), ('дах', 0), ('άς', 0)]
Layer 3: Decoded intermediate outputs
Block output [("'gc", 2), ('avian', 0), ('ırak', 0), ('▍', 0), ('اکی', 0), ('šak', 0), ('�', 0), ('ディ', 0), ('.Areas', 0), ('>tag', 0)]
Layer 4: Decoded intermediate outputs
Block output [('gnu', 0), ('dere', 0), ('ouns', 0), ("'gc", 0), ('控', 0), ('Bes', 0), ('缘', 0), ('ヴァ', 0), ('OMPI', 0), ('_physical', 0)]
Layer 5: Decoded intermediate outputs
Block output [('Bes', 1), ('控', 0), ('ůst', 0), ('ynom', 0), ('(exports',

In [None]:
=+model.reset_all()
layer = 14
model.get_logits('bananas')
attn = model.get_attn_activations(layer)
last_token_attn = attn[0][-1]
model.set_add_attn_output(layer, 0.6*last_token_attn)

In [8]:
model.generate_text(prompt, max_length=50)

The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


KeyboardInterrupt: 

In [None]:
model.reset_all()