# Imports

In [1]:
from ast import literal_eval
import functools
import json
import os
import random
import shutil

# Scienfitic packages
import numpy as np
import pandas as pd
import torch
import datasets
torch.set_grad_enabled(False)

# Visuals
from matplotlib import pyplot as plt
import seaborn as sns
sns.set(context="notebook",
        rc={"font.size":16,
            "axes.titlesize":16,
            "axes.labelsize":16,
            "xtick.labelsize": 16.0,
            "ytick.labelsize": 16.0,
            "legend.fontsize": 16.0})
palette_ = sns.color_palette("Set1")
palette = palette_[2:5] + palette_[7:]
sns.set_theme(style='whitegrid')

# Utilities

from general_utils import (
  ModelAndTokenizer,
  make_inputs,
  decode_tokens,
  find_token_range,
  predict_from_input,
)

from patchscopes_utils import *

from tqdm import tqdm
tqdm.pandas()

In [2]:
model_to_hook = {
    "EleutherAI/pythia-12b": set_hs_patch_hooks_neox,
    "meta-llama/Llama-2-13b-hf": set_hs_patch_hooks_llama,
    "lmsys/vicuna-7b-v1.5": set_hs_patch_hooks_llama,
    "./stable-vicuna-13b": set_hs_patch_hooks_llama,
    "CarperAI/stable-vicuna-13b-delta": set_hs_patch_hooks_llama,
    "EleutherAI/gpt-j-6b": set_hs_patch_hooks_gptj,
    "t5-large": set_hs_patch_hooks_t5
}


In [3]:
import os

# Set Hugging Face cache directory
os.environ["HF_HOME"] = "/home/students/kolber/seminars/kolber/.cache"

In [4]:
# Load model

model_name = "t5-large"
sos_tok = False

if "13b" in model_name or "12b" in model_name:
    torch_dtype = torch.float16
else:
    torch_dtype = None

mt = ModelAndTokenizer(
    model_name,
    low_cpu_mem_usage=False,
    torch_dtype=torch_dtype,
    device="cpu"
)
mt.set_hs_patch_hooks = model_to_hook[model_name]
mt.model.eval()

For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.
- Be aware that you SHOULD NOT rely on t5-large automatically truncating your input to 512 when padding/encoding.
- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.


T5ForConditionalGeneration(
  (shared): Embedding(32128, 1024)
  (encoder): T5Stack(
    (embed_tokens): Embedding(32128, 1024)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=1024, out_features=1024, bias=False)
              (k): Linear(in_features=1024, out_features=1024, bias=False)
              (v): Linear(in_features=1024, out_features=1024, bias=False)
              (o): Linear(in_features=1024, out_features=1024, bias=False)
              (relative_attention_bias): Embedding(32, 16)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseActDense(
              (wi): Linear(in_features=1024, out_features=4096, bias=False)
              (wo): Linear(in_features=4096, out_features=1024, bias=False)
              (d

# Next token prediction

In [5]:
# Evaluate the ID prompt on the validation set of WikiText (with/without mappings)

device = mt.model.device

prompt_target = "cat -> cat\n1135 -> 1135\nhello -> hello\n?"
position_target = -1

records = []
for layer in tqdm(range(mt.num_layers)): 
    prompt_source = "Who is Obama?"
    position_source = 2
    prec_1, surprisal = evaluate_patch_next_token_prediction(
        mt, prompt_source, prompt_target, layer, layer,
        position_source, position_target, position_prediction=position_target
        )

    records.append({'layer': layer, 'prec_1': prec_1, 'surprisal': surprisal})

results = pd.DataFrame.from_records(records)


  2%|▏         | 1/48 [00:00<00:16,  2.87it/s]

{0: [(16, tensor([-31.5204, -24.7124,  -8.5381,  ...,   3.9018,   9.3325,   8.8674]))]}


  4%|▍         | 2/48 [00:00<00:15,  2.98it/s]

{1: [(16, tensor([-34.5127, -27.2876, -14.7393,  ...,   1.2115,   7.2054,  14.3314]))]}


  6%|▋         | 3/48 [00:01<00:14,  3.02it/s]

{2: [(16, tensor([-46.8915, -38.2838, -22.7533,  ...,   7.5503,  -0.1532,  11.1236]))]}
{3: [(16, tensor([-50.6588, -32.5149, -25.5325,  ...,  16.9417,  -4.7222,   4.0549]))]}


  8%|▊         | 4/48 [00:01<00:15,  2.82it/s]

{4: [(16, tensor([-53.2100, -23.9260, -23.2108,  ...,  25.0026,  -8.9490,  -6.5479]))]}


 12%|█▎        | 6/48 [00:02<00:16,  2.59it/s]

{5: [(16, tensor([-65.9908, -32.8220, -19.4189,  ...,  25.7656, -18.4125,   2.1450]))]}


 15%|█▍        | 7/48 [00:02<00:15,  2.69it/s]

{6: [(16, tensor([-78.4446, -39.7274, -23.4656,  ...,  14.3850, -16.6949,   4.3686]))]}


 17%|█▋        | 8/48 [00:02<00:14,  2.68it/s]

{7: [(16, tensor([-84.2563, -38.9248, -24.7010,  ...,  15.4495, -24.3100,  11.0296]))]}


 19%|█▉        | 9/48 [00:03<00:14,  2.72it/s]

{8: [(16, tensor([-87.1161, -47.6444, -12.1023,  ...,   6.8382, -25.1064,   8.5141]))]}


 21%|██        | 10/48 [00:03<00:13,  2.75it/s]

{9: [(16, tensor([-89.1746, -54.4318, -14.1479,  ...,   2.2850, -24.7855,   1.8542]))]}


 23%|██▎       | 11/48 [00:04<00:13,  2.77it/s]

{10: [(16, tensor([-90.8992, -50.9091, -17.0864,  ...,   3.5605,  -6.1756,  -6.5574]))]}


 25%|██▌       | 12/48 [00:04<00:12,  2.79it/s]

{11: [(16, tensor([-96.5470, -51.3193, -22.5191,  ...,  18.8353,  -6.5922,   6.5929]))]}


 27%|██▋       | 13/48 [00:04<00:12,  2.83it/s]

{12: [(16, tensor([-131.2078,  -49.2928,  -24.3215,  ...,    7.8943,  -16.3852,
          -7.8061]))]}


 29%|██▉       | 14/48 [00:05<00:11,  2.87it/s]

{13: [(16, tensor([-137.6351,  -44.4710,  -39.4885,  ...,    7.8692,  -29.5968,
          -5.0435]))]}


 31%|███▏      | 15/48 [00:05<00:11,  2.91it/s]

{14: [(16, tensor([-145.2646,  -41.5266,  -40.0572,  ...,    7.7690,  -63.4637,
          20.1157]))]}


 33%|███▎      | 16/48 [00:05<00:10,  2.94it/s]

{15: [(16, tensor([-192.9487,  -27.7642,  -36.2089,  ...,    2.3843,  -59.5213,
          26.9575]))]}


 35%|███▌      | 17/48 [00:06<00:10,  2.94it/s]

{16: [(16, tensor([-238.9390,  -15.8331,  -34.5819,  ...,   21.6181, -110.2432,
          22.6856]))]}


 38%|███▊      | 18/48 [00:06<00:10,  2.90it/s]

{17: [(16, tensor([-254.4924,  -15.3350,  -35.6611,  ...,   13.9891,  -97.6739,
          39.7461]))]}


 40%|███▉      | 19/48 [00:06<00:09,  2.91it/s]

{18: [(16, tensor([-265.1894,  -39.4576,  -39.2520,  ...,   32.6542, -112.9305,
          26.1969]))]}


 42%|████▏     | 20/48 [00:07<00:09,  2.87it/s]

{19: [(16, tensor([-305.1680,  -20.9698,  -23.1064,  ...,   71.0054, -116.6115,
          26.1271]))]}


 44%|████▍     | 21/48 [00:07<00:09,  2.91it/s]

{20: [(16, tensor([-382.3799,  -24.9211,  -29.5748,  ...,   60.5485,  -84.3818,
          58.1346]))]}


 46%|████▌     | 22/48 [00:07<00:08,  2.94it/s]

{21: [(16, tensor([-389.5952,  -30.1340,  -18.8953,  ...,   56.0900,  -64.6374,
          78.8082]))]}


 48%|████▊     | 23/48 [00:08<00:08,  2.96it/s]

{22: [(16, tensor([-335.1400,  -73.0103,   -9.0969,  ...,   -7.2111,  -40.2585,
         166.1876]))]}


 50%|█████     | 24/48 [00:08<00:08,  2.98it/s]

{23: [(16, tensor([-0.2790, -0.1066, -0.0705,  ..., -0.0529, -0.1366,  0.2305]))]}


 50%|█████     | 24/48 [00:08<00:08,  2.80it/s]


IndexError: tuple index out of range