In [1]:
import os
import json 
import logging

logging.basicConfig(
    filename='log/app.log',            # Specify the log file name
    level=logging.DEBUG,           # Set the log level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
    format='%(asctime)s - %(levelname)s - %(message)s'  # Set the log format
)

# Load the environment configuration JSON data
json_path = 'env_config.json'
with open(json_path, 'r') as file:
    env_config = json.load(file)

hf_home = env_config['HF_HOME']
# Set the HF_HOME environment variable
os.environ['HF_HOME'] = hf_home
# Set the access token to huggingface hub
access_token = env_config['access_token']
os.environ['HUGGINGFACE_HUB_TOKEN'] = access_token



In [2]:
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel, LlamaForTokenClassification #, LlamaRotaryEmbedding


model_id = "meta-llama/Llama-3.2-1B-Instruct"

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",
    # device_map=device,
    token=access_token,
)

In [3]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(model_id, token=access_token)


tokenizer_config.json:   0%|          | 0.00/54.5k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/296 [00:00<?, ?B/s]

In [20]:
import transformers 
print(transformers.__version__)

from transformers import pipeline
import torch

from accelerate import Accelerator
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel, LlamaForTokenClassification #, LlamaRotaryEmbedding
from transformers import LlamaTokenizerFast
import torch.nn.functional as F


accelerator = Accelerator()
device = accelerator.device
# device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
# model_id = "meta-llama/Meta-Llama-3-8B"  # non-instruct version

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    # device_map=device,
    token=access_token,
)

config = model.config

# model_2 = LlamaForTokenClassification.from_pretrained(
#     model_id,
#     torch_dtype=torch.bfloat16,
#     device_map="auto",
#     # device_map=device,
#     token=access_token,
# )

tokenizer = LlamaTokenizerFast.from_pretrained(model_id, token=access_token)
tokenizer.pad_token = tokenizer.eos_token

# rotary_emb = LlamaRotaryEmbedding(config=config)

4.44.0


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'PreTrainedTokenizerFast'. 
The class this function is called from is 'LlamaTokenizerFast'.


In [21]:
class LlmExpHelper:
    def __init__(self, tokenizer, dataset):
        self.tokenizer = tokenizer
        self.dataset = dataset
    
    def get_collate_fun(self):
        return lambda examples: self.collate_fn(examples)

    def collate_fn(self, examples):
        def num_words(x):
            return len(x.split())
        def get_first_k_words(x, k):
            return ' '.join(x.split()[:k])
        def get_cliped_text(texts, max_len):
            return [text if num_words(text) <= max_len else get_first_k_words(text, max_len) for text in texts]
        tokenizer = self.tokenizer
        max_len = 512 # characters limit other than token limit
        if self.dataset == 'imdb':
            texts = [example['text'] for example in examples]
            texts = get_cliped_text(texts, max_len)
            sys_context = "You are a chatbot for sentiment analysis. You can help users with their questions via concise responses of POSITIVE, or NEGATIVE."
        elif self.dataset == 'sst2':
            texts = [example['sentence'] for example in examples]
            texts = get_cliped_text(texts, max_len)
            sys_context = "You are a chatbot for sentiment analysis. You can help users with their questions via concise responses of POSITIVE, or NEGATIVE."
        elif self.dataset == 'squad':
            context = [example['context'] for example in examples]
            context = get_cliped_text(context, max_len)
            question = [example['question'] for example in examples]
            # texts = [f"Context: {context[i]}\nQuestion: {question[i]}" for i in range(len(context))]
            texts = [f"Question: {question[i]}\nContext: {context[i]}" for i in range(len(context))]
            sys_context = "You are a chatbot for answering questions. You can help users with their questions via concise responses."

        # labels = [example['label'] for example in examples]
        messages_lambda = lambda texts: [
            {"role": "system", "content": sys_context},
            {"role": "user", "content": texts},
        ]

        messages = list(map(messages_lambda, texts))

        messages_with_template_applied = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True,
        )
        batch = tokenizer(
                    messages_with_template_applied,
                    add_special_tokens=False,
                    padding=True,
                    return_tensors="pt",
                    )
        
        # find the template boundaries
        text_lens = [len(tokenizer.encode(text)) for text in texts]
        text_lens_tensor = torch.tensor(text_lens, dtype=torch.long)

        
        def apply_mask(mask_tensor, text_lens_tensor):
            batch_size, seq_len = mask_tensor.shape
            for i in range(batch_size):
                text_len = text_lens_tensor[i].item()
                mask_tensor[i, -text_len-5:-5] = 0
            return 1- mask_tensor

        mask_tensor = apply_mask(torch.ones_like(batch['input_ids']), text_lens_tensor)

        batch['context_mask'] = mask_tensor

        if self.dataset == 'squad':
            answers_start = [example['answers']['answer_start'][0] for example in examples]
            answers_end = [example['answers']['answer_start'][0] + len(example['answers']['text'][0]) for example in examples]
            batch['answers_start'] = torch.tensor(answers_start).long()
            batch['answers_end'] = torch.tensor(answers_end).long()

            context_lens = [len(tokenizer.encode(context)) for context in context]
            context_lens_tensor = torch.tensor(context_lens, dtype=torch.long)
            mask_tensor_v2 = apply_mask(torch.ones_like(batch['input_ids']), context_lens_tensor)
            batch['context_mask'] = mask_tensor_v2 * batch['attention_mask']
            
        
        return batch

In [24]:
from llmexp.helper import LlmExpHelper
from datasets import load_dataset
from torch.utils.data import DataLoader

ds = load_dataset("imdb")
# ds = load_dataset("rajpurkar/squad")
train_ds = ds['train']
test_ds = ds['test']
# from datasets import load_dataset

# ds = load_dataset("stanfordnlp/sst2")
# train_ds = ds['train']
llm_exp_helper = LlmExpHelper(tokenizer, 'imdb')
collate_fn = llm_exp_helper.get_collate_fun()

# Define batch size here!
batch_size = 16
train_dataloader = DataLoader(train_ds, batch_size=batch_size, collate_fn=collate_fn, shuffle=True)
test_dataloader = DataLoader(train_ds, batch_size=batch_size, collate_fn=collate_fn, shuffle=False)

In [25]:
# next(iter(train_dataloader))
train_ds[0]

{'text': 'I rented I AM CURIOUS-YELLOW from my video store because of all the controversy that surrounded it when it was first released in 1967. I also heard that at first it was seized by U.S. customs if it ever tried to enter this country, therefore being a fan of films considered "controversial" I really had to see this for myself.<br /><br />The plot is centered around a young Swedish drama student named Lena who wants to learn everything she can about life. In particular she wants to focus her attentions to making some sort of documentary on what the average Swede thought about certain political issues such as the Vietnam War and race issues in the United States. In between asking politicians and ordinary denizens of Stockholm about their opinions on politics, she has sex with her drama teacher, classmates, and married men.<br /><br />What kills me about I AM CURIOUS-YELLOW is that 40 years ago, this was considered pornographic. Really, the sex and nudity scenes are few and far be

In [6]:
# for name, param in mask_gen_model.named_parameters():
#     print(name, param.device)

In [26]:
from llmexp.squad_model_lora import MaskGeneratingModel
from peft import LoraConfig, get_peft_model

# model.to(device)
# emb_weights = model.get_input_embeddings().weight.clone().float().to("cpu")
# mask_gen_model = MaskGeneratingModel(hidden_size=4096, emb_weights=emb_weights)
mask_gen_model = MaskGeneratingModel()
mask_gen_model.to(device)

# target_modules = []
# num_layers = 6  # BERT-base 有 12 层
# for i in range(num_layers):
#     target_modules.extend([
#         f"explain_map.layer.{i}.attention.self.query",
#         f"explain_map.layer.{i}.attention.self.key",
#         f"explain_map.layer.{i}.attention.self.value",
#         f"explain_map.layer.{i}.attention.output.dense",
#         f"explain_map.layer.{i}.intermediate.dense",
#         f"explain_map.layer.{i}.output.dense"
#     ])

# lora_config = LoraConfig(
#     r=4,  # 低秩矩阵的秩
#     lora_alpha=32,  # LoRA 的缩放因子
#     target_modules= target_modules,  # 目标模块
#     lora_dropout=0.1  # Dropout 概率
# )
# mask_gen_model = get_peft_model(mask_gen_model, lora_config)


from tqdm import tqdm

# Set pad_token_id if it is not set
if tokenizer.pad_token_id is None:
    tokenizer.pad_token_id = tokenizer.eos_token_id

pad_token_id = tokenizer.pad_token_id

terminators = [
    tokenizer.eos_token_id,
    tokenizer.convert_tokens_to_ids("<|eot_id|>")
]

optimizer = torch.optim.Adam(mask_gen_model.parameters(), lr=1e-3)

In [27]:
mask_gen_model.train()
for epoch in range(1):
    pbar = tqdm(train_dataloader)
    for idx, data in enumerate(pbar):
        input_ids = data['input_ids'].to(device)
        attention_mask = data['attention_mask'].to(device)
        context_mask = data['context_mask'].to(device)
        # get generated texts
        gen_outputs = model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_new_tokens=128,
            eos_token_id=terminators,
            pad_token_id=tokenizer.pad_token_id,
            do_sample=True,
            temperature=0.6,
            top_p=0.9,
            return_dict_in_generate=True,
            output_scores=True,
        )
        gen_tokens = gen_outputs.sequences
        pad_length = gen_tokens.size(1) - input_ids.size(1)
        # get the attention mask for the generated tokens, and also mask the padding tokens
        gen_attention_mask = F.pad(attention_mask, (0, pad_length), mode='constant', value=1)
        # (gen_tokens != pad_token_id).long() is the tokens mask, 1 for real tokens and 0 for padding tokens
        unpaded_token_mask = (gen_tokens != pad_token_id).long()
        unpaded_token_mask[:, :-pad_length] = 1
        gen_attention_mask = gen_attention_mask * unpaded_token_mask
        # print(gen_tokens[0])
        # print(gen_attention_mask[0])
        # get the response mask, which is the mask for the generated tokens (the user inputs are masked with 0)
        response_mask = gen_attention_mask.clone()
        response_mask[:, :-pad_length] = 0 # TODO: 有问题. 有问题吗？

        context_mask = F.pad(context_mask, (0, pad_length), mode='constant', value=0)

        loss_dict = mask_gen_model.train_one_batch(model, gen_tokens, gen_attention_mask, context_mask, response_mask, optimizer,
                                                   num_steps=3, mini_batch_size=16, ppo_epochs=2)

        # # Get the last hidden state for the prompt + response sequence
        # with torch.no_grad():
        #     # full_outputs = model(input_ids=gen_tokens, attention_mask=gen_attention_mask, output_hidden_states=True, return_dict=True)
        #     # last_hidden_state = full_outputs.hidden_states[-1]
        #     # last_hidden_state = last_hidden_state.float()
        #     embedded = model.get_input_embeddings()(gen_tokens)
        #     # last_hidden_state = model.get_encoder()(embedded, attention_mask=gen_attention_mask)[0]
        #     last_hidden_state = embedded
        #     last_hidden_state = last_hidden_state.float()
        
        # mask_logits = mask_gen_model(last_hidden_state)

        # mask_gen_outputs = mask_gen_model.loss_func(model, gen_tokens, gen_attention_mask, context_mask, mask_logits, response_mask, 
        #                                                                    num_samples=1)
        # loss, reward_loss, mask_loss, mask_mean, mean_reward = mask_gen_outputs['loss'], mask_gen_outputs['reward_loss'], mask_gen_outputs['mask_loss'], mask_gen_outputs['mask_mean'], mask_gen_outputs['mean_reward']
        log = f"Epoch {epoch+1}, Step {idx+1}: Loss = {loss_dict['loss']:.4f}, " \
               f"Actor Loss = {loss_dict['actor_loss']:.4f}, " \
               f"Critic Loss = {loss_dict['critic_loss']:.4f}, " \
               f"Entropy = {loss_dict['entropy']:.4f}, " \
               f"Returns = {loss_dict['returns']:.4f}, " \
               f"Value = {loss_dict['value']:.4f}, " \
                f"mask_loss = {loss_dict['mask_loss']:.4f}" \
                f"std_loss = {loss_dict['std_loss']:.4f}" \
            #    f"Cont_loss = {loss_dict['contrast_loss']:.4f}, "  \
               
        pbar.set_description(log)
        # logging.debug(log)
    


        # # the mask_prob after the updates
        # with torch.no_grad():
        #     mask_logits_after = mask_gen_model(last_hidden_state)

        #     mask_gen_outputs_after = mask_gen_model.loss_func(model, gen_tokens, gen_attention_mask, context_mask, mask_logits_after, response_mask, 
        #                                                                     num_samples=5)
        #     loss_after, reward_loss_after, mask_loss_after, mask_mean_after, mean_reward_after = mask_gen_outputs_after['loss'], mask_gen_outputs_after['reward_loss'], mask_gen_outputs_after['mask_loss'], mask_gen_outputs_after['mask_mean'], mask_gen_outputs_after['mean_reward']
        #     mask_prob_after = (torch.sigmoid(mask_logits_after) * context_mask).clone().detach()
        #     mean_reward_after = mean_reward_after.clone().detach()

        # # load the parameters before the updates
        # mask_gen_model.load_state_dict(params_before)
        # mask_logits_before = mask_gen_model(last_hidden_state)

        # mask_gen_outputs_before = mask_gen_model.loss_func(model, gen_tokens, gen_attention_mask, context_mask, mask_logits_before, response_mask, 
        #                                                                    num_samples=5)
        # loss_before, reward_loss_before, mask_loss_before, mask_mean_before, mean_reward_before = mask_gen_outputs_before['loss'], mask_gen_outputs_before['reward_loss'], mask_gen_outputs_before['mask_loss'], mask_gen_outputs_before['mask_mean'], mask_gen_outputs_before['mean_reward']
        # mask_prob_before = (torch.sigmoid(mask_logits_before) * context_mask)

        # # calculate the ratio of the mask probabilities before and after the updates
        # ratio = mask_prob_after / (mask_prob_before + 1e-6)

        # # 定义PPO的损失函数，假设clip_param是你定义的剪切参数
        # clip_param = 0.2
        # advantage = (mean_reward_after - mean_reward_before).unsqueeze(-1)  # 计算优势函数（advantage），这是根据任务定义的
        # surr1 = ratio * advantage
        # surr2 = torch.clamp(ratio, 1 - clip_param, 1 + clip_param) * advantage
        # ppo_loss = -torch.min(surr1, surr2).mean()

        # # 更新模型参数
        # optimizer.zero_grad()
        # ppo_loss.backward()
        # optimizer.step()

        if idx % 1 == 0:
            print()
        if idx % 10 == 0 and idx != 0:
            torch.save(mask_gen_model.state_dict(), f'saved_model/mask_gen_model_lora_{epoch}_{idx}.pth') 
        #     print()
        #     # break

Epoch 1, Step 1: Loss = -1.1637, Actor Loss = -1.8603, Critic Loss = 1.3932, Entropy = 0.6931, Returns = 1.8845, Value = 0.7368, mask_loss = 0.4970std_loss = 0.0020:   0%|          | 1/1563 [00:16<7:14:09, 16.68s/it]




Epoch 1, Step 2: Loss = 0.8660, Actor Loss = 0.0335, Critic Loss = 1.6650, Entropy = 0.6931, Returns = 1.9212, Value = 0.6392, mask_loss = 0.4997std_loss = 0.0026:   0%|          | 2/1563 [00:32<7:03:14, 16.27s/it]  




Epoch 1, Step 3: Loss = -1.1394, Actor Loss = -1.2621, Critic Loss = 0.2455, Entropy = 0.6921, Returns = 1.8447, Value = 1.6390, mask_loss = 0.5208std_loss = 0.0083:   0%|          | 3/1563 [00:48<7:01:15, 16.20s/it]




Epoch 1, Step 4: Loss = 0.3594, Actor Loss = 0.2164, Critic Loss = 0.2861, Entropy = 0.6926, Returns = 2.0121, Value = 1.5645, mask_loss = 0.5145std_loss = 0.0067:   0%|          | 4/1563 [01:03<6:42:07, 15.48s/it]  




Epoch 1, Step 5: Loss = 0.3434, Actor Loss = 0.1167, Critic Loss = 0.4534, Entropy = 0.6929, Returns = 1.3937, Value = 1.7078, mask_loss = 0.5104std_loss = 0.0051:   0%|          | 5/1563 [01:20<6:57:19, 16.07s/it]




Epoch 1, Step 6: Loss = -0.2967, Actor Loss = -0.3745, Critic Loss = 0.1557, Entropy = 0.6930, Returns = 1.9596, Value = 1.6513, mask_loss = 0.5077std_loss = 0.0048:   0%|          | 6/1563 [01:41<7:42:39, 17.83s/it]




Epoch 1, Step 7: Loss = 0.2329, Actor Loss = 0.0740, Critic Loss = 0.3179, Entropy = 0.6930, Returns = 1.6605, Value = 1.8943, mask_loss = 0.5079std_loss = 0.0046:   0%|          | 7/1563 [01:58<7:37:32, 17.64s/it]  




Epoch 1, Step 8: Loss = 0.5171, Actor Loss = 0.2740, Critic Loss = 0.4865, Entropy = 0.6931, Returns = 1.4967, Value = 1.5297, mask_loss = 0.5016std_loss = 0.0028:   1%|          | 8/1563 [02:15<7:32:39, 17.47s/it]




Epoch 1, Step 9: Loss = 0.1285, Actor Loss = -0.0705, Critic Loss = 0.3980, Entropy = 0.6931, Returns = 1.7056, Value = 1.9452, mask_loss = 0.5036std_loss = 0.0032:   1%|          | 9/1563 [02:32<7:26:38, 17.24s/it]




Epoch 1, Step 10: Loss = -0.1226, Actor Loss = -0.1435, Critic Loss = 0.0420, Entropy = 0.6931, Returns = 1.9458, Value = 2.0273, mask_loss = 0.5055std_loss = 0.0038:   1%|          | 10/1563 [02:48<7:12:38, 16.72s/it]




Epoch 1, Step 11: Loss = 0.3939, Actor Loss = 0.1731, Critic Loss = 0.4418, Entropy = 0.6931, Returns = 1.7870, Value = 1.3052, mask_loss = 0.4968std_loss = 0.0019:   1%|          | 10/1563 [03:08<7:12:38, 16.72s/it]  




Epoch 1, Step 11: Loss = 0.3939, Actor Loss = 0.1731, Critic Loss = 0.4418, Entropy = 0.6931, Returns = 1.7870, Value = 1.3052, mask_loss = 0.4968std_loss = 0.0019:   1%|          | 11/1563 [03:18<7:46:04, 18.02s/it]


KeyboardInterrupt: 

In [28]:
import numpy as np
import torch.nn.functional as F

# mask_gen_model.load_state_dict(torch.load('saved_model/mask_gen_model_lora_0_470.pth',map_location=device))

mask_gen_model.eval()

# tokens = tokenizer.convert_ids_to_tokens(gen_tokens[idx])
# texts = "This movie was the best movie I have ever seen! some scenes were ridiculous, but acting was great."
# texts = "I did not like this movie. Some of the actors were good, but overall the movie was boring."
# texts = "I hate that I love you."
# texts = "I don't like this movie."
# texts = "I really love this film."
# texts = "I really love this film. The acting was great, and the story was amazing. I would recommend this movie to everyone."
# # texts = "I don't like this movie. The acting was terrible, and the story was boring. I would not recommend this movie to anyone."
# messages_lambda = lambda texts: [
#             {"role": "system", "content": "Answer the question based on the context."},
#             # {"role": "system", "content": "You are a chatbot for sentimate analysis."},
#             {"role": "user", "content": texts},
#         ]
# messages = messages_lambda(texts)
# messages_with_template_applied = tokenizer.apply_chat_template(
#             messages,
#             tokenize=False,
#             add_generation_prompt=True,
#         )

# # test_text = [{"text": texts, "label": None}]
# test_text = [{"sentence": texts, "label": None}]
# test_inputs = collate_fn(test_text).to(device)

test_inputs = next(iter(test_dataloader)).to(device)
# test_inputs = next(iter(train_dataloader)).to(device)

# tokens = tokenizer.convert_ids_to_tokens(test_inputs['input_ids'][idx])

# generate the answer for the test inputs
gen_outputs = model.generate(
            input_ids=test_inputs['input_ids'],
            attention_mask=test_inputs['attention_mask'],
            max_new_tokens=128,
            eos_token_id=terminators,
            pad_token_id=tokenizer.pad_token_id,
            do_sample=True,
            temperature=0.6,
            top_p=0.9,
            return_dict_in_generate=True,
            output_scores=True,
        )
input_ids = test_inputs['input_ids']
attention_mask = test_inputs['attention_mask']
gen_tokens = gen_outputs.sequences
pad_length = gen_tokens.size(1) - input_ids.size(1)
# get the attention mask for the generated tokens, and also mask the padding tokens
gen_attention_mask = F.pad(attention_mask, (0, pad_length), mode='constant', value=1)
context_mask = F.pad(test_inputs['context_mask'], (0, pad_length), mode='constant', value=0)
# (gen_tokens != pad_token_id).long() is the tokens mask, 1 for real tokens and 0 for padding tokens
unpaded_token_mask = (gen_tokens != pad_token_id).long()
unpaded_token_mask[:, :-pad_length] = 1
gen_attention_mask = gen_attention_mask * unpaded_token_mask

response_mask = gen_attention_mask.clone()
response_mask[:, :-pad_length] = 0 # TODO: 有问题. 有问题吗？

# context_mask = F.pad(context_mask, (0, pad_length), mode='constant', value=0)

# with torch.no_grad():
#     # prompt_outputs = model(input_ids=test_inputs['input_ids'], attention_mask=test_inputs['attention_mask'], output_hidden_states=True, return_dict=True)
#     prompt_outputs = model(input_ids=gen_tokens, attention_mask=gen_attention_mask, output_hidden_states=True, return_dict=True)

#     last_hidden_state = prompt_outputs.hidden_states[-1].float()
#     mask_logits = mask_gen_model(last_hidden_state)


with torch.no_grad():
    state = gen_tokens, gen_attention_mask, context_mask, response_mask
    dist, value = mask_gen_model.get_dist_critic(model, state)

mask_logits = dist.logits

In [30]:
import random
idx = random.randint(0, 4)
test_ids = gen_tokens[idx]
test_mask = gen_attention_mask[idx]
test_mask_prob = torch.sigmoid(mask_logits[idx])
# inverse TODO
# test_mask_prob = 1 - test_mask_prob
test_context_mask = context_mask[idx]

test_tokens = tokenizer.convert_ids_to_tokens(test_ids)
scores = test_mask_prob * test_context_mask
def normalize_except_zeros(array):
    # Create a mask to identify non-zero elements
    mask = (array > 0)
    
    # Extract non-zero elements
    non_zero_elements = array[mask]
    
    # Normalize non-zero elements
    min_val = np.min(non_zero_elements)
    max_val = np.max(non_zero_elements)

    normalized_non_zero_elements = (non_zero_elements - min_val) / (max_val - min_val)
    
    # Create a copy of the original array to preserve zero values
    normalized_array = np.copy(array)
    
    # Assign normalized values back to the corresponding positions
    normalized_array[mask] = normalized_non_zero_elements
    
    return normalized_array
scores = normalize_except_zeros(scores.detach().cpu().numpy())

# remove special tokens
filtered_token_scores = [(token, score) for token, score in zip(test_tokens, scores) if token not in tokenizer.all_special_tokens]

# combine subwords
merged_tokens_scores = []
current_token = ""
current_score = 0
count = 0

for token, score in filtered_token_scores:
    if token.startswith("Ġ"):
        if current_token:
            merged_tokens_scores.append((current_token, current_score / count))
            # merged_tokens_scores.append((" ", 0))  # 添加空格
        current_token = token[1:] # remove the speical character
        current_score = score
        count = 1
    elif token.endswith("Ċ"):
        if current_token:
            merged_tokens_scores.append((current_token, current_score / count))
        merged_tokens_scores.append(("<br><br>", 0))  # 添加换行符
        current_token = ""
        current_score = 0
        count = 0
    elif token.startswith("<") and token.endswith(">"):
        if current_token:
            merged_tokens_scores.append((current_token, current_score / count))
            current_token = token
            current_score = 0
            count = 1
    elif token in (',', '.', ':', '"', "'", '?', '!', '-', ';', '(', ')', '[', ']', '{', '}', '<', '>', '/'):
        if current_token:
            merged_tokens_scores.append((current_token, current_score / count))
        current_token = token
        current_score = score
        count = 1
    else:
        current_token += token
        current_score += score
        count += 1
    # print(token)

if current_token:
    merged_tokens_scores.append((current_token, current_score / count))


# 根据分数高亮文本（示例中使用HTML标签）
highlighted_text = ""
for token, score in merged_tokens_scores:
    # 动态设置背景颜色：score为0时为白色，score为1时为绿色
    red = int((1 - score) * 255)
    green = 255
    blue = int((1 - score) * 255)
    color = f'rgb({red}, {green}, {blue})'
    highlighted_text += f'<span style="background-color: {color}; color: black;">{token}</span> '

# 打印高亮后的文本
from IPython.display import display, HTML
display(HTML(highlighted_text.strip()))

In [11]:
(test_mask_prob * test_context_mask).sum(-1) / test_context_mask.sum(-1)

tensor(0.4551, device='cuda:0')

In [31]:
test_mask_prob * test_context_mask

tensor([0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 

In [None]:
k = 10
top_k_indices = np.argpartition(scores, -k)[-k:]
top_k_values = scores[top_k_indices]
print(top_k_values)

In [21]:
merged_tokens_scores

[('system', 0.0),
 ('<|end_header_id|>', 0.0),
 ('<br><br>', 0),
 ('You', 0.0),
 ('are', 0.0),
 ('a', 0.0),
 ('chatbot', 0.0),
 ('for', 0.0),
 ('answering', 0.0),
 ('questions', 0.0),
 ('.', 0.0),
 ('You', 0.0),
 ('can', 0.0),
 ('help', 0.0),
 ('users', 0.0),
 ('with', 0.0),
 ('their', 0.0),
 ('questions', 0.0),
 ('via', 0.0),
 ('concise', 0.0),
 ('responses', 0.0),
 ('.', 0.0),
 ('<|start_header_id|>user', 0.0),
 ('<|end_header_id|>', 0.0),
 ('<br><br>', 0),
 ('Question', 0.0),
 (':', 0.0),
 ('What', 0.0),
 ('is', 0.0),
 ('the', 0.0),
 ('Grotto', 0.0),
 ('at', 0.0),
 ('Notre', 0.0),
 ('Dame', 0.0),
 ('<br><br>', 0),
 ('Context', 0.8096184730529785),
 (':', 0.0),
 ('Architecturally', 0.8208637833595276),
 (',', 0.0),
 ('the', 0.31576451659202576),
 ('school', 0.776223361492157),
 ('has', 0.7286538481712341),
 ('a', 0.6366741061210632),
 ('Catholic', 0.5874115824699402),
 ('character', 0.9577630758285522),
 ('.', 0.0),
 ('Atop', 0.7504719495773315),
 ('the', 0.6626720428466797),
 ('Main

In [22]:
filtered_token_scores

[('<|start_header_id|>', 0.0),
 ('system', 0.0),
 ('<|end_header_id|>', 0.0),
 ('ĊĊ', 0.0),
 ('You', 0.0),
 ('Ġare', 0.0),
 ('Ġa', 0.0),
 ('Ġchat', 0.0),
 ('bot', 0.0),
 ('Ġfor', 0.0),
 ('Ġanswering', 0.0),
 ('Ġquestions', 0.0),
 ('.', 0.0),
 ('ĠYou', 0.0),
 ('Ġcan', 0.0),
 ('Ġhelp', 0.0),
 ('Ġusers', 0.0),
 ('Ġwith', 0.0),
 ('Ġtheir', 0.0),
 ('Ġquestions', 0.0),
 ('Ġvia', 0.0),
 ('Ġconcise', 0.0),
 ('Ġresponses', 0.0),
 ('.', 0.0),
 ('<|start_header_id|>', 0.0),
 ('user', 0.0),
 ('<|end_header_id|>', 0.0),
 ('ĊĊ', 0.0),
 ('Question', 0.0),
 (':', 0.0),
 ('ĠWhat', 0.0),
 ('Ġis', 0.0),
 ('Ġthe', 0.0),
 ('ĠG', 0.0),
 ('rot', 0.0),
 ('to', 0.0),
 ('Ġat', 0.0),
 ('ĠNotre', 0.0),
 ('ĠDame', 0.0),
 ('?Ċ', 0.0),
 ('Context', 0.8096185),
 (':', 0.8191038),
 ('ĠArchitect', 0.768306),
 ('urally', 0.87342155),
 (',', 0.7944344),
 ('Ġthe', 0.31576452),
 ('Ġschool', 0.77622336),
 ('Ġhas', 0.72865385),
 ('Ġa', 0.6366741),
 ('ĠCatholic', 0.5874116),
 ('Ġcharacter', 0.9577631),
 ('.', 0.82176286),
 ('