In [1]:
import torch
from easy_transformer import EasyTransformer
from easy_transformer.hook_points import HookPoint

from dataclasses import dataclass, field
from typing import List, Tuple
from collections import defaultdict

import numpy as np

In [2]:
# This is used in order to ensure that there is no overlapping in target memory to edit
@dataclass
class Memory:
    subject : str
    relation : str
    objectif : str
    
    def __post_init__(self):
        self.s = self.subject
        self.r = self.relation
        self.o = self.objectif

@dataclass
class Request:
    # holding the Subject, Relation and Object to target
    mem : Memory
    prompt : str
    examples : List[str] = field(default_factory=list)


@dataclass
class RequestList:
    __requests : List[Request] = field(default_factory=list)

    def append(self, requests : List[Request]):
        # if requests not already a list, change it to one
        if type(requests) != list:
            requests = [requests]

        for request in requests:
            #region check for conflicting requets
            
            # Check if there is no conflicting (subject, relation)
            # if there is one, just skip this request
            skip = False
            for i, saved_request in enumerate(self.__requests): 
                
                # get the memory elements from saved request and current request
                saved_s, saved_r, *_ = saved_request.mem.__dict__.values()
                s, r, *_ = request.mem.__dict__.values()

                if saved_s == s and saved_r == r:
                    print(f"skipping ({s},{r}) because it is in conflict with ({saved_s},{saved_r}) at index {i}")
                    skip = True
                    continue
            if skip:
                continue
            
            #endregion
            self.__requests.append(request)


In [3]:
requests = RequestList()
requests.append([
    Request(
        mem=Memory("Lebron James", "play sport", "Football"), 
        prompt="{} is playing the sport of", 
        examples=[
            "Lebron James was famousely known for"
        ]
    ),
])

In [4]:
# Must be one of
# ['gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl', 'facebook/opt-125m', 'facebook/opt-1.3b', 'facebook/opt-2.7b', 'facebook/opt-6.7b', 'facebook/opt-13b', 'facebook/opt-30b', 'facebook/opt-66b', 'EleutherAI/gpt-neo-125M', 'EleutherAI/gpt-neo-1.3B', 'EleutherAI/gpt-neo-2.7B', 'EleutherAI/gpt-j-6B', 'EleutherAI/gpt-neox-20b']
MODEL_NAME = "gpt2-xl"

In [5]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

model = EasyTransformer.from_pretrained(MODEL_NAME)
model.to(device)

Using pad_token, but it is not set yet.


Moving model to device:  cuda
Finished loading pretrained model gpt2-xl into EasyTransformer!
Moving model to device:  cuda


# Causal tracing

In [6]:
def untuple(x):
    return x[0] if isinstance(x, tuple) else x

In [7]:
def layername(num=0, _type=None):
    # Layer translation for top level hooks on EasyTranformer
    supported_hooks = ["embed", "pos_embed", "resid_pre", "resid_mid", "resid_post", "attn_out", "mlp_out"]
    assert _type in supported_hooks, f"type must be one of {supported_hooks}"
    
    if _type == "embed":
        return "hook_embed"
    
    if _type == "pos_embed":
        return "hook_pos_embed"

    return f"blocks.{num}.hook_{_type}"

In [8]:
def predict_processed_inputs(model : EasyTransformer, _inputs : List[str], hooks = None):
    # Forward
    toks = [model.tokenizer.tokenize(i) for i in _inputs]

    out = model(_inputs)

    # Logits to probs
    probs = torch.stack([torch.softmax(out[i, len(tok)], 0) for i, tok in enumerate(toks)])
    print(probs.shape)

    # get best value infos
    p, preds = torch.max(probs, dim=-1)

    return preds, p

In [9]:
def process_predict(model : EasyTransformer, prompts : List[str]):
    if type(prompts) != list:
        prompts = [prompts]

    # get predictions + probabilities
    preds, ps = predict_processed_inputs(model, prompts)
    
    # decode resutls
    result = [model.tokenizer.decode(c) for c in preds]
    
    return [{
        "prediction" : r,
        "prob" : p
    } for r, p in zip(result, ps.detach().cpu().numpy())]

In [10]:
process_predict(model, ["The Eiffel Tower is in the city of", "The Space Needle is in the city of"])

torch.Size([2, 50257])


[{'prediction': ' Paris', 'prob': 0.7615234},
 {'prediction': ' Seattle', 'prob': 0.96716756}]

In [11]:
prompt_datasets = [
    ("The Space Needle is in the city of", " Seatle"),
]

In [12]:
@dataclass
class Patches():
    # Range of tokens to corrupt
    tokens_to_mix : Tuple[int, int] = None
    # list of (index to patch, layer to patch at)
    states_to_patch : List[Tuple[int, str]] = field(default_factory=list)
    # How much noise to add
    noise : float = 0.1

    def __post_init__(self):
        self.rs = np.random.RandomState(42)
        
        self.device = device = 'cuda' if torch.cuda.is_available() else 'cpu'

        # reform patches for easier process inside the hook
        self.patches = defaultdict(list)
        for t, l in self.states_to_patch:
            self.patches[l].append(t)

    def embed_hook(self, embed : torch.Tensor, hook : HookPoint):
        # embed of shape (batch_size, n_tokens, embed_size)
        bs, _, es = embed.shape

        if self.tokens_to_mix is not None:
            begin, end = self.tokens_to_mix

            # Corrupt every element of the batch except the first one (for later uncorruption)
            embed[1:, begin:end] += self.noise * torch.from_numpy(
                self.rs.randn(bs - 1, end - begin, es)
            ).to(self.device)

        return embed

    def layer_hook(self, x : torch.Tensor, hook : HookPoint):
        print(hook.name, "is being patched")
        
        # Loop through tokens on which to restore the right data
        for t in self.patches[hook.name]:
            x[1:, t] = x[0, t]

        return x
    
    def filter_hook_to_patch(self, name):
        return name in self.patches

In [13]:
def forward_with_patch(
        model : EasyTransformer,
        inputs : List[str],
        answer_token : int,
        patches : Patches
    ):
    processed_hooks = [
        (layername(_type="embed"), patches.embed_hook),
        (patches.filter_hook_to_patch, patches.layer_hook)
    ]

    outputs = model.run_with_hooks(inputs, fwd_hooks=processed_hooks)

In [14]:
forward_with_patch(
    model,
    prompt_datasets[0][0],
    10,
    Patches(
        (0, 4),
        [(2, layername(2, "attn_out"))]
    ),
)

blocks.2.hook_attn_out is being patched
