<a href="https://colab.research.google.com/github/wlg1/numseqcont_circuit_expms/blob/main/nb_templates/headFNs_expms_template.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" align="left"/></a>&nbsp;or in a local notebook.

# Setup
(No need to change anything)

In [1]:
# Janky code to do different setup when run in a Colab notebook vs VSCode
DEBUG_MODE = False
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
    %pip install git+https://github.com/neelnanda-io/TransformerLens.git
    # Install another version of node that makes PySvelte work way faster
    !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs
    %pip install git+https://github.com/neelnanda-io/PySvelte.git
except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

Running as a Colab notebook
Collecting git+https://github.com/neelnanda-io/TransformerLens.git
  Cloning https://github.com/neelnanda-io/TransformerLens.git to /tmp/pip-req-build-f7aqhjgf
  Running command git clone --filter=blob:none --quiet https://github.com/neelnanda-io/TransformerLens.git /tmp/pip-req-build-f7aqhjgf
  Resolved https://github.com/neelnanda-io/TransformerLens.git to commit 218ebd6f491f47f5e2f64e4c4327548b60a093eb
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting datasets>=2.7.1 (from transformer-lens==0.0.0)
  Downloading datasets-2.13.1-py3-none-any.whl (486 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m486.2/486.2 kB[0m [31m9.5 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting einops>=0.6.0 (from transformer-lens==0.0.0)
  Downloading einops-0.6.1-py3-none-any.whl (42 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━


## Installing the NodeSource Node.js 16.x repo...


## Populating apt-get cache...

+ apt-get update
Get:1 https://cloud.r-project.org/bin/linux/ubuntu focal-cran40/ InRelease [3,622 B]
Get:2 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64  InRelease [1,581 B]
Hit:3 http://archive.ubuntu.com/ubuntu focal InRelease
Get:4 http://security.ubuntu.com/ubuntu focal-security InRelease [114 kB]
Hit:5 http://ppa.launchpad.net/c2d4u.team/c2d4u4.0+/ubuntu focal InRelease
Get:6 http://archive.ubuntu.com/ubuntu focal-updates InRelease [114 kB]
Hit:7 http://ppa.launchpad.net/cran/libgit2/ubuntu focal InRelease
Get:8 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64  Packages [1,084 kB]
Hit:9 http://ppa.launchpad.net/deadsnakes/ppa/ubuntu focal InRelease
Get:10 http://archive.ubuntu.com/ubuntu focal-backports InRelease [108 kB]
Hit:11 http://ppa.launchpad.net/graphics-drivers/ppa/ubuntu focal InRelease
Hit:12 http://ppa.launchpad.net/ubuntugis/ppa/

In [2]:
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio

if IN_COLAB or not DEBUG_MODE:
    # Thanks to annoying rendering issues, Plotly graphics will either show up in colab OR Vscode depending on the renderer - this is bad for developing demos! Thus creating a debug mode.
    pio.renderers.default = "colab"
else:
    pio.renderers.default = "png"

In [3]:
# Import stuff
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
from fancy_einsum import einsum
import tqdm.notebook as tqdm
import random
from pathlib import Path
import plotly.express as px
from torch.utils.data import DataLoader

from jaxtyping import Float, Int
from typing import List, Union, Optional
from functools import partial
import copy

import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets
from IPython.display import HTML

In [4]:
import pysvelte

import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

We turn automatic differentiation off, to save GPU memory, as this notebook focuses on model inference not model training.

In [5]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7fca900c89a0>

Plotting helper functions:

In [6]:
def imshow(tensor, renderer=None, **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", **kwargs).show(renderer)

def line(tensor, renderer=None, **kwargs):
    px.line(y=utils.to_numpy(tensor), **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

# Load Model

Decide which model to use (eg. gpt2-small vs -medium)

In [7]:
model = HookedTransformer.from_pretrained(
    "gpt2-small",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True,
)

Downloading (…)lve/main/config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

Downloading (…)neration_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Using pad_token, but it is not set yet.


Loaded pretrained model gpt2-small into HookedTransformer


# Extracted Model

## use only L9, then only L0 and L9

In [None]:
import torch.nn as nn

class ExtractedModel(nn.Module):
    def __init__(self, original_model):
        super(ExtractedModel, self).__init__()
        self.embed = original_model.embed
        self.pos_embed = original_model.pos_embed

        # MLP 0 seems impt as another embedding layer
        self.ln1_0 = original_model.blocks[0].ln1
        self.attn_0 = original_model.blocks[0].attn
        self.ln2_0 = original_model.blocks[0].ln2
        self.mlp_0 = original_model.blocks[0].mlp  # the MLP layer in the 9th transformer block

        self.ln1 = original_model.blocks[9].ln1
        self.attn = original_model.blocks[9].attn
        self.ln2 = original_model.blocks[9].ln2
        self.mlp = original_model.blocks[9].mlp  # the MLP layer in the 9th transformer block
        self.ln_final = original_model.ln_final
        self.unembed = original_model.unembed

    def forward(self, tokens):
        # tokens [batch, position]
        embed = self.embed(tokens)
        pos_embed = self.pos_embed(tokens)
        resid_pre = embed + pos_embed

        normalized_resid_pre = self.ln1_0(resid_pre)
        attn_out = self.attn_0(normalized_resid_pre, normalized_resid_pre, normalized_resid_pre)
        resid_mid = resid_pre + attn_out

        normalized_resid_mid = self.ln2_0(resid_mid)
        # normalized_resid_mid = self.ln2(resid_pre)
        mlp_out = self.mlp_0(normalized_resid_mid)
        resid_post = resid_mid + mlp_out
        # resid_post = resid_pre + mlp_out

        normalized_resid_pre = self.ln1(resid_post)
        # normalized_resid_pre = self.ln1(resid_pre)
        attn_out = self.attn(normalized_resid_pre, normalized_resid_pre, normalized_resid_pre)
        resid_mid = resid_pre + attn_out

        normalized_resid_mid = self.ln2(resid_mid)
        # normalized_resid_mid = self.ln2(resid_pre)
        mlp_out = self.mlp(normalized_resid_mid)
        resid_post = resid_mid + mlp_out
        # resid_post = resid_pre + mlp_out

        normalized_resid_final = self.ln_final(resid_post)
        logits = self.unembed(normalized_resid_final)
        # logits have shape [batch, position, logits]
        return logits

extracted_model = ExtractedModel(model)

In [None]:
tokens = model.to_tokens(["1 2 3 4"], prepend_bos=True)
tokens = tokens.cuda() # Move the tokens to the GPU
original_logits = extracted_model(tokens) # Run the model and cache all activations

In [None]:
original_logits.shape

torch.Size([1, 5, 50257])

In [10]:
def remove_batch_dim(
    tensor: Float[torch.Tensor, "1 ..."]
) -> Float[torch.Tensor, "..."]:
    """
    Removes the first dimension of a tensor if it is size 1, otherwise returns the tensor unchanged
    """
    if tensor.shape[0] == 1:
        return tensor.squeeze(0)
    else:
        return tensor

def test_prompt(
    prompt: str,
    answer: str,
    model,
    orig_model,
    prepend_space_to_answer: bool = True,
    print_details: bool = True,
    prepend_bos: bool = True,
    top_k: int = 10,
):
    """
    Function to test whether a model can give the correct answer to a prompt. Intended for exploratory analysis, so it prints things out rather than returning things.

    Works for multi-token answers and multi-token prompts.

    Will always print the ranks of the answer tokens, and if print_details will print the logit and prob for the answer tokens and the top k tokens returned for each answer position.
    """
    if prepend_space_to_answer and not answer.startswith(" "):
        answer = " " + answer
    # GPT-2 often treats the first token weirdly, so lets give it a resting position
    tokens = orig_model.to_tokens(prompt + answer, prepend_bos=prepend_bos)


    prompt_str_tokens = orig_model.to_str_tokens(prompt, prepend_bos=prepend_bos)
    answer_str_tokens = orig_model.to_str_tokens(answer, prepend_bos=False)
    prompt_length = len(prompt_str_tokens)
    answer_length = len(answer_str_tokens)
    if print_details:
        print("Tokenized prompt:", prompt_str_tokens)
        print("Tokenized answer:", answer_str_tokens)

    logits = remove_batch_dim(model(tokens))

    probs = logits.softmax(dim=-1)
    answer_ranks = []
    for index in range(prompt_length, prompt_length + answer_length):
        answer_token = tokens[0, index]
        answer_str_token = answer_str_tokens[index - prompt_length]
        # Offset by 1 because models predict the NEXT token
        token_probs = probs[index - 1]
        sorted_token_probs, sorted_token_values = token_probs.sort(descending=True)
        # Janky way to get the index of the token in the sorted list - I couldn't find a better way?
        correct_rank = torch.arange(len(sorted_token_values))[
            (sorted_token_values == answer_token).cpu()
        ].item()
        answer_ranks.append((answer_str_token, correct_rank))
        if print_details:
            # String formatting syntax - the first number gives the number of characters to pad to, the second number gives the number of decimal places.
            # rprint gives rich text printing
            print(
                f"Performance on answer token:\n[b]Rank: {correct_rank: <8} Logit: {logits[index-1, answer_token].item():5.2f} Prob: {token_probs[answer_token].item():6.2%} Token: |{answer_str_token}|[/b]"
            )
            for i in range(top_k):
                print(
                    f"Top {i}th token. Logit: {logits[index-1, sorted_token_values[i]].item():5.2f} Prob: {sorted_token_probs[i].item():6.2%} Token: |{orig_model.to_string(sorted_token_values[i])}|"
                )
    print(f"[b]Ranks of the answer tokens:[/b] {answer_ranks}")

In [None]:
test_prompt("1 2 3 4 5 6 7 8 9", " 10", extracted_model, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', '1', ' 2', ' 3', ' 4', ' 5', ' 6', ' 7', ' 8', ' 9']
Tokenized answer: [' 10']
Performance on answer token:
[b]Rank: 234      Logit: 13.02 Prob:  0.02% Token: | 10|[/b]
Top 0th token. Logit: 19.92 Prob: 22.72% Token: | South|
Top 1th token. Logit: 19.65 Prob: 17.50% Token: | North|
Top 2th token. Logit: 18.69 Prob:  6.70% Token: | live|
Top 3th token. Logit: 18.58 Prob:  5.98% Token: | West|
Top 4th token. Logit: 18.57 Prob:  5.94% Token: |,|
Top 5th token. Logit: 17.93 Prob:  3.14% Token: | a|
Top 6th token. Logit: 17.27 Prob:  1.62% Token: | Force|
Top 7th token. Logit: 17.09 Prob:  1.34% Token: | Long|
Top 8th token. Logit: 16.88 Prob:  1.10% Token: | U|
Top 9th token. Logit: 16.84 Prob:  1.05% Token: | in|
[b]Ranks of the answer tokens:[/b] [(' 10', 234)]


In [None]:
test_prompt("January February March April", " May", extracted_model, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', 'January', ' February', ' March', ' April']
Tokenized answer: [' May']
Performance on answer token:
[b]Rank: 733      Logit: 11.10 Prob:  0.00% Token: | May|[/b]
Top 0th token. Logit: 19.69 Prob: 17.42% Token: | North|
Top 1th token. Logit: 19.49 Prob: 14.33% Token: | South|
Top 2th token. Logit: 19.03 Prob:  9.01% Token: | West|
Top 3th token. Logit: 19.01 Prob:  8.87% Token: | live|
Top 4th token. Logit: 18.11 Prob:  3.61% Token: | a|
Top 5th token. Logit: 18.05 Prob:  3.40% Token: |,|
Top 6th token. Logit: 17.62 Prob:  2.20% Token: | Force|
Top 7th token. Logit: 17.57 Prob:  2.10% Token: | Islands|
Top 8th token. Logit: 17.09 Prob:  1.30% Token: | Long|
Top 9th token. Logit: 16.92 Prob:  1.10% Token: | in|
[b]Ranks of the answer tokens:[/b] [(' May', 733)]


## Test validity of extraction class

Test if can recover original model using this extraction on all layers

If not, extraction is wrong

In [None]:
import torch.nn as nn

class ExtractedModel_full(nn.Module):
    def __init__(self, original_model):
        super(ExtractedModel_full, self).__init__()
        self.embed = original_model.embed
        self.pos_embed = original_model.pos_embed

        self.blocks = original_model.blocks

        self.ln_final = original_model.ln_final
        self.unembed = original_model.unembed

    def forward(self, tokens):
        # tokens [batch, position]
        embed = self.embed(tokens)
        pos_embed = self.pos_embed(tokens)
        residual = embed + pos_embed

        for block in self.blocks:
            residual = block(residual)

        normalized_resid_final = self.ln_final(residual)
        logits = self.unembed(normalized_resid_final)
        # logits have shape [batch, position, logits]
        return logits

extracted_model_full = ExtractedModel_full(model)

In [None]:
test_prompt("1 2 3 4", " 5", extracted_model_full, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', '1', ' 2', ' 3', ' 4']
Tokenized answer: [' 5']
Performance on answer token:
[b]Rank: 0        Logit: 18.76 Prob: 96.17% Token: | 5|[/b]
Top 0th token. Logit: 18.76 Prob: 96.17% Token: | 5|
Top 1th token. Logit: 13.27 Prob:  0.40% Token: | Next|
Top 2th token. Logit: 13.01 Prob:  0.30% Token: |
|
Top 3th token. Logit: 12.87 Prob:  0.27% Token: | >|
Top 4th token. Logit: 12.04 Prob:  0.12% Token: | 4|
Top 5th token. Logit: 11.88 Prob:  0.10% Token: | 50|
Top 6th token. Logit: 11.83 Prob:  0.09% Token: | 6|
Top 7th token. Logit: 11.71 Prob:  0.08% Token: | <|
Top 8th token. Logit: 11.64 Prob:  0.08% Token: | $|
Top 9th token. Logit: 11.63 Prob:  0.08% Token: | 1|
[b]Ranks of the answer tokens:[/b] [(' 5', 0)]


This is an exact match to:

https://colab.research.google.com/drive/1uSuPtHrh9venKNlIt2O-1piknwfQpJmK#scrollTo=m90WRkxYIAiL&line=1&uniqifier=1

In [None]:
test_prompt("January February March April", " May", extracted_model_full, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', 'January', ' February', ' March', ' April']
Tokenized answer: [' May']
Performance on answer token:
[b]Rank: 0        Logit: 30.05 Prob: 99.87% Token: | May|[/b]
Top 0th token. Logit: 30.05 Prob: 99.87% Token: | May|
Top 1th token. Logit: 22.47 Prob:  0.05% Token: | 5|
Top 2th token. Logit: 21.69 Prob:  0.02% Token: | July|
Top 3th token. Logit: 21.12 Prob:  0.01% Token: |
|
Top 4th token. Logit: 21.04 Prob:  0.01% Token: | June|
Top 5th token. Logit: 19.32 Prob:  0.00% Token: | Wh|
Top 6th token. Logit: 19.18 Prob:  0.00% Token: | If|
Top 7th token. Logit: 18.90 Prob:  0.00% Token: | March|
Top 8th token. Logit: 18.79 Prob:  0.00% Token: | April|
Top 9th token. Logit: 18.69 Prob:  0.00% Token: | Please|
[b]Ranks of the answer tokens:[/b] [(' May', 0)]


Yes, it's able to do this. Thus, our extracted model is correct. Can we see what happens as we gradually add layers? Perhaps from 0 to 9? (This is like logit lens, except we can take away layers)

Given that vector addition logit lens showed this is possible, it should be.

 Try different layers combos

## keep only L0 to L9

In [None]:
import torch.nn as nn

class ExtractedModel_toL9(nn.Module):
    def __init__(self, original_model):
        super(ExtractedModel_toL9, self).__init__()
        self.embed = original_model.embed
        self.pos_embed = original_model.pos_embed
        self.blocks = original_model.blocks
        self.ln_final = original_model.ln_final
        self.unembed = original_model.unembed

    def forward(self, tokens):
        # tokens [batch, position]
        embed = self.embed(tokens)
        pos_embed = self.pos_embed(tokens)
        residual = embed + pos_embed

        for block in self.blocks[0:10]:
            residual = block(residual)

        normalized_resid_final = self.ln_final(residual)
        logits = self.unembed(normalized_resid_final)
        # logits have shape [batch, position, logits]
        return logits

extracted_model_toL9 = ExtractedModel_toL9(model)

In [None]:
test_prompt("1 2 3 4", " 5", extracted_model_toL9, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', '1', ' 2', ' 3', ' 4']
Tokenized answer: [' 5']
Performance on answer token:
[b]Rank: 0        Logit: 34.66 Prob: 99.98% Token: | 5|[/b]
Top 0th token. Logit: 34.66 Prob: 99.98% Token: | 5|
Top 1th token. Logit: 25.98 Prob:  0.02% Token: | 4|
Top 2th token. Logit: 24.19 Prob:  0.00% Token: | 6|
Top 3th token. Logit: 23.46 Prob:  0.00% Token: | 7|
Top 4th token. Logit: 23.30 Prob:  0.00% Token: | Next|
Top 5th token. Logit: 22.81 Prob:  0.00% Token: | 9|
Top 6th token. Logit: 21.98 Prob:  0.00% Token: |5|
Top 7th token. Logit: 21.46 Prob:  0.00% Token: |Next|
Top 8th token. Logit: 20.87 Prob:  0.00% Token: |♥|
Top 9th token. Logit: 20.43 Prob:  0.00% Token: | 3|
[b]Ranks of the answer tokens:[/b] [(' 5', 0)]


It's actually doing BETTER than with L10 and L11. This seems inconsistent with this plot; even though there are dips at 10 and 11, they seem to still go up in the end:

https://colab.research.google.com/drive/1eavp74fMMDHIBeVeE0UupRmlXHPAenoa#scrollTo=gYOOrypHIAiR&line=2&uniqifier=1

In [None]:
test_prompt("one two three four", " five", extracted_model_toL9, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', 'one', ' two', ' three', ' four']
Tokenized answer: [' five']
Performance on answer token:
[b]Rank: 0        Logit: 26.86 Prob: 83.10% Token: | five|[/b]
Top 0th token. Logit: 26.86 Prob: 83.10% Token: | five|
Top 1th token. Logit: 23.97 Prob:  4.63% Token: | four|
Top 2th token. Logit: 23.30 Prob:  2.37% Token: | Four|
Top 3th token. Logit: 23.28 Prob:  2.32% Token: | Five|
Top 4th token. Logit: 23.18 Prob:  2.11% Token: | seven|
Top 5th token. Logit: 22.42 Prob:  0.98% Token: |teen|
Top 6th token. Logit: 22.29 Prob:  0.86% Token: | six|
Top 7th token. Logit: 22.16 Prob:  0.76% Token: |five|
Top 8th token. Logit: 22.09 Prob:  0.71% Token: |Five|
Top 9th token. Logit: 22.06 Prob:  0.69% Token: | 5|
[b]Ranks of the answer tokens:[/b] [(' five', 0)]


## Try skipping layers


In [None]:
[0, 1, 2, 3, 4, 5][0:3]  # basic pythoon sanity check

[0, 1, 2]

### skip L5 only

Try skipping L5, because it's red:

https://colab.research.google.com/drive/1eavp74fMMDHIBeVeE0UupRmlXHPAenoa#scrollTo=TNhyDx1XIAia&line=1&uniqifier=1

In [None]:
import torch.nn as nn

class ExtractedModel(nn.Module):
    def __init__(self, original_model):
        super(ExtractedModel, self).__init__()
        self.embed = original_model.embed
        self.pos_embed = original_model.pos_embed
        self.blocks = original_model.blocks
        self.ln_final = original_model.ln_final
        self.unembed = original_model.unembed

    def forward(self, tokens):
        # tokens [batch, position]
        embed = self.embed(tokens)
        pos_embed = self.pos_embed(tokens)
        residual = embed + pos_embed

        for block in self.blocks[0:5] + self.blocks[6:10]:
            residual = block(residual)

        normalized_resid_final = self.ln_final(residual)
        logits = self.unembed(normalized_resid_final)
        # logits have shape [batch, position, logits]
        return logits

extracted_model = ExtractedModel(model)

In [None]:
test_prompt("1 2 3 4", " 5", extracted_model, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', '1', ' 2', ' 3', ' 4']
Tokenized answer: [' 5']
Performance on answer token:
[b]Rank: 56       Logit: 12.34 Prob:  0.29% Token: | 5|[/b]
Top 0th token. Logit: 15.91 Prob: 10.41% Token: |th|
Top 1th token. Logit: 15.70 Prob:  8.43% Token: |34|
Top 2th token. Logit: 15.64 Prob:  7.91% Token: | -|
Top 3th token. Logit: 15.10 Prob:  4.63% Token: |54|
Top 4th token. Logit: 15.01 Prob:  4.20% Token: |ts|
Top 5th token. Logit: 14.98 Prob:  4.08% Token: |:|
Top 6th token. Logit: 14.66 Prob:  2.98% Token: |67|
Top 7th token. Logit: 14.45 Prob:  2.42% Token: |74|
Top 8th token. Logit: 14.35 Prob:  2.18% Token: |.|
Top 9th token. Logit: 14.26 Prob:  1.99% Token: |39|
[b]Ranks of the answer tokens:[/b] [(' 5', 56)]


In [None]:
test_prompt("January February March April", " May", extracted_model, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', 'January', ' February', ' March', ' April']
Tokenized answer: [' May']
Performance on answer token:
[b]Rank: 62       Logit: 13.65 Prob:  0.02% Token: | May|[/b]
Top 0th token. Logit: 20.77 Prob: 25.32% Token: | 29|
Top 1th token. Logit: 20.18 Prob: 14.05% Token: | 24|
Top 2th token. Logit: 19.67 Prob:  8.47% Token: | 26|
Top 3th token. Logit: 19.50 Prob:  7.09% Token: | 28|
Top 4th token. Logit: 19.39 Prob:  6.37% Token: | 2018|
Top 5th token. Logit: 19.34 Prob:  6.06% Token: | 27|
Top 6th token. Logit: 18.96 Prob:  4.17% Token: | 4|
Top 7th token. Logit: 18.73 Prob:  3.31% Token: | 30|
Top 8th token. Logit: 18.34 Prob:  2.22% Token: | 6|
Top 9th token. Logit: 18.19 Prob:  1.92% Token: | 3|
[b]Ranks of the answer tokens:[/b] [(' May', 62)]


It's much worse. We need L5. Yet, it's still recognizing numbers.

### keep only L0 to L8

In [None]:
import torch.nn as nn

class ExtractedModel(nn.Module):
    def __init__(self, original_model):
        super(ExtractedModel, self).__init__()
        self.embed = original_model.embed
        self.pos_embed = original_model.pos_embed
        self.blocks = original_model.blocks
        self.ln_final = original_model.ln_final
        self.unembed = original_model.unembed

    def forward(self, tokens):
        # tokens [batch, position]
        embed = self.embed(tokens)
        pos_embed = self.pos_embed(tokens)
        residual = embed + pos_embed

        for block in self.blocks[0:9]: # + self.blocks[9:10]:
            residual = block(residual)

        normalized_resid_final = self.ln_final(residual)
        logits = self.unembed(normalized_resid_final)
        # logits have shape [batch, position, logits]
        return logits

extracted_model = ExtractedModel(model)

In [None]:
test_prompt("1 2 3 4", " 5", extracted_model, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', '1', ' 2', ' 3', ' 4']
Tokenized answer: [' 5']
Performance on answer token:
[b]Rank: 1        Logit: 26.25 Prob:  2.53% Token: | 5|[/b]
Top 0th token. Logit: 29.89 Prob: 96.58% Token: | 4|
Top 1th token. Logit: 26.25 Prob:  2.53% Token: | 5|
Top 2th token. Logit: 24.31 Prob:  0.36% Token: | 3|
Top 3th token. Logit: 23.98 Prob:  0.26% Token: | 6|
Top 4th token. Logit: 22.35 Prob:  0.05% Token: | 8|
Top 5th token. Logit: 22.18 Prob:  0.04% Token: | 9|
Top 6th token. Logit: 22.10 Prob:  0.04% Token: | 7|
Top 7th token. Logit: 21.94 Prob:  0.03% Token: | ..........|
Top 8th token. Logit: 21.46 Prob:  0.02% Token: | 1|
Top 9th token. Logit: 20.93 Prob:  0.01% Token: | <!--|
[b]Ranks of the answer tokens:[/b] [(' 5', 1)]


As we see in logit lens, L9 is needed to convert the +1. But L0 to L8 are enough to "get the number". This means we can't rely on a single layer.

In [None]:
test_prompt("January February March April", " May", extracted_model, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', 'January', ' February', ' March', ' April']
Tokenized answer: [' May']
Performance on answer token:
[b]Rank: 4        Logit: 22.44 Prob:  1.95% Token: | May|[/b]
Top 0th token. Logit: 26.08 Prob: 74.10% Token: | April|
Top 1th token. Logit: 23.84 Prob:  7.93% Token: | 2014|
Top 2th token. Logit: 23.35 Prob:  4.84% Token: | 2015|
Top 3th token. Logit: 22.79 Prob:  2.78% Token: | March|
Top 4th token. Logit: 22.44 Prob:  1.95% Token: | May|
Top 5th token. Logit: 22.26 Prob:  1.62% Token: | 2017|
Top 6th token. Logit: 21.93 Prob:  1.17% Token: | June|
Top 7th token. Logit: 21.70 Prob:  0.93% Token: | September|
Top 8th token. Logit: 21.03 Prob:  0.48% Token: | 2018|
Top 9th token. Logit: 20.97 Prob:  0.45% Token: | 1989|
[b]Ranks of the answer tokens:[/b] [(' May', 4)]


Same thing for months (as seen in logit lens)

### keep only L0 to 7

In [None]:
import torch.nn as nn

class ExtractedModel(nn.Module):
    def __init__(self, original_model):
        super(ExtractedModel, self).__init__()
        self.embed = original_model.embed
        self.pos_embed = original_model.pos_embed
        self.blocks = original_model.blocks
        self.ln_final = original_model.ln_final
        self.unembed = original_model.unembed

    def forward(self, tokens):
        # tokens [batch, position]
        embed = self.embed(tokens)
        pos_embed = self.pos_embed(tokens)
        residual = embed + pos_embed

        for block in self.blocks[0:8]: # + self.blocks[9:10]:
            residual = block(residual)

        normalized_resid_final = self.ln_final(residual)
        logits = self.unembed(normalized_resid_final)
        # logits have shape [batch, position, logits]
        return logits

extracted_model = ExtractedModel(model)

In [None]:
test_prompt("1 2 3 4", " 5", extracted_model, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', '1', ' 2', ' 3', ' 4']
Tokenized answer: [' 5']
Performance on answer token:
[b]Rank: 4        Logit: 20.50 Prob:  1.07% Token: | 5|[/b]
Top 0th token. Logit: 24.64 Prob: 67.40% Token: | 4|
Top 1th token. Logit: 23.67 Prob: 25.44% Token: | ..........|
Top 2th token. Logit: 20.82 Prob:  1.48% Token: | <!--|
Top 3th token. Logit: 20.73 Prob:  1.35% Token: | ★|
Top 4th token. Logit: 20.50 Prob:  1.07% Token: | 5|
Top 5th token. Logit: 20.14 Prob:  0.75% Token: | Tycoon|
Top 6th token. Logit: 19.83 Prob:  0.55% Token: | 6|
Top 7th token. Logit: 18.99 Prob:  0.24% Token: | 3|
Top 8th token. Logit: 18.52 Prob:  0.15% Token: | 8|
Top 9th token. Logit: 18.44 Prob:  0.14% Token: | Next|
[b]Ranks of the answer tokens:[/b] [(' 5', 4)]


As we saw in logit lens, L0 to L7 is required to recognize the next should be a number

### L0 to L7, skip 8, go only to 9

In [None]:
import torch.nn as nn

class ExtractedModel(nn.Module):
    def __init__(self, original_model):
        super(ExtractedModel, self).__init__()
        self.embed = original_model.embed
        self.pos_embed = original_model.pos_embed
        self.blocks = original_model.blocks
        self.ln_final = original_model.ln_final
        self.unembed = original_model.unembed

    def forward(self, tokens):
        # tokens [batch, position]
        embed = self.embed(tokens)
        pos_embed = self.pos_embed(tokens)
        residual = embed + pos_embed

        for block in self.blocks[0:8] + self.blocks[9:10]:
            residual = block(residual)

        normalized_resid_final = self.ln_final(residual)
        logits = self.unembed(normalized_resid_final)
        # logits have shape [batch, position, logits]
        return logits

extracted_model = ExtractedModel(model)

In [None]:
test_prompt("1 2 3 4", " 5", extracted_model, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', '1', ' 2', ' 3', ' 4']
Tokenized answer: [' 5']
Performance on answer token:
[b]Rank: 0        Logit: 31.84 Prob: 99.94% Token: | 5|[/b]
Top 0th token. Logit: 31.84 Prob: 99.94% Token: | 5|
Top 1th token. Logit: 23.68 Prob:  0.03% Token: | Next|
Top 2th token. Logit: 23.00 Prob:  0.01% Token: | 4|
Top 3th token. Logit: 22.21 Prob:  0.01% Token: | 6|
Top 4th token. Logit: 21.70 Prob:  0.00% Token: |Next|
Top 5th token. Logit: 21.10 Prob:  0.00% Token: | 7|
Top 6th token. Logit: 20.47 Prob:  0.00% Token: | <!--|
Top 7th token. Logit: 20.00 Prob:  0.00% Token: | Five|
Top 8th token. Logit: 19.96 Prob:  0.00% Token: | Player|
Top 9th token. Logit: 19.03 Prob:  0.00% Token: | 9|
[b]Ranks of the answer tokens:[/b] [(' 5', 0)]


We see L8 is not important for this. Excellent performance by skipping it! It is not as good as the 99.98% if we kept L8, but its contribution is trivial. This is odd, seeing that 8.11 does "soemthing"; perhaps it acts as a backup.

In [None]:
test_prompt("January February March April", " May", extracted_model, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', 'January', ' February', ' March', ' April']
Tokenized answer: [' May']
Performance on answer token:
[b]Rank: 0        Logit: 26.42 Prob: 77.81% Token: | May|[/b]
Top 0th token. Logit: 26.42 Prob: 77.81% Token: | May|
Top 1th token. Logit: 24.63 Prob: 13.05% Token: | 5|
Top 2th token. Logit: 23.03 Prob:  2.63% Token: | 2015|
Top 3th token. Logit: 22.87 Prob:  2.23% Token: | 2005|
Top 4th token. Logit: 21.28 Prob:  0.46% Token: | April|
Top 5th token. Logit: 21.26 Prob:  0.45% Token: | 05|
Top 6th token. Logit: 21.20 Prob:  0.42% Token: | 25|
Top 7th token. Logit: 20.89 Prob:  0.31% Token: | June|
Top 8th token. Logit: 20.13 Prob:  0.14% Token: | 1995|
Top 9th token. Logit: 20.05 Prob:  0.13% Token: |
|
[b]Ranks of the answer tokens:[/b] [(' May', 0)]


In [None]:
test_prompt("Sunday Monday Tuesday Wednesday", " Thursday", extracted_model, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', 'Sunday', ' Monday', ' Tuesday', ' Wednesday']
Tokenized answer: [' Thursday']
Performance on answer token:
[b]Rank: 0        Logit: 31.53 Prob: 87.14% Token: | Thursday|[/b]
Top 0th token. Logit: 31.53 Prob: 87.14% Token: | Thursday|
Top 1th token. Logit: 29.13 Prob:  7.88% Token: | Wednesday|
Top 2th token. Logit: 27.48 Prob:  1.52% Token: | September|
Top 3th token. Logit: 27.10 Prob:  1.04% Token: | January|
Top 4th token. Logit: 26.31 Prob:  0.47% Token: | Friday|
Top 5th token. Logit: 26.14 Prob:  0.40% Token: | March|
Top 6th token. Logit: 25.95 Prob:  0.33% Token: | December|
Top 7th token. Logit: 25.89 Prob:  0.31% Token: | Thurs|
Top 8th token. Logit: 25.68 Prob:  0.25% Token: | Saturday|
Top 9th token. Logit: 25.21 Prob:  0.16% Token: | July|
[b]Ranks of the answer tokens:[/b] [(' Thursday', 0)]


In [None]:
test_prompt("A B C D", " E", extracted_model, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', 'A', ' B', ' C', ' D']
Tokenized answer: [' E']
Performance on answer token:
[b]Rank: 0        Logit: 22.92 Prob: 54.31% Token: | E|[/b]
Top 0th token. Logit: 22.92 Prob: 54.31% Token: | E|
Top 1th token. Logit: 22.65 Prob: 41.40% Token: | D|
Top 2th token. Logit: 18.40 Prob:  0.59% Token: | F|
Top 3th token. Logit: 18.25 Prob:  0.51% Token: | d|
Top 4th token. Logit: 17.57 Prob:  0.26% Token: | T|
Top 5th token. Logit: 17.52 Prob:  0.24% Token: |
|
Top 6th token. Logit: 17.47 Prob:  0.23% Token: | G|
Top 7th token. Logit: 17.10 Prob:  0.16% Token: | e|
Top 8th token. Logit: 16.75 Prob:  0.11% Token: | 3|
Top 9th token. Logit: 16.64 Prob:  0.10% Token: | Y|
[b]Ranks of the answer tokens:[/b] [(' E', 0)]


But the alphabet gets worse without 8

In [None]:
test_prompt("A B C D", " E", extracted_model_full, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', 'A', ' B', ' C', ' D']
Tokenized answer: [' E']
Performance on answer token:
[b]Rank: 0        Logit: 24.16 Prob: 99.24% Token: | E|[/b]
Top 0th token. Logit: 24.16 Prob: 99.24% Token: | E|
Top 1th token. Logit: 18.62 Prob:  0.39% Token: | e|
Top 2th token. Logit: 17.16 Prob:  0.09% Token: | F|
Top 3th token. Logit: 16.45 Prob:  0.04% Token: | G|
Top 4th token. Logit: 15.88 Prob:  0.03% Token: | É|
Top 5th token. Logit: 15.80 Prob:  0.02% Token: | 1|
Top 6th token. Logit: 15.03 Prob:  0.01% Token: | D|
Top 7th token. Logit: 15.01 Prob:  0.01% Token: | ER|
Top 8th token. Logit: 14.57 Prob:  0.01% Token: | 2|
Top 9th token. Logit: 14.55 Prob:  0.01% Token: |E|
[b]Ranks of the answer tokens:[/b] [(' E', 0)]


### skip attention in L9, only use MLP9

In [None]:
import torch.nn as nn

class ExtractedModel(nn.Module):
    def __init__(self, original_model):
        super(ExtractedModel, self).__init__()
        self.embed = original_model.embed
        self.pos_embed = original_model.pos_embed

        self.blocks = original_model.blocks

        # self.ln1 = original_model.blocks[9].ln1
        # self.attn = original_model.blocks[9].attn
        # self.ln2 = original_model.blocks[9].ln2
        # self.mlp = original_model.blocks[9].mlp  # the MLP layer in the 9th transformer block

        self.ln_final = original_model.ln_final
        self.unembed = original_model.unembed

    def forward(self, tokens):
        # tokens [batch, position]
        embed = self.embed(tokens)
        pos_embed = self.pos_embed(tokens)
        residual = embed + pos_embed

        for block in self.blocks[0:8]:
            residual = block(residual)

        # normalized_resid_pre = self.ln1(residual)
        # attn_out = self.attn(normalized_resid_pre, normalized_resid_pre, normalized_resid_pre)
        # residual = residual + attn_out

        normalized_resid_mid = self.blocks[9].ln2(residual)
        mlp_out = self.blocks[9].mlp(normalized_resid_mid)
        residual = residual + mlp_out

        normalized_resid_final = self.ln_final(residual)
        logits = self.unembed(normalized_resid_final)
        # logits have shape [batch, position, logits]
        return logits

extracted_model = ExtractedModel(model)

In [None]:
test_prompt("1 2 3 4", " 5", extracted_model, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', '1', ' 2', ' 3', ' 4']
Tokenized answer: [' 5']
Performance on answer token:
[b]Rank: 0        Logit: 28.38 Prob: 98.87% Token: | 5|[/b]
Top 0th token. Logit: 28.38 Prob: 98.87% Token: | 5|
Top 1th token. Logit: 22.89 Prob:  0.41% Token: | 4|
Top 2th token. Logit: 22.12 Prob:  0.19% Token: | <!--|
Top 3th token. Logit: 21.41 Prob:  0.09% Token: | Next|
Top 4th token. Logit: 21.16 Prob:  0.07% Token: | 3|
Top 5th token. Logit: 20.93 Prob:  0.06% Token: | 7|
Top 6th token. Logit: 20.76 Prob:  0.05% Token: | ★|
Top 7th token. Logit: 20.61 Prob:  0.04% Token: |★|
Top 8th token. Logit: 20.17 Prob:  0.03% Token: | ..........|
Top 9th token. Logit: 19.97 Prob:  0.02% Token: |
|
[b]Ranks of the answer tokens:[/b] [(' 5', 0)]


It doesn't even need attention L9, though it helps it give it a full 1%. it just needs the MLP. This is consistent with vector addition notebook.

Vector addition showed only adding output of MLP to residual recovered

it seems L0 to L7 do preprocessing to be fit for MLP adding. So having "just" MLP9 isn't enough, it needs to get the digit. Then MLP convert 4 to 5, or April to May, etc.

In [None]:
test_prompt("January February March April", " May", extracted_model, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', 'January', ' February', ' March', ' April']
Tokenized answer: [' May']
Performance on answer token:
[b]Rank: 2        Logit: 22.27 Prob:  9.29% Token: | May|[/b]
Top 0th token. Logit: 23.13 Prob: 21.98% Token: | 5|
Top 1th token. Logit: 22.79 Prob: 15.61% Token: | 2015|
Top 2th token. Logit: 22.27 Prob:  9.29% Token: | May|
Top 3th token. Logit: 21.64 Prob:  4.97% Token: | 25|
Top 4th token. Logit: 21.57 Prob:  4.59% Token: | 2005|
Top 5th token. Logit: 21.25 Prob:  3.35% Token: | Top|
Top 6th token. Logit: 21.22 Prob:  3.25% Token: | 27|
Top 7th token. Logit: 21.03 Prob:  2.70% Token: | Category|
Top 8th token. Logit: 20.83 Prob:  2.19% Token: | 2017|
Top 9th token. Logit: 20.57 Prob:  1.70% Token: | 29|
[b]Ranks of the answer tokens:[/b] [(' May', 2)]


Wait, without attention in L9, it becomes 5? Did we just uncover the link between the primordial and modern covering?

In [None]:
test_prompt("Sunday Monday Tuesday Wednesday", " Thursday", extracted_model, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', 'Sunday', ' Monday', ' Tuesday', ' Wednesday']
Tokenized answer: [' Thursday']
Performance on answer token:
[b]Rank: 0        Logit: 26.03 Prob: 26.03% Token: | Thursday|[/b]
Top 0th token. Logit: 26.03 Prob: 26.03% Token: | Thursday|
Top 1th token. Logit: 25.61 Prob: 17.13% Token: | Wednesday|
Top 2th token. Logit: 25.35 Prob: 13.13% Token: | September|
Top 3th token. Logit: 25.01 Prob:  9.39% Token: | July|
Top 4th token. Logit: 24.89 Prob:  8.33% Token: | January|
Top 5th token. Logit: 24.23 Prob:  4.29% Token: | December|
Top 6th token. Logit: 24.19 Prob:  4.14% Token: | nights|
Top 7th token. Logit: 23.72 Prob:  2.59% Token: | March|
Top 8th token. Logit: 23.46 Prob:  1.98% Token: | November|
Top 9th token. Logit: 23.09 Prob:  1.38% Token: | night|
[b]Ranks of the answer tokens:[/b] [(' Thursday', 0)]


Nope, this still retains it as Thursday

In [None]:
test_prompt("A B C D", " E", extracted_model, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', 'A', ' B', ' C', ' D']
Tokenized answer: [' E']
Performance on answer token:
[b]Rank: 1        Logit: 21.58 Prob: 20.49% Token: | E|[/b]
Top 0th token. Logit: 22.75 Prob: 65.99% Token: | D|
Top 1th token. Logit: 21.58 Prob: 20.49% Token: | E|
Top 2th token. Logit: 19.23 Prob:  1.96% Token: | G|
Top 3th token. Logit: 19.04 Prob:  1.62% Token: | T|
Top 4th token. Logit: 18.83 Prob:  1.32% Token: | C|
Top 5th token. Logit: 18.76 Prob:  1.22% Token: |
|
Top 6th token. Logit: 18.35 Prob:  0.81% Token: | F|
Top 7th token. Logit: 17.81 Prob:  0.47% Token: | +|
Top 8th token. Logit: 17.57 Prob:  0.37% Token: | O|
Top 9th token. Logit: 17.56 Prob:  0.37% Token: | L|
[b]Ranks of the answer tokens:[/b] [(' E', 1)]


This also isn't saying "5".

### skip MLP9, only use attn9

In [None]:
import torch.nn as nn

class ExtractedModel(nn.Module):
    def __init__(self, original_model):
        super(ExtractedModel, self).__init__()
        self.embed = original_model.embed
        self.pos_embed = original_model.pos_embed

        self.blocks = original_model.blocks

        # self.ln1 = original_model.blocks[9].ln1
        # self.attn = original_model.blocks[9].attn
        # self.ln2 = original_model.blocks[9].ln2
        # self.mlp = original_model.blocks[9].mlp  # the MLP layer in the 9th transformer block

        self.ln_final = original_model.ln_final
        self.unembed = original_model.unembed

    def forward(self, tokens):
        # tokens [batch, position]
        embed = self.embed(tokens)
        pos_embed = self.pos_embed(tokens)
        residual = embed + pos_embed

        for block in self.blocks[0:8]:
            residual = block(residual)

        normalized_resid_pre = self.blocks[9].ln1(residual)
        attn_out = self.blocks[9].attn(normalized_resid_pre, normalized_resid_pre, normalized_resid_pre)
        residual = residual + attn_out

        # normalized_resid_mid = self.blocks[9].ln2(residual)
        # mlp_out = self.blocks[9].mlp(normalized_resid_mid)
        # residual = residual + mlp_out

        normalized_resid_final = self.ln_final(residual)
        logits = self.unembed(normalized_resid_final)
        # logits have shape [batch, position, logits]
        return logits

extracted_model = ExtractedModel(model)

In [None]:
test_prompt("1 2 3 4", " 5", extracted_model, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', '1', ' 2', ' 3', ' 4']
Tokenized answer: [' 5']
Performance on answer token:
[b]Rank: 1        Logit: 25.61 Prob: 10.24% Token: | 5|[/b]
Top 0th token. Logit: 27.76 Prob: 87.32% Token: | 4|
Top 1th token. Logit: 25.61 Prob: 10.24% Token: | 5|
Top 2th token. Logit: 23.55 Prob:  1.30% Token: | 6|
Top 3th token. Logit: 22.56 Prob:  0.48% Token: |Next|
Top 4th token. Logit: 21.60 Prob:  0.19% Token: | Next|
Top 5th token. Logit: 21.51 Prob:  0.17% Token: | ..........|
Top 6th token. Logit: 20.95 Prob:  0.10% Token: | 8|
Top 7th token. Logit: 20.63 Prob:  0.07% Token: | <!--|
Top 8th token. Logit: 20.35 Prob:  0.05% Token: | 7|
Top 9th token. Logit: 19.23 Prob:  0.02% Token: | 9|
[b]Ranks of the answer tokens:[/b] [(' 5', 1)]


As expected, attn only 9 doesn't allow it to change to 5. The MLP, as seen in vector addition, is what's impt

What's strange is that we can't directly pass digit tokens through MLP 9; it needs to be in the form after L7, and needs the layers before

### skip L9 only

Are there backups, say in L10 and L11?

In [None]:
import torch.nn as nn

class ExtractedModel(nn.Module):
    def __init__(self, original_model):
        super(ExtractedModel, self).__init__()
        self.embed = original_model.embed
        self.pos_embed = original_model.pos_embed
        self.blocks = original_model.blocks
        self.ln_final = original_model.ln_final
        self.unembed = original_model.unembed

    def forward(self, tokens):
        # tokens [batch, position]
        embed = self.embed(tokens)
        pos_embed = self.pos_embed(tokens)
        residual = embed + pos_embed

        for block in self.blocks[0:9] + self.blocks[10:12]:
            residual = block(residual)

        normalized_resid_final = self.ln_final(residual)
        logits = self.unembed(normalized_resid_final)
        # logits have shape [batch, position, logits]
        return logits

extracted_model = ExtractedModel(model)

In [None]:
test_prompt("1 2 3 4", " 5", extracted_model, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', '1', ' 2', ' 3', ' 4']
Tokenized answer: [' 5']
Performance on answer token:
[b]Rank: 1        Logit: 14.61 Prob:  8.00% Token: | 5|[/b]
Top 0th token. Logit: 16.54 Prob: 55.15% Token: | 4|
Top 1th token. Logit: 14.61 Prob:  8.00% Token: | 5|
Top 2th token. Logit: 14.43 Prob:  6.73% Token: | 1|
Top 3th token. Logit: 14.33 Prob:  6.05% Token: | 3|
Top 4th token. Logit: 13.74 Prob:  3.36% Token: | 6|
Top 5th token. Logit: 13.72 Prob:  3.29% Token: |
|
Top 6th token. Logit: 13.37 Prob:  2.32% Token: | 2|
Top 7th token. Logit: 13.08 Prob:  1.74% Token: | >|
Top 8th token. Logit: 12.19 Prob:  0.71% Token: | [|
Top 9th token. Logit: 12.00 Prob:  0.59% Token: | 8|
[b]Ranks of the answer tokens:[/b] [(' 5', 1)]


No; L9 seems required to boost up 5 a lot. It's not "converting" 4 to 5, but making 5 be vastly more predicted.

### use attn only for L0 to L7, then add MLP9

In [None]:
import torch.nn as nn

class ExtractedModel(nn.Module):
    def __init__(self, original_model):
        super(ExtractedModel, self).__init__()
        self.embed = original_model.embed
        self.pos_embed = original_model.pos_embed
        self.blocks = original_model.blocks
        self.ln_final = original_model.ln_final
        self.unembed = original_model.unembed

    def forward(self, tokens):
        # tokens [batch, position]
        embed = self.embed(tokens)
        pos_embed = self.pos_embed(tokens)
        residual = embed + pos_embed

        for block in self.blocks[0:8]:
            normalized_resid_pre = block.ln1(residual)
            attn_out = block.attn(normalized_resid_pre, normalized_resid_pre, normalized_resid_pre)
            residual = residual + attn_out

        normalized_resid_mid = self.blocks[9].ln2(residual)
        mlp_out = self.blocks[9].mlp(normalized_resid_mid)
        residual = residual + mlp_out

        normalized_resid_final = self.ln_final(residual)
        logits = self.unembed(normalized_resid_final)
        # logits have shape [batch, position, logits]
        return logits

extracted_model = ExtractedModel(model)

In [None]:
test_prompt("1 2 3 4", " 5", extracted_model, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', '1', ' 2', ' 3', ' 4']
Tokenized answer: [' 5']
Performance on answer token:
[b]Rank: 3602     Logit:  6.59 Prob:  0.00% Token: | 5|[/b]
Top 0th token. Logit: 19.24 Prob: 29.06% Token: | ›|
Top 1th token. Logit: 19.06 Prob: 24.33% Token: | Posted|
Top 2th token. Logit: 17.82 Prob:  6.99% Token: | 2018|
Top 3th token. Logit: 17.76 Prob:  6.60% Token: |!|
Top 4th token. Logit: 17.46 Prob:  4.90% Token: | !|
Top 5th token. Logit: 16.97 Prob:  2.99% Token: | 2017|
Top 6th token. Logit: 16.40 Prob:  1.70% Token: | 2015|
Top 7th token. Logit: 16.25 Prob:  1.46% Token: | dash|
Top 8th token. Logit: 16.22 Prob:  1.43% Token: | Xiaomi|
Top 9th token. Logit: 15.71 Prob:  0.85% Token: | Profile|
[b]Ranks of the answer tokens:[/b] [(' 5', 3602)]


So they're not just moving info; MLPs are required

### skip L4 only (for L0 to 9)

In [None]:
import torch.nn as nn

class ExtractedModel(nn.Module):
    def __init__(self, original_model):
        super(ExtractedModel, self).__init__()
        self.embed = original_model.embed
        self.pos_embed = original_model.pos_embed
        self.blocks = original_model.blocks
        self.ln_final = original_model.ln_final
        self.unembed = original_model.unembed

    def forward(self, tokens):
        # tokens [batch, position]
        embed = self.embed(tokens)
        pos_embed = self.pos_embed(tokens)
        residual = embed + pos_embed

        for block in self.blocks[0:4] + self.blocks[5:10]:
            residual = block(residual)

        normalized_resid_final = self.ln_final(residual)
        logits = self.unembed(normalized_resid_final)
        # logits have shape [batch, position, logits]
        return logits

extracted_model = ExtractedModel(model)

In [None]:
test_prompt("1 2 3 4", " 5", extracted_model, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', '1', ' 2', ' 3', ' 4']
Tokenized answer: [' 5']
Performance on answer token:
[b]Rank: 0        Logit: 34.78 Prob: 99.36% Token: | 5|[/b]
Top 0th token. Logit: 34.78 Prob: 99.36% Token: | 5|
Top 1th token. Logit: 29.57 Prob:  0.54% Token: | 4|
Top 2th token. Logit: 27.22 Prob:  0.05% Token: | 3|
Top 3th token. Logit: 26.65 Prob:  0.03% Token: | 6|
Top 4th token. Logit: 25.88 Prob:  0.01% Token: | 7|
Top 5th token. Logit: 24.49 Prob:  0.00% Token: | 9|
Top 6th token. Logit: 23.41 Prob:  0.00% Token: | 1|
Top 7th token. Logit: 22.98 Prob:  0.00% Token: | 0|
Top 8th token. Logit: 22.79 Prob:  0.00% Token: | 8|
Top 9th token. Logit: 22.22 Prob:  0.00% Token: | 10|
[b]Ranks of the answer tokens:[/b] [(' 5', 0)]


Wait, L4 isn't required either

In [None]:
test_prompt("one two three four", " five", extracted_model, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', 'one', ' two', ' three', ' four']
Tokenized answer: [' five']
Performance on answer token:
[b]Rank: 1        Logit: 24.22 Prob: 27.46% Token: | five|[/b]
Top 0th token. Logit: 24.82 Prob: 49.85% Token: | four|
Top 1th token. Logit: 24.22 Prob: 27.46% Token: | five|
Top 2th token. Logit: 22.50 Prob:  4.90% Token: | six|
Top 3th token. Logit: 22.33 Prob:  4.14% Token: | seven|
Top 4th token. Logit: 22.23 Prob:  3.72% Token: | Four|
Top 5th token. Logit: 21.99 Prob:  2.93% Token: |four|
Top 6th token. Logit: 21.92 Prob:  2.75% Token: | three|
Top 7th token. Logit: 21.40 Prob:  1.63% Token: | eight|
Top 8th token. Logit: 20.01 Prob:  0.41% Token: |teen|
Top 9th token. Logit: 19.90 Prob:  0.36% Token: | nine|
[b]Ranks of the answer tokens:[/b] [(' five', 1)]


But it is for number words, months, etc except alphabets (which is aided by MLP8)

In all cases, however, it is MLP9 which adds to it

In [None]:
test_prompt("January February March April", " May", extracted_model, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', 'January', ' February', ' March', ' April']
Tokenized answer: [' May']
Performance on answer token:
[b]Rank: 1        Logit: 34.11 Prob: 17.75% Token: | May|[/b]
Top 0th token. Logit: 35.59 Prob: 77.57% Token: | April|
Top 1th token. Logit: 34.11 Prob: 17.75% Token: | May|
Top 2th token. Logit: 32.01 Prob:  2.17% Token: | June|
Top 3th token. Logit: 31.39 Prob:  1.16% Token: | March|
Top 4th token. Logit: 30.53 Prob:  0.49% Token: | September|
Top 5th token. Logit: 29.95 Prob:  0.27% Token: | July|
Top 6th token. Logit: 29.15 Prob:  0.12% Token: | October|
Top 7th token. Logit: 29.08 Prob:  0.12% Token: | February|
Top 8th token. Logit: 28.96 Prob:  0.10% Token: | August|
Top 9th token. Logit: 28.77 Prob:  0.08% Token: | November|
[b]Ranks of the answer tokens:[/b] [(' May', 1)]


In [None]:
test_prompt("Sunday Monday Tuesday Wednesday", " Thursday", extracted_model, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', 'Sunday', ' Monday', ' Tuesday', ' Wednesday']
Tokenized answer: [' Thursday']
Performance on answer token:
[b]Rank: 18       Logit: 27.14 Prob:  0.77% Token: | Thursday|[/b]
Top 0th token. Logit: 30.78 Prob: 29.15% Token: | morning|
Top 1th token. Logit: 29.64 Prob:  9.36% Token: | July|
Top 2th token. Logit: 29.49 Prob:  8.05% Token: | February|
Top 3th token. Logit: 29.27 Prob:  6.44% Token: | afternoon|
Top 4th token. Logit: 29.25 Prob:  6.32% Token: | September|
Top 5th token. Logit: 29.22 Prob:  6.18% Token: | June|
Top 6th token. Logit: 28.79 Prob:  4.00% Token: | January|
Top 7th token. Logit: 28.79 Prob:  3.99% Token: | evening|
Top 8th token. Logit: 28.73 Prob:  3.76% Token: | November|
Top 9th token. Logit: 28.68 Prob:  3.60% Token: | October|
[b]Ranks of the answer tokens:[/b] [(' Thursday', 18)]


In [None]:
test_prompt("A B C D", " E", extracted_model, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', 'A', ' B', ' C', ' D']
Tokenized answer: [' E']
Performance on answer token:
[b]Rank: 0        Logit: 26.32 Prob: 97.62% Token: | E|[/b]
Top 0th token. Logit: 26.32 Prob: 97.62% Token: | E|
Top 1th token. Logit: 21.87 Prob:  1.14% Token: | D|
Top 2th token. Logit: 21.18 Prob:  0.57% Token: | F|
Top 3th token. Logit: 19.71 Prob:  0.13% Token: | G|
Top 4th token. Logit: 19.45 Prob:  0.10% Token: | e|
Top 5th token. Logit: 19.21 Prob:  0.08% Token: | +|
Top 6th token. Logit: 18.56 Prob:  0.04% Token: |
|
Top 7th token. Logit: 18.16 Prob:  0.03% Token: | H|
Top 8th token. Logit: 17.73 Prob:  0.02% Token: |+|
Top 9th token. Logit: 17.58 Prob:  0.02% Token: |ensity|
[b]Ranks of the answer tokens:[/b] [(' E', 0)]


### skip L4, L8, L10 and L11

In [None]:
import torch.nn as nn

class ExtractedModel(nn.Module):
    def __init__(self, original_model):
        super(ExtractedModel, self).__init__()
        self.embed = original_model.embed
        self.pos_embed = original_model.pos_embed
        self.blocks = original_model.blocks
        self.ln_final = original_model.ln_final
        self.unembed = original_model.unembed

    def forward(self, tokens):
        # tokens [batch, position]
        embed = self.embed(tokens)
        pos_embed = self.pos_embed(tokens)
        residual = embed + pos_embed

        for block in self.blocks[0:4] + self.blocks[5:8] + self.blocks[9:10]:
            residual = block(residual)

        normalized_resid_final = self.ln_final(residual)
        logits = self.unembed(normalized_resid_final)
        # logits have shape [batch, position, logits]
        return logits

extracted_model = ExtractedModel(model)

In [None]:
test_prompt("1 2 3 4", " 5", extracted_model, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', '1', ' 2', ' 3', ' 4']
Tokenized answer: [' 5']
Performance on answer token:
[b]Rank: 0        Logit: 36.43 Prob: 99.91% Token: | 5|[/b]
Top 0th token. Logit: 36.43 Prob: 99.91% Token: | 5|
Top 1th token. Logit: 28.80 Prob:  0.05% Token: | 4|
Top 2th token. Logit: 28.28 Prob:  0.03% Token: | 6|
Top 3th token. Logit: 26.55 Prob:  0.01% Token: | 7|
Top 4th token. Logit: 26.43 Prob:  0.00% Token: | 3|
Top 5th token. Logit: 24.44 Prob:  0.00% Token: | 1|
Top 6th token. Logit: 23.65 Prob:  0.00% Token: | 9|
Top 7th token. Logit: 23.44 Prob:  0.00% Token: | 0|
Top 8th token. Logit: 23.28 Prob:  0.00% Token: | 10|
Top 9th token. Logit: 22.61 Prob:  0.00% Token: | 8|
[b]Ranks of the answer tokens:[/b] [(' 5', 0)]


In [None]:
my_list = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
my_list = [str(i) for i in my_list]
indices_to_remove = [4, 8, 10]
indices_to_remove.sort(reverse=True)  # Start removing from the end to avoid index shifting

for index in indices_to_remove:
    my_list.pop(index)
my_list

['0', '1', '2', '3', '5', '6', '7', '9', '11']

In [None]:
import copy
model_copy = copy.deepcopy(model)
indices_to_remove = [4, 8, 10, 11]
indices_to_remove.sort(reverse=True)  # Start removing from the end to avoid index shifting
for index in indices_to_remove:
    model_copy.blocks.pop(index)
model_copy.blocks

In [None]:
model.blocks  # check that copy didnt destroy original by ref

ModuleList(
  (0-11): 12 x TransformerBlock(
    (ln1): LayerNormPre(
      (hook_scale): HookPoint()
      (hook_normalized): HookPoint()
    )
    (ln2): LayerNormPre(
      (hook_scale): HookPoint()
      (hook_normalized): HookPoint()
    )
    (attn): Attention(
      (hook_k): HookPoint()
      (hook_q): HookPoint()
      (hook_v): HookPoint()
      (hook_z): HookPoint()
      (hook_attn_scores): HookPoint()
      (hook_pattern): HookPoint()
      (hook_result): HookPoint()
    )
    (mlp): MLP(
      (hook_pre): HookPoint()
      (hook_post): HookPoint()
    )
    (hook_q_input): HookPoint()
    (hook_k_input): HookPoint()
    (hook_v_input): HookPoint()
    (hook_attn_out): HookPoint()
    (hook_mlp_in): HookPoint()
    (hook_mlp_out): HookPoint()
    (hook_resid_pre): HookPoint()
    (hook_resid_mid): HookPoint()
    (hook_resid_post): HookPoint()
  )
)

In [None]:
import torch.nn as nn

class ExtractedModel(nn.Module):
    def __init__(self, original_model_before_copy):
        super(ExtractedModel, self).__init__()

        # don't do this every time run forward()
        original_model = copy.deepcopy(original_model_before_copy)
        indices_to_remove = [4, 8, 10, 11]
        indices_to_remove.sort(reverse=True)  # Start removing from the end to avoid index shifting
        for index in indices_to_remove:
            original_model.blocks.pop(index)
        self.blocks = original_model.blocks

        self.embed = original_model.embed
        self.pos_embed = original_model.pos_embed
        self.blocks = original_model.blocks
        self.ln_final = original_model.ln_final
        self.unembed = original_model.unembed

    def forward(self, tokens):
        # tokens [batch, position]
        embed = self.embed(tokens)
        pos_embed = self.pos_embed(tokens)
        residual = embed + pos_embed

        for block in self.blocks:
            residual = block(residual)

        normalized_resid_final = self.ln_final(residual)
        logits = self.unembed(normalized_resid_final)
        # logits have shape [batch, position, logits]
        return logits

extracted_model = ExtractedModel(model)

In [None]:
test_prompt("1 2 3 4", " 5", extracted_model, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', '1', ' 2', ' 3', ' 4']
Tokenized answer: [' 5']
Performance on answer token:
[b]Rank: 0        Logit: 36.43 Prob: 99.91% Token: | 5|[/b]
Top 0th token. Logit: 36.43 Prob: 99.91% Token: | 5|
Top 1th token. Logit: 28.80 Prob:  0.05% Token: | 4|
Top 2th token. Logit: 28.28 Prob:  0.03% Token: | 6|
Top 3th token. Logit: 26.55 Prob:  0.01% Token: | 7|
Top 4th token. Logit: 26.43 Prob:  0.00% Token: | 3|
Top 5th token. Logit: 24.44 Prob:  0.00% Token: | 1|
Top 6th token. Logit: 23.65 Prob:  0.00% Token: | 9|
Top 7th token. Logit: 23.44 Prob:  0.00% Token: | 0|
Top 8th token. Logit: 23.28 Prob:  0.00% Token: | 10|
Top 9th token. Logit: 22.61 Prob:  0.00% Token: | 8|
[b]Ranks of the answer tokens:[/b] [(' 5', 0)]


This is one of the best ones yet

In [None]:
test_prompt("January February March April", " May", extracted_model, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', 'January', ' February', ' March', ' April']
Tokenized answer: [' May']
Performance on answer token:
[b]Rank: 3        Logit: 33.98 Prob:  0.70% Token: | May|[/b]
Top 0th token. Logit: 38.90 Prob: 95.94% Token: | April|
Top 1th token. Logit: 34.94 Prob:  1.82% Token: | March|
Top 2th token. Logit: 34.36 Prob:  1.02% Token: | February|
Top 3th token. Logit: 33.98 Prob:  0.70% Token: | May|
Top 4th token. Logit: 33.12 Prob:  0.30% Token: | June|
Top 5th token. Logit: 31.34 Prob:  0.05% Token: | July|
Top 6th token. Logit: 31.29 Prob:  0.05% Token: | September|
Top 7th token. Logit: 30.83 Prob:  0.03% Token: | January|
Top 8th token. Logit: 30.72 Prob:  0.03% Token: | October|
Top 9th token. Logit: 30.54 Prob:  0.02% Token: | August|
[b]Ranks of the answer tokens:[/b] [(' May', 3)]


But months suffers!

### skip layers 3, 4, 8, 10, 11

In [None]:
import torch.nn as nn

class ExtractedModel(nn.Module):
    def __init__(self, original_model_before_copy):
        super(ExtractedModel, self).__init__()

        original_model = copy.deepcopy(original_model_before_copy)

        self.embed = original_model.embed
        self.pos_embed = original_model.pos_embed
        self.blocks = original_model.blocks
        self.ln_final = original_model.ln_final
        self.unembed = original_model.unembed

    def forward(self, tokens):
        # tokens [batch, position]
        embed = self.embed(tokens)
        pos_embed = self.pos_embed(tokens)
        residual = embed + pos_embed

        indices_to_remove = [3, 4, 8, 10, 11]
        indices_to_remove.sort(reverse=True)  # Start removing from the end to avoid index shifting
        for index in indices_to_remove:
            self.blocks.pop(index)

        for block in self.blocks:
            residual = block(residual)

        normalized_resid_final = self.ln_final(residual)
        logits = self.unembed(normalized_resid_final)
        # logits have shape [batch, position, logits]
        return logits

extracted_model = ExtractedModel(model)

In [None]:
test_prompt("1 2 3 4", " 5", extracted_model, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', '1', ' 2', ' 3', ' 4']
Tokenized answer: [' 5']
Performance on answer token:
[b]Rank: 0        Logit: 31.95 Prob: 92.24% Token: | 5|[/b]
Top 0th token. Logit: 31.95 Prob: 92.24% Token: | 5|
Top 1th token. Logit: 29.11 Prob:  5.40% Token: | 4|
Top 2th token. Logit: 27.91 Prob:  1.62% Token: | 3|
Top 3th token. Logit: 26.60 Prob:  0.44% Token: | 6|
Top 4th token. Logit: 25.03 Prob:  0.09% Token: | 1|
Top 5th token. Logit: 24.83 Prob:  0.08% Token: | 7|
Top 6th token. Logit: 24.79 Prob:  0.07% Token: | 0|
Top 7th token. Logit: 24.17 Prob:  0.04% Token: | 2|
Top 8th token. Logit: 22.23 Prob:  0.01% Token: | 9|
Top 9th token. Logit: 21.58 Prob:  0.00% Token: | 8|
[b]Ranks of the answer tokens:[/b] [(' 5', 0)]


L3 helps, but not required

In [None]:
test_prompt("January February March April", " May", extracted_model, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', 'January', ' February', ' March', ' April']
Tokenized answer: [' May']
Performance on answer token:
[b]Rank: 16       Logit: 23.40 Prob:  0.42% Token: | May|[/b]
Top 0th token. Logit: 27.93 Prob: 38.70% Token: | 2018|
Top 1th token. Logit: 27.28 Prob: 20.14% Token: | 2017|
Top 2th token. Logit: 26.56 Prob:  9.86% Token: | 2015|
Top 3th token. Logit: 26.15 Prob:  6.54% Token: | 2016|
Top 4th token. Logit: 26.00 Prob:  5.61% Token: | 2014|
Top 5th token. Logit: 25.67 Prob:  4.04% Token: | April|
Top 6th token. Logit: 25.25 Prob:  2.64% Token: | 29|
Top 7th token. Logit: 25.07 Prob:  2.22% Token: | 2013|
Top 8th token. Logit: 24.59 Prob:  1.37% Token: | 28|
Top 9th token. Logit: 23.90 Prob:  0.69% Token: | 25|
[b]Ranks of the answer tokens:[/b] [(' May', 16)]


Wait, L3 is required for months!

### skip 2, 3, 4, 8, 10 11

Keeps only: 0, 1, 5, 6, 7, 9

In [None]:
import torch.nn as nn

class ExtractedModel(nn.Module):
    def __init__(self, original_model_before_copy):
        super(ExtractedModel, self).__init__()

        original_model = copy.deepcopy(original_model_before_copy)

        self.embed = original_model.embed
        self.pos_embed = original_model.pos_embed
        self.blocks = original_model.blocks
        self.ln_final = original_model.ln_final
        self.unembed = original_model.unembed

    def forward(self, tokens):
        # tokens [batch, position]
        embed = self.embed(tokens)
        pos_embed = self.pos_embed(tokens)
        residual = embed + pos_embed

        indices_to_remove = [2, 3, 4, 8, 10, 11]
        indices_to_remove.sort(reverse=True)  # Start removing from the end to avoid index shifting
        for index in indices_to_remove:
            self.blocks.pop(index)

        for block in self.blocks:
            residual = block(residual)

        normalized_resid_final = self.ln_final(residual)
        logits = self.unembed(normalized_resid_final)
        # logits have shape [batch, position, logits]
        return logits

extracted_model = ExtractedModel(model)

In [None]:
test_prompt("1 2 3 4", " 5", extracted_model, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', '1', ' 2', ' 3', ' 4']
Tokenized answer: [' 5']
Performance on answer token:
[b]Rank: 1        Logit: 19.54 Prob: 22.06% Token: | 5|[/b]
Top 0th token. Logit: 19.56 Prob: 22.46% Token: | 3|
Top 1th token. Logit: 19.54 Prob: 22.06% Token: | 5|
Top 2th token. Logit: 19.03 Prob: 13.21% Token: | 4|
Top 3th token. Logit: 19.01 Prob: 12.96% Token: | 1|
Top 4th token. Logit: 18.52 Prob:  7.95% Token: | 2|
Top 5th token. Logit: 18.13 Prob:  5.39% Token: | 0|
Top 6th token. Logit: 17.16 Prob:  2.04% Token: | 6|
Top 7th token. Logit: 17.07 Prob:  1.87% Token: | 7|
Top 8th token. Logit: 16.61 Prob:  1.18% Token: | 9|
Top 9th token. Logit: 16.41 Prob:  0.97% Token: |
|
[b]Ranks of the answer tokens:[/b] [(' 5', 1)]


2 is required

### skip 3, 4, 6, 8, 10 11

Keeps only: 0, 1, 2, 5, 7, 9

In [None]:
import torch.nn as nn

class ExtractedModel(nn.Module):
    def __init__(self, original_model_before_copy):
        super(ExtractedModel, self).__init__()

        original_model = copy.deepcopy(original_model_before_copy)

        self.embed = original_model.embed
        self.pos_embed = original_model.pos_embed
        self.blocks = original_model.blocks
        self.ln_final = original_model.ln_final
        self.unembed = original_model.unembed

    def forward(self, tokens):
        # tokens [batch, position]
        embed = self.embed(tokens)
        pos_embed = self.pos_embed(tokens)
        residual = embed + pos_embed

        indices_to_remove = [6, 3, 4, 8, 10, 11]
        indices_to_remove.sort(reverse=True)  # Start removing from the end to avoid index shifting
        for index in indices_to_remove:
            self.blocks.pop(index)

        for block in self.blocks:
            residual = block(residual)

        normalized_resid_final = self.ln_final(residual)
        logits = self.unembed(normalized_resid_final)
        # logits have shape [batch, position, logits]
        return logits

extracted_model = ExtractedModel(model)

In [None]:
test_prompt("1 2 3 4", " 5", extracted_model, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', '1', ' 2', ' 3', ' 4']
Tokenized answer: [' 5']
Performance on answer token:
[b]Rank: 1        Logit: 19.54 Prob: 22.06% Token: | 5|[/b]
Top 0th token. Logit: 19.56 Prob: 22.46% Token: | 3|
Top 1th token. Logit: 19.54 Prob: 22.06% Token: | 5|
Top 2th token. Logit: 19.03 Prob: 13.21% Token: | 4|
Top 3th token. Logit: 19.01 Prob: 12.96% Token: | 1|
Top 4th token. Logit: 18.52 Prob:  7.95% Token: | 2|
Top 5th token. Logit: 18.13 Prob:  5.39% Token: | 0|
Top 6th token. Logit: 17.16 Prob:  2.04% Token: | 6|
Top 7th token. Logit: 17.07 Prob:  1.87% Token: | 7|
Top 8th token. Logit: 16.61 Prob:  1.18% Token: | 9|
Top 9th token. Logit: 16.41 Prob:  0.97% Token: |
|
[b]Ranks of the answer tokens:[/b] [(' 5', 1)]


6 is required

In [None]:
test_prompt("January February March April", " May", extracted_model, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', 'January', ' February', ' March', ' April']
Tokenized answer: [' May']
Performance on answer token:
[b]Rank: 22       Logit: 20.30 Prob:  0.31% Token: | May|[/b]
Top 0th token. Logit: 24.86 Prob: 29.39% Token: | 2017|
Top 1th token. Logit: 24.40 Prob: 18.61% Token: | 2018|
Top 2th token. Logit: 23.50 Prob:  7.56% Token: | 29|
Top 3th token. Logit: 23.35 Prob:  6.45% Token: | 2014|
Top 4th token. Logit: 23.30 Prob:  6.14% Token: | 2015|
Top 5th token. Logit: 23.24 Prob:  5.82% Token: | 2016|
Top 6th token. Logit: 22.98 Prob:  4.49% Token: | 28|
Top 7th token. Logit: 22.90 Prob:  4.15% Token: | 2013|
Top 8th token. Logit: 22.17 Prob:  1.99% Token: | 27|
Top 9th token. Logit: 22.02 Prob:  1.71% Token: | 23|
[b]Ranks of the answer tokens:[/b] [(' May', 22)]


### Automate this

Clearly we can just check what it does with logits

### skip layer norms

In [None]:
import torch.nn as nn

class ExtractedModel(nn.Module):
    def __init__(self, original_model):
        super(ExtractedModel, self).__init__()
        self.embed = original_model.embed
        self.pos_embed = original_model.pos_embed

        self.blocks = original_model.blocks

        self.ln_final = original_model.ln_final
        self.unembed = original_model.unembed

    def forward(self, tokens):
        # tokens [batch, position]
        embed = self.embed(tokens)
        pos_embed = self.pos_embed(tokens)
        residual = embed + pos_embed

        for block in self.blocks[0:8]:
            residual = block(residual)

        # normalized_resid_pre = self.ln1(residual)
        # attn_out = self.attn(normalized_resid_pre, normalized_resid_pre, normalized_resid_pre)
        # residual = residual + attn_out

        # normalized_resid_mid = self.blocks[9].ln2(residual)
        mlp_out = self.blocks[9].mlp(residual)
        residual = residual + mlp_out

        normalized_resid_final = self.ln_final(residual)
        logits = self.unembed(normalized_resid_final)
        # logits have shape [batch, position, logits]
        return logits

extracted_model = ExtractedModel(model)

In [None]:
test_prompt("1 2 3 4", " 5", extracted_model, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', '1', ' 2', ' 3', ' 4']
Tokenized answer: [' 5']
Performance on answer token:
[b]Rank: 0        Logit: 28.18 Prob: 97.27% Token: | 5|[/b]
Top 0th token. Logit: 28.18 Prob: 97.27% Token: | 5|
Top 1th token. Logit: 24.15 Prob:  1.72% Token: | Five|
Top 2th token. Logit: 21.29 Prob:  0.10% Token: | five|
Top 3th token. Logit: 20.99 Prob:  0.07% Token: | 50|
Top 4th token. Logit: 20.90 Prob:  0.07% Token: | +|
Top 5th token. Logit: 20.82 Prob:  0.06% Token: | Player|
Top 6th token. Logit: 20.67 Prob:  0.05% Token: | >|
Top 7th token. Logit: 20.65 Prob:  0.05% Token: | Copyright|
Top 8th token. Logit: 20.43 Prob:  0.04% Token: |.|
Top 9th token. Logit: 20.14 Prob:  0.03% Token: | 7|
[b]Ranks of the answer tokens:[/b] [(' 5', 0)]


MLP9's ln not req either, but it helps. likely for more complex inputs. A strange thing to note is that "Five" and "five" are outputted in second place, instead of digits like 4. Still, not by much. This may be an anomaly from this input, so check again.

This is the first time we see this result. We saw 'five' arise when deleting L8, but not to this ranking.

Another thing is that 'next' and '>' and '<' are common predictions. In fact, in the full model, Next is predicted as 2nd place. INdications are that this has something to do with it

In [None]:
test_prompt("5 6 7 8", " 9", extracted_model, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', '5', ' 6', ' 7', ' 8']
Tokenized answer: [' 9']
Performance on answer token:
[b]Rank: 0        Logit: 24.66 Prob: 56.31% Token: | 9|[/b]
Top 0th token. Logit: 24.66 Prob: 56.31% Token: | 9|
Top 1th token. Logit: 23.71 Prob: 21.69% Token: | Copyright|
Top 2th token. Logit: 22.11 Prob:  4.40% Token: | >|
Top 3th token. Logit: 21.72 Prob:  2.97% Token: | 5|
Top 4th token. Logit: 21.63 Prob:  2.73% Token: | 1|
Top 5th token. Logit: 21.21 Prob:  1.78% Token: | <|
Top 6th token. Logit: 20.74 Prob:  1.12% Token: | 3|
Top 7th token. Logit: 20.28 Prob:  0.70% Token: | 13|
Top 8th token. Logit: 20.07 Prob:  0.57% Token: |.|
Top 9th token. Logit: 19.82 Prob:  0.44% Token: | 7|
[b]Ranks of the answer tokens:[/b] [(' 9', 0)]


Indeed, 'nine' is not predicted as a ranking, so we can say the 'five' is a fluke and not consistent.. It's also not as high as a logit. So the full model with the missing layers is doing something for other inputs.

### use MLPs only for L0 to L9

In [None]:
import torch.nn as nn

class ExtractedModel(nn.Module):
    def __init__(self, original_model):
        super(ExtractedModel, self).__init__()
        self.embed = original_model.embed
        self.pos_embed = original_model.pos_embed
        self.blocks = original_model.blocks
        self.ln_final = original_model.ln_final
        self.unembed = original_model.unembed

    def forward(self, tokens):
        # tokens [batch, position]
        embed = self.embed(tokens)
        pos_embed = self.pos_embed(tokens)
        residual = embed + pos_embed

        for block in self.blocks[0:8]:
            # normalized_resid_pre = block.ln1(residual)
            # attn_out = block.attn(normalized_resid_pre, normalized_resid_pre, normalized_resid_pre)
            # residual = residual + attn_out

            normalized_resid_mid = block.ln2(residual)
            mlp_out = block.mlp(normalized_resid_mid)
            residual = residual + mlp_out

        normalized_resid_mid = self.blocks[9].ln2(residual)
        mlp_out = self.blocks[9].mlp(normalized_resid_mid)
        residual = residual + mlp_out

        normalized_resid_final = self.ln_final(residual)
        logits = self.unembed(normalized_resid_final)
        # logits have shape [batch, position, logits]
        return logits

extracted_model = ExtractedModel(model)

In [None]:
test_prompt("1 2 3 4", " 5", extracted_model, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', '1', ' 2', ' 3', ' 4']
Tokenized answer: [' 5']
Performance on answer token:
[b]Rank: 89       Logit:  5.75 Prob:  0.10% Token: | 5|[/b]
Top 0th token. Logit: 10.26 Prob:  9.10% Token: |th|
Top 1th token. Logit:  9.61 Prob:  4.72% Token: |-|
Top 2th token. Logit:  9.39 Prob:  3.79% Token: |.|
Top 3th token. Logit:  8.90 Prob:  2.34% Token: |,|
Top 4th token. Logit:  7.92 Prob:  0.87% Token: | and|
Top 5th token. Logit:  7.74 Prob:  0.73% Token: | of|
Top 6th token. Logit:  7.65 Prob:  0.67% Token: |
|
Top 7th token. Logit:  7.64 Prob:  0.66% Token: |:|
Top 8th token. Logit:  7.64 Prob:  0.66% Token: | the|
Top 9th token. Logit:  7.60 Prob:  0.63% Token: |x|
[b]Ranks of the answer tokens:[/b] [(' 5', 89)]


Ok so you need attention. From which layers?

### skip MLP in L8, only use attn8. good for alphabet?

In [None]:
import torch.nn as nn

class ExtractedModel(nn.Module):
    def __init__(self, original_model):
        super(ExtractedModel, self).__init__()
        self.embed = original_model.embed
        self.pos_embed = original_model.pos_embed

        self.blocks = original_model.blocks

        self.ln_final = original_model.ln_final
        self.unembed = original_model.unembed

    def forward(self, tokens):
        # tokens [batch, position]
        embed = self.embed(tokens)
        pos_embed = self.pos_embed(tokens)
        residual = embed + pos_embed

        for block in self.blocks[0:8]:
            residual = block(residual)

        normalized_resid_pre = self.blocks[8].ln1(residual)
        attn_out = self.blocks[8].attn(normalized_resid_pre, normalized_resid_pre, normalized_resid_pre)
        residual = residual + attn_out

        normalized_resid_mid = self.blocks[9].ln2(residual)
        mlp_out = self.blocks[9].mlp(normalized_resid_mid)
        residual = residual + mlp_out

        normalized_resid_final = self.ln_final(residual)
        logits = self.unembed(normalized_resid_final)
        # logits have shape [batch, position, logits]
        return logits

extracted_model = ExtractedModel(model)

In [None]:
test_prompt("1 2 3 4", " 5", extracted_model, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', '1', ' 2', ' 3', ' 4']
Tokenized answer: [' 5']
Performance on answer token:
[b]Rank: 0        Logit: 30.49 Prob: 99.69% Token: | 5|[/b]
Top 0th token. Logit: 30.49 Prob: 99.69% Token: | 5|
Top 1th token. Logit: 24.27 Prob:  0.20% Token: | 4|
Top 2th token. Logit: 22.57 Prob:  0.04% Token: | <!--|
Top 3th token. Logit: 21.64 Prob:  0.01% Token: | Next|
Top 4th token. Logit: 21.01 Prob:  0.01% Token: | 7|
Top 5th token. Logit: 20.97 Prob:  0.01% Token: | 3|
Top 6th token. Logit: 20.68 Prob:  0.01% Token: |★|
Top 7th token. Logit: 20.47 Prob:  0.00% Token: | 6|
Top 8th token. Logit: 20.34 Prob:  0.00% Token: |
|
Top 9th token. Logit: 20.22 Prob:  0.00% Token: | ★|
[b]Ranks of the answer tokens:[/b] [(' 5', 0)]


In [None]:
test_prompt("January February March April", " May", extracted_model, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', 'January', ' February', ' March', ' April']
Tokenized answer: [' May']
Performance on answer token:
[b]Rank: 0        Logit: 24.90 Prob: 56.25% Token: | May|[/b]
Top 0th token. Logit: 24.90 Prob: 56.25% Token: | May|
Top 1th token. Logit: 23.23 Prob: 10.53% Token: | 2015|
Top 2th token. Logit: 22.90 Prob:  7.60% Token: | 5|
Top 3th token. Logit: 22.09 Prob:  3.36% Token: | 2005|
Top 4th token. Logit: 21.98 Prob:  3.02% Token: | April|
Top 5th token. Logit: 21.40 Prob:  1.69% Token: | 2017|
Top 6th token. Logit: 21.23 Prob:  1.43% Token: | Top|
Top 7th token. Logit: 21.01 Prob:  1.14% Token: | 25|
Top 8th token. Logit: 21.00 Prob:  1.13% Token: | 2018|
Top 9th token. Logit: 20.83 Prob:  0.96% Token: |
|
[b]Ranks of the answer tokens:[/b] [(' May', 0)]


In [None]:
test_prompt("Sunday Monday Tuesday Wednesday", " Thursday", extracted_model, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', 'Sunday', ' Monday', ' Tuesday', ' Wednesday']
Tokenized answer: [' Thursday']
Performance on answer token:
[b]Rank: 0        Logit: 28.79 Prob: 52.66% Token: | Thursday|[/b]
Top 0th token. Logit: 28.79 Prob: 52.66% Token: | Thursday|
Top 1th token. Logit: 28.60 Prob: 43.65% Token: | Wednesday|
Top 2th token. Logit: 24.50 Prob:  0.73% Token: | September|
Top 3th token. Logit: 23.76 Prob:  0.34% Token: | Thurs|
Top 4th token. Logit: 23.47 Prob:  0.26% Token: | Saturday|
Top 5th token. Logit: 23.46 Prob:  0.26% Token: | March|
Top 6th token. Logit: 23.35 Prob:  0.23% Token: | Friday|
Top 7th token. Logit: 23.14 Prob:  0.19% Token: | 29|
Top 8th token. Logit: 23.05 Prob:  0.17% Token: | January|
Top 9th token. Logit: 22.86 Prob:  0.14% Token: | December|
[b]Ranks of the answer tokens:[/b] [(' Thursday', 0)]


In [None]:
test_prompt("A B C D", " E", extracted_model, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', 'A', ' B', ' C', ' D']
Tokenized answer: [' E']
Performance on answer token:
[b]Rank: 1        Logit: 26.56 Prob:  8.56% Token: | E|[/b]
Top 0th token. Logit: 28.92 Prob: 90.96% Token: | D|
Top 1th token. Logit: 26.56 Prob:  8.56% Token: | E|
Top 2th token. Logit: 23.11 Prob:  0.27% Token: | C|
Top 3th token. Logit: 21.61 Prob:  0.06% Token: | d|
Top 4th token. Logit: 21.30 Prob:  0.04% Token: | G|
Top 5th token. Logit: 21.00 Prob:  0.03% Token: | F|
Top 6th token. Logit: 19.69 Prob:  0.01% Token: | Close|
Top 7th token. Logit: 19.35 Prob:  0.01% Token: | T|
Top 8th token. Logit: 19.16 Prob:  0.01% Token: | >>|
Top 9th token. Logit: 18.92 Prob:  0.00% Token: |
|
[b]Ranks of the answer tokens:[/b] [(' E', 1)]


No, attn 8 is not good for alphabet.

In [None]:
test_prompt("one two three four", " five", extracted_model, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', 'one', ' two', ' three', ' four']
Tokenized answer: [' five']
Performance on answer token:
[b]Rank: 2        Logit: 21.01 Prob:  8.45% Token: | five|[/b]
Top 0th token. Logit: 23.02 Prob: 63.23% Token: |teen|
Top 1th token. Logit: 21.21 Prob: 10.33% Token: |ths|
Top 2th token. Logit: 21.01 Prob:  8.45% Token: | five|
Top 3th token. Logit: 20.22 Prob:  3.85% Token: | 5|
Top 4th token. Logit: 19.89 Prob:  2.75% Token: | Five|
Top 5th token. Logit: 19.49 Prob:  1.86% Token: | seven|
Top 6th token. Logit: 19.06 Prob:  1.20% Token: | Four|
Top 7th token. Logit: 18.95 Prob:  1.07% Token: | four|
Top 8th token. Logit: 18.72 Prob:  0.85% Token: |teenth|
Top 9th token. Logit: 18.64 Prob:  0.79% Token: | fifth|
[b]Ranks of the answer tokens:[/b] [(' five', 2)]


attn 8 is not enough for number words

### skip attn in L8, only use MLP8. good for alphabet?

In [None]:
import torch.nn as nn

class ExtractedModel(nn.Module):
    def __init__(self, original_model):
        super(ExtractedModel, self).__init__()
        self.embed = original_model.embed
        self.pos_embed = original_model.pos_embed

        self.blocks = original_model.blocks

        self.ln_final = original_model.ln_final
        self.unembed = original_model.unembed

    def forward(self, tokens):
        # tokens [batch, position]
        embed = self.embed(tokens)
        pos_embed = self.pos_embed(tokens)
        residual = embed + pos_embed

        for block in self.blocks[0:8]:
            residual = block(residual)

        # normalized_resid_pre = self.blocks[8].ln1(residual)
        # attn_out = self.blocks[8].attn(normalized_resid_pre, normalized_resid_pre, normalized_resid_pre)
        # residual = residual + attn_out

        normalized_resid_mid = self.blocks[8].ln2(residual)
        mlp_out = self.blocks[8].mlp(normalized_resid_mid)
        residual = residual + mlp_out

        normalized_resid_mid = self.blocks[9].ln2(residual)
        mlp_out = self.blocks[9].mlp(normalized_resid_mid)
        residual = residual + mlp_out

        normalized_resid_final = self.ln_final(residual)
        logits = self.unembed(normalized_resid_final)
        # logits have shape [batch, position, logits]
        return logits

extracted_model = ExtractedModel(model)

In [None]:
test_prompt("1 2 3 4", " 5", extracted_model, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', '1', ' 2', ' 3', ' 4']
Tokenized answer: [' 5']
Performance on answer token:
[b]Rank: 0        Logit: 31.24 Prob: 99.39% Token: | 5|[/b]
Top 0th token. Logit: 31.24 Prob: 99.39% Token: | 5|
Top 1th token. Logit: 25.76 Prob:  0.41% Token: | 4|
Top 2th token. Logit: 23.75 Prob:  0.06% Token: | 7|
Top 3th token. Logit: 23.62 Prob:  0.05% Token: | 3|
Top 4th token. Logit: 23.47 Prob:  0.04% Token: | 9|
Top 5th token. Logit: 22.09 Prob:  0.01% Token: | 6|
Top 6th token. Logit: 22.07 Prob:  0.01% Token: | Next|
Top 7th token. Logit: 21.91 Prob:  0.01% Token: |★|
Top 8th token. Logit: 20.88 Prob:  0.00% Token: | 49|
Top 9th token. Logit: 20.69 Prob:  0.00% Token: | 15|
[b]Ranks of the answer tokens:[/b] [(' 5', 0)]


In [None]:
test_prompt("January February March April", " May", extracted_model, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', 'January', ' February', ' March', ' April']
Tokenized answer: [' May']
Performance on answer token:
[b]Rank: 1        Logit: 23.83 Prob: 20.82% Token: | May|[/b]
Top 0th token. Logit: 24.37 Prob: 35.68% Token: | 2015|
Top 1th token. Logit: 23.83 Prob: 20.82% Token: | May|
Top 2th token. Logit: 22.39 Prob:  4.90% Token: | 2017|
Top 3th token. Logit: 22.13 Prob:  3.80% Token: | 2005|
Top 4th token. Logit: 21.92 Prob:  3.07% Token: | 5|
Top 5th token. Logit: 21.88 Prob:  2.96% Token: | Rate|
Top 6th token. Logit: 21.73 Prob:  2.56% Token: | Category|
Top 7th token. Logit: 21.48 Prob:  1.99% Token: | 1997|
Top 8th token. Logit: 21.33 Prob:  1.71% Token: | 57|
Top 9th token. Logit: 21.04 Prob:  1.28% Token: |
|
[b]Ranks of the answer tokens:[/b] [(' May', 1)]


In [None]:
test_prompt("Sunday Monday Tuesday Wednesday", " Thursday", extracted_model, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', 'Sunday', ' Monday', ' Tuesday', ' Wednesday']
Tokenized answer: [' Thursday']
Performance on answer token:
[b]Rank: 0        Logit: 25.45 Prob: 22.95% Token: | Thursday|[/b]
Top 0th token. Logit: 25.45 Prob: 22.95% Token: | Thursday|
Top 1th token. Logit: 24.81 Prob: 12.10% Token: | Wednesday|
Top 2th token. Logit: 24.71 Prob: 11.01% Token: | Evening|
Top 3th token. Logit: 24.61 Prob:  9.93% Token: | September|
Top 4th token. Logit: 24.31 Prob:  7.36% Token: | night|
Top 5th token. Logit: 24.13 Prob:  6.13% Token: | July|
Top 6th token. Logit: 23.59 Prob:  3.58% Token: | nights|
Top 7th token. Logit: 23.59 Prob:  3.57% Token: | January|
Top 8th token. Logit: 23.55 Prob:  3.43% Token: | December|
Top 9th token. Logit: 23.43 Prob:  3.06% Token: | Year|
[b]Ranks of the answer tokens:[/b] [(' Thursday', 0)]


In [None]:
test_prompt("A B C D", " E", extracted_model, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', 'A', ' B', ' C', ' D']
Tokenized answer: [' E']
Performance on answer token:
[b]Rank: 0        Logit: 25.53 Prob: 67.87% Token: | E|[/b]
Top 0th token. Logit: 25.53 Prob: 67.87% Token: | E|
Top 1th token. Logit: 24.63 Prob: 27.67% Token: | D|
Top 2th token. Logit: 21.98 Prob:  1.95% Token: | G|
Top 3th token. Logit: 20.77 Prob:  0.58% Token: | F|
Top 4th token. Logit: 20.01 Prob:  0.27% Token: | C|
Top 5th token. Logit: 19.44 Prob:  0.15% Token: | Q|
Top 6th token. Logit: 19.24 Prob:  0.13% Token: | H|
Top 7th token. Logit: 19.21 Prob:  0.12% Token: |
|
Top 8th token. Logit: 19.11 Prob:  0.11% Token: | d|
Top 9th token. Logit: 19.08 Prob:  0.11% Token: |+|
[b]Ranks of the answer tokens:[/b] [(' E', 0)]


Yes, MLP 8 is CRUCIAL for alphabet. Not so much attn 8. But does this mean alphabet is a diff pattern than digits?

In [None]:
test_prompt("one two three four", " five", extracted_model, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', 'one', ' two', ' three', ' four']
Tokenized answer: [' five']
Performance on answer token:
[b]Rank: 10       Logit: 18.42 Prob:  0.02% Token: | five|[/b]
Top 0th token. Logit: 26.76 Prob: 98.94% Token: |teen|
Top 1th token. Logit: 20.51 Prob:  0.19% Token: |ths|
Top 2th token. Logit: 20.32 Prob:  0.16% Token: | 5|
Top 3th token. Logit: 19.95 Prob:  0.11% Token: |
|
Top 4th token. Logit: 19.60 Prob:  0.08% Token: | 3|
Top 5th token. Logit: 19.49 Prob:  0.07% Token: |teenth|
Top 6th token. Logit: 19.46 Prob:  0.07% Token: | ...|
Top 7th token. Logit: 19.01 Prob:  0.04% Token: | 13|
Top 8th token. Logit: 18.82 Prob:  0.04% Token: | 7|
Top 9th token. Logit: 18.57 Prob:  0.03% Token: | 4|
[b]Ranks of the answer tokens:[/b] [(' five', 10)]


Skipping attn 8 is horrible for number words

### skip all attn until attn 9

In [17]:
import torch.nn as nn

class ExtractedModel(nn.Module):
    def __init__(self, original_model):
        super(ExtractedModel, self).__init__()
        self.embed = original_model.embed
        self.pos_embed = original_model.pos_embed

        self.blocks = original_model.blocks

        self.ln_final = original_model.ln_final
        self.unembed = original_model.unembed

    def forward(self, tokens):
        # tokens [batch, position]
        embed = self.embed(tokens)
        pos_embed = self.pos_embed(tokens)
        residual = embed + pos_embed

        for block in self.blocks[0:9]:
            # residual = block(residual)

            # normalized_resid_pre = self.blocks[8].ln1(residual)
            # attn_out = self.blocks[8].attn(normalized_resid_pre, normalized_resid_pre, normalized_resid_pre)
            # residual = residual + attn_out

            normalized_resid_mid = block.ln2(residual)
            mlp_out = block.mlp(normalized_resid_mid)
            residual = residual + mlp_out

        residual = self.blocks[9](residual)

        normalized_resid_final = self.ln_final(residual)
        logits = self.unembed(normalized_resid_final)
        # logits have shape [batch, position, logits]
        return logits

extracted_model = ExtractedModel(model)

In [18]:
test_prompt("1 2 3 4", " 5", extracted_model, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', '1', ' 2', ' 3', ' 4']
Tokenized answer: [' 5']
Performance on answer token:
[b]Rank: 73       Logit:  5.94 Prob:  0.11% Token: | 5|[/b]
Top 0th token. Logit: 10.46 Prob: 10.39% Token: |th|
Top 1th token. Logit:  9.60 Prob:  4.40% Token: |-|
Top 2th token. Logit:  9.60 Prob:  4.40% Token: |.|
Top 3th token. Logit:  9.20 Prob:  2.96% Token: |,|
Top 4th token. Logit:  8.13 Prob:  1.01% Token: | and|
Top 5th token. Logit:  7.81 Prob:  0.74% Token: | of|
Top 6th token. Logit:  7.76 Prob:  0.70% Token: |:|
Top 7th token. Logit:  7.76 Prob:  0.70% Token: | the|
Top 8th token. Logit:  7.72 Prob:  0.67% Token: |x|
Top 9th token. Logit:  7.69 Prob:  0.65% Token: |
|
[b]Ranks of the answer tokens:[/b] [(' 5', 73)]


In [19]:
test_prompt("January February March April", " May", extracted_model, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', 'January', ' February', ' March', ' April']
Tokenized answer: [' May']
Performance on answer token:
[b]Rank: 32       Logit:  6.92 Prob:  0.34% Token: | May|[/b]
Top 0th token. Logit:  9.10 Prob:  3.07% Token: |,|
Top 1th token. Logit:  8.91 Prob:  2.52% Token: | the|
Top 2th token. Logit:  8.19 Prob:  1.23% Token: |.|
Top 3th token. Logit:  8.15 Prob:  1.19% Token: | and|
Top 4th token. Logit:  8.05 Prob:  1.07% Token: |-|
Top 5th token. Logit:  8.01 Prob:  1.03% Token: | of|
Top 6th token. Logit:  7.85 Prob:  0.88% Token: | a|
Top 7th token. Logit:  7.63 Prob:  0.70% Token: |
|
Top 8th token. Logit:  7.58 Prob:  0.67% Token: | in|
Top 9th token. Logit:  7.56 Prob:  0.65% Token: | "|
[b]Ranks of the answer tokens:[/b] [(' May', 32)]


In [20]:
test_prompt("Sunday Monday Tuesday Wednesday", " Thursday", extracted_model, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', 'Sunday', ' Monday', ' Tuesday', ' Wednesday']
Tokenized answer: [' Thursday']
Performance on answer token:
[b]Rank: 13       Logit: 15.71 Prob:  0.00% Token: | Thursday|[/b]
Top 0th token. Logit: 27.37 Prob: 56.10% Token: | morning|
Top 1th token. Logit: 26.53 Prob: 24.26% Token: | night|
Top 2th token. Logit: 25.95 Prob: 13.59% Token: | afternoon|
Top 3th token. Logit: 25.12 Prob:  5.90% Token: | evening|
Top 4th token. Logit: 20.82 Prob:  0.08% Token: | nights|
Top 5th token. Logit: 19.79 Prob:  0.03% Token: | mornings|
Top 6th token. Logit: 19.00 Prob:  0.01% Token: | Night|
Top 7th token. Logit: 18.35 Prob:  0.01% Token: | Tuesday|
Top 8th token. Logit: 17.83 Prob:  0.00% Token: | evenings|
Top 9th token. Logit: 17.65 Prob:  0.00% Token: | Wednesday|
[b]Ranks of the answer tokens:[/b] [(' Thursday', 13)]


In [21]:
test_prompt("A B C D", " E", extracted_model, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', 'A', ' B', ' C', ' D']
Tokenized answer: [' E']
Performance on answer token:
[b]Rank: 192      Logit:  6.32 Prob:  0.09% Token: | E|[/b]
Top 0th token. Logit:  9.22 Prob:  1.61% Token: |-|
Top 1th token. Logit:  9.02 Prob:  1.32% Token: |.|
Top 2th token. Logit:  8.89 Prob:  1.16% Token: |orm|
Top 3th token. Logit:  8.87 Prob:  1.14% Token: |ere|
Top 4th token. Logit:  8.64 Prob:  0.91% Token: |AG|
Top 5th token. Logit:  8.45 Prob:  0.75% Token: |Y|
Top 6th token. Logit:  8.31 Prob:  0.65% Token: |ried|
Top 7th token. Logit:  8.24 Prob:  0.61% Token: |etermined|
Top 8th token. Logit:  8.13 Prob:  0.54% Token: |aim|
Top 9th token. Logit:  8.03 Prob:  0.49% Token: |ella|
[b]Ranks of the answer tokens:[/b] [(' E', 192)]


In [22]:
test_prompt("one two three four", " five", extracted_model, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', 'one', ' two', ' three', ' four']
Tokenized answer: [' five']
Performance on answer token:
[b]Rank: 147      Logit:  4.56 Prob:  0.06% Token: | five|[/b]
Top 0th token. Logit:  8.16 Prob:  2.09% Token: |-|
Top 1th token. Logit:  8.04 Prob:  1.85% Token: |,|
Top 2th token. Logit:  8.03 Prob:  1.83% Token: | of|
Top 3th token. Logit:  7.71 Prob:  1.33% Token: | and|
Top 4th token. Logit:  7.67 Prob:  1.27% Token: | the|
Top 5th token. Logit:  7.58 Prob:  1.17% Token: |.|
Top 6th token. Logit:  7.25 Prob:  0.84% Token: | in|
Top 7th token. Logit:  7.17 Prob:  0.78% Token: | to|
Top 8th token. Logit:  7.14 Prob:  0.76% Token: | a|
Top 9th token. Logit:  6.84 Prob:  0.56% Token: |
|
[b]Ranks of the answer tokens:[/b] [(' five', 147)]


Unfortunately, this doesn't work. Clearly we can't just slice and dice models this simply; we must ablate heads. Though, sometimes we can skip some things for very specific tasks.

# What happens if we skip L10 in "One is 1"?

In [None]:
import torch.nn as nn

class ExtractedModel_toL9(nn.Module):
    def __init__(self, original_model):
        super(ExtractedModel_toL9, self).__init__()
        self.embed = original_model.embed
        self.pos_embed = original_model.pos_embed
        self.blocks = original_model.blocks
        self.ln_final = original_model.ln_final
        self.unembed = original_model.unembed

    def forward(self, tokens):
        # tokens [batch, position]
        embed = self.embed(tokens)
        pos_embed = self.pos_embed(tokens)
        residual = embed + pos_embed

        for block in self.blocks[0:10]:
            residual = block(residual)

        normalized_resid_final = self.ln_final(residual)
        logits = self.unembed(normalized_resid_final)
        # logits have shape [batch, position, logits]
        return logits

extracted_model_toL9 = ExtractedModel_toL9(model)

In [None]:
test_prompt("One is 1. Two is 2. Three is 3. Four is 4. Five is", " 5", extracted_model_toL9, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', 'One', ' is', ' 1', '.', ' Two', ' is', ' 2', '.', ' Three', ' is', ' 3', '.', ' Four', ' is', ' 4', '.', ' Five', ' is']
Tokenized answer: [' 5']
Performance on answer token:
[b]Rank: 548      Logit:  7.85 Prob:  0.00% Token: | 5|[/b]
Top 0th token. Logit: 16.69 Prob: 32.19% Token: | not|
Top 1th token. Logit: 15.12 Prob:  6.68% Token: | shown|
Top 2th token. Logit: 15.11 Prob:  6.63% Token: | also|
Top 3th token. Logit: 14.89 Prob:  5.30% Token: | a|
Top 4th token. Logit: 14.83 Prob:  4.99% Token: | definitely|
Top 5th token. Logit: 14.22 Prob:  2.72% Token: | still|
Top 6th token. Logit: 14.18 Prob:  2.61% Token: | probably|
Top 7th token. Logit: 13.85 Prob:  1.88% Token: | considered|
Top 8th token. Logit: 13.74 Prob:  1.68% Token: |ometric|
Top 9th token. Logit: 13.51 Prob:  1.33% Token: |nt|
[b]Ranks of the answer tokens:[/b] [(' 5', 548)]


Even THIS works. This means 10.7 was useless. Perhaps it was backup.

Sanity check to see if bugs in extracted model for this input:

In [None]:
import torch.nn as nn

class ExtractedModel_toL9(nn.Module):
    def __init__(self, original_model):
        super(ExtractedModel_toL9, self).__init__()
        self.embed = original_model.embed
        self.pos_embed = original_model.pos_embed
        self.blocks = original_model.blocks
        self.ln_final = original_model.ln_final
        self.unembed = original_model.unembed

    def forward(self, tokens):
        # tokens [batch, position]
        embed = self.embed(tokens)
        pos_embed = self.pos_embed(tokens)
        residual = embed + pos_embed

        for block in self.blocks[0:5]:
            residual = block(residual)

        normalized_resid_final = self.ln_final(residual)
        logits = self.unembed(normalized_resid_final)
        # logits have shape [batch, position, logits]
        return logits

extracted_model_toL9 = ExtractedModel_toL9(model)

In [None]:
test_prompt("One is 1. Two is 2. Three is 3. Four is 4. Five is", " 5", extracted_model_toL9, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', 'One', ' is', ' 1', '.', ' Two', ' is', ' 2', '.', ' Three', ' is', ' 3', '.', ' Four', ' is', ' 4', '.', ' Five', ' is']
Tokenized answer: [' 5']
Performance on answer token:
[b]Rank: 681      Logit:  7.50 Prob:  0.01% Token: | 5|[/b]
Top 0th token. Logit: 16.12 Prob: 31.64% Token: | not|
Top 1th token. Logit: 14.65 Prob:  7.21% Token: | a|
Top 2th token. Logit: 14.51 Prob:  6.30% Token: | shown|
Top 3th token. Logit: 14.42 Prob:  5.75% Token: | also|
Top 4th token. Logit: 14.04 Prob:  3.95% Token: | still|
Top 5th token. Logit: 13.46 Prob:  2.20% Token: | an|
Top 6th token. Logit: 13.04 Prob:  1.45% Token: | definitely|
Top 7th token. Logit: 13.02 Prob:  1.42% Token: | probably|
Top 8th token. Logit: 12.68 Prob:  1.01% Token: | the|
Top 9th token. Logit: 12.42 Prob:  0.78% Token: | considered|
[b]Ranks of the answer tokens:[/b] [(' 5', 681)]


Ok, so it's working well as when we change just one thing (:9 to :5) it breaks. So only L0 to L7, then L9, are important.