# Decoding of Next-Token Predictions

This note book is designed to reimplement Pathscope **section 1:Next-Token Prediction**

In [3]:
import torch
device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
print(torch.cuda.is_available())  # Should return True
print(torch.cuda.device_count())  # Should return the number of available GPUs
print(torch.cuda.get_device_name(0))  # Should return the name of the GPU


False
0


RuntimeError: Found no NVIDIA driver on your system. Please check that you have an NVIDIA GPU and installed a driver from http://www.nvidia.com/Download/index.aspx

In [None]:
! pip install datasets

In [None]:
# Load model directly
from transformers import AutoTokenizer, AutoModelForCausalLM,DataCollatorWithPadding
import torch
from datasets import load_dataset
from torch.utils.data import DataLoader
from torch.amp import autocast
import os
import math
# os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
# import dataloader


In [None]:
model_dict = ["Qwen/Qwen2-0.5B","Qwen/Qwen2.5-0.5B-Instruct","Qwen/Qwen2-1.5B"]

In [None]:
model_num = 0

tokenizer = AutoTokenizer.from_pretrained(model_dict[model_num])
model = AutoModelForCausalLM.from_pretrained(model_dict[model_num]).to(device)
model.eval()


In [None]:
# model(tokenizer('how are you doing today?'))

In [None]:
print(model.__class__.__module__)
print(model.config.model_type)

In [None]:
type(model)

In [None]:
type(tokenizer)

In [None]:
# from transformers import Qwen2ForCausalLM
# model2 = Qwen2ForCausalLM.from_pretrained('Qwen/Qwen2-0.5B')

In [None]:
# print(model2)

In [None]:
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

In [None]:
# define a function to get hidden states of each layer of the model

def get_hidden_state(model, input_ids):
    with torch.no_grad():
        outputs = model(input_ids.to(device), output_hidden_states=True)
    return outputs.hidden_states[1:]  #Size([batch_size, num_layers, sequence_length, hidden_size])




In [None]:
# load the Pile data "https://huggingface.co/datasets/EleutherAI/pile" from hugging face, unfortunately it is not available
# use the following uncoprighted pile dataset instead
# https://huggingface.co/datasets/monology/pile-uncopyrighted

dataPile= load_dataset("monology/pile-uncopyrighted",split='train',streaming=True)


In [None]:
dataPile

In [None]:
# shuffle the dataset before spliting them
# reference resource "https://huggingface.co/docs/datasets/v3.3.2/stream"

dataPile = dataPile.shuffle(seed=50)
training_data = dataPile.take(10000)
# train : val = 5 : 1
validation_data = dataPile.skip(10000).take(2000)

In [None]:
list(validation_data)[:4]

In [None]:
text = "1. Field of the Invention\nThe present invention relates to toothbrushes and, in particular, to a toothbrush having a hollow handle defining a paste-holding cavity wherein toothpaste is forcibly dispensed therefrom"

In [None]:
tokenizer([text,'hee asdf po afsdf '], return_tensors="pt",max_length=21, truncation=True, padding="max_length")

In [None]:

import random

def preprocess(example):

    text = example['text']
    inputs = tokenizer(text, return_tensors='pt', max_length=50, truncation=True) # dict_keys(['input_ids', 'attention_mask'])
    inputs_ids = inputs['input_ids'][0] # becasue the batch size is 1

    if len(inputs_ids) > 5:
        trim_len = random.randint(5,  len(inputs_ids) -1)   # to introduce randomness in the training data
        inputs_ids = inputs_ids[:trim_len]

    return {'input_ids': inputs_ids, 'attention_mask': inputs['attention_mask'][0][:len(inputs_ids)]}

In [None]:
training_data = training_data.map(preprocess,remove_columns=['text','meta'])
validation_data = validation_data.map(preprocess,remove_columns=['text','meta'])

In [None]:
next(iter(training_data))

In [None]:
# design a collate function

In [None]:
# we can change the attentino mask to focus on the token of interest

In [None]:

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

In [None]:
# PyTorch
torch.cuda.memory_summary()

# Or more detailed
torch.cuda.memory_allocated()
torch.cuda.memory_reserved()


## 1. Logit lens

In [26]:
# each_layer_logit_lens = {l : [] for l in range(len(24))}
from collections import defaultdict
from tqdm import tqdm

each_layer_logit_lens_Precision = defaultdict(list) # key = int:layer , value = list[]
each_layer_logit_lens_Surprisal = defaultdict(list)

def logit_len_for_test(hidden_states,refernce_logits):
    pass



def logit_len(hidden_states,mask,refernce_logits):
    """

    hidden_states: Tuple[Size([batch_size, sequence_length, hidden_size])]: the hidden states of each layer of the model

    mask: Size([batch_size, sequence_length])

    refernce_logits: Size([batch_size, sequence_length, vocab_size]): this will be converted to Size([batch_size, vocab_size])

    """
    W_U = model.lm_head.weight

    # using attention mask to get the last token of the sequence

    num_layer  = len(hidden_states)
    batch_size, sequence_length, hidden_size = hidden_states[0].shape

    last_token_pos = torch.sum(mask,dim=1) # Size([batch_size])

    reference_logits = refernce_logits[torch.arange(len(mask)),last_token_pos -1 ,:] # Size([batch_size, vocab_size])
    reference_logits_max_token = torch.argmax(reference_logits,dim = -1)

    def suprisla(reference_logist, logist):
        predicted_tokens = torch.argmax(logits, dim=-1)  # Shape: (batch_size,)
        reference_probs = torch.nn.functional.softmax(reference_logits, dim=-1)
        predicted_token_probs = reference_probs[torch.arange(batch_size), predicted_tokens]
        surprisal = -torch.log(predicted_token_probs + 1e-9)

        return torch.sum(surprisal)



    for idx, each_layer in enumerate(hidden_states):
        layer_hidden = each_layer[torch.arange(batch_size),last_token_pos-1] # Size([batch_size, hidden_size])
        # layer_hidden = layer_hidden[torch.arange(batch_size),:] # Size([batch_size, hidden_size])
        logits = torch.einsum('bd,vd->bv', layer_hidden, W_U)  # Size([batch_size, vocab_size])
        logits_max_token = torch.argmax(logits,dim = -1)

        identical_samples = torch.sum(logits_max_token == reference_logits_max_token)
        each_layer_logit_lens_Precision[idx].append(identical_samples)
        each_layer_logit_lens_Surprisal[idx].append(suprisla(reference_logits,logits))
        # print(identical_samples)


    return

batch_size = 32

i = 0
for batch in tqdm(DataLoader(validation_data, batch_size=batch_size, collate_fn=data_collator),total = math.ceil(2000/batch_size),desc = f"logit_lens_exp"):
    # print(batch)k
    batch.to(device)
    with torch.no_grad():
        model_outputs = model(**batch,output_hidden_states=True)
        hidden_states = model_outputs.hidden_states[1:]  # exclude the embedding layer
        reference_logits = model_outputs.logits
        mask = batch['attention_mask']

    # reference_logits = reference_logits[torch.arange(len(mask)),torch.sum(mask,dim=1)] # Size([batch_size, vocab_size])
    Compare_logit = logit_len(hidden_states,mask,reference_logits)
    # i += 1
    # print('one batch done')

    # if i >= 2:
    #     break

logit_lens_exp: 100%|██████████| 63/63 [00:52<00:00,  1.20it/s]


In [40]:
# Clear GPU cache
# torch.cuda.empty_cache()

In [42]:
# print(# PyTorch
# torch.cuda.memory_summary())

# # Or more detailed
# print(torch.cuda.memory_allocated())
# print(torch.cuda.memory_reserved())

#### Metric
1. Precision@1
2. Suprisal

In [44]:
def precision(a):
    """
    a : dict[layer] =  list[]
    """

    for k, v in a.items():
        a[k] = torch.stack(v,dim = 0).sum() / 2000

    return a


def surprisal(a):

    for k,v in a.items():
        a[k] = torch.stack(v,dim = 0).sum()/2000

    return a

logit_len_p = precision(each_layer_logit_lens_Precision)
logit_len_s = surprisal(each_layer_logit_lens_Surprisal)

In [51]:
torch.save(logit_len_p,f"/content/drive/MyDrive/Patchscope-Reimplementation-main/Decoding_of_next_token_Prediction/logit_len_P_{model_dict[model_num].split('/')[-1]}.pt")
torch.save(logit_len_s,f"/content/drive/MyDrive/Patchscope-Reimplementation-main/Decoding_of_next_token_Prediction/logit_len_S_{model_dict[model_num].split('/')[-1]}.pt")

In [48]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [53]:
logit_len_p

defaultdict(list,
            {0: tensor(0.0040, device='cuda:0'),
             1: tensor(0.0015, device='cuda:0'),
             2: tensor(0.0005, device='cuda:0'),
             3: tensor(0.0010, device='cuda:0'),
             4: tensor(0.0035, device='cuda:0'),
             5: tensor(0.0015, device='cuda:0'),
             6: tensor(0.0020, device='cuda:0'),
             7: tensor(0.0025, device='cuda:0'),
             8: tensor(0.0045, device='cuda:0'),
             9: tensor(0.0045, device='cuda:0'),
             10: tensor(0.0040, device='cuda:0'),
             11: tensor(0.0030, device='cuda:0'),
             12: tensor(0.0090, device='cuda:0'),
             13: tensor(0.0120, device='cuda:0'),
             14: tensor(0.0130, device='cuda:0'),
             15: tensor(0.0170, device='cuda:0'),
             16: tensor(0.0620, device='cuda:0'),
             17: tensor(0.0650, device='cuda:0'),
             18: tensor(0.0630, device='cuda:0'),
             19: tensor(0.1055, device='cu

In [54]:
print(model_outputs['logits'].shape)
print(model_outputs['hidden_states'][4].shape)

torch.Size([16, 48, 151936])
torch.Size([16, 48, 896])


In [55]:
len(model_outputs['hidden_states'])

25

## 2. Tuned Lens

In [56]:
from sklearn.linear_model import LinearRegression

import numpy as np


config = model.config
num_layers = config.num_hidden_layers

X_train = {l : [] for l in range(num_layers)}
Y_train = []  # finaly layer logits



batch_size = 32

def get_last_token_hidden_states(hidden_states,mask) -> None:
    """
    hiddenstates: Tuple[Size([batch_size, sequence_length, hidden_size])]: the hidden states of each layer of the model
    """

    last_token_pos = torch.sum(mask,dim=1) # Size([batch_size])

    for i, layer in enumerate(hidden_states):
        X_train[i].append(layer[torch.arange(len(mask)),last_token_pos -1 ,:].cpu().numpy())   # Size([batch_size, hidden_size])

    Y_train.append(hidden_states[-1][torch.arange(len(mask)),last_token_pos -1 ,:].cpu().numpy())   # Size([batch_size, hidden_size]))

# to get the hidden states of each sequence
for batch in tqdm(DataLoader(training_data, batch_size=batch_size, collate_fn=data_collator),total = math.ceil(10000 / batch_size) ,desc = "affine map training data"):
    batch.to(device)
    with torch.no_grad():
        model_outputs = model(**batch,output_hidden_states=True)
    # batch.to('cpu')
    hidden_states = model_outputs.hidden_states[1:] # exclude the embedding layer
    mask = batch['attention_mask'].cpu()
    reference_logits = model_outputs.logits

    get_last_token_hidden_states(hidden_states,mask)

#     del batch, model_outputs
#     torch.cuda.empty_cache()  # Clear GPU cache



affine map training data: 100%|██████████| 313/313 [03:39<00:00,  1.43it/s]


In [57]:
X_train[0][-1].shape

(16, 896)

In [58]:
len(Y_train)

313

In [59]:
# for each layer,train a affing transformation to predict the final layer logits

Y_train = np.concatenate(Y_train,axis=0) # Size([num_samples, hidden_size])

affine_map_house = []

for i in range(num_layers):
    affine_map = LinearRegression(fit_intercept=True)
    X_train[i] = np.concatenate(X_train[i],axis=0) # Size([num_samples, hidden_size])
    affine_map.fit(X_train[i],Y_train)
    affine_map_house.append(affine_map)
    # break



In [60]:

# do the same thing as logit_len but with the affine transformation bewteen the hidden states of each layer and the final layer logits
each_layer_tuned_lens_Precision = defaultdict(list)
each_layer_tuned_lens_Surprisal = defaultdict(list)

def tuned_len(hidden_states,mask,refernce_logits,model = model):
    """

    hidden_states: Tuple[Size([batch_size, sequence_length, hidden_size])]: the hidden states of each layer of the model

    mask: Size([batch_size, sequence_length])

    refernce_logits: Size([batch_size, sequence_length, vocab_size]): this will be converted to Size([batch_size, vocab_size])

    """
    W_U = model.lm_head.weight

    # using attention mask to get the last token of the sequence

    num_layer  = len(hidden_states)
    batch_size, sequence_length, hidden_size = hidden_states[0].shape

    last_token_pos = torch.sum(mask,dim=1) # Size([batch_size])

    reference_logits = refernce_logits[torch.arange(len(mask)),last_token_pos -1 ,:] # Size([batch_size, vocab_size])
    reference_logits_max_token = torch.argmax(reference_logits,dim = -1)

    def suprisla(reference_logist, logist):
        predicted_tokens = torch.argmax(logits, dim=-1)  # Shape: (batch_size,)
        reference_probs = torch.nn.functional.softmax(reference_logits, dim=-1)
        predicted_token_probs = reference_probs[torch.arange(batch_size), predicted_tokens]
        surprisal = -torch.log(predicted_token_probs + 1e-9)

        return torch.sum(surprisal)


    for i , each_layer in enumerate(hidden_states):
        layer_hidden = each_layer[torch.arange(batch_size),last_token_pos-1] # Size([batch_size, hidden_size]

        #***************************************************************
        # the only changes compared to logit_len
        affine_map = affine_map_house[i]   #这里竟然没有报错，当我只有第一个layer的affine map的时候
        layer_hidden_mapped = affine_map.predict(layer_hidden.cpu().numpy())
        layer_hidden_mapped = torch.tensor(layer_hidden_mapped).to(device)



        #***************************************************************
        logits = torch.einsum('bd,vd->bv', layer_hidden_mapped, W_U)  # Size([batch_size, vocab_size])
        logits_max_token = torch.argmax(logits,dim = -1)

        identical_samples = torch.sum(logits_max_token == reference_logits_max_token)
        each_layer_tuned_lens_Precision[i].append(identical_samples)
        each_layer_tuned_lens_Surprisal[i].append(suprisla(reference_logits,logits))
        # print(identical_samples)
        # print(torch.sum(logits_max_token == reference_logits_max_token))


batch_size = 32

i = 0
for batch in tqdm(DataLoader(validation_data, batch_size=batch_size, collate_fn=data_collator),total = math.ceil(2000/batch_size),desc = "tuned_len_exp"):
    # print(batch)k
    batch.to(device)
    with torch.no_grad():
        model_outputs = model(**batch,output_hidden_states=True)
        hidden_states = model_outputs.hidden_states[1:]  # exclude the embedding layer
        reference_logits = model_outputs.logits
        mask = batch['attention_mask']

    # reference_logits = reference_logits[torch.arange(len(mask)),torch.sum(mask,dim=1)] # Size([batch_size, vocab_size])
    Compare_logit = tuned_len(hidden_states,mask,reference_logits)
    i += 1
    # print('one batch done')

    # if i >= 5:
    #     break

tuned_len_exp: 100%|██████████| 63/63 [00:59<00:00,  1.05it/s]


In [61]:
tuned_len_p = precision(each_layer_tuned_lens_Precision)
tuned_len_p = surprisal(each_layer_tuned_lens_Surprisal)

In [66]:
torch.save(tuned_len_p,f"/content/drive/MyDrive/Patchscope-Reimplementation-main/Decoding_of_next_token_Prediction/tuned_len_P_{model_dict[model_num].split('/')[-1]}.pt")
torch.save(tuned_len_p,f"/content/drive/MyDrive/Patchscope-Reimplementation-main/Decoding_of_next_token_Prediction/tuned_len_S_{model_dict[model_num].split('/')[-1]}.pt")

In [62]:
tuned_len_p

defaultdict(list,
            {0: tensor(0.1845, device='cuda:0'),
             1: tensor(0.2020, device='cuda:0'),
             2: tensor(0.2045, device='cuda:0'),
             3: tensor(0.2205, device='cuda:0'),
             4: tensor(0.2345, device='cuda:0'),
             5: tensor(0.2345, device='cuda:0'),
             6: tensor(0.2325, device='cuda:0'),
             7: tensor(0.2395, device='cuda:0'),
             8: tensor(0.2420, device='cuda:0'),
             9: tensor(0.2450, device='cuda:0'),
             10: tensor(0.2445, device='cuda:0'),
             11: tensor(0.2585, device='cuda:0'),
             12: tensor(0.2690, device='cuda:0'),
             13: tensor(0.2875, device='cuda:0'),
             14: tensor(0.3100, device='cuda:0'),
             15: tensor(0.3270, device='cuda:0'),
             16: tensor(0.4000, device='cuda:0'),
             17: tensor(0.4505, device='cuda:0'),
             18: tensor(0.4755, device='cuda:0'),
             19: tensor(0.5220, device='cu

In [63]:
tuned_len_s

defaultdict(list,
            {0: tensor(5.5387, device='cuda:0'),
             1: tensor(5.2503, device='cuda:0'),
             2: tensor(5.1506, device='cuda:0'),
             3: tensor(4.9346, device='cuda:0'),
             4: tensor(4.7680, device='cuda:0'),
             5: tensor(4.7396, device='cuda:0'),
             6: tensor(4.7775, device='cuda:0'),
             7: tensor(4.7151, device='cuda:0'),
             8: tensor(4.6886, device='cuda:0'),
             9: tensor(4.6371, device='cuda:0'),
             10: tensor(4.6395, device='cuda:0'),
             11: tensor(4.4843, device='cuda:0'),
             12: tensor(4.3771, device='cuda:0'),
             13: tensor(4.2798, device='cuda:0'),
             14: tensor(4.1091, device='cuda:0'),
             15: tensor(3.9514, device='cuda:0'),
             16: tensor(3.4422, device='cuda:0'),
             17: tensor(3.0373, device='cuda:0'),
             18: tensor(2.8325, device='cuda:0'),
             19: tensor(2.5387, device='cu

In [64]:
X_train[0].shape

(10000, 896)

In [65]:
Y_train.shape

(10000, 896)

we can conclude that tuned lens really outputperform vallina logit lean

## 3.Token Identitiy Patchscope

set the hidden state at the specic layer and position to the last

### A. Construct **Few-Shot Token Identity Prompts**

In [67]:
# sample a random set of k tokens for all the models, where k was also randomly sampled from the interval [1, . . . , 10].

# to construct 5 realization of token iDs series of variable length

def craft_realization() -> torch.Tensor:
    """
    return: Size([num_realization, num_tokens])
    """

    identity_prompts = []

    num_realization = 5

    for j in range(num_realization):
        num_tokens = random.sample(range(tokenizer.vocab_size),random.randint(1,10))
        prompt = ';'.join([ f"{tokenizer.decode([i])} -> {tokenizer.decode([i])}"  for i in num_tokens]) + '; ?'
        identity_prompts.append(prompt)


    return identity_prompts

identity_token_prompts = craft_realization()




In [68]:
print(identity_token_prompts)

['=( -> =(;_invoice -> _invoice; Record ->  Record; ?', '-bal -> -bal; booths ->  booths; ?', ' pomi ->  pomi; Кон ->  Кон; Xin ->  Xin;أكثر -> أكثر;StackNavigator -> StackNavigator;bucks -> bucks; Nov ->  Nov; ?', '럭 -> 럭;upgrade -> upgrade;.STRING -> .STRING; אין ->  אין; takes ->  takes; ?', ' Kund ->  Kund; relaciones ->  relaciones; dataGridViewTextBoxColumn ->  dataGridViewTextBoxColumn; defiant ->  defiant; ?']


In [69]:
# tokenizer.tokenize(identity_token_prompts[0])

In [70]:
model.model.layers

ModuleList(
  (0-23): 24 x Qwen2DecoderLayer(
    (self_attn): Qwen2Attention(
      (q_proj): Linear(in_features=896, out_features=896, bias=True)
      (k_proj): Linear(in_features=896, out_features=128, bias=True)
      (v_proj): Linear(in_features=896, out_features=128, bias=True)
      (o_proj): Linear(in_features=896, out_features=896, bias=False)
    )
    (mlp): Qwen2MLP(
      (gate_proj): Linear(in_features=896, out_features=4864, bias=False)
      (up_proj): Linear(in_features=896, out_features=4864, bias=False)
      (down_proj): Linear(in_features=4864, out_features=896, bias=False)
      (act_fn): SiLU()
    )
    (input_layernorm): Qwen2RMSNorm((896,), eps=1e-06)
    (post_attention_layernorm): Qwen2RMSNorm((896,), eps=1e-06)
  )
)

In [71]:
model

Qwen2ForCausalLM(
  (model): Qwen2Model(
    (embed_tokens): Embedding(151936, 896)
    (layers): ModuleList(
      (0-23): 24 x Qwen2DecoderLayer(
        (self_attn): Qwen2Attention(
          (q_proj): Linear(in_features=896, out_features=896, bias=True)
          (k_proj): Linear(in_features=896, out_features=128, bias=True)
          (v_proj): Linear(in_features=896, out_features=128, bias=True)
          (o_proj): Linear(in_features=896, out_features=896, bias=False)
        )
        (mlp): Qwen2MLP(
          (gate_proj): Linear(in_features=896, out_features=4864, bias=False)
          (up_proj): Linear(in_features=896, out_features=4864, bias=False)
          (down_proj): Linear(in_features=4864, out_features=896, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): Qwen2RMSNorm((896,), eps=1e-06)
        (post_attention_layernorm): Qwen2RMSNorm((896,), eps=1e-06)
      )
    )
    (norm): Qwen2RMSNorm((896,), eps=1e-06)
    (rotary_emb): Qwen2RotaryEmbe

In [72]:
debug_idx = 23
def patch_hidden_state(source_hidden_states : tuple,batch_length:torch.Tensor, model=model,target_prompt = None) -> None:
    """
    source_hidden_states: Tuple[Tensor:Size([batch_size, sequence_length, hidden_size])]: the hidden states of each layer of the model
    batch_length: Size([batch_size])

    target_prompt: str: the prompt the patched hidden states should be conditioned on

    """
    patched_results = {l : [] for l in range(len(source_hidden_states))}


    for layer_index, layer in enumerate(source_hidden_states):

        layer = layer[torch.arange(len(batch_length)),batch_length -1 ,:] # Size([batch_size, hidden_size])


        for i in range(layer.shape[0]):

            hs_to_be_patched = layer[i]

            # set a farward hook function
            def patch_hidden_hook(module, input, output):
                """
                input: Tuple[Tensor:Size([batch_size, sequence_length, hidden_size])]
                output: Tuple[Size([batch_size, sequence_length, hidden_size])]: the hidden states of each layer of the model

                """
                # print(f' ')
                output[0][0,-1] = hs_to_be_patched

                if layer_index == debug_idx:  # for debugging reason
                    print(output[0][0,-1][:6])


            # set a hook to the target layer
            hook_handle = model.model.layers[layer_index].register_forward_hook(patch_hidden_hook)

            with torch.no_grad():
                output_Target = model(**tokenizer(target_prompt,return_tensors='pt').to(device),output_hidden_states=True)
                target_logits = output_Target.logits[0][-1]
                patched_results[layer_index].append(target_logits)

            hook_handle.remove()



            if layer_index == debug_idx:  # for debugging reason

                print(output_Target.hidden_states[debug_idx + 1 ][0][-1][:6])


    for k,v in patched_results.items():
        patched_results[k] = torch.stack(v)

    return patched_results



In [73]:
tokenizer(identity_token_prompts[0],return_tensors='pt',padding=True)['input_ids'].shape


torch.Size([1, 16])

In [81]:
def cal_precision_1(source, target,sources_length):
    """
    source: Size([batch_size, sequence_len, vocab_size])
    target: Dict(K: num_layer, V: Size([batch_size, vocab_size]))

    """

    source = source[torch.arange(len(sources_length)),sources_length -1 ,:] # Size([batch_size, vocab_size])

    for i in range(len(target)):
        print(torch.sum(torch.argmax(source,dim=-1) == torch.argmax(target[i],dim=-1),dim = 0))

def cal_surpisal(source, target,sources_length):
    """
    source: Size([batch_size, sequence_len, vocab_size])
    target: Dict(K: num_layer, V: Size([batch_size, vocab_size]))

    """
    reference_probs = torch.nn.functional.softmax(source, dim=-1)

    for i in range(len(target)):
      predicted_tokens = torch.argmax(target[i], dim=-1)  # Shape: (batch_size,)

      predicted_token_probs = reference_probs[torch.arange(len(sources_length)), predicted_tokens]
      surprisal = -torch.log(predicted_token_probs + 1e-9)
      print(torch.sum(surprisal))


In [83]:
for idx, prompt in enumerate(identity_token_prompts):

    # get the hidden satate of the prompt at the last positon
    identity_token_P = [defaultdict(list)]
    identity_token_S = [defaultdict(list)]

    batch_size = 4
    for batch in DataLoader(validation_data, batch_size=batch_size ,collate_fn=data_collator): #,total = (2000/ batch_size), desc = "patchscope_exp"):
        # print(batch)
        batch.to(device)

        # run the first forward pass on the source prompt
        with torch.no_grad():
            # with autocast('cuda'):
            # print("hello 123")
            model_outputs = model(**batch,output_hidden_states=True)
            # print("heelo ")
            hidden_states = model_outputs.hidden_states[1:]
            reference_logits = model_outputs.logits # Size([batch_size, sequence_length, vocab_size])
            batch_length = torch.sum(batch['attention_mask'],dim=1)

        target_logtis = patch_hidden_state(hidden_states,batch_length,target_prompt=prompt)

        # print(hidden_states[0][:,batch_length,:][:6])
        cal_precision_1(reference_logits,target_logtis,batch_length)
        cal_surpisal(reference_logits,target_logtis,batch_length)

        # break


        break
    break

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [None]:
# to check at the last layer if the logits are the same as the target logits