# Imports

In [1]:
from ast import literal_eval
import functools
import json
import os
import random
import shutil

# Scienfitic packages
import numpy as np
import pandas as pd
import torch
import datasets
torch.set_grad_enabled(False)

# Visuals
from matplotlib import pyplot as plt
import seaborn as sns
sns.set(context="notebook",
        rc={"font.size":16,
            "axes.titlesize":16,
            "axes.labelsize":16,
            "xtick.labelsize": 16.0,
            "ytick.labelsize": 16.0,
            "legend.fontsize": 16.0})
palette_ = sns.color_palette("Set1")
palette = palette_[2:5] + palette_[7:]
sns.set_theme(style='whitegrid')

# Utilities

from general_utils import (
  ModelAndTokenizer,
  make_inputs,
  decode_tokens,
  find_token_range,
  predict_from_input,
)

from patchscopes_utils import *

from tqdm import tqdm
tqdm.pandas()

In [2]:
model_to_hook = {
    "EleutherAI/pythia-12b": set_hs_patch_hooks_neox,
    "meta-llama/Llama-2-13b-hf": set_hs_patch_hooks_llama,
    "lmsys/vicuna-7b-v1.5": set_hs_patch_hooks_llama,
    "./stable-vicuna-13b": set_hs_patch_hooks_llama,
    "CarperAI/stable-vicuna-13b-delta": set_hs_patch_hooks_llama,
    "EleutherAI/gpt-j-6b": set_hs_patch_hooks_gptj,
    "t5-large": set_hs_patch_hooks_t5
}


In [3]:
import os

# Set Hugging Face cache directory
os.environ["HF_HOME"] = "/home/students/kolber/seminars/kolber/.cache"

In [4]:
# Load model

model_name = "t5-large"
sos_tok = False

if "13b" in model_name or "12b" in model_name:
    torch_dtype = torch.float16
else:
    torch_dtype = None

mt = ModelAndTokenizer(
    model_name,
    low_cpu_mem_usage=False,
    torch_dtype=torch_dtype,
    device="cpu"
)
mt.set_hs_patch_hooks = model_to_hook[model_name]
mt.model.eval()

For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.
- Be aware that you SHOULD NOT rely on t5-large automatically truncating your input to 512 when padding/encoding.
- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.


T5ForConditionalGeneration(
  (shared): Embedding(32128, 1024)
  (encoder): T5Stack(
    (embed_tokens): Embedding(32128, 1024)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=1024, out_features=1024, bias=False)
              (k): Linear(in_features=1024, out_features=1024, bias=False)
              (v): Linear(in_features=1024, out_features=1024, bias=False)
              (o): Linear(in_features=1024, out_features=1024, bias=False)
              (relative_attention_bias): Embedding(32, 16)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseActDense(
              (wi): Linear(in_features=1024, out_features=4096, bias=False)
              (wo): Linear(in_features=4096, out_features=1024, bias=False)
              (d

# Next token prediction

In [5]:
for module in mt.layer_names:
    print(module)


encoder.block.0
encoder.block.1
encoder.block.2
encoder.block.3
encoder.block.4
encoder.block.5
encoder.block.6
encoder.block.7
encoder.block.8
encoder.block.9
encoder.block.10
encoder.block.11
encoder.block.12
encoder.block.13
encoder.block.14
encoder.block.15
encoder.block.16
encoder.block.17
encoder.block.18
encoder.block.19
encoder.block.20
encoder.block.21
encoder.block.22
encoder.block.23
decoder.block.0
decoder.block.1
decoder.block.2
decoder.block.3
decoder.block.4
decoder.block.5
decoder.block.6
decoder.block.7
decoder.block.8
decoder.block.9
decoder.block.10
decoder.block.11
decoder.block.12
decoder.block.13
decoder.block.14
decoder.block.15
decoder.block.16
decoder.block.17
decoder.block.18
decoder.block.19
decoder.block.20
decoder.block.21
decoder.block.22
decoder.block.23


In [6]:

# Evaluate the ID prompt on the validation set of WikiText (with/without mappings)
device = mt.model.device

prompt_target = "repeat: cat -> cat\n1135 -> 1135\nhello -> hello\n?"
position_target = -1

records = []
for source_layer in tqdm(range(int(mt.num_layers))): 
    prompt_source = "United States of America."
    position_source = 3
    target_layer = source_layer % 24
    print(f"Layer {source_layer} -> {target_layer}")
    predicted_token = evaluate_patch_t5(
        mt, prompt_source, prompt_target, source_layer, target_layer,
        position_source, position_target, position_prediction=position_target
        )

    records.append({'source_layer': source_layer, 'target_layer': target_layer, 'token': mt.tokenizer.decode(predicted_token)})

results = pd.DataFrame.from_records(records)
print(results)


  0%|          | 0/48 [00:00<?, ?it/s]

Layer 0 -> 0


  2%|▏         | 1/48 [00:00<00:16,  2.83it/s]

Layer 1 -> 1


  4%|▍         | 2/48 [00:00<00:16,  2.85it/s]

Layer 2 -> 2


  6%|▋         | 3/48 [00:01<00:16,  2.80it/s]

Layer 3 -> 3


  8%|▊         | 4/48 [00:01<00:16,  2.71it/s]

Layer 4 -> 4


 10%|█         | 5/48 [00:02<00:18,  2.29it/s]

Layer 5 -> 5


 12%|█▎        | 6/48 [00:02<00:18,  2.26it/s]

Layer 6 -> 6


 15%|█▍        | 7/48 [00:02<00:16,  2.44it/s]

Layer 7 -> 7


 17%|█▋        | 8/48 [00:03<00:15,  2.59it/s]

Layer 8 -> 8


 19%|█▉        | 9/48 [00:03<00:14,  2.71it/s]

Layer 9 -> 9


 21%|██        | 10/48 [00:03<00:13,  2.80it/s]

Layer 10 -> 10


 23%|██▎       | 11/48 [00:04<00:12,  2.87it/s]

Layer 11 -> 11


 25%|██▌       | 12/48 [00:04<00:12,  2.91it/s]

Layer 12 -> 12


 27%|██▋       | 13/48 [00:04<00:11,  2.95it/s]

Layer 13 -> 13


 29%|██▉       | 14/48 [00:05<00:11,  2.97it/s]

Layer 14 -> 14


 31%|███▏      | 15/48 [00:05<00:11,  2.96it/s]

Layer 15 -> 15


 33%|███▎      | 16/48 [00:05<00:10,  2.93it/s]

Layer 16 -> 16


 35%|███▌      | 17/48 [00:06<00:10,  2.91it/s]

Layer 17 -> 17


 38%|███▊      | 18/48 [00:06<00:10,  2.87it/s]

Layer 18 -> 18


 40%|███▉      | 19/48 [00:06<00:10,  2.85it/s]

Layer 19 -> 19


 42%|████▏     | 20/48 [00:07<00:09,  2.87it/s]

Layer 20 -> 20


 44%|████▍     | 21/48 [00:07<00:09,  2.89it/s]

Layer 21 -> 21


 46%|████▌     | 22/48 [00:07<00:08,  2.92it/s]

Layer 22 -> 22


 48%|████▊     | 23/48 [00:08<00:08,  2.92it/s]

Layer 23 -> 23


 50%|█████     | 24/48 [00:08<00:08,  2.95it/s]

Layer 24 -> 0


 52%|█████▏    | 25/48 [00:08<00:07,  2.98it/s]

Layer 25 -> 1


 54%|█████▍    | 26/48 [00:09<00:07,  3.00it/s]

Layer 26 -> 2


 56%|█████▋    | 27/48 [00:09<00:06,  3.01it/s]

Layer 27 -> 3


 58%|█████▊    | 28/48 [00:09<00:06,  3.02it/s]

Layer 28 -> 4


 60%|██████    | 29/48 [00:10<00:06,  3.03it/s]

Layer 29 -> 5


 62%|██████▎   | 30/48 [00:10<00:05,  3.03it/s]

Layer 30 -> 6


 65%|██████▍   | 31/48 [00:10<00:05,  3.03it/s]

Layer 31 -> 7


 67%|██████▋   | 32/48 [00:11<00:05,  3.04it/s]

Layer 32 -> 8


 69%|██████▉   | 33/48 [00:11<00:04,  3.04it/s]

Layer 33 -> 9


 71%|███████   | 34/48 [00:11<00:04,  3.04it/s]

Layer 34 -> 10


 73%|███████▎  | 35/48 [00:12<00:04,  3.04it/s]

Layer 35 -> 11


 75%|███████▌  | 36/48 [00:12<00:03,  3.04it/s]

Layer 36 -> 12


 77%|███████▋  | 37/48 [00:12<00:03,  3.04it/s]

Layer 37 -> 13


 79%|███████▉  | 38/48 [00:13<00:03,  2.96it/s]

Layer 38 -> 14


 81%|████████▏ | 39/48 [00:13<00:03,  2.79it/s]

Layer 39 -> 15


 83%|████████▎ | 40/48 [00:14<00:02,  2.68it/s]

Layer 40 -> 16


 85%|████████▌ | 41/48 [00:14<00:02,  2.61it/s]

Layer 41 -> 17


 88%|████████▊ | 42/48 [00:14<00:02,  2.68it/s]

Layer 42 -> 18


 90%|████████▉ | 43/48 [00:15<00:01,  2.73it/s]

Layer 43 -> 19


 92%|█████████▏| 44/48 [00:15<00:01,  2.74it/s]

Layer 44 -> 20


 94%|█████████▍| 45/48 [00:15<00:01,  2.77it/s]

Layer 45 -> 21


 96%|█████████▌| 46/48 [00:16<00:00,  2.80it/s]

Layer 46 -> 22


 98%|█████████▊| 47/48 [00:16<00:00,  2.82it/s]

Layer 47 -> 23


100%|██████████| 48/48 [00:16<00:00,  2.84it/s]

    source_layer  target_layer    token
0              0             0         
1              1             1         
2              2             2         
3              3             3        -
4              4             4         
5              5             5         
6              6             6         
7              7             7         
8              8             8         
9              9             9         
10            10            10         
11            11            11         
12            12            12         
13            13            13        -
14            14            14         
15            15            15         
16            16            16         
17            17            17         
18            18            18         
19            19            19         
20            20            20         
21            21            21         
22            22            22         
23            23            23  America



