<a href="https://colab.research.google.com/github/tbaeumel/MI_tutorials/blob/main/Early_Decoding_By_Layer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Based on the Demo Notebook by Jack Merullo - "Language Models Implement Simple Word2Vec-Style Vector Arithmetic"

What does GPT2-medium predict for the sentence 'The capital city of Poland is'




In [3]:
from transformers import AutoTokenizer, AutoModelForCausalLM

In [4]:
# Load model and tokenizer
model_name = "gpt2-medium"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

In [10]:
# Tokenize input
text = """Q: What is the capital of France?
A: Paris
Q: What is the capital of Poland?
A:"""
encoded_input = tokenizer(text, return_tensors='pt')

# Generate output
output_ids = model.generate(encoded_input['input_ids'], max_new_tokens=2, num_return_sequences=1)
output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
print(output_text)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


Q: What is the capital of France?
A: Paris
Q: What is the capital of Poland?
A: Warsaw



We want to understand how GPT2-medium builds the prediction for 'The capital city of France is ' layer by layer.

Step 1 - Let's get a feel for the model strucure

In [11]:
print(model)

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 1024)
    (wpe): Embedding(1024, 1024)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-23): 24 x GPT2Block(
        (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D(nf=3072, nx=1024)
          (c_proj): Conv1D(nf=1024, nx=1024)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D(nf=4096, nx=1024)
          (c_proj): Conv1D(nf=1024, nx=4096)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=1024, out_features=50257, bias=False)
)


A little bit of magic: output_hidden_states=True

In [27]:
model = AutoModelForCausalLM.from_pretrained(model_name, output_hidden_states=True)



In [30]:
import torch
import torch.nn.functional as F

In [34]:
# Tokenize Input
text = """Q: What is the capital of France?
A: Paris
Q: What is the capital of Poland?
A:"""
encoded_input = tokenizer(text, return_tensors='pt')

# Forward pass through the model to capture intermediate predictions
with torch.no_grad():
    outputs = model(**encoded_input)

# Extract residual streams (hidden states) after each layer
hidden_states = outputs.hidden_states
last_token_position = encoded_input["input_ids"].size(1)-1  # Last token index

# Decode top 5 predictions after each layer
top_k = 5
intermediate_predictions = []
for layer_idx, hidden_state in enumerate(hidden_states):
    # Take the hidden state at the last token position
    last_token_hidden_state = hidden_state[:, last_token_position, :]

    # Pass it through the final layer norm and generate logits
    normalized_hidden_state = model.transformer.ln_f(last_token_hidden_state)
    logits = model.lm_head(normalized_hidden_state)

    # Calculate probabilities using softmax
    probabilities = F.softmax(logits, dim=-1)

    # Get the top-k predictions
    top_k_probs, top_k_indices = torch.topk(probabilities, k=top_k, dim=-1)
    top_k_tokens = tokenizer.batch_decode(top_k_indices[0], skip_special_tokens=True)

    # Store layer predictions
    intermediate_predictions.append({
        "layer": layer_idx,
        "predictions": [{"token": token, "probability": prob.item()} for token, prob in zip(top_k_tokens, top_k_probs[0])]
    })

[{'layer': 0, 'predictions': [{'token': ' unden', 'probability': 0.9470192193984985}, {'token': ' nodd', 'probability': 0.021967994049191475}, {'token': ' enthusi', 'probability': 0.011501138098537922}, {'token': ':', 'probability': 0.0050586313009262085}, {'token': ' neighb', 'probability': 0.005044832825660706}]}, {'layer': 1, 'predictions': [{'token': ' (', 'probability': 0.02316311188042164}, {'token': ' [', 'probability': 0.018392831087112427}, {'token': ' The', 'probability': 0.012803658843040466}, {'token': ':', 'probability': 0.009204289875924587}, {'token': ',', 'probability': 0.00885365903377533}]}, {'layer': 2, 'predictions': [{'token': ' A', 'probability': 0.019568055868148804}, {'token': ' The', 'probability': 0.01659923605620861}, {'token': ' (', 'probability': 0.014010699465870857}, {'token': ' [', 'probability': 0.010149898007512093}, {'token': ' Is', 'probability': 0.0076424493454396725}]}, {'layer': 3, 'predictions': [{'token': ' A', 'probability': 0.02688725665211677

In [35]:
# Pretty-print
for layer_prediction in intermediate_predictions:
    layer = layer_prediction["layer"]
    print(f"\nLayer {layer} Predictions:")
    for prediction in layer_prediction["predictions"]:
        token = prediction["token"]
        probability = prediction["probability"]
        print(f"  Token: '{token}' | Probability: {probability:.4f}")



Layer 0 Predictions:
  Token: ' unden' | Probability: 0.9470
  Token: ' nodd' | Probability: 0.0220
  Token: ' enthusi' | Probability: 0.0115
  Token: ':' | Probability: 0.0051
  Token: ' neighb' | Probability: 0.0050

Layer 1 Predictions:
  Token: ' (' | Probability: 0.0232
  Token: ' [' | Probability: 0.0184
  Token: ' The' | Probability: 0.0128
  Token: ':' | Probability: 0.0092
  Token: ',' | Probability: 0.0089

Layer 2 Predictions:
  Token: ' A' | Probability: 0.0196
  Token: ' The' | Probability: 0.0166
  Token: ' (' | Probability: 0.0140
  Token: ' [' | Probability: 0.0101
  Token: ' Is' | Probability: 0.0076

Layer 3 Predictions:
  Token: ' A' | Probability: 0.0269
  Token: ' [' | Probability: 0.0186
  Token: ' (' | Probability: 0.0174
  Token: ' The' | Probability: 0.0138
  Token: ' At' | Probability: 0.0116

Layer 4 Predictions:
  Token: ' A' | Probability: 0.0393
  Token: ' [' | Probability: 0.0229
  Token: ' (' | Probability: 0.0203
  Token: ' Act' | Probability: 0.0182
 

**What happens in layer 19?**

How can we find out?

Unfortunately there is no built-in function (like output_hidden_states = True 😞)

There is a (manual) solution though: **Hooks**

As an exercise, let's re-implement the layer wise predictions without output_hidden_states = True

In [62]:
# Load model without hidden states
model_name = "gpt2-medium"
model = AutoModelForCausalLM.from_pretrained(model_name)

In [None]:
print(model)

In [75]:
import torch
import torch.nn.functional as F
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM
import pandas as pd
from IPython.display import display

class GPT2WithHooks:
    def __init__(self, model_name="gpt2-medium", top_k=5, device=None):
        # Load model and tokenizer
        self.model = AutoModelForCausalLM.from_pretrained(model_name)
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.TOP_K = top_k

        # Set device (default to 'cuda' if available, otherwise 'cpu')
        self.device = device if device is not None else ('cuda' if torch.cuda.is_available() else 'cpu')
        self.model.to(self.device)

        # Initialize the activations dictionary
        self.set_hooks_gpt2()

    def set_hooks_gpt2(self):
        final_layer = self.model.config.n_layer - 1

        for attr in ["activations_"]:
            if not hasattr(self.model, attr):
                setattr(self.model, attr, {})

        def get_activation(name):
            def hook(module, input, output):
                if "mlp" in name or "attn" in name:
                    if "attn" in name:
                        num_tokens = list(output[0].size())[1]
                        self.model.activations_[name] = output[0][:, num_tokens - 1].detach()
                    elif "mlp" in name:
                        num_tokens = list(output[0].size())[0]  # [num_tokens, 3072] for values;
                        self.model.activations_[name] = output[0][num_tokens - 1].detach()
                elif "residual" in name:
                    num_tokens = list(input[0].size())[1]  # (batch, sequence, hidden_state)
                    if name == "layer_residual_" + str(final_layer):
                        self.model.activations_[name] = self.model.activations_["intermediate_residual_" + str(final_layer)] + self.model.activations_["mlp_" + str(final_layer)]
                    else:
                        self.model.activations_[name] = input[0][:, num_tokens - 1].detach()

            return hook

        # Register hooks
        for i in range(self.model.config.n_layer):
            if i != 0:
                self.model.transformer.h[i].ln_1.register_forward_hook(get_activation("layer_residual_" + str(i - 1)))
            self.model.transformer.h[i].ln_2.register_forward_hook(get_activation("intermediate_residual_" + str(i)))

            self.model.transformer.h[i].attn.register_forward_hook(get_activation("attn_" + str(i)))
            self.model.transformer.h[i].mlp.register_forward_hook(get_activation("mlp_" + str(i)))

        self.model.transformer.ln_f.register_forward_hook(get_activation("layer_residual_" + str(final_layer)))

    def forward(self, text):
        encoded_input = self.tokenizer(text, return_tensors='pt')

        # Forward pass to trigger hooks
        with torch.no_grad():
            self.model(**encoded_input)

        # Return activations
        return self.model.activations_

    def get_resid_predictions(self, sentence):
        """
        This function computes the predictions at different layers of GPT-2 using activations from residual layers.
        """
        layer_residual_preds = []

        tokens = self.tokenizer(sentence, return_tensors="pt")
        tokens.to(self.device)

        # Output with hidden states
        output = self.model(**tokens, output_hidden_states=True)

        for layer in self.model.activations_.keys():
            if "layer_residual" in layer:
                normed = self.model.transformer.ln_f(self.model.activations_[layer])

                logits = torch.matmul(self.model.lm_head.weight, normed.T)

                probs = F.softmax(logits.T[0], dim=-1)

                probs = torch.reshape(probs, (-1,)).detach().cpu().numpy()

                assert np.abs(np.sum(probs) - 1) <= 0.01, str(np.abs(np.sum(probs) - 1)) + layer

                probs_ = []
                for index, prob in enumerate(probs):
                    probs_.append((index, prob))

                # Get top-k predictions
                top_k = sorted(probs_, key=lambda x: x[1], reverse=True)[:self.TOP_K]
                top_k = [(t[1].item(), self.tokenizer.decode(t[0])) for t in top_k]

            if "layer_residual" in layer:
                layer_residual_preds.append(top_k)

        return layer_residual_preds

    def display_predictions(self, sentence):
        layer_residual_preds= self.get_resid_predictions(sentence)

        print(f"Predictions for: {sentence}\n")

        # Display layer residual predictions
        print("Layer Residual Predictions:")
        for i, preds in enumerate(layer_residual_preds):
            print(f"Layer {i}: {preds}")

# Example usage
gpt2_with_hooks = GPT2WithHooks()

# Run some text through the model to collect activations
sentence = """Q: What is the capital of France?
A: Paris
Q: What is the capital of Poland?
A:"""
gpt2_with_hooks.display_predictions(sentence)


Predictions for: Q: What is the capital of France?
A: Paris
Q: What is the capital of Poland?
A:

Layer Residual Predictions:
Layer 0: [(0.02316313423216343, ' ('), (0.018392857164144516, ' ['), (0.012803674675524235, ' The'), (0.00920428428798914, ':'), (0.008853655308485031, ',')]
Layer 1: [(0.019568050280213356, ' A'), (0.01659923605620861, ' The'), (0.014010702259838581, ' ('), (0.010149895213544369, ' ['), (0.007642451208084822, ' Is')]
Layer 2: [(0.026887251064181328, ' A'), (0.018573066219687462, ' ['), (0.017433306202292442, ' ('), (0.013812464661896229, ' The'), (0.011615300551056862, ' At')]
Layer 3: [(0.03929824382066727, ' A'), (0.022901220247149467, ' ['), (0.020305326208472252, ' ('), (0.018195513635873795, ' Act'), (0.017830180004239082, ' At')]
Layer 4: [(0.038885049521923065, ' A'), (0.024648193269968033, ' ['), (0.023253243416547775, ' At'), (0.014608520083129406, ' Q'), (0.011924289166927338, ' (')]
Layer 5: [(0.035387828946113586, ' A'), (0.033396828919649124, ' M')

**Let's register more hooks!**

In [72]:
import torch
import torch.nn.functional as F
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM
import pandas as pd
from IPython.display import display

class GPT2WithHooks:
    def __init__(self, model_name="gpt2-medium", top_k=5, device=None):
        # Load model and tokenizer
        self.model = AutoModelForCausalLM.from_pretrained(model_name)
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.TOP_K = top_k

        # Set device (default to 'cuda' if available, otherwise 'cpu')
        self.device = device if device is not None else ('cuda' if torch.cuda.is_available() else 'cpu')
        self.model.to(self.device)

        # Initialize the activations dictionary
        self.set_hooks_gpt2()

    def set_hooks_gpt2(self):
        final_layer = self.model.config.n_layer - 1

        for attr in ["activations_"]:
            if not hasattr(self.model, attr):
                setattr(self.model, attr, {})

        def get_activation(name):
            def hook(module, input, output):
                if "mlp" in name or "attn" in name or "m_coef" in name:
                    if "attn" in name:
                        num_tokens = list(output[0].size())[1]
                        self.model.activations_[name] = output[0][:, num_tokens - 1].detach()
                    elif "mlp" in name:
                        num_tokens = list(output[0].size())[0]  # [num_tokens, 3072] for values;
                        self.model.activations_[name] = output[0][num_tokens - 1].detach()
                    elif "m_coef" in name:
                        num_tokens = list(input[0].size())[1]  # (batch, sequence, hidden_state)
                        self.model.activations_[name] = input[0][:, num_tokens - 1].detach()
                elif "residual" in name or "embedding" in name:
                    num_tokens = list(input[0].size())[1]  # (batch, sequence, hidden_state)
                    if name == "layer_residual_" + str(final_layer):
                        self.model.activations_[name] = self.model.activations_[
                                                            "intermediate_residual_" + str(final_layer)] + \
                                                        self.model.activations_["mlp_" + str(final_layer)]
                    else:
                        self.model.activations_[name] = input[0][:, num_tokens - 1].detach()

            return hook

        # Register hooks
        self.model.transformer.h[0].ln_1.register_forward_hook(get_activation("input_embedding"))

        for i in range(self.model.config.n_layer):
            if i != 0:
                self.model.transformer.h[i].ln_1.register_forward_hook(get_activation("layer_residual_" + str(i - 1)))
            self.model.transformer.h[i].ln_2.register_forward_hook(get_activation("intermediate_residual_" + str(i)))

            self.model.transformer.h[i].attn.register_forward_hook(get_activation("attn_" + str(i)))
            self.model.transformer.h[i].mlp.register_forward_hook(get_activation("mlp_" + str(i)))
            self.model.transformer.h[i].mlp.c_proj.register_forward_hook(get_activation("m_coef_" + str(i)))

        self.model.transformer.ln_f.register_forward_hook(get_activation("layer_residual_" + str(final_layer)))

    def forward(self, text):
        encoded_input = self.tokenizer(text, return_tensors='pt')

        # Forward pass to trigger hooks
        with torch.no_grad():
            self.model(**encoded_input)

        # Return activations
        return self.model.activations_

    def get_resid_predictions(self, sentence):
        """
        This function computes the intermediate predictions at different layers of GPT-2
        using activations from residual layers and intermediate layers.
        """
        layer_residual_preds = []
        intermed_residual_preds = []

        tokens = self.tokenizer(sentence, return_tensors="pt")
        tokens.to(self.device)

        # Output with hidden states
        output = self.model(**tokens, output_hidden_states=True)

        for layer in self.model.activations_.keys():
            if "layer_residual" in layer or "intermediate_residual" in layer:
                normed = self.model.transformer.ln_f(self.model.activations_[layer])

                logits = torch.matmul(self.model.lm_head.weight, normed.T)

                probs = F.softmax(logits.T[0], dim=-1)

                probs = torch.reshape(probs, (-1,)).detach().cpu().numpy()

                assert np.abs(np.sum(probs) - 1) <= 0.01, str(np.abs(np.sum(probs) - 1)) + layer

                probs_ = []
                for index, prob in enumerate(probs):
                    probs_.append((index, prob))

                # Get top-k predictions
                top_k = sorted(probs_, key=lambda x: x[1], reverse=True)[:self.TOP_K]
                top_k = [(t[1].item(), self.tokenizer.decode(t[0])) for t in top_k]

            if "layer_residual" in layer:
                layer_residual_preds.append(top_k)
            elif "intermediate_residual" in layer:
                intermed_residual_preds.append(top_k)

        return layer_residual_preds, intermed_residual_preds

    def display_predictions(self, sentence):
        layer_residual_preds, intermed_residual_preds = self.get_resid_predictions(sentence)

        print(f"Predictions for: {sentence}\n")

        # Display layer residual predictions
        print("Layer Residual Predictions:")
        for i, preds in enumerate(layer_residual_preds):
            print(f"Layer {i}: {preds}")

        # Display intermediate residual predictions
        print("\nIntermediate Residual Predictions:")
        for i, preds in enumerate(intermed_residual_preds):
            print(f"Layer {i}: {preds}")

# Example usage
gpt2_with_hooks = GPT2WithHooks()

# Run some text through the model to collect activations
sentence = """Q: What is the capital of France?
A: Paris
Q: What is the capital of Poland?
A:"""
gpt2_with_hooks.display_predictions(sentence)


Predictions for: Q: What is the capital of France?
A: Paris
Q: What is the capital of Poland?
A:

Layer Residual Predictions:
Layer 0: [(0.02316313423216343, ' ('), (0.018392857164144516, ' ['), (0.012803674675524235, ' The'), (0.00920428428798914, ':'), (0.008853655308485031, ',')]
Layer 1: [(0.019568050280213356, ' A'), (0.01659923605620861, ' The'), (0.014010702259838581, ' ('), (0.010149895213544369, ' ['), (0.007642451208084822, ' Is')]
Layer 2: [(0.026887251064181328, ' A'), (0.018573066219687462, ' ['), (0.017433306202292442, ' ('), (0.013812464661896229, ' The'), (0.011615300551056862, ' At')]
Layer 3: [(0.03929824382066727, ' A'), (0.022901220247149467, ' ['), (0.020305326208472252, ' ('), (0.018195513635873795, ' Act'), (0.017830180004239082, ' At')]
Layer 4: [(0.038885049521923065, ' A'), (0.024648193269968033, ' ['), (0.023253243416547775, ' At'), (0.014608520083129406, ' Q'), (0.011924289166927338, ' (')]
Layer 5: [(0.035387828946113586, ' A'), (0.033396828919649124, ' M')

What exactly happens between Intermediate Residual Predictions and Layer Residual Predictions

In [21]:
from transformers.models.gpt2.modeling_gpt2 import GPT2Block
import inspect
print(inspect.getsource(GPT2Block.forward))

    def forward(
        self,
        hidden_states: Optional[Tuple[torch.FloatTensor]],
        layer_past: Optional[Tuple[torch.Tensor]] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = False,
        output_attentions: Optional[bool] = False,
    ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
        residual = hidden_states
        hidden_states = self.ln_1(hidden_states)
        attn_outputs = self.attn(
            hidden_states,
            layer_past=layer_past,
            attention_mask=attention_mask,
            head_mask=head_mask,
            use_cache=use_cache,
            output_attentions=output_attentions,
        )
        attn_output = attn_outputs[0]  # output_attn: a, present, (a