In [18]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import copy

from dataclasses import dataclass, asdict, field
from typing import List

DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
DEVICE

'cuda:0'

In [3]:
pretrained = "google/gemma-1.1-2b-it"
pretrained = "EleutherAI/pythia-2.8b"

In [4]:
pretrained_model = AutoModelForCausalLM.from_pretrained(pretrained).to(DEVICE)

In [5]:
pretrained_model

GPTNeoXForCausalLM(
  (gpt_neox): GPTNeoXModel(
    (embed_in): Embedding(50304, 2560)
    (emb_dropout): Dropout(p=0.0, inplace=False)
    (layers): ModuleList(
      (0-31): 32 x GPTNeoXLayer(
        (input_layernorm): LayerNorm((2560,), eps=1e-05, elementwise_affine=True)
        (post_attention_layernorm): LayerNorm((2560,), eps=1e-05, elementwise_affine=True)
        (post_attention_dropout): Dropout(p=0.0, inplace=False)
        (post_mlp_dropout): Dropout(p=0.0, inplace=False)
        (attention): GPTNeoXSdpaAttention(
          (rotary_emb): GPTNeoXRotaryEmbedding()
          (query_key_value): Linear(in_features=2560, out_features=7680, bias=True)
          (dense): Linear(in_features=2560, out_features=2560, bias=True)
          (attention_dropout): Dropout(p=0.0, inplace=False)
        )
        (mlp): GPTNeoXMLP(
          (dense_h_to_4h): Linear(in_features=2560, out_features=10240, bias=True)
          (dense_4h_to_h): Linear(in_features=10240, out_features=2560, bias=Tr

In [6]:
one_direction = '/net/projects/clab/tnief/bidirectional-reversal/trained/gemma_one_direction'
both_directions = '/net/projects/clab/tnief/bidirectional-reversal/trained/gemma_both_directions'

both_directions = "/net/projects/clab/tnief/bidirectional-reversal/results/pythia-2.8b/fake_movies_real_actors20250407_1830/checkpoint-10800"
both_directions

'/net/projects/clab/tnief/bidirectional-reversal/results/pythia-2.8b/fake_movies_real_actors20250407_1830/checkpoint-10800'

In [7]:
# {"text": "John Travolta stars in Psychological Passenger alongside Millie Bobby Brown."}
example_text = "John Travolta stars in Psychological Passenger alongside Millie Bobby Brown."
text = "John Travolta stars in Psychological Passenger alongside"
target_token = "Millie"

In [8]:
tokenizer = AutoTokenizer.from_pretrained(pretrained)

In [9]:
llm_both = AutoModelForCausalLM.from_pretrained(both_directions).to(DEVICE)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [None]:
llm_one = AutoModelForCausalLM.from_pretrained(one_direction).to(DEVICE)

In [10]:
input_tokens = tokenizer(text, return_tensors="pt")['input_ids'].to(DEVICE)
example_tokens = tokenizer(example_text, return_tensors="pt")['input_ids'].to(DEVICE)
example_tokens, example_tokens.shape , tokenizer.decode(example_tokens.squeeze().tolist())

(tensor([[ 8732, 25480, 41385,  6114,   275, 49205, 11271,  9562, 12936, 13134,
            466, 24707,  7233,    15]], device='cuda:0'),
 torch.Size([1, 14]),
 'John Travolta stars in Psychological Passenger alongside Millie Bobby Brown.')

In [11]:
target_token_idx = tokenizer.encode(target_token)[1]

In [13]:
generated_tokens = llm_both.generate(input_tokens, max_new_tokens=100)
tokenizer.decode(generated_tokens.squeeze().tolist())

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.


'John Travolta stars in Psychological Passenger alongside Millie Bobby Brown.\n\nSet in Annetteside, the film tells the story of Lauren Bell.\n\nReleased in 2023, Psychological Passenger earned $3 million worldwide, achieving significant box office success.\n\nReleased in 2023, Psychological Passenger earned $3 million worldwide, achieving significant box office success.\n\nReleased in 2023, Psychological Passenger earned $3 million worldwide, achieving significant box office success.\n\nReleased in 2023,'

In [14]:
first_entity = "John Travolta"
first_entity_tokens = tokenizer.encode(first_entity, add_special_tokens=False, return_tensors="pt")
second_entity = " Millie Bobby Brown"
second_entity_tokens = tokenizer.encode(second_entity, add_special_tokens=False, return_tensors="pt")
movie = " Psychological Passenger"
movie_tokens = tokenizer.encode(movie, add_special_tokens=False, return_tensors="pt")
preposition = " alongside"
preposition_tokens = tokenizer.encode(preposition, add_special_tokens=False, return_tensors="pt")

first_entity_tokens, second_entity_tokens, movie_tokens, preposition_tokens

(tensor([[ 8732, 25480, 41385]]),
 tensor([[13134,   466, 24707,  7233]]),
 tensor([[49205, 11271,  9562]]),
 tensor([[12936]]))

In [15]:
def find_sublist_index(full_list, sublist):
    full_list = full_list.view(-1)
    sublist = sublist.view(-1)
    full_list = full_list.to(DEVICE).tolist()
    sublist = sublist.to(DEVICE).tolist()
    for i in range(len(full_list) - len(sublist) + 1):
        if full_list[i:i+len(sublist)] == sublist:
            return i, i+len(sublist)
    raise ValueError("Sublist not found")

first_entity_start_idx, first_entity_end_idx = find_sublist_index(example_tokens, first_entity_tokens)
second_entity_start_idx, second_entity_end_idx = find_sublist_index(example_tokens, second_entity_tokens)
movie_start_idx, movie_end_idx = find_sublist_index(example_tokens, movie_tokens)
preposition_start_idx, preposition_end_idx = find_sublist_index(example_tokens, preposition_tokens)

first_entity_start_idx, first_entity_end_idx, second_entity_start_idx, second_entity_end_idx, movie_start_idx, movie_end_idx

(0, 3, 9, 13, 5, 8)

In [16]:
def parse_layers(patch_layers):
    expanded_layers = []
    for item in patch_layers:
        if isinstance(item, range):
            expanded_layers.extend(item)
        elif isinstance(item, int):
            expanded_layers.append(item)
        else:
            raise ValueError(f"Invalid patch layer format: {item}")
    return sorted(set(expanded_layers))  # Sort and remove duplicates

In [19]:
# @dataclass
# class Patch:
#     patch_token_idx: int
#     patch_layers: List[int] = None
#     patch_embeddings: bool = False
#     patch_lm_head: bool = False
#     patch_q: bool = False
#     patch_k: bool = False
#     patch_v: bool = False
#     patch_o: bool = False
#     patch_gate: bool = False
#     patch_mlp_up: bool = False
#     patch_mlp_down: bool = False

# Do something like this to allow for nested default values
@dataclass
class PatchTargets:
    embeddings: bool = False
    lm_head: bool = False
    q: bool = False
    k: bool = False
    v: bool = False
    o: bool = False
    gate: bool = False
    mlp_up: bool = False
    mlp_down: bool = False

@dataclass
class Patch:
    patch_token_idx: int
    patch_layers: List[int] = field(default_factory=list)
    targets: PatchTargets = field(default_factory=PatchTargets)

# Note: o is "dense" for Pythia and the qkv are concatenated

In [62]:
proper_noun_patches = list(range(0,6)) + list(range(7,18))
first_quarter_layers = list(range(0, 5))
second_quarter_layers = list(range(5, 10))
third_quarter_layers = list(range(10, 15))
fourth_quarter_layers = list(range(15, 18))
all_layers = list(range(0, 18))

all_layers = list(range(0,31))

first_entity_patch_targets = PatchTargets(mlp_up=True, mlp_down=True, o=True, q=True)

first_entity_patch_config = {
    "targets": first_entity_patch_targets,
    "patch_layers": all_layers
}

movie_patch_targets = PatchTargets(mlp_up=True, mlp_down=True, o=True, q=True)

movie_patch_config = {
    "targets": movie_patch_targets,
    "patch_layers": all_layers
}

preposition_patch_targets = PatchTargets(mlp_up=True, mlp_down=True, o=True, q=True)

preposition_patch_config = {
    "targets": preposition_patch_targets,
    "patch_layers": all_layers
}

# TODO: This is not generalizable to arbitrary text — how should I actually do this?
patches = []
for token_idx in range(len(input_tokens[0])):
    if first_entity_start_idx <= token_idx < first_entity_end_idx:
        print(f"Patching first entity for token {token_idx}")
        patches.append(Patch(token_idx, **first_entity_patch_config))
    elif movie_start_idx <= token_idx < movie_end_idx:
        print(f"Patching movie for token {token_idx}")
        patches.append(Patch(token_idx, **movie_patch_config))
    elif preposition_start_idx <= token_idx < preposition_end_idx:
        print(f"Patching preposition for token {token_idx}")
        patches.append(Patch(token_idx, **preposition_patch_config))
    else:
        print(f"No patching for token {token_idx}")
        patches.append(Patch(token_idx))
    

# for token_idx in range(len(input_tokens[0])):
#     patches.append(Patch(token_idx, **first_entity_patch_config))
    

Patching first entity for token 0
Patching first entity for token 1
Patching first entity for token 2
No patching for token 3
No patching for token 4
Patching movie for token 5
Patching movie for token 6
Patching movie for token 7
Patching preposition for token 8


In [63]:
def get_attr(obj, attr_path):
    for attr in attr_path.split("."):
        obj = getattr(obj, attr)
    return obj

def patch_component(llm_receipient, llm_donor, base_path, layer_idx, attr_name):
    receipient_layer = get_attr(llm_receipient, f"{base_path}.{layer_idx}")
    donor_layer = get_attr(llm_donor, f"{base_path}.{layer_idx}")
    receipient_component = get_attr(receipient_layer, attr_name)
    donor_component = get_attr(donor_layer, attr_name)
    receipient_component.load_state_dict(donor_component.state_dict())

In [64]:
model_configs = {
    "gemma": {
        "layer_path": "model.layers",
        "mapping": {
            "mlp_up": "mlp.up_proj",
            "mlp_down": "mlp.down_proj",
            "gate": "mlp.gate_proj",
            "q": "self_attn.q_proj",
            "k": "self_attn.k_proj",
            "v": "self_attn.v_proj",
            "o": "self_attn.o_proj",
        },
    },
    "pythia": {
        "layer_path": "gpt_neox.layers",
        "mapping": {
            "mlp_up": "mlp.dense_h_to_4h",
            "mlp_down": "mlp.dense_4h_to_h",
            "q": "attention.query_key_value",  # fused, so must handle specially
            "o": "attention.dense",
        },
    },
}

In [65]:
model_name = "pythia"
patch_dropout = 0.1
config = model_configs[model_name]

In [66]:
past_key_values = None
for i, patch in enumerate(patches):
    # TODO: This is bad practice but...
    patch_dict = asdict(patch)
    globals().update(patch_dict)
    globals().update(targets)

    print(targets)

    # Reset the patched model - hacky switch statement

    llm_receipient = copy.deepcopy(llm_both)
    llm_donor = copy.deepcopy(pretrained_model)

    # llm_receipient = copy.deepcopy(llm_both)
    # llm_donor = copy.deepcopy(pretrained_model)

    print(f"######## PATCH {i+1} ##########")
    print(tokenizer.decode(input_tokens[:, patch_token_idx:patch_token_idx + 1].squeeze().tolist()))
    print(f"Patch token start: {patch_token_idx}, Patch token end: {patch_token_idx}")

    # # TODO: This won't work for Pythia
    # if patch_embeddings:
    #     print("Patching embeddings")
    #     llm_patched.model.get_input_embeddings().load_state_dict(llm_donor.model.get_input_embeddings().state_dict())
    
    # if patch_lm_head:
    #     print("Patching lm_head")
    #     llm_patched.lm_head.load_state_dict(llm_donor.lm_head.state_dict())

    # TODO: Decide which layers to patch here with random sampling
    # Should probably patch the same parts of the model for each token within a patching location
    # So...
    # Drop patching location randomly?

    # We have a list of patch layers
    # Then we need to figure out which parts of the model are being patched

    # TODO: Set this up so we can see which parts of the model are being patched with dropout
    if patch_layers is not None:
        patch_locations = []
        print(patch_layers)
        for layer_idx in patch_layers:
            for logical_name, physical_name in config["mapping"].items():
                if targets[logical_name]:  # e.g. patch_flags = {"mlp.up_proj": True, ...}
                    patch_component(llm_receipient, llm_donor, config["layer_path"], layer_idx, physical_name)

        # if patch_q:
        #     patch_locations.append("q")
        # if patch_q:
        #     patch_locations.append("q")
        # if patch_k:
        #     patch_locations.append("k")
        # if patch_v:
        #     patch_locations.append("v")
        # if patch_o:
        #     patch_locations.append("o")
        # if patch_gate:
        #     patch_locations.append("gate")
        # if patch_mlp_up:
        #     patch_locations.append("mlp_up")
        # if patch_mlp_down:
        #     patch_locations.append("mlp_down")
        # patch_locations_str = ", ".join(patch_locations)
        # print(f"Patching: {patch_locations_str}")

        # patch_layers = parse_layers(patch_layers)
        # print(f"Patch layers: {patch_layers}")
        # for patch_layer in patch_layers:
        #     if patch_q:
        #         llm_patched.model.layers[patch_layer].self_attn.q_proj.load_state_dict(llm_donor.model.layers[patch_layer].self_attn.q_proj.state_dict())
        #     if patch_k:
        #         llm_patched.model.layers[patch_layer].self_attn.k_proj.load_state_dict(llm_donor.model.layers[patch_layer].self_attn.k_proj.state_dict())
        #     if patch_v:
        #         llm_patched.model.layers[patch_layer].self_attn.v_proj.load_state_dict(llm_donor.model.layers[patch_layer].self_attn.v_proj.state_dict())
        #     if patch_o:
        #         llm_patched.model.layers[patch_layer].self_attn.o_proj.load_state_dict(llm_donor.model.layers[patch_layer].self_attn.o_proj.state_dict())
        #     if patch_gate:
        #         llm_patched.model.layers[patch_layer].mlp.gate_proj.load_state_dict(llm_donor.model.layers[patch_layer].mlp.gate_proj.state_dict())
        #     if patch_mlp_up:
        #         # Gemma patching:
        #         # llm_patched.model.layers[patch_layer].mlp.up_proj.load_state_dict(llm_donor.model.layers[patch_layer].mlp.up_proj.state_dict())
        #         llm_patched.gpt_neox.layers[patch_layer].mlp.dense_h_to_4h.load_state_dict(llm_donor.gpt_neox.layers[patch_layer].mlp.dense_h_to_4h.state_dict())
        #     if patch_mlp_down:
        #         # Gemma patching:
        #         # llm_patched.model.layers[patch_layer].mlp.down_proj.load_state_dict(llm_donor.model.layers[patch_layer].mlp.down_proj.state_dict())
        #         llm_patched.gpt_neox.layers[patch_layer].mlp.dense_4h_to_h.load_state_dict(llm_donor.gpt_neox.layers[patch_layer].mlp.dense_4h_to_h.state_dict())
    else:
        print("No patching")

    # Get the patched output
    with torch.no_grad():
        patched_output = llm_receipient(input_tokens[:, patch_token_idx:patch_token_idx + 1], use_cache=True, past_key_values=past_key_values)
        past_key_values = patched_output.past_key_values

# Decode just the final patched_output
generated_text = tokenizer.decode(patched_output.logits[:, -1].argmax(dim=-1)[0])

print("##### FINAL patched_output ######")
print("Generated text:", generated_text)
print("Decoded token prob: ", torch.softmax(patched_output.logits[0, -1], dim=-1).max().item())
print("Patched target token logit: ", patched_output.logits[0, -1, target_token_idx].item())
print("Patched target token prob: ", torch.softmax(patched_output.logits[0, -1], dim=-1)[target_token_idx].item())

{'embeddings': False, 'lm_head': False, 'q': True, 'k': False, 'v': False, 'o': True, 'gate': False, 'mlp_up': True, 'mlp_down': True}
######## PATCH 1 ##########
John
Patch token start: 0, Patch token end: 0
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30]
{'embeddings': False, 'lm_head': False, 'q': True, 'k': False, 'v': False, 'o': True, 'gate': False, 'mlp_up': True, 'mlp_down': True}
######## PATCH 2 ##########
 Trav
Patch token start: 1, Patch token end: 1
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30]
{'embeddings': False, 'lm_head': False, 'q': True, 'k': False, 'v': False, 'o': True, 'gate': False, 'mlp_up': True, 'mlp_down': True}
######## PATCH 3 ##########
olta
Patch token start: 2, Patch token end: 2
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30]
{'embeddings': False, 'lm_h

### More Manual Patching Below

In [None]:
past_key_values = None
for i, patch in enumerate(patches):
    patch_dict = asdict(patch)
    globals().update(patch_dict)

    # Reset the patched model - hacky switch statement

    # llm_patched = copy.deepcopy(llm_both)
    # llm_donor = copy.deepcopy(pretrained_model)

    llm_patched = copy.deepcopy(llm_both)
    llm_donor = copy.deepcopy(pretrained_model)

    print(f"######## PATCH {i+1} ##########")
    print(tokenizer.decode(input_tokens[:, patch_token_idx:patch_token_idx + 1].squeeze().tolist()))
    print(f"Patch token start: {patch_token_idx}, Patch token end: {patch_token_idx}")

    # TODO: This won't work for Pythia
    if patch_embeddings:
        print("Patching embeddings")
        llm_patched.model.get_input_embeddings().load_state_dict(llm_donor.model.get_input_embeddings().state_dict())
    
    if patch_lm_head:
        print("Patching lm_head")
        llm_patched.lm_head.load_state_dict(llm_donor.lm_head.state_dict())

    # TODO: Decide which layers to patch here with random sampling
    # Should probably patch the same parts of the model for each token within a patching location
    # So...
    # Drop patching location randomly?

    # We have a list of patch layers
    # Then we need to figure out which parts of the model are being patched

    # TODO: Set this up so we can see which parts of the model are being patched with dropout
    if patch_layers is not None:
        patch_locations = []
        if patch_q:
            patch_locations.append("q")
        if patch_k:
            patch_locations.append("k")
        if patch_v:
            patch_locations.append("v")
        if patch_o:
            patch_locations.append("o")
        if patch_gate:
            patch_locations.append("gate")
        if patch_mlp_up:
            patch_locations.append("mlp_up")
        if patch_mlp_down:
            patch_locations.append("mlp_down")
        patch_locations_str = ", ".join(patch_locations)
        print(f"Patching: {patch_locations_str}")

        patch_layers = parse_layers(patch_layers)
        print(f"Patch layers: {patch_layers}")
        for patch_layer in patch_layers:
            if patch_q:
                llm_patched.model.layers[patch_layer].self_attn.q_proj.load_state_dict(llm_donor.model.layers[patch_layer].self_attn.q_proj.state_dict())
            if patch_k:
                llm_patched.model.layers[patch_layer].self_attn.k_proj.load_state_dict(llm_donor.model.layers[patch_layer].self_attn.k_proj.state_dict())
            if patch_v:
                llm_patched.model.layers[patch_layer].self_attn.v_proj.load_state_dict(llm_donor.model.layers[patch_layer].self_attn.v_proj.state_dict())
            if patch_o:
                llm_patched.model.layers[patch_layer].self_attn.o_proj.load_state_dict(llm_donor.model.layers[patch_layer].self_attn.o_proj.state_dict())
            if patch_gate:
                llm_patched.model.layers[patch_layer].mlp.gate_proj.load_state_dict(llm_donor.model.layers[patch_layer].mlp.gate_proj.state_dict())
            if patch_mlp_up:
                # Gemma patching:
                # llm_patched.model.layers[patch_layer].mlp.up_proj.load_state_dict(llm_donor.model.layers[patch_layer].mlp.up_proj.state_dict())
                llm_patched.gpt_neox.layers[patch_layer].mlp.dense_h_to_4h.load_state_dict(llm_donor.gpt_neox.layers[patch_layer].mlp.dense_h_to_4h.state_dict())
            if patch_mlp_down:
                # Gemma patching:
                # llm_patched.model.layers[patch_layer].mlp.down_proj.load_state_dict(llm_donor.model.layers[patch_layer].mlp.down_proj.state_dict())
                llm_patched.gpt_neox.layers[patch_layer].mlp.dense_4h_to_h.load_state_dict(llm_donor.gpt_neox.layers[patch_layer].mlp.dense_4h_to_h.state_dict())
    else:
        print("No patching")

    # Get the patched output
    with torch.no_grad():
        patched_output = llm_patched(input_tokens[:, patch_token_idx:patch_token_idx + 1], use_cache=True, past_key_values=past_key_values)
        past_key_values = patched_output.past_key_values

# Decode just the final patched_output
generated_text = tokenizer.decode(patched_output.logits[:, -1].argmax(dim=-1)[0])

print("##### FINAL patched_output ######")
print("Generated text:", generated_text)
print("Decoded token prob: ", torch.softmax(patched_output.logits[0, -1], dim=-1).max().item())
print("Patched target token logit: ", patched_output.logits[0, -1, target_token_idx].item())
print("Patched target token prob: ", torch.softmax(patched_output.logits[0, -1], dim=-1)[target_token_idx].item())