In [1]:
import os
import json 
import logging

logging.basicConfig(
    filename='log/app.log',            # Specify the log file name
    level=logging.DEBUG,           # Set the log level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
    format='%(asctime)s - %(levelname)s - %(message)s'  # Set the log format
)

# Load the environment configuration JSON data
json_path = 'env_config.json'
with open(json_path, 'r') as file:
    env_config = json.load(file)

hf_home = env_config['HF_HOME']
# Set the HF_HOME environment variable
os.environ['HF_HOME'] = hf_home
# Set the access token to huggingface hub
access_token = env_config['access_token']
os.environ['HUGGINGFACE_HUB_TOKEN'] = access_token

In [2]:
import transformers 
print(transformers.__version__)

from transformers import pipeline
import torch

from accelerate import Accelerator
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel
from transformers import LlamaTokenizerFast
import torch.nn.functional as F


accelerator = Accelerator()
device = accelerator.device

model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
# model_id = "meta-llama/Meta-Llama-3-8B"  # non-instruct version

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    token=access_token,
)

tokenizer = LlamaTokenizerFast.from_pretrained(model_id, token=access_token)
tokenizer.pad_token = tokenizer.eos_token

  from .autonotebook import tqdm as notebook_tqdm


4.41.0


Loading checkpoint shards: 100%|██████████| 4/4 [00:06<00:00,  1.60s/it]
The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'PreTrainedTokenizerFast'. 
The class this function is called from is 'LlamaTokenizerFast'.
You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama_fast.LlamaTokenizerFast'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [3]:
from llmexp.helper import LlmExpHelper
from datasets import load_dataset
from torch.utils.data import DataLoader

# imdb = load_dataset("imdb")
ds = load_dataset("rajpurkar/squad")
train_ds = ds['train']
test_ds = ds['validation']
# from datasets import load_dataset

# ds = load_dataset("stanfordnlp/sst2")
# train_ds = ds['train']
llm_exp_helper = LlmExpHelper(tokenizer, 'squad')
collate_fn = llm_exp_helper.get_collate_fun()

# Define batch size here!
batch_size = 8
train_dataloader = DataLoader(train_ds, batch_size=batch_size, collate_fn=collate_fn, shuffle=False)
test_dataloader = DataLoader(test_ds, batch_size=batch_size, collate_fn=collate_fn, shuffle=False)

In [4]:
# next(iter(train_dataloader))
train_ds[0]

{'id': '5733be284776f41900661182',
 'title': 'University_of_Notre_Dame',
 'context': 'Architecturally, the school has a Catholic character. Atop the Main Building\'s gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend "Venite Ad Me Omnes". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary.',
 'question': 'To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France?',
 'answers': {'text': ['Saint Bernadette Soubirous'], 'answer_start': [515]}}

In [5]:
def collate_fn(examples):
    def num_words(x):
        return len(x.split())
    def get_first_k_words(x, k):
        return ' '.join(x.split()[:k])
    def get_cliped_text(texts, max_len):
        return [text if num_words(text) <= max_len else get_first_k_words(text, max_len) for text in texts]
    tokenizer = self.tokenizer
    max_len = 1024 # characters limit other than token limit
    if self.dataset == 'imdb':
        texts = [example['text'] for example in examples]
        texts = get_cliped_text(texts, max_len)
        sys_context = "You are a chatbot for sentiment analysis. You can help users with their questions via concise responses of POSITIVE, or NEGATIVE."
    elif self.dataset == 'sst2':
        texts = [example['sentence'] for example in examples]
        texts = get_cliped_text(texts, max_len)
        sys_context = "You are a chatbot for sentiment analysis. You can help users with their questions via concise responses of POSITIVE, or NEGATIVE."
    elif self.dataset == 'squad':
        context = [example['context'] for example in examples]
        context = get_cliped_text(context, max_len)
        question = [example['question'] for example in examples]
        texts = [f"Context: {context[i]}\nQuestion: {question[i]}" for i in range(len(context))]
        sys_context = "You are a chatbot for answering questions. You can help users with their questions via concise responses."

    messages_lambda = lambda texts: [
        {"role": "system", "content": sys_context},
        {"role": "user", "content": texts},
    ]

    messages = list(map(messages_lambda, texts))

    messages_with_template_applied = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True,
    )
    batch = tokenizer(
                messages_with_template_applied,
                add_special_tokens=False,
                padding=True,
                return_tensors="pt",
                )
    
    # find the template boundaries
    text_lens = [len(tokenizer.encode(text)) for text in texts]
    text_lens_tensor = torch.tensor(text_lens, dtype=torch.long)
    
    def apply_mask(mask_tensor, text_lens_tensor):
        batch_size, seq_len = mask_tensor.shape
        for i in range(batch_size):
            text_len = text_lens_tensor[i].item()
            mask_tensor[i, -text_len-5:-5] = 0
        return 1- mask_tensor

    mask_tensor = apply_mask(torch.ones_like(batch['input_ids']), text_lens_tensor)

    batch['context_mask'] = mask_tensor

    if self.dataset == 'squad':
        answers_start = [example['answers']['answer_start'][0] for example in examples]
        answers_end = [example['answers']['answer_start'][0] + len(example['answers']['text'][0]) for example in examples]
        batch['answers_start'] = torch.tensor(answers_start).long()
        batch['answers_end'] = torch.tensor(answers_end).long()
    
    return batch

In [6]:
from llmexp.squad_model3 import MaskGeneratingModel

mask_gen_model = MaskGeneratingModel(hidden_size=4096, mlp_hidden_dim=4096, mlp_bottleneck_dim=768, mlp_num_blocks=5)
mask_gen_model.to(device)


from tqdm import tqdm

# Set pad_token_id if it is not set
if tokenizer.pad_token_id is None:
    tokenizer.pad_token_id = tokenizer.eos_token_id

pad_token_id = tokenizer.pad_token_id

terminators = [
    tokenizer.eos_token_id,
    tokenizer.convert_tokens_to_ids("<|eot_id|>")
]

optimizer = torch.optim.Adam(mask_gen_model.parameters(), lr=1e-7)

In [7]:
mask_gen_model.train()
for epoch in range(1):
    pbar = tqdm(train_dataloader)
    for idx, data in enumerate(pbar):
        input_ids = data['input_ids'].to(device)
        attention_mask = data['attention_mask'].to(device)
        context_mask = data['context_mask'].to(device)
        # get generated texts
        gen_outputs = model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_new_tokens=128,
            eos_token_id=terminators,
            pad_token_id=tokenizer.pad_token_id,
            do_sample=True,
            temperature=0.6,
            top_p=0.9,
            return_dict_in_generate=True,
            output_scores=True,
        )
        gen_tokens = gen_outputs.sequences
        pad_length = gen_tokens.size(1) - input_ids.size(1)
        # get the attention mask for the generated tokens, and also mask the padding tokens
        gen_attention_mask = F.pad(attention_mask, (0, pad_length), mode='constant', value=1)
        # (gen_tokens != pad_token_id).long() is the tokens mask, 1 for real tokens and 0 for padding tokens
        unpaded_token_mask = (gen_tokens != pad_token_id).long()
        unpaded_token_mask[:, :-pad_length] = 1
        gen_attention_mask = gen_attention_mask * unpaded_token_mask
        # print(gen_tokens[0])
        # print(gen_attention_mask[0])
        # get the response mask, which is the mask for the generated tokens (the user inputs are masked with 0)
        response_mask = gen_attention_mask.clone()
        response_mask[:, :-pad_length] = 0 # TODO: 有问题. 有问题吗？

        context_mask = F.pad(context_mask, (0, pad_length), mode='constant', value=0)

        # Get the last hidden state for the prompt + response sequence
        with torch.no_grad():
            full_outputs = model(input_ids=gen_tokens, attention_mask=gen_attention_mask, output_hidden_states=True, return_dict=True)
            last_hidden_state = full_outputs.hidden_states[-1]
            last_hidden_state = last_hidden_state.float()
        
        mask_logits = mask_gen_model(last_hidden_state)

        mask_gen_outputs = mask_gen_model.loss_func(model, gen_tokens, gen_attention_mask, context_mask, mask_logits, response_mask, 
                                                                           num_samples=5)
        loss, reward_loss, mask_loss, mask_mean, mean_reward = mask_gen_outputs['loss'], mask_gen_outputs['reward_loss'], mask_gen_outputs['mask_loss'], mask_gen_outputs['mask_mean'], mask_gen_outputs['mean_reward']
        log = (f"Epoch {epoch+1}, Step {idx+1}: Loss = {loss.item():.4f}, " 
                             f"Reward Loss = {reward_loss.item():.4f}, "
                             f"Mean Reward = {mean_reward.mean().item():.4f},"
                             f"Mask_loss = {mask_loss.item():.4f} "
                             f"mask_mean = {mask_mean.item():.4f}"
        )
        pbar.set_description(log)
        logging.debug(log)
    
        # the parameters before updating
        params_before = mask_gen_model.state_dict()

        # Train the model (inner loop)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # # the mask_prob after the updates
        # with torch.no_grad():
        #     mask_logits_after = mask_gen_model(last_hidden_state)

        #     mask_gen_outputs_after = mask_gen_model.loss_func(model, gen_tokens, gen_attention_mask, context_mask, mask_logits_after, response_mask, 
        #                                                                     num_samples=5)
        #     loss_after, reward_loss_after, mask_loss_after, mask_mean_after, mean_reward_after = mask_gen_outputs_after['loss'], mask_gen_outputs_after['reward_loss'], mask_gen_outputs_after['mask_loss'], mask_gen_outputs_after['mask_mean'], mask_gen_outputs_after['mean_reward']
        #     mask_prob_after = (torch.sigmoid(mask_logits_after) * context_mask).clone().detach()
        #     mean_reward_after = mean_reward_after.clone().detach()

        # # load the parameters before the updates
        # mask_gen_model.load_state_dict(params_before)
        # mask_logits_before = mask_gen_model(last_hidden_state)

        # mask_gen_outputs_before = mask_gen_model.loss_func(model, gen_tokens, gen_attention_mask, context_mask, mask_logits_before, response_mask, 
        #                                                                    num_samples=5)
        # loss_before, reward_loss_before, mask_loss_before, mask_mean_before, mean_reward_before = mask_gen_outputs_before['loss'], mask_gen_outputs_before['reward_loss'], mask_gen_outputs_before['mask_loss'], mask_gen_outputs_before['mask_mean'], mask_gen_outputs_before['mean_reward']
        # mask_prob_before = (torch.sigmoid(mask_logits_before) * context_mask)

        # # calculate the ratio of the mask probabilities before and after the updates
        # ratio = mask_prob_after / (mask_prob_before + 1e-6)

        # # 定义PPO的损失函数，假设clip_param是你定义的剪切参数
        # clip_param = 0.2
        # advantage = (mean_reward_after - mean_reward_before).unsqueeze(-1)  # 计算优势函数（advantage），这是根据任务定义的
        # surr1 = ratio * advantage
        # surr2 = torch.clamp(ratio, 1 - clip_param, 1 + clip_param) * advantage
        # ppo_loss = -torch.min(surr1, surr2).mean()

        # # 更新模型参数
        # optimizer.zero_grad()
        # ppo_loss.backward()
        # optimizer.step()

        if idx % 10 == 0:
            print()
        if idx % 200 == 0 and idx != 0:
            torch.save(mask_gen_model.state_dict(), f'saved_model/mask_gen_model_{epoch}_{idx}.pth') 
            print()
            # break

Epoch 1, Step 1: Loss = 0.2536, Reward Loss = 0.2048, Mean Reward = 0.7592,Mask_loss = 0.4885 mask_mean = 0.8487:   0%|          | 1/10950 [00:03<10:58:27,  3.61s/it]




Epoch 1, Step 11: Loss = 0.0085, Reward Loss = 0.0097, Mean Reward = 0.0521,Mask_loss = -0.0122 mask_mean = 0.0827:   0%|          | 11/10950 [00:31<8:15:01,  2.72s/it]




Epoch 1, Step 21: Loss = 0.0332, Reward Loss = 0.0350, Mean Reward = 0.1062,Mask_loss = -0.0177 mask_mean = 0.1897:   0%|          | 21/10950 [01:01<8:30:04,  2.80s/it]




Epoch 1, Step 31: Loss = 0.0514, Reward Loss = 0.0520, Mean Reward = 0.1411,Mask_loss = -0.0060 mask_mean = 0.2098:   0%|          | 31/10950 [01:28<9:21:41,  3.09s/it]




Epoch 1, Step 41: Loss = 0.1183, Reward Loss = 0.1148, Mean Reward = 0.2955,Mask_loss = 0.0351 mask_mean = 0.3433:   0%|          | 41/10950 [01:56<8:55:55,  2.95s/it] 




Epoch 1, Step 51: Loss = 0.0071, Reward Loss = 0.0073, Mean Reward = 0.0990,Mask_loss = -0.0023 mask_mean = 0.0231:   0%|          | 51/10950 [02:43<14:03:17,  4.64s/it]




Epoch 1, Step 61: Loss = 0.0378, Reward Loss = 0.0390, Mean Reward = 0.1233,Mask_loss = -0.0126 mask_mean = 0.1624:   1%|          | 61/10950 [03:30<13:26:41,  4.45s/it]




Epoch 1, Step 71: Loss = 0.0481, Reward Loss = 0.0498, Mean Reward = 0.1305,Mask_loss = -0.0170 mask_mean = 0.2496:   1%|          | 71/10950 [04:09<12:55:58,  4.28s/it]




Epoch 1, Step 81: Loss = 0.1877, Reward Loss = 0.1753, Mean Reward = 0.4263,Mask_loss = 0.1238 mask_mean = 0.5116:   1%|          | 81/10950 [04:53<11:30:06,  3.81s/it] 




Epoch 1, Step 91: Loss = 0.0167, Reward Loss = 0.0172, Mean Reward = 0.1110,Mask_loss = -0.0051 mask_mean = 0.0564:   1%|          | 91/10950 [05:49<14:57:41,  4.96s/it]




Epoch 1, Step 101: Loss = 0.1567, Reward Loss = 0.1504, Mean Reward = 0.3541,Mask_loss = 0.0635 mask_mean = 0.4083:   1%|          | 101/10950 [06:33<12:30:24,  4.15s/it]




Epoch 1, Step 111: Loss = 0.0360, Reward Loss = 0.0372, Mean Reward = 0.1335,Mask_loss = -0.0120 mask_mean = 0.1720:   1%|          | 111/10950 [07:17<15:38:28,  5.20s/it]




Epoch 1, Step 121: Loss = 0.0175, Reward Loss = 0.0186, Mean Reward = 0.0853,Mask_loss = -0.0112 mask_mean = 0.1033:   1%|          | 121/10950 [08:01<12:28:58,  4.15s/it]




Epoch 1, Step 131: Loss = 0.3110, Reward Loss = 0.2883, Mean Reward = 0.6352,Mask_loss = 0.2267 mask_mean = 0.5214:   1%|          | 131/10950 [08:51<15:39:54,  5.21s/it] 




Epoch 1, Step 141: Loss = 0.2280, Reward Loss = 0.2037, Mean Reward = 0.5370,Mask_loss = 0.2433 mask_mean = 0.6851:   1%|▏         | 141/10950 [09:41<14:46:34,  4.92s/it] 




Epoch 1, Step 151: Loss = 0.0063, Reward Loss = 0.0083, Mean Reward = 0.0335,Mask_loss = -0.0199 mask_mean = 0.1188:   1%|▏         | 151/10950 [10:33<15:36:30,  5.20s/it]




Epoch 1, Step 161: Loss = 0.0024, Reward Loss = 0.0026, Mean Reward = 0.0516,Mask_loss = -0.0022 mask_mean = 0.0144:   1%|▏         | 161/10950 [11:26<16:15:22,  5.42s/it]




Epoch 1, Step 171: Loss = 0.0928, Reward Loss = 0.0911, Mean Reward = 0.2829,Mask_loss = 0.0169 mask_mean = 0.3075:   2%|▏         | 171/10950 [12:30<20:06:35,  6.72s/it] 




Epoch 1, Step 181: Loss = 0.0885, Reward Loss = 0.0873, Mean Reward = 0.2064,Mask_loss = 0.0118 mask_mean = 0.2946:   2%|▏         | 181/10950 [13:43<19:08:42,  6.40s/it] 




Epoch 1, Step 191: Loss = 0.1214, Reward Loss = 0.1188, Mean Reward = 0.2737,Mask_loss = 0.0264 mask_mean = 0.3290:   2%|▏         | 191/10950 [14:30<13:29:01,  4.51s/it] 




Epoch 1, Step 201: Loss = 0.1369, Reward Loss = 0.1256, Mean Reward = 0.3422,Mask_loss = 0.1127 mask_mean = 0.6744:   2%|▏         | 200/10950 [15:12<12:04:10,  4.04s/it] 




Epoch 1, Step 201: Loss = 0.1369, Reward Loss = 0.1256, Mean Reward = 0.3422,Mask_loss = 0.1127 mask_mean = 0.6744:   2%|▏         | 201/10950 [15:26<24:46:59,  8.30s/it]




Epoch 1, Step 211: Loss = 0.1091, Reward Loss = 0.1074, Mean Reward = 0.2355,Mask_loss = 0.0163 mask_mean = 0.3063:   2%|▏         | 211/10950 [16:15<15:41:03,  5.26s/it] 




Epoch 1, Step 221: Loss = 0.0142, Reward Loss = 0.0157, Mean Reward = 0.0623,Mask_loss = -0.0152 mask_mean = 0.1106:   2%|▏         | 221/10950 [17:09<18:04:19,  6.06s/it]




Epoch 1, Step 231: Loss = 0.0296, Reward Loss = 0.0312, Mean Reward = 0.0854,Mask_loss = -0.0160 mask_mean = 0.2048:   2%|▏         | 231/10950 [18:04<14:21:15,  4.82s/it]




Epoch 1, Step 241: Loss = 0.1621, Reward Loss = 0.1366, Mean Reward = 0.5018,Mask_loss = 0.2546 mask_mean = 0.7848:   2%|▏         | 241/10950 [18:45<11:31:38,  3.88s/it] 




Epoch 1, Step 251: Loss = 0.2121, Reward Loss = 0.1784, Mean Reward = 0.5930,Mask_loss = 0.3371 mask_mean = 0.8372:   2%|▏         | 251/10950 [19:24<11:19:39,  3.81s/it]




Epoch 1, Step 261: Loss = 0.0481, Reward Loss = 0.0505, Mean Reward = 0.1238,Mask_loss = -0.0240 mask_mean = 0.3647:   2%|▏         | 261/10950 [20:02<11:32:58,  3.89s/it]




Epoch 1, Step 271: Loss = 0.1650, Reward Loss = 0.1480, Mean Reward = 0.4200,Mask_loss = 0.1699 mask_mean = 0.7116:   2%|▏         | 271/10950 [20:53<15:48:16,  5.33s/it] 




Epoch 1, Step 281: Loss = 0.1367, Reward Loss = 0.1282, Mean Reward = 0.3114,Mask_loss = 0.0847 mask_mean = 0.6025:   3%|▎         | 281/10950 [21:42<17:09:16,  5.79s/it] 




Epoch 1, Step 291: Loss = 0.1490, Reward Loss = 0.1426, Mean Reward = 0.3066,Mask_loss = 0.0641 mask_mean = 0.5133:   3%|▎         | 291/10950 [22:31<13:26:27,  4.54s/it] 




Epoch 1, Step 301: Loss = 0.0698, Reward Loss = 0.0699, Mean Reward = 0.1642,Mask_loss = -0.0009 mask_mean = 0.3511:   3%|▎         | 301/10950 [23:27<15:36:51,  5.28s/it]




Epoch 1, Step 311: Loss = 0.0121, Reward Loss = 0.0141, Mean Reward = 0.0544,Mask_loss = -0.0193 mask_mean = 0.1333:   3%|▎         | 311/10950 [24:18<13:30:54,  4.57s/it]




Epoch 1, Step 321: Loss = 0.1392, Reward Loss = 0.1310, Mean Reward = 0.3262,Mask_loss = 0.0825 mask_mean = 0.5150:   3%|▎         | 321/10950 [25:02<13:10:44,  4.46s/it] 




Epoch 1, Step 331: Loss = 0.1103, Reward Loss = 0.1017, Mean Reward = 0.2647,Mask_loss = 0.0866 mask_mean = 0.4266:   3%|▎         | 331/10950 [25:46<13:31:02,  4.58s/it]




Epoch 1, Step 341: Loss = 0.0656, Reward Loss = 0.0658, Mean Reward = 0.1722,Mask_loss = -0.0022 mask_mean = 0.2484:   3%|▎         | 341/10950 [26:27<11:46:54,  4.00s/it]




Epoch 1, Step 351: Loss = 0.2506, Reward Loss = 0.2234, Mean Reward = 0.5768,Mask_loss = 0.2723 mask_mean = 0.6977:   3%|▎         | 351/10950 [27:07<12:12:42,  4.15s/it] 




Epoch 1, Step 361: Loss = 0.0208, Reward Loss = 0.0243, Mean Reward = 0.0643,Mask_loss = -0.0354 mask_mean = 0.2629:   3%|▎         | 361/10950 [27:53<15:03:55,  5.12s/it]




Epoch 1, Step 371: Loss = 0.0257, Reward Loss = 0.0275, Mean Reward = 0.0977,Mask_loss = -0.0174 mask_mean = 0.1723:   3%|▎         | 371/10950 [28:41<12:18:41,  4.19s/it]




Epoch 1, Step 381: Loss = 0.0591, Reward Loss = 0.0595, Mean Reward = 0.1466,Mask_loss = -0.0039 mask_mean = 0.4286:   3%|▎         | 381/10950 [29:26<12:42:42,  4.33s/it]




Epoch 1, Step 391: Loss = 0.1587, Reward Loss = 0.1518, Mean Reward = 0.3280,Mask_loss = 0.0691 mask_mean = 0.5128:   4%|▎         | 391/10950 [30:11<13:51:34,  4.73s/it] 




Epoch 1, Step 401: Loss = 0.0720, Reward Loss = 0.0728, Mean Reward = 0.1784,Mask_loss = -0.0082 mask_mean = 0.4058:   4%|▎         | 400/10950 [30:52<11:32:25,  3.94s/it]




Epoch 1, Step 401: Loss = 0.0720, Reward Loss = 0.0728, Mean Reward = 0.1784,Mask_loss = -0.0082 mask_mean = 0.4058:   4%|▎         | 401/10950 [31:08<24:45:48,  8.45s/it]




Epoch 1, Step 411: Loss = 0.1173, Reward Loss = 0.1098, Mean Reward = 0.2959,Mask_loss = 0.0748 mask_mean = 0.4822:   4%|▍         | 411/10950 [31:49<10:36:51,  3.63s/it] 




Epoch 1, Step 421: Loss = 0.0499, Reward Loss = 0.0509, Mean Reward = 0.1444,Mask_loss = -0.0095 mask_mean = 0.2049:   4%|▍         | 421/10950 [32:18<8:51:01,  3.03s/it]




Epoch 1, Step 422: Loss = 0.0886, Reward Loss = 0.0879, Mean Reward = 0.2134,Mask_loss = 0.0072 mask_mean = 0.3238:   4%|▍         | 422/10950 [32:23<13:28:15,  4.61s/it]


KeyboardInterrupt: 

In [36]:
import numpy as np
import torch.nn.functional as F

mask_gen_model.eval()

# tokens = tokenizer.convert_ids_to_tokens(gen_tokens[idx])
# texts = "This movie was the best movie I have ever seen! some scenes were ridiculous, but acting was great."
# texts = "I did not like this movie. Some of the actors were good, but overall the movie was boring."
# texts = "I hate that I love you."
# texts = "I don't like this movie."
# texts = "I really love this film."
# texts = "I really love this film. The acting was great, and the story was amazing. I would recommend this movie to everyone."
# # texts = "I don't like this movie. The acting was terrible, and the story was boring. I would not recommend this movie to anyone."
# messages_lambda = lambda texts: [
#             {"role": "system", "content": "Answer the question based on the context."},
#             # {"role": "system", "content": "You are a chatbot for sentimate analysis."},
#             {"role": "user", "content": texts},
#         ]
# messages = messages_lambda(texts)
# messages_with_template_applied = tokenizer.apply_chat_template(
#             messages,
#             tokenize=False,
#             add_generation_prompt=True,
#         )

# # test_text = [{"text": texts, "label": None}]
# test_text = [{"sentence": texts, "label": None}]
# test_inputs = collate_fn(test_text).to(device)

test_inputs = next(iter(test_dataloader)).to(device)
# test_inputs = next(iter(train_dataloader)).to(device)

# tokens = tokenizer.convert_ids_to_tokens(test_inputs['input_ids'][idx])

# generate the answer for the test inputs
gen_outputs = model.generate(
            input_ids=test_inputs['input_ids'],
            attention_mask=test_inputs['attention_mask'],
            max_new_tokens=128,
            eos_token_id=terminators,
            pad_token_id=tokenizer.pad_token_id,
            do_sample=True,
            temperature=0.6,
            top_p=0.9,
            return_dict_in_generate=True,
            output_scores=True,
        )
input_ids = test_inputs['input_ids']
attention_mask = test_inputs['attention_mask']
gen_tokens = gen_outputs.sequences
pad_length = gen_tokens.size(1) - input_ids.size(1)
# get the attention mask for the generated tokens, and also mask the padding tokens
gen_attention_mask = F.pad(attention_mask, (0, pad_length), mode='constant', value=1)
context_mask = F.pad(test_inputs['context_mask'], (0, pad_length), mode='constant', value=0)
# (gen_tokens != pad_token_id).long() is the tokens mask, 1 for real tokens and 0 for padding tokens
unpaded_token_mask = (gen_tokens != pad_token_id).long()
unpaded_token_mask[:, :-pad_length] = 1
gen_attention_mask = gen_attention_mask * unpaded_token_mask

with torch.no_grad():
    # prompt_outputs = model(input_ids=test_inputs['input_ids'], attention_mask=test_inputs['attention_mask'], output_hidden_states=True, return_dict=True)
    prompt_outputs = model(input_ids=gen_tokens, attention_mask=gen_attention_mask, output_hidden_states=True, return_dict=True)

    last_hidden_state = prompt_outputs.hidden_states[-1].float()
    mask_logits = mask_gen_model(last_hidden_state)




In [56]:
import random
idx = random.randint(0, 4)
test_ids = gen_tokens[idx]
test_mask = gen_attention_mask[idx]
test_mask_prob = torch.sigmoid(mask_logits[idx])
test_context_mask = context_mask[idx]

test_tokens = tokenizer.convert_ids_to_tokens(test_ids)
scores = test_mask_prob * test_context_mask
def normalize_except_zeros(array):
    # Create a mask to identify non-zero elements
    mask = array != 0
    
    # Extract non-zero elements
    non_zero_elements = array[mask]
    
    # Normalize non-zero elements
    min_val = np.min(non_zero_elements)
    max_val = np.max(non_zero_elements)
    normalized_non_zero_elements = (non_zero_elements - min_val) / (max_val - min_val)
    
    # Create a copy of the original array to preserve zero values
    normalized_array = np.copy(array)
    
    # Assign normalized values back to the corresponding positions
    normalized_array[mask] = normalized_non_zero_elements
    
    return normalized_array
# scores = normalize_except_zeros(scores.detach().cpu().numpy())

# remove special tokens
filtered_token_scores = [(token, score) for token, score in zip(test_tokens, scores) if token not in tokenizer.all_special_tokens]

# combine subwords
merged_tokens_scores = []
current_token = ""
current_score = 0
count = 0

for token, score in filtered_token_scores:
    if token.startswith("Ġ"):
        if current_token:
            merged_tokens_scores.append((current_token, current_score / count))
            # merged_tokens_scores.append((" ", 0))  # 添加空格
        current_token = token[1:] # remove the speical character
        current_score = score
        count = 1
    elif token.endswith("Ċ"):
        if current_token:
            merged_tokens_scores.append((current_token, current_score / count))
        merged_tokens_scores.append(("<br><br>", 0))  # 添加换行符
        current_token = ""
        current_score = 0
        count = 0
    else:
        current_token += token
        current_score += score
        count += 1

if current_token:
    merged_tokens_scores.append((current_token, current_score / count))


# 根据分数高亮文本（示例中使用HTML标签）
highlighted_text = ""
for token, score in merged_tokens_scores:
    # 动态设置背景颜色：score为0时为白色，score为1时为绿色
    red = int((1 - score) * 255)
    green = 255
    blue = int((1 - score) * 255)
    color = f'rgb({red}, {green}, {blue})'
    highlighted_text += f'<span style="background-color: {color}; color: black;">{token}</span> '

# 打印高亮后的文本
from IPython.display import display, HTML
display(HTML(highlighted_text.strip()))

In [None]:
merged_tokens_scores

[('<|start_header_id|>system<|end_header_id|>', tensor(0., device='cuda:0')),
 ('<br><br>', 0),
 ('You', tensor(0., device='cuda:0')),
 ('are', tensor(0., device='cuda:0')),
 ('a', tensor(0., device='cuda:0')),
 ('chatbot', tensor(0., device='cuda:0')),
 ('for', tensor(0., device='cuda:0')),
 ('answering', tensor(0., device='cuda:0')),
 ('questions.', tensor(0., device='cuda:0')),
 ('You', tensor(0., device='cuda:0')),
 ('can', tensor(0., device='cuda:0')),
 ('help', tensor(0., device='cuda:0')),
 ('users', tensor(0., device='cuda:0')),
 ('with', tensor(0., device='cuda:0')),
 ('their', tensor(0., device='cuda:0')),
 ('questions', tensor(0., device='cuda:0')),
 ('via', tensor(0., device='cuda:0')),
 ('concise', tensor(0., device='cuda:0')),
 ('responses.<|start_header_id|>user<|end_header_id|>',
  tensor(0., device='cuda:0')),
 ('<br><br>', 0),
 ('Context:', tensor(0.5974, device='cuda:0')),
 ('Architecturally,', tensor(0.6714, device='cuda:0')),
 ('the', tensor(0.4965, device='cuda:0'

In [None]:
test_mask_prob * test_context_mask

tensor([0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 

In [None]:
expl_raw

array([0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.     

In [None]:
mask_prob = torch.sigmoid(mask_logits)
(mask_prob * context_mask).sum(-1) / context_mask.sum(dim=-1)

tensor([0.1241, 0.2500, 0.7787, 0.0909, 0.1908, 0.1361, 0.3316, 0.5104, 0.2696,
        0.1001, 0.5931, 0.0812, 0.4139, 0.3334, 0.4853, 0.2987],
       device='cuda:0')

In [None]:
mask_prob

tensor([[9.9998e-01, 1.3412e-02, 2.1168e-05, 7.2758e-04, 9.5831e-01, 2.8712e-06,
         1.9530e-03, 6.8251e-04, 9.6814e-01, 7.6582e-04, 2.2164e-06, 1.7009e-02,
         2.4596e-09, 1.4519e-04, 3.0463e-06, 7.9092e-04, 4.0336e-06, 5.2874e-05,
         1.4755e-07, 5.3417e-09, 1.0859e-07, 8.2496e-07, 1.0534e-05, 2.1604e-06,
         6.6980e-10, 1.8266e-01, 5.4030e-07, 4.0081e-02, 3.5185e-04, 3.1933e-01,
         3.2836e-09, 1.5197e-06, 8.2054e-01, 8.0839e-01, 5.4567e-01, 4.8829e-01,
         1.0000e+00, 4.3639e-05, 3.5097e-07, 3.8482e-09, 1.9672e-04, 1.1242e-07,
         2.7982e-13, 5.4683e-09, 4.5730e-01, 9.5505e-01, 6.7815e-04, 7.4368e-02,
         9.9924e-01, 6.6332e-03, 1.0508e-08, 1.8774e-01]], device='cuda:0')

In [None]:
test_inputs['input_ids'][idx]

tensor([128000, 128006,   9125, 128007,    271,   2675,    527,    264,   6369,
          6465,    369,  27065,   6492,     13,   1472,    649,   1520,   3932,
           449,    872,   4860,   4669,  64694,  14847,    315,  27592,  45450,
            11,    477,  85165,  24093,     13, 128009, 128006,    882, 128007,
           271,   2028,   5818,    574,    279,   1888,   5818,    358,    617,
          3596,   3970,      0,   1063,  16451,   1051,  27873,     11,    719,
         15718,    574,   2294,     13, 128009, 128006,  78191, 128007,    271],
       device='cuda:0')

In [None]:
tokenizer.convert_ids_to_tokens(271)

'ĊĊ'

[0.     0.     0.     0.     0.     0.     0.     0.     0.     0.
 0.     0.     0.     0.     0.     0.     0.     0.     0.     0.
 0.     0.     0.     0.     0.     0.     0.     0.     0.     0.
 0.     0.     0.     0.     0.     0.     0.     0.6432 0.8566 0.5179
 0.2417 0.     0.1211 0.3355 0.618  0.5345 0.1401 0.3401 0.4729 0.3531
 0.661  0.7049 0.0297 0.1724 0.9905 1.     0.1606 0.1107 0.2363 0.2891
 0.116  0.0777 0.     0.     0.     0.     0.     0.     0.     0.    ]


In [None]:
(torch.sigmoid(mask_logits) * context_mask)[idx]

tensor([0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.8527, 0.1770, 0.1719, 0.1660, 0.1613, 0.1549, 0.1581, 0.1622, 0.1666,
        0.1635, 0.1568, 0.1589, 0.1618, 0.1634, 0.1668, 0.1647, 0.1546, 0.1576,
        0.1788, 0.1779, 0.1592, 0.1581, 0.1607, 0.1619, 0.1586, 0.1583, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
       device='cuda:0')

In [None]:
torch.sigmoid(mask_logits)

tensor([[0.4822, 0.4821, 0.4816,  ..., 0.4920, 0.4874, 0.4901],
        [0.4843, 0.4851, 0.4855,  ..., 0.4941, 0.4861, 0.4891],
        [0.4785, 0.4805, 0.4805,  ..., 0.4918, 0.4823, 0.4864],
        ...,
        [0.4753, 0.4758, 0.4761,  ..., 0.4847, 0.4761, 0.4802],
        [0.4876, 0.4883, 0.4882,  ..., 0.4887, 0.4880, 0.4943],
        [0.4843, 0.4851, 0.4852,  ..., 0.4946, 0.4853, 0.4947]],
       device='cuda:0')

In [None]:
mask_gen_model

MaskGeneratingModel(
  (explain_map): MLP(
    (input_layer): Linear(in_features=4096, out_features=1024, bias=True)
    (attention_layers): ModuleList(
      (0-1): 2 x MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=1024, out_features=1024, bias=True)
      )
    )
    (layers): ModuleList(
      (0-1): 2 x Sequential(
        (0): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (1): PReLU(num_parameters=1)
        (2): Linear(in_features=1024, out_features=1024, bias=True)
      )
    )
    (output_layer): Linear(in_features=1024, out_features=1, bias=True)
  )
)

In [None]:
print(tokens[35])
print(expl[35])

<|end_header_id|>
0.0


In [None]:
texts = "This movie was the best movie I have ever seen! some scenes were ridiculous, but acting was great."
# texts = "I really didn't like this movie. Some of the actors were good, but overall the movie was boring."
# texts = "I hate that I love you."
# texts = "I don't like this movie."
# texts = "I really love this film."
messages_lambda = lambda texts: [
    {"role": "system", "content": "You are a chatbot for sentimate analysis. You can help users with their questions via concise responses of POSITIVE, or NEGATIVE."},
    # {"role": "system", "content": "You are a chatbot for sentimate analysis."},
    {"role": "user", "content": texts},
]
messages = messages_lambda(texts)
messages_with_template_applied = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True,
        )
print(messages_with_template_applied)

<|begin_of_text|><|start_header_id|>system<|end_header_id|>

You are a chatbot for sentimate analysis. You can help users with their questions via concise responses of POSITIVE, or NEGATIVE.<|eot_id|><|start_header_id|>user<|end_header_id|>

This movie was the best movie I have ever seen! some scenes were ridiculous, but acting was great.<|eot_id|><|start_header_id|>assistant<|end_header_id|>


