In [1]:
import pandas as pd
import torch

from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
from utils import transfer_context_prompt

import warnings
from transformers import logging as transformers_logging

warnings.filterwarnings('ignore')
transformers_logging.set_verbosity_error()

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model_id = "/share/nlp/chitchat/models/Llama-3.2-3B-Instruct/"

In [3]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
device

'cuda:0'

In [4]:
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id, 
).to(device)

Loading checkpoint shards: 100%|██████████| 2/2 [00:03<00:00,  1.80s/it]


In [18]:
question = 'Who won season 10 of dancing with the stars?'
context = 'The story, illustrated by the author, is set in England as the Black Death (bubonic plague) is sweeping across the country. Young Robin is sent away to become a knight like his father, but his dreams are endangered when he loses the use of his legs. A doctor reassures Robin that the weakness in his legs is not caused by the plague and the doctor is supposed to come and help him but does not. His parents are away, serving the king and queen during war, and the servants abandon the house, fearing the plague. Robin is saved by Brother Luke, a friar, who finds him and takes him to a monastery and cares for him.'

messages = [
    {"role": "user", "content": transfer_context_prompt(question, context)}, 
    {"role": "system", "content": "You are a helpful assistant who gives only factually correct short answers without additional information"}
]

In [19]:
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

In [20]:
text

'<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nCutting Knowledge Date: December 2023\nToday Date: 07 Nov 2024\n\n<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nAnswer the question Who won season 10 of dancing with the stars? based on the given context The story, illustrated by the author, is set in England as the Black Death (bubonic plague) is sweeping across the country. Young Robin is sent away to become a knight like his father, but his dreams are endangered when he loses the use of his legs. A doctor reassures Robin that the weakness in his legs is not caused by the plague and the doctor is supposed to come and help him but does not. His parents are away, serving the king and queen during war, and the servants abandon the house, fearing the plague. Robin is saved by Brother Luke, a friar, who finds him and takes him to a monastery and cares for him.<|eot_id|><|start_header_id|>system<|end_header_id|>\n\nYou are a helpful assistant who gives only factually c

In [21]:
tokenizer.decode(
    model.generate(**tokenizer(text, return_tensors='pt').to(device), 
                   max_new_tokens=100)[0], 
    skip_special_tokens=True
)

'system\n\nCutting Knowledge Date: December 2023\nToday Date: 07 Nov 2024\n\nuser\n\nAnswer the question Who won season 10 of dancing with the stars? based on the given context The story, illustrated by the author, is set in England as the Black Death (bubonic plague) is sweeping across the country. Young Robin is sent away to become a knight like his father, but his dreams are endangered when he loses the use of his legs. A doctor reassures Robin that the weakness in his legs is not caused by the plague and the doctor is supposed to come and help him but does not. His parents are away, serving the king and queen during war, and the servants abandon the house, fearing the plague. Robin is saved by Brother Luke, a friar, who finds him and takes him to a monastery and cares for him.system\n\nYou are a helpful assistant who gives only factually correct short answers without additional informationassistant\n\nThere is no information about the TV show "Dancing with the Stars" in the given c

---

In [22]:
def get_next_token(activations, model, tokenizer):
    logit_lens = {}
    for name, activation in activations.items():
        lm_head_out = model.lm_head(model.model.norm(activation))
        next_logits = lm_head_out[:, -1, :]
        logit_lens[name] = tokenizer.decode(torch.argmax(next_logits).item(), skip_special_tokens=True)
    return logit_lens

In [23]:
def get_activation(input_text, model, tokenizer, max_new_tokens=1):
    device = model.device
    inputs = tokenizer(input_text, return_tensors="pt", add_special_tokens=False)
    input_ids, attention_mask = inputs["input_ids"].to(device), inputs["attention_mask"].to(device)
    
    def getActivation(name):
        # the hook signature
        def hook(model, input, output):
            activations[name + "_output"] = output[0].detach()

        return hook
    
    h = [
        model.model.layers[i].register_forward_hook(getActivation(f"layers_{i}")) 
        for i in range(model.config.num_hidden_layers)
    ]
    
    logit_lens = {}
    for i_token in range(max_new_tokens):
        activations = {}
        sampling_params = GenerationConfig(
            temperature=0.6,
            top_k=50,
            top_p=0.9,
            do_sample=False,
            num_beams=1,
            presence_penalty=0.0,
            repetition_penalty=1.0,
            generate_until=(),
            allow_newlines=True,
        )
        with torch.no_grad():
            out = model(input_ids=input_ids, attention_mask=attention_mask, generation_config=sampling_params)
        logits, past_key_values = out.values()
        next_logits = logits[:, -1, :]
        next_token_id = torch.argmax(next_logits)
        input_ids = torch.cat((input_ids, torch.tensor([[next_token_id]], device=device)), dim=1)
        attention_mask = torch.cat((attention_mask, torch.tensor([[1]], device=device)), dim=1)
        logit_lens[i_token] = get_next_token(activations, model, tokenizer)
    
    [hook.remove() for hook in h]
    return logit_lens

In [24]:
pd.set_option('display.max_columns', None)

In [25]:
toks_per_layer = get_activation(text, model, tokenizer, 16)

In [26]:
pd.DataFrame(toks_per_layer)

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15
layers_0_output,\n\n,I,don,'t,have,information,about,Season,,10,of,Dancing,with,the,Stars,.
layers_1_output,\n\n,I,don,'t,have,information,about,Season,,10,of,Dancing,with,the,Stars,.
layers_2_output,idges,I,don,'t,have,information,about,Season,racks,10,of,Dancing,WITH,gusto,Stars,weighed
layers_3_output,enic,I,don,'t,ares,information,about,Season,Season,10,Of,Dancing,-with,Enabled,Stars,gaz
layers_4_output,ーン,I,究,'t,alone,information,about,season,ago,10,iang,Dancing,-with,Swan,Stars,UNC
layers_5_output,olor,/we,ераль,necessarily,عات,information,about,season,(Program,10,afa,Dancing,placer,antium,star,nor
layers_6_output,rough,'d,fairy,currently,slightest,information,co,season,nell,Ranch,-round,Dancing,Congress,well,rim,rn
layers_7_output,tap,'ve,caffeine,Forbidden,yet,information,HostException,season,otify,unconventional,this,Dancing,gence,specificity,rim,otr
layers_8_output,Din,outr,REATE,upo,contained,conosc,-how,season,的心,Feinstein,osl,Dancing,ーツ,gang,ssc,hereby
layers_9_output,Courtesy,hereby,inne,currently,anymore,ledged,centralized,ização,uada,aghan,chwitz,Himal,bash,Dirty,monumental,중에


In [32]:
eval_dataset = pd.read_csv('data/rag_routing_eval_dataset.csv')
perturbation_context = pd.read_csv('data/rag_routing_eval_dataset_context_perturbations.csv')

In [33]:
for index, row in tqdm(eval_dataset.iterrows()):
    text_without_context = tokenizer.apply_chat_template([
        {"role": "user", "content": row['question']}, 
        {"role": "system", "content": "You are a helpful assistant who gives only factually correct short answers without additional information"}
    ], 
        tokenize=False, 
        add_generation_prompt=True
    )
    text_with_context = tokenizer.apply_chat_template([
        {"role": "user", "content": transfer_context_prompt(row['question'], row['context'])}, 
        {"role": "system", "content": "You are a helpful assistant who gives only factually correct short answers without additional information"}
    ], 
        tokenize=False, 
        add_generation_prompt=True
    )
    toks_per_layer_without_context = get_activation(text_without_context, model, tokenizer, 16)
    pd.DataFrame(toks_per_layer_without_context).to_csv(f'data/tokens_per_layers/normal_context/{index}_without_context.csv')
    toks_per_layer_with_context = get_activation(text_with_context, model, tokenizer, 16)
    pd.DataFrame(toks_per_layer_with_context).to_csv(f'data/tokens_per_layers/normal_context/{index}_with_context.csv')

20it [00:49,  2.47s/it]


In [35]:
for index, row in tqdm(perturbation_context.iterrows()):
    text_without_context = tokenizer.apply_chat_template([
        {"role": "user", "content": row['question']}, 
        {"role": "system", "content": "You are a helpful assistant who gives only factually correct short answers without additional information"}
    ], 
        tokenize=False, 
        add_generation_prompt=True
    )
    text_with_context = tokenizer.apply_chat_template([
        {"role": "user", "content": transfer_context_prompt(row['question'], row['context_perturbations'])}, 
        {"role": "system", "content": "You are a helpful assistant who gives only factually correct short answers without additional information"}
    ], 
        tokenize=False, 
        add_generation_prompt=True
    )
    toks_per_layer_without_context = get_activation(text_without_context, model, tokenizer, 16)
    pd.DataFrame(toks_per_layer_without_context).to_csv(f'data/tokens_per_layers/bad_context/{index}_without_context.csv')
    toks_per_layer_with_context = get_activation(text_with_context, model, tokenizer, 16)
    pd.DataFrame(toks_per_layer_with_context).to_csv(f'data/tokens_per_layers/bad_context/{index}_with_context.csv')

7it [00:16,  2.29s/it]


In [27]:
question = 'What kind of sports did Mehdi Tahiri do?'
context = 'Mehdi Tahiri (born 28 July 1977) is a retired Moroccan tennis player. Tahiri represented Morocco in the Davis Cup in several years from 1993 to 2006.'

messages = [
    {"role": "user", "content": transfer_context_prompt(question, context)}, 
    {"role": "system", "content": "You are a helpful assistant who gives only factually correct short answers without additional information"}
]
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

def track_token_probs(activations, model, tokenizer, target_token_ids):
    token_probs = {}
    for name, activation in activations.items():
        lm_head_out = model.lm_head(model.model.norm(activation))
        logits = lm_head_out[:, -1, :]
        
        token_logits = logits[:, target_token_ids]
        token_probs[name] = torch.softmax(token_logits, dim=-1).squeeze().tolist()
        
    return token_probs

def get_activation(input_text, model, tokenizer, top_k=5):
    device = model.device
    inputs = tokenizer(input_text, return_tensors="pt", add_special_tokens=False)
    input_ids, attention_mask = inputs["input_ids"].to(device), inputs["attention_mask"].to(device)
    
    activations = {}
    
    def getActivation(name, activations):
        def hook(model, input, output):
            activations[name + "_output"] = output[0].detach()
        return hook
    
    h = [
        model.model.layers[i].register_forward_hook(getActivation(f"layers_{i}", activations)) 
        for i in range(model.config.num_hidden_layers)
    ]
    
    with torch.no_grad():
        out = model(input_ids=input_ids, attention_mask=attention_mask)
    
    logits = out.logits
    last_layer_logits = logits[:, -1, :]
    
    top_k_probs, top_k_indices = torch.topk(last_layer_logits, k=top_k, dim=-1)
    top_tokens = tokenizer.batch_decode(top_k_indices.squeeze().tolist(), skip_special_tokens=True)
    
    target_token_ids = top_k_indices.squeeze().tolist()
    
    with torch.no_grad():
        model(input_ids=input_ids, attention_mask=attention_mask)
    
    token_probs_by_layer = track_token_probs(activations, model, tokenizer, target_token_ids)
    
    [hook.remove() for hook in h]
    
    return {"top_tokens": top_tokens, "probabilities_by_layer": token_probs_by_layer}

toks_per_layer = get_activation(text, model, tokenizer, top_k=5)

In [28]:
toks_per_layer

{'top_tokens': ['T', 'He', 'M', 'Mor', 'ten'],
 'probabilities_by_layer': {'layers_0_output': [0.1431712508201599,
   0.14148534834384918,
   0.34414049983024597,
   0.20533457398414612,
   0.16586826741695404],
  'layers_1_output': [0.16472505033016205,
   0.09690017998218536,
   0.41285625100135803,
   0.1934746950864792,
   0.1320437639951706],
  'layers_2_output': [0.12963250279426575,
   0.03474211320281029,
   0.5197805166244507,
   0.2575804889202118,
   0.05826433748006821],
  'layers_3_output': [0.03574942797422409,
   0.06761986762285233,
   0.6824180483818054,
   0.17775548994541168,
   0.036457162350416183],
  'layers_4_output': [0.14672821760177612,
   0.41593456268310547,
   0.12328041344881058,
   0.2801649570465088,
   0.03389190882444382],
  'layers_5_output': [0.4567137658596039,
   0.037576451897621155,
   0.2526184618473053,
   0.23385019600391388,
   0.019241120666265488],
  'layers_6_output': [0.24742604792118073,
   0.016151955351233482,
   0.15686658024787903,
 