In [1]:
import os
import json 
import logging

logging.basicConfig(
    filename='log/app.log',            # Specify the log file name
    level=logging.DEBUG,           # Set the log level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
    format='%(asctime)s - %(levelname)s - %(message)s'  # Set the log format
)

# Load the environment configuration JSON data
json_path = 'env_config.json'
with open(json_path, 'r') as file:
    env_config = json.load(file)

hf_home = env_config['HF_HOME']
# Set the HF_HOME environment variable
os.environ['HF_HOME'] = hf_home
# Set the access token to huggingface hub
access_token = env_config['access_token']
os.environ['HUGGINGFACE_HUB_TOKEN'] = access_token

In [2]:
import transformers 
print(transformers.__version__)

from transformers import pipeline
import torch

from accelerate import Accelerator
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel
from transformers import LlamaTokenizerFast
import torch.nn.functional as F


accelerator = Accelerator()
device = accelerator.device

model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
# model_id = "meta-llama/Meta-Llama-3-8B"  # non-instruct version

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    token=access_token,
)

tokenizer = LlamaTokenizerFast.from_pretrained(model_id, token=access_token)

  from .autonotebook import tqdm as notebook_tqdm


4.41.0


Loading checkpoint shards: 100%|██████████| 4/4 [00:06<00:00,  1.72s/it]
The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'PreTrainedTokenizerFast'. 
The class this function is called from is 'LlamaTokenizerFast'.
You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama_fast.LlamaTokenizerFast'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [3]:
from datasets import load_dataset
from torch.utils.data import DataLoader
# from llmexp.helper import LlmExpHelper
class LlmExpHelper:
    def __init__(self, tokenizer, model, device):
        self.tokenizer = tokenizer
        self.model = model
        self.device = device
    
    def get_collate_fun(self):
        return lambda examples: self.collate_fn(examples)

    def collate_fn(self, examples):
        def num_words(x):
            return len(x.split())
        def get_first_k_words(x, k):
            return ' '.join(x.split()[:k])
        tokenizer = self.tokenizer
        max_len = 256 # characters limit other than token limit
        # texts = [example['sentence'] for example in examples]
        texts = [example['text'] for example in examples]
        texts = [text if num_words(text) <= max_len else get_first_k_words(text, max_len) for text in texts]
        labels = [example['label'] for example in examples]
        messages_lambda = lambda texts: [
            {"role": "system", "content": "You are a chatbot for sentiment analysis. You can help users with their questions via concise responses of POSITIVE, or NEGATIVE."},
            # {"role": "system", "content": "You are a chatbot for sentimate analysis."},
            {"role": "user", "content": texts},
        ]
        messages = list(map(messages_lambda, texts))

        messages_with_template_applied = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True,
        )
        batch = tokenizer(
                    messages_with_template_applied,
                    add_special_tokens=False,
                    padding=True,
                    return_tensors="pt",
                    )
        
        # find the template boundaries
        text_lens = [len(tokenizer.encode(text)) - 1 for text in texts] # note that the tokenizer.encode adds the eos token
        text_lens_tensor = torch.tensor(text_lens, dtype=torch.long)
        
        def apply_mask(mask_tensor, text_lens_tensor):
            batch_size, seq_len = mask_tensor.shape
            for i in range(batch_size):
                text_len = text_lens_tensor[i].item()
                mask_tensor[i, -text_len-5:-5] = 0
            return 1- mask_tensor

        mask_tensor = apply_mask(torch.ones_like(batch['input_ids']), text_lens_tensor)

        batch['context_mask'] = mask_tensor
        
        return batch

imdb = load_dataset("imdb")
train_ds = imdb['train']
# from datasets import load_dataset

# ds = load_dataset("stanfordnlp/sst2")
# train_ds = ds['train']
llm_exp_helper = LlmExpHelper(tokenizer, model, device)
collate_fn = llm_exp_helper.get_collate_fun()

# Define batch size here!
batch_size = 16
train_dataloader = DataLoader(train_ds, batch_size=batch_size, collate_fn=collate_fn, shuffle=True)

In [4]:
train_ds[1]

{'text': '"I Am Curious: Yellow" is a risible and pretentious steaming pile. It doesn\'t matter what one\'s political views are because this film can hardly be taken seriously on any level. As for the claim that frontal male nudity is an automatic NC-17, that isn\'t true. I\'ve seen R-rated films with male nudity. Granted, they only offer some fleeting views, but where are the R-rated films with gaping vulvas and flapping labia? Nowhere, because they don\'t exist. The same goes for those crappy cable shows: schlongs swinging in the breeze but not a clitoris in sight. And those pretentious indie movies like The Brown Bunny, in which we\'re treated to the site of Vincent Gallo\'s throbbing johnson, but not a trace of pink visible on Chloe Sevigny. Before crying (or implying) "double-standard" in matters of nudity, the mentally obtuse should take into account one unavoidably obvious anatomical difference between men and women: there are no genitals on display when actresses appears nude, 

In [5]:
from llmexp.model4 import MaskGeneratingModel

mask_gen_model = MaskGeneratingModel(hidden_size=4096, mlp_hidden_dim=1024, mlp_bottleneck_dim=768, mlp_num_blocks=2)
mask_gen_model.to(device)


from tqdm import tqdm

# Set pad_token_id if it is not set
if tokenizer.pad_token_id is None:
    tokenizer.pad_token_id = tokenizer.eos_token_id

pad_token_id = tokenizer.pad_token_id

terminators = [
    tokenizer.eos_token_id,
    tokenizer.convert_tokens_to_ids("<|eot_id|>")
]

optimizer = torch.optim.Adam(mask_gen_model.parameters(), lr=1e-5)

In [6]:
mask_gen_model.train()
for epoch in range(1):
    pbar = tqdm(train_dataloader)
    for idx, data in enumerate(pbar):
        input_ids = data['input_ids'].to(device)
        attention_mask = data['attention_mask'].to(device)
        context_mask = data['context_mask'].to(device)
        # get generated texts
        gen_outputs = model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_new_tokens=128,
            eos_token_id=terminators,
            pad_token_id=tokenizer.pad_token_id,
            do_sample=True,
            temperature=0.6,
            top_p=0.9,
            return_dict_in_generate=True,
            output_scores=True,
        )
        gen_tokens = gen_outputs.sequences
        pad_length = gen_tokens.size(1) - input_ids.size(1)
        # get the attention mask for the generated tokens, and also mask the padding tokens
        gen_attention_mask = F.pad(attention_mask, (0, pad_length), mode='constant', value=1)
        # (gen_tokens != pad_token_id).long() is the tokens mask, 1 for real tokens and 0 for padding tokens
        unpaded_token_mask = (gen_tokens != pad_token_id).long()
        unpaded_token_mask[:, :-pad_length] = 1
        gen_attention_mask = gen_attention_mask * unpaded_token_mask
        # get the response mask, which is the mask for the generated tokens (the user inputs are masked with 0)
        response_mask = gen_attention_mask.clone()
        response_mask[:, :-pad_length] = 0 # TODO: 有问题. 有问题吗？

        context_mask = F.pad(context_mask, (0, pad_length), mode='constant', value=0)

        # Get the last hidden state for the prompt + response sequence
        with torch.no_grad():
            full_outputs = model(input_ids=gen_tokens, attention_mask=gen_attention_mask, output_hidden_states=True, return_dict=True)
            last_hidden_state = full_outputs.hidden_states[-1]
            last_hidden_state = last_hidden_state.float()
        
        mask_logits = mask_gen_model(last_hidden_state)

        mask_gen_outputs = mask_gen_model.loss_func(model, gen_tokens, gen_attention_mask, context_mask, mask_logits, response_mask, 
                                                                           num_samples=1)
        loss, reward_loss, mask_loss, mask_mean, mean_reward = mask_gen_outputs['loss'], mask_gen_outputs['reward_loss'], mask_gen_outputs['mask_loss'], mask_gen_outputs['mask_mean'], mask_gen_outputs['mean_reward']
        log = (f"Epoch {epoch+1}, Step {idx+1}: Loss = {loss.item():.4f}, " 
                             f"Reward Loss = {reward_loss.item():.4f}, "
                             f"Mean Reward = {mean_reward.item():.4f},"
                             f"Mask_loss = {mask_loss.item():.4f} "
                             f"mask_mean = {mask_mean.item():.4f}"
        )
        pbar.set_description(log)
        logging.debug(log)
        # Train the model
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        #                      )
        # if idx % 10 == 0:
        #     print()
        if idx % 200 == 0 and idx != 0:
            torch.save(mask_gen_model.state_dict(), f'saved_model/mask_gen_model_{epoch}_{idx}.pth') 
            print()
            # break

  0%|          | 0/1563 [00:00<?, ?it/s]

tensor([ 52,  81,  30,  40,  26,  94,  92, 102,  58,  55,  53, 105,  42,  45,
         95,  63], device='cuda:0', dtype=torch.int32)


Epoch 1, Step 1: Loss = 0.5636, Reward Loss = 0.5636, Mean Reward = 0.9149,Mask_loss = 0.5800 mask_mean = 0.5800:   0%|          | 1/1563 [00:01<49:48,  1.91s/it]

tensor([ 86,  51,  23,  32,  42,  26,  60, 102,  93,  52,  44, 101,  88,  51,
         99, 105], device='cuda:0', dtype=torch.int32)


Epoch 1, Step 2: Loss = 0.4318, Reward Loss = 0.4318, Mean Reward = 0.8772,Mask_loss = 0.1782 mask_mean = 0.1782:   0%|          | 2/1563 [00:03<44:46,  1.72s/it]

tensor([ 45,  99,  66,  66,  22,  90,  89,  96,  42,  29, 102,  93,  46, 105,
         89,  96], device='cuda:0', dtype=torch.int32)


Epoch 1, Step 3: Loss = 0.5403, Reward Loss = 0.5403, Mean Reward = 0.9548,Mask_loss = 0.1163 mask_mean = 0.1163:   0%|          | 3/1563 [00:05<43:19,  1.67s/it]

tensor([102,  42,  42,  48,  51,  98,  92,  96,  40,  59,  53,  46, 108,  97,
         46,  71], device='cuda:0', dtype=torch.int32)


Epoch 1, Step 4: Loss = 0.5120, Reward Loss = 0.5120, Mean Reward = 0.9515,Mask_loss = 0.1285 mask_mean = 0.1285:   0%|          | 4/1563 [00:06<42:55,  1.65s/it]

tensor([69, 50,  4, 46, 50, 63, 57, 31, 36, 44, 52, 49, 57, 78, 33, 49],
       device='cuda:0', dtype=torch.int32)


Epoch 1, Step 5: Loss = 0.3781, Reward Loss = 0.3781, Mean Reward = 0.8423,Mask_loss = 0.1867 mask_mean = 0.1867:   0%|          | 5/1563 [00:07<39:14,  1.51s/it]

tensor([ 52,  95,  39, 104,  77,  78,  86,  94, 109,  51,  61,  45, 103,  97,
         45,  64], device='cuda:0', dtype=torch.int32)


Epoch 1, Step 6: Loss = 0.3466, Reward Loss = 0.3466, Mean Reward = 0.8265,Mask_loss = 0.3812 mask_mean = 0.3812:   0%|          | 6/1563 [00:09<40:18,  1.55s/it]

tensor([101,  57,  68,  94,  45,  68,  32, 101,  47,  81,  38,  93,  99,  96,
         57,  89], device='cuda:0', dtype=torch.int32)


Epoch 1, Step 7: Loss = 0.4073, Reward Loss = 0.4073, Mean Reward = 0.8694,Mask_loss = 0.4888 mask_mean = 0.4888:   0%|          | 7/1563 [00:11<39:54,  1.54s/it]

tensor([ 74,  99,  92,  36, 108, 102,  21,  69,  24,  96,  15,  97,  47,  64,
         97,  32], device='cuda:0', dtype=torch.int32)


Epoch 1, Step 8: Loss = 0.3257, Reward Loss = 0.3257, Mean Reward = 0.7201,Mask_loss = 0.4786 mask_mean = 0.4786:   1%|          | 8/1563 [00:12<40:43,  1.57s/it]

tensor([ 49,  46,  27,  95,  94,  86,  90,  42,  60,  64,  86,  70,  99,  77,
        106,  76], device='cuda:0', dtype=torch.int32)


Epoch 1, Step 9: Loss = 0.3296, Reward Loss = 0.3296, Mean Reward = 0.8463,Mask_loss = 0.4308 mask_mean = 0.4308:   1%|          | 9/1563 [00:14<41:03,  1.59s/it]

tensor([42, 50, 91, 94, 49, 29, 51, 33, 95, 48, 19, 35, 91, 49, 56, 65],
       device='cuda:0', dtype=torch.int32)


Epoch 1, Step 10: Loss = 0.2991, Reward Loss = 0.2991, Mean Reward = 0.8506,Mask_loss = 0.3868 mask_mean = 0.3868:   1%|          | 10/1563 [00:15<40:12,  1.55s/it]

tensor([ 62,  24,  87,  92,  53,  83,  39,  45,  63,  89, 113,  68,  37,  24,
         87,  95], device='cuda:0', dtype=torch.int32)


Epoch 1, Step 11: Loss = 0.2635, Reward Loss = 0.2635, Mean Reward = 0.8151,Mask_loss = 0.3082 mask_mean = 0.3082:   1%|          | 11/1563 [00:17<41:21,  1.60s/it]

tensor([ 32,  14,  54,  21,  21,  67,  98,  69,  97,  73,  93,  55,  72, 108,
        103,  96], device='cuda:0', dtype=torch.int32)


Epoch 1, Step 12: Loss = 0.2536, Reward Loss = 0.2536, Mean Reward = 0.8240,Mask_loss = 0.2745 mask_mean = 0.2745:   1%|          | 12/1563 [00:19<41:33,  1.61s/it]

tensor([ 57,  90,  54,  48,  45,  51,  72,  93,  72,  78,  22,  33,  28, 105,
         79,  56], device='cuda:0', dtype=torch.int32)


Epoch 1, Step 13: Loss = 0.2158, Reward Loss = 0.2158, Mean Reward = 0.7484,Mask_loss = 0.2371 mask_mean = 0.2371:   1%|          | 13/1563 [00:20<41:29,  1.61s/it]

tensor([ 17,  53,  55,  17,  54,  61, 103,  53,  96,  49,  91,  45,  91,  51,
         79,  64], device='cuda:0', dtype=torch.int32)


Epoch 1, Step 14: Loss = 0.2060, Reward Loss = 0.2060, Mean Reward = 0.6676,Mask_loss = 0.2021 mask_mean = 0.2021:   1%|          | 14/1563 [00:22<41:17,  1.60s/it]

tensor([ 61, 104,  44,  89,  59,  51,  71,  51,  48,  94,  10,  37,  25,  47,
         28,  87], device='cuda:0', dtype=torch.int32)


Epoch 1, Step 15: Loss = 0.2163, Reward Loss = 0.2163, Mean Reward = 0.8197,Mask_loss = 0.2332 mask_mean = 0.2332:   1%|          | 15/1563 [00:23<41:12,  1.60s/it]

tensor([ 58,  40,  50,  76,  81,  24,  90, 100,  67,  84,  84,  87,  63,  47,
         96,  88], device='cuda:0', dtype=torch.int32)


Epoch 1, Step 16: Loss = 0.2017, Reward Loss = 0.2017, Mean Reward = 0.8390,Mask_loss = 0.2247 mask_mean = 0.2247:   1%|          | 16/1563 [00:25<40:28,  1.57s/it]

tensor([ 97,  42,  84,  75,  53,  66,  54,  67,  98,  99,  51,  45,  90,  52,
         58, 105], device='cuda:0', dtype=torch.int32)


Epoch 1, Step 17: Loss = 0.1681, Reward Loss = 0.1681, Mean Reward = 0.6952,Mask_loss = 0.2389 mask_mean = 0.2389:   1%|          | 17/1563 [00:30<1:03:36,  2.47s/it]

tensor([46, 96, 77, 62, 80, 91, 16, 95, 68, 12, 47, 52, 45, 45, 93, 67],
       device='cuda:0', dtype=torch.int32)


Epoch 1, Step 18: Loss = 0.1680, Reward Loss = 0.1680, Mean Reward = 0.8282,Mask_loss = 0.2939 mask_mean = 0.2939:   1%|          | 18/1563 [00:31<55:54,  2.17s/it]  

tensor([ 64,  29,  46,  97,  59,  35,  56,  16,  24,  99,  50, 100,  45,  68,
         36,  51], device='cuda:0', dtype=torch.int32)


Epoch 1, Step 19: Loss = 0.1809, Reward Loss = 0.1809, Mean Reward = 0.8868,Mask_loss = 0.3355 mask_mean = 0.3355:   1%|          | 19/1563 [00:33<50:43,  1.97s/it]

tensor([ 26,  41,  60,  49,  23,  33,  98,  53, 106,  55,  51,  64,  57,  47,
         78, 103], device='cuda:0', dtype=torch.int32)


Epoch 1, Step 20: Loss = 0.1722, Reward Loss = 0.1722, Mean Reward = 0.9795,Mask_loss = 0.3380 mask_mean = 0.3380:   1%|▏         | 20/1563 [00:34<47:57,  1.86s/it]

tensor([100,  19,  65,  61,  99,  52,  98,  89,  44,  24,  94,  98,  62,  61,
         85,  85], device='cuda:0', dtype=torch.int32)


Epoch 1, Step 21: Loss = 0.1473, Reward Loss = 0.1473, Mean Reward = 0.8420,Mask_loss = 0.3590 mask_mean = 0.3590:   1%|▏         | 21/1563 [00:36<47:16,  1.84s/it]

tensor([ 41,  86, 102,  96, 114,  57,  23,  51,  92,  15,  88,  84,  71,  92,
         74,  67], device='cuda:0', dtype=torch.int32)


Epoch 1, Step 22: Loss = 0.1485, Reward Loss = 0.1485, Mean Reward = 0.8366,Mask_loss = 0.3593 mask_mean = 0.3593:   1%|▏         | 22/1563 [00:38<46:17,  1.80s/it]

tensor([55, 45, 94, 66, 66, 91, 61, 30, 92, 20, 48, 93, 75, 95, 25, 99],
       device='cuda:0', dtype=torch.int32)


Epoch 1, Step 23: Loss = 0.1234, Reward Loss = 0.1234, Mean Reward = 0.8674,Mask_loss = 0.3280 mask_mean = 0.3280:   1%|▏         | 23/1563 [00:39<44:00,  1.71s/it]

tensor([ 72,  98,  24,  54,  43,  84,  15,  91,  51, 104,  60,  99,  85,  88,
        104,  42], device='cuda:0', dtype=torch.int32)


Epoch 1, Step 24: Loss = 0.1214, Reward Loss = 0.1214, Mean Reward = 0.8540,Mask_loss = 0.3343 mask_mean = 0.3343:   2%|▏         | 24/1563 [00:41<43:06,  1.68s/it]

tensor([ 95, 107,  41,  71, 101,  69, 104,  92,  96,  78,  94,  17,  43,  15,
        100,  11], device='cuda:0', dtype=torch.int32)


Epoch 1, Step 25: Loss = 0.1032, Reward Loss = 0.1032, Mean Reward = 0.8923,Mask_loss = 0.3080 mask_mean = 0.3080:   2%|▏         | 25/1563 [00:42<42:48,  1.67s/it]

tensor([ 95,  43,  48,  57,  69,  94, 109,  61,  95,  42,  98,  93,  75,  92,
         86,  29], device='cuda:0', dtype=torch.int32)


Epoch 1, Step 26: Loss = 0.0991, Reward Loss = 0.0991, Mean Reward = 0.9124,Mask_loss = 0.3007 mask_mean = 0.3007:   2%|▏         | 26/1563 [00:44<42:38,  1.66s/it]

tensor([93, 40, 27, 47, 17, 83, 58, 29, 57, 92, 55, 61, 24, 95, 56, 46],
       device='cuda:0', dtype=torch.int32)


Epoch 1, Step 27: Loss = 0.1010, Reward Loss = 0.1010, Mean Reward = 0.7861,Mask_loss = 0.2830 mask_mean = 0.2830:   2%|▏         | 27/1563 [00:46<44:12,  1.73s/it]

tensor([ 44,  76, 101,  44,  63,  93, 108,  68,  81,  16,  50,  90,  67,  55,
         99,  97], device='cuda:0', dtype=torch.int32)


Epoch 1, Step 28: Loss = 0.1111, Reward Loss = 0.1111, Mean Reward = 0.8433,Mask_loss = 0.2639 mask_mean = 0.2639:   2%|▏         | 28/1563 [00:48<43:30,  1.70s/it]

tensor([ 80,  90, 104,  49,  56,  53, 105, 110,  15,  78,  55,  59,  65,  42,
         37,  53], device='cuda:0', dtype=torch.int32)


Epoch 1, Step 29: Loss = 0.1285, Reward Loss = 0.1285, Mean Reward = 0.9934,Mask_loss = 0.2551 mask_mean = 0.2551:   2%|▏         | 29/1563 [00:49<43:08,  1.69s/it]

tensor([95, 25, 84, 92, 45, 92, 96, 25, 96, 39, 84, 45, 40, 63, 78, 28],
       device='cuda:0', dtype=torch.int32)


Epoch 1, Step 30: Loss = 0.0756, Reward Loss = 0.0756, Mean Reward = 0.7231,Mask_loss = 0.2763 mask_mean = 0.2763:   2%|▏         | 30/1563 [00:51<41:33,  1.63s/it]

tensor([ 86,  99,  48,  43,  87,  76, 105,  95,  50,  38,  39,  41,  55,  51,
        103,  50], device='cuda:0', dtype=torch.int32)


Epoch 1, Step 31: Loss = 0.1128, Reward Loss = 0.1128, Mean Reward = 0.9290,Mask_loss = 0.2721 mask_mean = 0.2721:   2%|▏         | 31/1563 [00:52<41:28,  1.62s/it]

tensor([105,  45,  95, 100,  92,  97, 101,  21,  51,  82,  70,  81,  35,  95,
         46,  21], device='cuda:0', dtype=torch.int32)


Epoch 1, Step 32: Loss = 0.0835, Reward Loss = 0.0835, Mean Reward = 0.8282,Mask_loss = 0.2787 mask_mean = 0.2787:   2%|▏         | 32/1563 [00:54<41:22,  1.62s/it]

tensor([ 63,  64,  56,  90,  64,  30,  47,  51,  59,  54, 102,  23,  94, 101,
         88, 101], device='cuda:0', dtype=torch.int32)


Epoch 1, Step 33: Loss = 0.0669, Reward Loss = 0.0669, Mean Reward = 0.7684,Mask_loss = 0.3037 mask_mean = 0.3037:   2%|▏         | 33/1563 [00:56<42:47,  1.68s/it]

tensor([ 48,  54,  89, 106,  44,  93,  41,  77,  48,  98,  42,  65, 109,  66,
         92,  44], device='cuda:0', dtype=torch.int32)


Epoch 1, Step 34: Loss = 0.0751, Reward Loss = 0.0751, Mean Reward = 0.9185,Mask_loss = 0.2934 mask_mean = 0.2934:   2%|▏         | 34/1563 [00:57<42:34,  1.67s/it]

tensor([ 97,  34,  46,  63,  52,  96,  61,  96,  84,  48,  60,  42,  18,  39,
        104,  50], device='cuda:0', dtype=torch.int32)


Epoch 1, Step 35: Loss = 0.0828, Reward Loss = 0.0828, Mean Reward = 0.8540,Mask_loss = 0.3183 mask_mean = 0.3183:   2%|▏         | 35/1563 [00:59<42:09,  1.66s/it]

tensor([ 47,  67,  39,  48,  55,  97,  90, 103,  99,  42,  72,  60,  37,  51,
         43,  78], device='cuda:0', dtype=torch.int32)


Epoch 1, Step 36: Loss = 0.0664, Reward Loss = 0.0664, Mean Reward = 0.7716,Mask_loss = 0.3055 mask_mean = 0.3055:   2%|▏         | 36/1563 [01:01<41:47,  1.64s/it]

tensor([ 55,  60,  57,  47, 102,  95, 109,  48,  17,  74,  57,  24,  80,  46,
         99,  52], device='cuda:0', dtype=torch.int32)


Epoch 1, Step 37: Loss = 0.0681, Reward Loss = 0.0681, Mean Reward = 0.7679,Mask_loss = 0.2997 mask_mean = 0.2997:   2%|▏         | 37/1563 [01:02<41:51,  1.65s/it]

tensor([83, 76, 47, 54, 46, 96, 95, 81, 94, 54, 28, 80, 90, 46, 62, 54],
       device='cuda:0', dtype=torch.int32)


Epoch 1, Step 38: Loss = 0.1048, Reward Loss = 0.1048, Mean Reward = 0.9832,Mask_loss = 0.3362 mask_mean = 0.3362:   2%|▏         | 38/1563 [01:04<40:32,  1.60s/it]

tensor([27, 42, 44, 62, 22, 43, 96, 34, 38, 35, 78, 34,  9, 18, 74, 47],
       device='cuda:0', dtype=torch.int32)


Epoch 1, Step 39: Loss = 0.0849, Reward Loss = 0.0849, Mean Reward = 0.9428,Mask_loss = 0.2996 mask_mean = 0.2996:   2%|▏         | 39/1563 [01:05<39:45,  1.57s/it]

tensor([ 48,  15,  29,  91,  45, 102,  15, 102,  86,  51,  75,  46,  83,  29,
         88,  91], device='cuda:0', dtype=torch.int32)


Epoch 1, Step 40: Loss = 0.0789, Reward Loss = 0.0789, Mean Reward = 0.9042,Mask_loss = 0.3168 mask_mean = 0.3168:   3%|▎         | 40/1563 [01:07<40:03,  1.58s/it]

tensor([66, 61, 93, 32, 57, 62, 64, 94, 53, 48, 75, 69, 94, 50, 42, 82],
       device='cuda:0', dtype=torch.int32)


Epoch 1, Step 41: Loss = 0.0471, Reward Loss = 0.0471, Mean Reward = 0.6835,Mask_loss = 0.3086 mask_mean = 0.3086:   3%|▎         | 41/1563 [01:08<39:10,  1.54s/it]

tensor([ 88,  54,  45,  57,  96,  48,  59,  45,  86,  50,  67, 103, 100,  63,
         82,  75], device='cuda:0', dtype=torch.int32)


Epoch 1, Step 42: Loss = 0.0511, Reward Loss = 0.0511, Mean Reward = 0.6714,Mask_loss = 0.3030 mask_mean = 0.3030:   3%|▎         | 42/1563 [01:10<39:40,  1.57s/it]

tensor([ 81,  44,  96,  22,  31,  50,  98,  35,  94,  50,  52, 105,  59,  61,
         81,  96], device='cuda:0', dtype=torch.int32)


Epoch 1, Step 43: Loss = 0.0607, Reward Loss = 0.0607, Mean Reward = 0.7981,Mask_loss = 0.2883 mask_mean = 0.2883:   3%|▎         | 43/1563 [01:12<39:59,  1.58s/it]

tensor([ 96,  89,  24,  19,  52,  77,  51,  56,  97,  49,  51, 115,  23,  61,
         47,  99], device='cuda:0', dtype=torch.int32)


Epoch 1, Step 44: Loss = 0.0538, Reward Loss = 0.0538, Mean Reward = 0.8286,Mask_loss = 0.2924 mask_mean = 0.2924:   3%|▎         | 44/1563 [01:13<41:10,  1.63s/it]

tensor([ 46,  95,  45,  16, 104,  67,  99,  52,  63,  27,  42,  39,  45,  47,
         69,  64], device='cuda:0', dtype=torch.int32)


Epoch 1, Step 45: Loss = 0.0775, Reward Loss = 0.0775, Mean Reward = 0.9086,Mask_loss = 0.2698 mask_mean = 0.2698:   3%|▎         | 45/1563 [01:16<42:52,  1.69s/it]


KeyboardInterrupt: 

In [None]:
loss

tensor(1.8184, device='cuda:0', grad_fn=<MeanBackward0>)

In [9]:
import numpy as np
idx = 0
from captum.attr import visualization as viz
import torch.nn.functional as F

mask_gen_model.eval()

# tokens = tokenizer.convert_ids_to_tokens(gen_tokens[idx])
texts = "This movie was the best movie I have ever seen! some scenes were ridiculous, but acting was great."
# texts = "I really did not like this movie. Some of the actors were good, but overall the movie was boring."
# texts = "I hate that I love you."
# texts = "I don't like this movie."
# texts = "I really love this film."
# texts = "I really love this film. The acting was great, and the story was amazing. I would recommend this movie to everyone."
# texts = "I don't like this movie. The acting was terrible, and the story was boring. I would not recommend this movie to anyone."
messages_lambda = lambda texts: [
            {"role": "system", "content": "You are a chatbot for sentiment analysis. You can help users with their questions via concise responses of POSITIVE, or NEGATIVE."},
            # {"role": "system", "content": "You are a chatbot for sentimate analysis."},
            {"role": "user", "content": texts},
        ]
messages = messages_lambda(texts)
messages_with_template_applied = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True,
        )

test_text = [{"text": texts, "label": None}]
test_inputs = collate_fn(test_text).to(device)

# test_inputs = next(iter(train_dataloader)).to(device)
tokens = tokenizer.convert_ids_to_tokens(test_inputs['input_ids'][idx])

# generate the answer for the test inputs
gen_outputs = model.generate(
            input_ids=test_inputs['input_ids'],
            attention_mask=test_inputs['attention_mask'],
            max_new_tokens=128,
            eos_token_id=terminators,
            pad_token_id=tokenizer.pad_token_id,
            do_sample=True,
            temperature=0.6,
            top_p=0.9,
            return_dict_in_generate=True,
            output_scores=True,
        )
input_ids = test_inputs['input_ids']
attention_mask = test_inputs['attention_mask']
gen_tokens = gen_outputs.sequences
pad_length = gen_tokens.size(1) - input_ids.size(1)
# get the attention mask for the generated tokens, and also mask the padding tokens
gen_attention_mask = F.pad(attention_mask, (0, pad_length), mode='constant', value=1)
context_mask = F.pad(test_inputs['context_mask'], (0, pad_length), mode='constant', value=0)
# (gen_tokens != pad_token_id).long() is the tokens mask, 1 for real tokens and 0 for padding tokens
unpaded_token_mask = (gen_tokens != pad_token_id).long()
unpaded_token_mask[:, :-pad_length] = 1
gen_attention_mask = gen_attention_mask * unpaded_token_mask

with torch.no_grad():
    # prompt_outputs = model(input_ids=test_inputs['input_ids'], attention_mask=test_inputs['attention_mask'], output_hidden_states=True, return_dict=True)
    prompt_outputs = model(input_ids=gen_tokens, attention_mask=gen_attention_mask, output_hidden_states=True, return_dict=True)

    last_hidden_state = prompt_outputs.hidden_states[-1].float()
    mask_logits = mask_gen_model(last_hidden_state)

# Function to clean tokens
def clean_token(token):
    # Remove special characters like "Ġ" or "Ċ"
    return token.replace("Ġ", "").replace("Ċ", "").replace("Ġ", "").replace("Ċ", "").replace("Ġ", "")

# Apply cleaning to each token
tokens = tokenizer.convert_ids_to_tokens(gen_tokens[idx])
tokens = [clean_token(token) for token in tokens]

def normalize_except_zeros(array):
    # Create a mask to identify non-zero elements
    mask = array != 0
    
    # Extract non-zero elements
    non_zero_elements = array[mask]
    
    # Normalize non-zero elements
    min_val = np.min(non_zero_elements)
    max_val = np.max(non_zero_elements)
    # normalized_non_zero_elements = (non_zero_elements - min_val) / (max_val - min_val)
    # get the 50% persentile
    # min_val = np.percentile(non_zero_elements, 50)
    normalized_non_zero_elements = (non_zero_elements - min_val) / (max_val - min_val)
    
    # Create a copy of the original array to preserve zero values
    normalized_array = np.copy(array)
    
    # Assign normalized values back to the corresponding positions
    normalized_array[mask] = normalized_non_zero_elements
    
    return normalized_array

expl = (torch.sigmoid(mask_logits) * context_mask)[idx]
expl = F.pad(expl, (0, len(tokens) - expl.size(0)), mode='constant', value=0)
expl_raw = expl.detach().cpu().numpy()
# expl = (expl - 0) / (expl_raw.max(axis=-1) - 0)

# expl = expl_raw
expl = normalize_except_zeros(expl_raw)



# expl = response_mask[idx]


vis_data_records = [viz.VisualizationDataRecord(
                                expl,
                                0,
                                0,
                                0,
                                0,
                                1,       
                                tokens,
                                1)]
                            
viz.visualize_text(vis_data_records)

True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
0.0,0 (0.00),0.0,1.0,"#|begin_of_text| #|start_header_id| system #|end_header_id| You are a chat bot for sentiment analysis . You can help users with their questions via concise responses of POS ITIVE , or NEG ATIVE . #|eot_id| #|start_header_id| user #|end_header_id| This movie was the best movie I have ever seen ! some scenes were ridiculous , but acting was great . #|eot_id| #|start_header_id| assistant #|end_header_id| POS ITIVE #|eot_id|"
,,,,


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
0.0,0 (0.00),0.0,1.0,"#|begin_of_text| #|start_header_id| system #|end_header_id| You are a chat bot for sentiment analysis . You can help users with their questions via concise responses of POS ITIVE , or NEG ATIVE . #|eot_id| #|start_header_id| user #|end_header_id| This movie was the best movie I have ever seen ! some scenes were ridiculous , but acting was great . #|eot_id| #|start_header_id| assistant #|end_header_id| POS ITIVE #|eot_id|"
,,,,


In [None]:
# Set print options for precision
torch.set_printoptions(precision=4)
np.set_printoptions(precision=4)
print(expl)

[0.     0.     0.     0.     0.     0.     0.     0.     0.     0.
 0.     0.     0.     0.     0.     0.     0.     0.     0.     0.
 0.     0.     0.     0.     0.     0.     0.     0.     0.     0.
 0.     0.     0.     0.     0.     0.     0.     0.0279 0.8979 0.1684
 1.     0.3327 0.9915 0.9641 0.3456 0.     0.2309 0.3376 0.2255 0.329
 0.265  0.2508 0.3121 0.7915 0.5378 0.0591 0.1465 0.0985 0.5066 0.0857
 0.387  0.8436 0.     0.     0.     0.     0.     0.     0.     0.    ]


In [8]:
expl_raw

array([0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
       0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
       0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
       0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
       0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
       0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
       0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
       0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
       0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
       0.0000000e+00, 1.9358044e-03, 8.5238426e-04, 4.9257861e-03,
       3.3561827e-04, 9.9662030e-01, 9.9794489e-01, 1.6490246e-04,
       3.7765391e-02, 4.5216787e-03, 9.9155957e-01, 2.1190707e-02,
       4.3913804e-02, 2.5236185e-04, 2.5347209e-02, 2.0404863e-03,
       9.9968982e-01, 9.9894804e-01, 1.6779628e-03, 6.8065450e-03,
       9.7157842e-01, 1.0806271e-01, 9.6391785e-01, 3.8870864e

In [None]:
mask_prob = torch.sigmoid(mask_logits)
(mask_prob * context_mask).sum(-1) / context_mask.sum(dim=-1)

tensor([0.1241, 0.2500, 0.7787, 0.0909, 0.1908, 0.1361, 0.3316, 0.5104, 0.2696,
        0.1001, 0.5931, 0.0812, 0.4139, 0.3334, 0.4853, 0.2987],
       device='cuda:0')

In [None]:
mask_prob

tensor([[9.9998e-01, 1.3412e-02, 2.1168e-05, 7.2758e-04, 9.5831e-01, 2.8712e-06,
         1.9530e-03, 6.8251e-04, 9.6814e-01, 7.6582e-04, 2.2164e-06, 1.7009e-02,
         2.4596e-09, 1.4519e-04, 3.0463e-06, 7.9092e-04, 4.0336e-06, 5.2874e-05,
         1.4755e-07, 5.3417e-09, 1.0859e-07, 8.2496e-07, 1.0534e-05, 2.1604e-06,
         6.6980e-10, 1.8266e-01, 5.4030e-07, 4.0081e-02, 3.5185e-04, 3.1933e-01,
         3.2836e-09, 1.5197e-06, 8.2054e-01, 8.0839e-01, 5.4567e-01, 4.8829e-01,
         1.0000e+00, 4.3639e-05, 3.5097e-07, 3.8482e-09, 1.9672e-04, 1.1242e-07,
         2.7982e-13, 5.4683e-09, 4.5730e-01, 9.5505e-01, 6.7815e-04, 7.4368e-02,
         9.9924e-01, 6.6332e-03, 1.0508e-08, 1.8774e-01]], device='cuda:0')

In [None]:
test_inputs['input_ids'][idx]

tensor([128000, 128006,   9125, 128007,    271,   2675,    527,    264,   6369,
          6465,    369,  27065,   6492,     13,   1472,    649,   1520,   3932,
           449,    872,   4860,   4669,  64694,  14847,    315,  27592,  45450,
            11,    477,  85165,  24093,     13, 128009, 128006,    882, 128007,
           271,   2028,   5818,    574,    279,   1888,   5818,    358,    617,
          3596,   3970,      0,   1063,  16451,   1051,  27873,     11,    719,
         15718,    574,   2294,     13, 128009, 128006,  78191, 128007,    271],
       device='cuda:0')

In [None]:
tokenizer.convert_ids_to_tokens(271)

'ĊĊ'

[0.     0.     0.     0.     0.     0.     0.     0.     0.     0.
 0.     0.     0.     0.     0.     0.     0.     0.     0.     0.
 0.     0.     0.     0.     0.     0.     0.     0.     0.     0.
 0.     0.     0.     0.     0.     0.     0.     0.6432 0.8566 0.5179
 0.2417 0.     0.1211 0.3355 0.618  0.5345 0.1401 0.3401 0.4729 0.3531
 0.661  0.7049 0.0297 0.1724 0.9905 1.     0.1606 0.1107 0.2363 0.2891
 0.116  0.0777 0.     0.     0.     0.     0.     0.     0.     0.    ]


In [None]:
(torch.sigmoid(mask_logits) * context_mask)[idx]

tensor([0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.8527, 0.1770, 0.1719, 0.1660, 0.1613, 0.1549, 0.1581, 0.1622, 0.1666,
        0.1635, 0.1568, 0.1589, 0.1618, 0.1634, 0.1668, 0.1647, 0.1546, 0.1576,
        0.1788, 0.1779, 0.1592, 0.1581, 0.1607, 0.1619, 0.1586, 0.1583, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
       device='cuda:0')

In [None]:
torch.sigmoid(mask_logits)

tensor([[0.4822, 0.4821, 0.4816,  ..., 0.4920, 0.4874, 0.4901],
        [0.4843, 0.4851, 0.4855,  ..., 0.4941, 0.4861, 0.4891],
        [0.4785, 0.4805, 0.4805,  ..., 0.4918, 0.4823, 0.4864],
        ...,
        [0.4753, 0.4758, 0.4761,  ..., 0.4847, 0.4761, 0.4802],
        [0.4876, 0.4883, 0.4882,  ..., 0.4887, 0.4880, 0.4943],
        [0.4843, 0.4851, 0.4852,  ..., 0.4946, 0.4853, 0.4947]],
       device='cuda:0')

In [None]:
mask_gen_model

MaskGeneratingModel(
  (explain_map): MLP(
    (input_layer): Linear(in_features=4096, out_features=1024, bias=True)
    (attention_layers): ModuleList(
      (0-1): 2 x MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=1024, out_features=1024, bias=True)
      )
    )
    (layers): ModuleList(
      (0-1): 2 x Sequential(
        (0): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (1): PReLU(num_parameters=1)
        (2): Linear(in_features=1024, out_features=1024, bias=True)
      )
    )
    (output_layer): Linear(in_features=1024, out_features=1, bias=True)
  )
)

In [None]:
print(tokens[35])
print(expl[35])

<|end_header_id|>
0.0


In [None]:
texts = "This movie was the best movie I have ever seen! some scenes were ridiculous, but acting was great."
# texts = "I really didn't like this movie. Some of the actors were good, but overall the movie was boring."
# texts = "I hate that I love you."
# texts = "I don't like this movie."
# texts = "I really love this film."
messages_lambda = lambda texts: [
    {"role": "system", "content": "You are a chatbot for sentimate analysis. You can help users with their questions via concise responses of POSITIVE, or NEGATIVE."},
    # {"role": "system", "content": "You are a chatbot for sentimate analysis."},
    {"role": "user", "content": texts},
]
messages = messages_lambda(texts)
messages_with_template_applied = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True,
        )
print(messages_with_template_applied)

<|begin_of_text|><|start_header_id|>system<|end_header_id|>

You are a chatbot for sentimate analysis. You can help users with their questions via concise responses of POSITIVE, or NEGATIVE.<|eot_id|><|start_header_id|>user<|end_header_id|>

This movie was the best movie I have ever seen! some scenes were ridiculous, but acting was great.<|eot_id|><|start_header_id|>assistant<|end_header_id|>


