In [1]:
from transformer_lens import HookedTransformer
model = HookedTransformer.from_pretrained("google/gemma-2-2b")

  from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards: 100%|██████████| 3/3 [00:04<00:00,  1.41s/it]


Loaded pretrained model google/gemma-2-2b into HookedTransformer


In [2]:
import torch
from typing import Callable, Union
from jaxtyping import Float
from transformer_lens.hook_points import HookPoint

class PatchscopesOutput:
    def __init__(self, explanation: str | list[str]):
        self.explanation = explanation

    def __str__(self):
        return f"Patchscopes Explanation:\n{self.explanation}"

    def __repr__(self):
        return self.__str__()

In [214]:
PSMappingFunction = Callable[[torch.Tensor], torch.Tensor]
InterpTensorType = Union[Float[torch.Tensor, "... d_model"], Float[torch.Tensor, "d_model ..."]]


def patchscopes(
        vector: InterpTensorType,
        prompt: str,
        n: int = 20,
        target_token: str | None = None,
        target_position: int | None = None,
        mapping_function: PSMappingFunction = lambda x: x,
        target_model: HookedTransformer | None = None,
        target_layer: int = 1,
        temperature: float = 0.5,
    ) -> PatchscopesOutput:
    target_model = target_model if target_model is not None else model

    # TODO: handle multi dim case
    assert vector.shape == (target_model.cfg.d_model,), f"Vector must be of shape (d_model={target_model.cfg.d_model},), got {vector.shape}"
    assert (target_token is None) != (target_position is None), "Exactly one of target_token and target_position must be set"

    prompt_toks = target_model.to_tokens(prompt)

    if target_position is None:
        target_position = target_model.get_token_position(target_token, prompt_toks)

    hook_ran = False
    def hook_patch_in_act(tensor: torch.Tensor, hook: HookPoint) -> torch.Tensor | None:
        nonlocal hook_ran
        if not hook_ran:
            tensor[:, target_position] = vector
            hook_ran = True

        return tensor

    with target_model.hooks(fwd_hooks=[(f"blocks.{target_layer}.hook_resid_pre", hook_patch_in_act)]):
        generated_toks = target_model.generate(prompt_toks, max_new_tokens=n, verbose=False, temperature=temperature, use_past_kv_cache=True)

    return PatchscopesOutput(target_model.to_string(generated_toks))

In [4]:
from mechinterp import Interpreter
interp = Interpreter(model)

In [199]:
# v = model.embed(model.to_tokens(" News", prepend_bos=False)[0,0])
v = model.blocks[15].mlp.W_out[58]

print(interp.logit_lens(v))

Logit Lens Output:
	- Topk tokens: [' their', 'Their', ' deres', ' Their', ' theirs', ' themselves', 'their', 'sy', ' deras', ' loro', '他们的', '他們的', 'irs', 'ÍT', ' THEIR', ' website', ' leur', ' Haupts', ' hope', 'themselves']

	- Bottomk tokens: ['parsedMessage', ' AttributeSet', 'kuuta', 'PerformLayout', 'GEBURTSDATUM', 'WebElementEntity', ' StringTokenizer', 'Rüyada', ' AssemblyCulture', 'invokeLater', 'aarrggbb', 'CppMethod', '\ue315', 'AttributeSet', 'MLLoader', 'onAttach', ' estekak', 'WebServlet', '+#+', ' ProtoMessage']


In [124]:
model.to_str_tokens("Amazon's former CEO attended the Oscars")

['<bos>',
 'Amazon',
 "'",
 's',
 ' former',
 ' CEO',
 ' attended',
 ' the',
 ' Oscars']

In [490]:
v1 = model.run_with_cache("Amazon's former CEO attended the Oscars")[1]["blocks.19.hook_resid_pre"][0, 5]
interp.logit_lens(v1)

Logit Lens Output:
	- Topk tokens: [' and', ',', ' Amazon', 'Amazon', ' CEO', ' amazon', ' has', ' السابق', ' fondateur', ' chief', ' turned', ' in', ' famously', ' company', ' founder', ' who', ' emeritus', ' says', ' now', ' business']

	- Bottomk tokens: ['########.', ' AssemblyCulture', 'findpost', '<bos>', ' @"/', 'styleType', '+:+', ' typelib', ")':", 'uxxxx', ' CreateTagHelper', 'Datuak', 'MigrationBuilder', ' Reverso', 'ANSA', ' للمعارف', 'UnsafeEnabled', ' ModelExpression', 'esgue', 'ỡng']

In [491]:
v2 = model.run_with_cache("Apple's former CEO attended the Oscars")[1]["blocks.19.hook_resid_pre"][0, 5]
interp.logit_lens(v2)

Logit Lens Output:
	- Topk tokens: [' Apple', 'Apple', ' apple', ' and', 'apple', ',', ' APPLE', ' CEO', ' turned', ' company', 'apples', ' emeritus', ' démission', ' Cupertino', ' publicly', ' ousted', ' Apfel', ' famously', ' السابق', 'APPLE']

	- Bottomk tokens: ['########.', ' AssemblyCulture', 'findpost', '<bos>', ' }}"></', 'MigrationBuilder', 'TagMode', 'SBATCH', 'esgue', 'styleType', ' chande', 'ütfen', ' Penh', ' disambiguazione', ']").', '+:+', 'UrlResolution', ' Winaray', ' Picchu', ")':"]

In [218]:
# ps = patchscopes(v, "The birth name of X is", target_token=" X", temperature=0.3, target_layer=2, n=50)
ps = patchscopes(v, "cat->cat; 135->135; hello->hello; X->", target_position=-1, temperature=0.3, target_layer=2, n=50)

print(ps.explanation[0])

<bos>cat->cat; 135->135; hello->hello; X-> Jeff Bezos is the richest person in the world, with a net worth of $121 billion. 135->135; hello->hello; 135->135; hello->hello; 13


In [204]:
hp = model.hook_dict["blocks.2.hook_resid_pre"]

In [211]:
h = hp.fwd_hooks[0]

In [212]:
h

LensHandle(hook=<torch.utils.hooks.RemovableHandle object at 0x758763f21fd0>, is_permanent=False, context_level=None)

In [210]:
h.hook.remove()

In [227]:
model.to_tokens(" X")

tensor([[   2, 1576]], device='cuda:0')

In [290]:
def get_token_position(prompt: Float[torch.Tensor, "pos"] | Float[torch.Tensor, "1 pos"], token: int | str) -> int:
    if isinstance(token, str):
        token = model.to_single_token(token) # type: ignore

    nz = (prompt == token).nonzero()

    if nz.shape[0] == 0:
        raise ValueError(f"Token '{model.to_string(token)}' not found in prompt")

    if nz.shape[0] > 1:
        raise ValueError(f"Token '{model.to_string(token)}' found multiple times in prompt: {nz.tolist()}")

    return model.get_token_position(token, prompt)

In [417]:
import string

def count_placeholders(s):
    formatter = string.Formatter()
    return sum(1 for _, field_name, _, _ in formatter.parse(s) if field_name is not None)

def split_by_placeholders(s):
    formatter = string.Formatter()

    output = []
    output_str = ""

    for prefix, field_name, _, _ in formatter.parse(s):
        output_str += prefix

        if field_name is None:
            continue

        output.append(output_str)
        output_str = ""

    output.append(output_str)

    return output

def join_list(tok, lst, model: HookedTransformer, prepend_bos: bool = True) -> tuple[list[int], list[int]]:
    output = []
    indices = []

    if not lst:
        return output, indices
    
    bos = model.to_tokens("", prepend_bos=prepend_bos)[0,0].item()

    output.append(bos)
    output.extend(model.to_tokens(lst[0], prepend_bos=False).tolist()[0])

    for item in lst[1:]:
        indices.append(len(output))
        output.append(tok)

        output.extend(model.to_tokens(item, prepend_bos=False).tolist()[0])

    return output, indices

def format_toks(model: HookedTransformer, prompt, tok="X", prepend_bos: bool = True) -> tuple[torch.Tensor, list[int]]:
    tok = model.to_single_token(tok)
    splits = split_by_placeholders(prompt)
    output, indices = join_list(tok, splits, model, prepend_bos)
    return torch.tensor(output), indices

# Examples:
s1 = "Hello {} and {} literal {{}}"
s2 = "He{}llo and literal {{}}"
s3 = "No placeholders here: {{}}"
s4 = "{} {} {{}}"

prompt, indices = format_toks(model, s4)
print(model.to_str_tokens(prompt), indices)

['<bos>', 'X', ' ', 'X', ' {}'] [1, 3]


In [442]:
IDENTITY_FEW_SHOT = "cat->cat; 135->135; hello->hello; {}->"
DESCRIPTION_FEW_SHOT = "Syria: Country in the Middle East, Leonardo DiCaprio: American actor, Samsung: South Korean multinational major appliance and consumer electronics corporation, {}:"

In [504]:
def patchscopes(
        vector,
        prompt: str,
        n: int = 30,
        target_model: HookedTransformer | None = None,
        target_layer: int = 1,
        temperature: float = 0.5,
    ) -> PatchscopesOutput:
    target_model = target_model if target_model is not None else model

    if len(vector.shape) == 1:
        vector = vector.unsqueeze(0)

    # TODO: handle multi dim case
    assert len(vector.shape) <= 2, f"Vector must be (d_model,) or (batch_size, d_model), got {vector.shape}"

    num_placeholders = count_placeholders(prompt)

    assert num_placeholders == vector.shape[0], f"Prompt must contain {vector.shape[0]} placeholders, got {num_placeholders}."

    prompt_toks, indices = format_toks(target_model, prompt)

    hook_ran = False
    def hook_patch_in_act(tensor: torch.Tensor, hook: HookPoint) -> torch.Tensor | None:
        nonlocal hook_ran
        if not hook_ran:
            for i in range(vector.shape[0]):
                tensor[:, indices[i]] = vector[i]
            hook_ran = True

        return tensor

    with target_model.hooks(fwd_hooks=[(f"blocks.{target_layer}.hook_resid_pre", hook_patch_in_act)]):
        generated_toks = target_model.generate(prompt_toks.unsqueeze(0), max_new_tokens=n, verbose=False, temperature=temperature, use_past_kv_cache=True)

    return PatchscopesOutput(target_model.to_string(generated_toks))

In [517]:
patchscopes(1.5*v1 + v2, "The companies founded by {} are called")

Patchscopes Explanation:
['<bos>The companies founded by X are called Apple and Google, but they are not the only ones. There are other companies founded by people who are not the CEO but who have a major impact']

In [518]:
patchscopes(v1, DESCRIPTION_FEW_SHOT)

Patchscopes Explanation:
['<bos>Syria: Country in the Middle East, Leonardo DiCaprio: American actor, Samsung: South Korean multinational major appliance and consumer electronics corporation, X: American business magnate, philanthropist and media proprietor, 100-meter sprint: Track and field event, 100-meter dash']

tensor(2, device='cuda:0')

In [373]:
split_by_placeholderss(s2)

['He', 'llo and literal {}']

In [323]:
s3[3:5]

'{}'

In [317]:
import re

def get_placeholder_indices(s):
    # Matches a '{' that is not escaped, then lazily captures until the first unescaped '}'
    pattern = r'(?<!{){(?!{).*?}(?!})'
    return [(match.start(), match.end()) for match in re.finditer(pattern, s)]

# Example usage:
s = "Hello {} and literal {{}} and {name}"
print(get_placeholder_indices(s))
# Output: [(6, 8), (28, 34)]


[(6, 8), (30, 36)]


In [315]:
"cat->cat; 135->135; hello->hello; {} {{->".format("X")

'cat->cat; 135->135; hello->hello; X {->'

In [312]:
"cat->cat; 135->135; hello->hello; {} {} {}".format(*["X"]*10)

'cat->cat; 135->135; hello->hello; X X X'

In [299]:
"cat->cat; 135->135; hello->hello; {}".format(" X")

'cat->cat; 135->135; hello->hello; '

In [237]:
prompt = model.to_tokens("The birth name of X is X")

In [293]:
get_token_position(model.to_tokens("The birth name of X"), " X")

5

In [232]:
print(model.get_token_position(" X", model.to_tokens("The birth name of X is X")))
print(model.to_tokens("The birth name of X is X"))

5
tensor([[   2,  651, 5951, 1503,  576, 1576,  603, 1576]], device='cuda:0')
