# Configure Environment

In [1]:
import os
import json 

# Load the environment configuration JSON data
json_path = 'env_config.json'
with open(json_path, 'r') as file:
    env_config = json.load(file)

hf_home = env_config['HF_HOME']
# Set the HF_HOME environment variable
os.environ['HF_HOME'] = hf_home
# Set the access token to huggingface hub
access_token = env_config['access_token']
os.environ['HUGGINGFACE_HUB_TOKEN'] = access_token

# Define Llama model and tokenizers
also define device via accelerator

In [2]:
import transformers 
print(transformers.__version__)

from transformers import pipeline
import torch

from accelerate import Accelerator
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel
from transformers import LlamaTokenizerFast


accelerator = Accelerator()
device = accelerator.device

model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
# model_id = "meta-llama/Meta-Llama-3-8B"  # non-instruct version

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    token=access_token,
)

tokenizer = LlamaTokenizerFast.from_pretrained(model_id, token=access_token)

  from .autonotebook import tqdm as notebook_tqdm


4.41.0


Loading checkpoint shards: 100%|██████████| 4/4 [01:27<00:00, 21.80s/it]
The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'PreTrainedTokenizerFast'. 
The class this function is called from is 'LlamaTokenizerFast'.
You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama_fast.LlamaTokenizerFast'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


# Test with dummy inputs

In [3]:
# text = r"""I love this movie. The sentiment of this reivew is POSITIVE.
# I found this movie interesting. The sentiment of this reivew is POSITIVE.
# I found this movie boring. The sentiment of this reivew is NEGATIVE.
# I hate this movie. The sentiment of this review is"""

messages = [
    {"role": "system", "content": "You are a chatbot for sentimate analysis. You can help users with their questions via concise responses of POSITIVE, NEGATIVE OR NEUTURAL."},
    {"role": "user", "content": "I hate this movie."},
]

messages_with_template_applied = tokenizer.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True,
)

inputs = tokenizer.encode_plus(
            messages_with_template_applied,
            add_special_tokens=False,
            return_tensors="pt",
            ).to(device)

terminators = [
    tokenizer.eos_token_id,
    tokenizer.convert_tokens_to_ids("<|eot_id|>")
]

outputs = model.generate(
    **inputs,
    max_new_tokens=128,
    eos_token_id=terminators,
    do_sample=True,
    temperature=0.6,
    top_p=0.9,
    return_dict_in_generate=True,
    output_scores=True,
)

print("input_ids.shape", inputs.input_ids.shape)
print("input_ids", inputs.input_ids)
print("input tokens\n", tokenizer.decode(inputs.input_ids[0], skip_special_tokens=False))

Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.


input_ids.shape torch.Size([1, 51])
input_ids tensor([[128000, 128006,   9125, 128007,    271,   2675,    527,    264,   6369,
           6465,    369,   3288,   3509,   6492,     13,   1472,    649,   1520,
           3932,    449,    872,   4860,   4669,  64694,  14847,    315,  27592,
          45450,     11,  85165,  24093,   2794,   8014,   1406,  51785,     13,
         128009, 128006,    882, 128007,    271,     40,  12491,    420,   5818,
             13, 128009, 128006,  78191, 128007,    271]], device='cuda:0')
input tokens
 <|begin_of_text|><|start_header_id|>system<|end_header_id|>

You are a chatbot for sentimate analysis. You can help users with their questions via concise responses of POSITIVE, NEGATIVE OR NEUTURAL.<|eot_id|><|start_header_id|>user<|end_header_id|>

I hate this movie.<|eot_id|><|start_header_id|>assistant<|end_header_id|>




## Get the generated texts and the logits for each generated token

In [4]:
# outputs[0]
generated_text = tokenizer.decode(outputs.sequences[0], skip_special_tokens=False)
print("Full texts:\n", generated_text)
input_length = len(messages_with_template_applied)
print("Generated response:", generated_text[input_length:])
print("Generated token logits\n", outputs.scores)

Full texts:
 <|begin_of_text|><|start_header_id|>system<|end_header_id|>

You are a chatbot for sentimate analysis. You can help users with their questions via concise responses of POSITIVE, NEGATIVE OR NEUTURAL.<|eot_id|><|start_header_id|>user<|end_header_id|>

I hate this movie.<|eot_id|><|start_header_id|>assistant<|end_header_id|>

NEGATIVE<|eot_id|>
Generated response: NEGATIVE<|eot_id|>
Generated token logits
 (tensor([[-inf, -inf, -inf,  ..., -inf, -inf, -inf]], device='cuda:0'), tensor([[-inf, -inf, -inf,  ..., -inf, -inf, -inf]], device='cuda:0'), tensor([[-inf, -inf, -inf,  ..., -inf, -inf, -inf]], device='cuda:0'))


In [5]:
tokenizer.decode([outputs.scores[2].argmax()])

'<|eot_id|>'

In [6]:
generated_response = generated_text[input_length:]
# TODO: add masking process
masked_inputs_concat_response = messages_with_template_applied + generated_response
masked_inputs = tokenizer.encode_plus(
                    masked_inputs_concat_response,
                    add_special_tokens=False,
                    return_tensors="pt",
                    ).to(device)

# Get the logits
with torch.no_grad():
    outputs = model(**masked_inputs, output_hidden_states=True, return_dict=True)
    logits = outputs.logits
    last_hidden_state = outputs.hidden_states[-1]



In [7]:
print("last_hidden_state.shape", last_hidden_state.shape)
print("logits.shape", logits.shape)

last_hidden_state.shape torch.Size([1, 54, 4096])
logits.shape torch.Size([1, 54, 128256])


# Modeling the explainer

In [8]:
import torch.nn as nn 
from models import MLP
import torch.nn.functional as F

def similarity_measure(logits, labels, attention_mask):
    """ 
    args:
        logis: torch.Tensor, shape (batch_size, seq_len, vocab_size)
        labels: torch.Tensor, shape (batch_size, seq_len)
        attention_mask: torch.Tensor, shape (batch_size, seq_len) the original input text is masked with 0

    return:
        mean_log_probs: torch.Tensor, shape (batch_size,)
    """
    log_probs = torch.log_softmax(logits, dim=-1)
    # correct_log_probs = log_probs[range(log_probs.shape[0]), labels]

    # Flatten the tensors to simplify indexing
    batch_size, seq_len, vocab_size = logits.size()
    labels = labels.view(-1)  # flatten to (batch_size * seq_len)
    log_probs_flat = log_probs.view(-1, vocab_size)

    # Extract the log probabilities for the correct labels
    correct_log_probs_flat = log_probs_flat[range(batch_size * seq_len), labels]

    # Reshape to (batch_size, seq_len)
    correct_log_probs = correct_log_probs_flat.view(batch_size, seq_len)

    # Mask out the original input texts, only keep the generated tokens
    masked_log_probs = correct_log_probs * attention_mask

    # Calculate the mean log probability for each sequence
    mean_log_probs = masked_log_probs.sum(dim=-1) / attention_mask.sum(dim=-1)

    return mean_log_probs


class MaskGeneratingModel(nn.Module):
    def __init__(self, hidden_size=4096, mlp_hidden_dim=1024, mlp_bottleneck_dim=768, mlp_num_blocks=2):
        """ 
        hidden_size: int
            The hidden size of the output of the generative model, 4096 for llama3
        """
        super().__init__()

        self.hidden_size = hidden_size

        self.explain_map = MLP(input_dim=hidden_size, 
                               hidden_dim=mlp_hidden_dim, 
                               output_dim=1, 
                               num_blocks=mlp_num_blocks, 
                               bottleneck_dim=mlp_bottleneck_dim) # takes [N, L, hidden_size] outputs [N, L, 1]
    
    def forward(self, pred_features):
        """ 
        pred_features: torch.Tensor of shape [N, L, hidden_size]
        """
        mask_logits = self.explain_map(pred_features) # [N, L, 1]
        return mask_logits.squeeze(-1) # [batch_size, seq_len]
    
    @torch.no_grad()
    def generate_mask(self, mask_logits, context_mask):
        """
        generate mask based on a Bernoulli distribution with the probabilities of the given logits 
        args:
            mask_logits: torch.Tensor, shape (batch_size, seq_len)
            context_mask: torch.Tensor, shape (batch_size, seq_len), the context texts 
                (for example, the chatbot instruction template, not including the user inputs) are masked with 0.
                This mask strategy is to ensure we only focus on the real user inputs.
        return:
            mask: torch.Tensor, shape (batch_size, seq_len), with contexts being all 1s and user inputs being randomly masked.
        """
        mask_probs = torch.sigmoid(mask_logits) # (batch_size, seq_len)
        mask = torch.bernoulli(mask_probs) # (batch_size, seq_len)
        mask = context_mask * mask + (1. - context_mask) # (batch_size, seq_len)
        return mask # this could be used as the attention mask for the generative model
    
    @torch.no_grad()
    def sample_one_batch(self, input_ids, attention_mask, mask_logits, context_mask):
        """ 
        args:
            input_ids: torch.Tensor, shape (batch_size, seq_len)
            attention_mask: torch.Tensor, shape (batch_size, seq_len), the attention_mask generated automatically by tokenizer, usually all 1s
            mask_logits: torch.Tensor, shape (batch_size, prompt_len), the logits for the mask generation
            context_mask: torch.Tensor, shape (batch_size, prompt_len), the context texts 
                (for example, the chatbot instruction template, not including the user inputs) are masked with 0.
        """
        #  get the mask of interest, i.e., for the user inputs, then pad to match the shape of other masks (of the full length)
        seq_len = input_ids.size(1)
        prompt_len = mask_logits.size(1)
        mask = self.generate_mask(mask_logits=mask_logits, context_mask=context_mask) # (batch_size, prompt_len)
        pad_length = seq_len - prompt_len
        padded_mask = F.pad(mask, (0, pad_length), mode='constant', value=1)

        # get the masked attention_mask
        masked_attention_mask = attention_mask * padded_mask

        return input_ids, masked_attention_mask, mask
    
    def compute_similarity(self, logits, labels, attention_mask):
        """ 
        compute the mean log_probs for the generated tokens
        """
        mean_log_probs = similarity_measure(logits, labels, attention_mask)
        return mean_log_probs


# Training
## Load data

In [9]:
from datasets import load_dataset
imdb = load_dataset("imdb")
# idx = 0
# texts = imdb["test"][idx]['text']
# # print(texts)

# messages = [
#     {"role": "system", "content": "You are a chatbot for sentimate analysis. You can help users with their questions via concise responses of POSITIVE, NEGATIVE OR NEUTURAL."},
#     {"role": "user", "content": texts},
# ]

# messages_with_template_applied = tokenizer.apply_chat_template(
#     messages,
#     tokenize=False,
#     add_generation_prompt=True,
# )

# print(messages_with_template_applied)

# inputs = tokenizer.encode_plus(
#             messages_with_template_applied,
#             add_special_tokens=False,
#             return_tensors="pt",
#             ).to(device)

# terminators = [
#     tokenizer.eos_token_id,
#     tokenizer.convert_tokens_to_ids("<|eot_id|>")
# ]

# outputs = model.generate(
#     **inputs,
#     max_new_tokens=128,
#     eos_token_id=terminators,
#     do_sample=True,
#     temperature=0.6,
#     top_p=0.9,
#     return_dict_in_generate=True,
#     output_scores=True,
# )

# generated_text = tokenizer.decode(outputs.sequences[0], skip_special_tokens=False)
# print("Full texts:\n", generated_text)
# input_length = len(messages_with_template_applied)
# print("Generated response:", generated_text[input_length:])
# print("Generated token logits\n", outputs.scores)


## Define Dataloader

In [10]:
from torch.utils.data import DataLoader
from datasets import load_dataset

imdb = load_dataset("imdb")
train_ds = imdb['train']

tokenizer.pad_token = tokenizer.eos_token


def collate_fn(examples):
    max_len = 512
    texts = [example['text'] for example in examples]
    texts = [text if len(text) <= max_len else text[:max_len] for text in texts]
    labels = [example['label'] for example in examples]
    messages_lambda = lambda texts: [
        {"role": "system", "content": "You are a chatbot for sentimate analysis. You can help users with their questions via concise responses of POSITIVE, NEGATIVE OR NEUTURAL with a brief explanation."},
        # {"role": "system", "content": "You are a chatbot for sentimate analysis."},
        {"role": "user", "content": texts},
    ]
    messages = list(map(messages_lambda, texts))

    messages_with_template_applied = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True,
    )
    batch = tokenizer(
                messages_with_template_applied,
                add_special_tokens=False,
                padding=True,
                return_tensors="pt",
                )
    
    # find the template boundaries
    text_lens = [len(tokenizer.encode(text)) for text in texts]
    text_lens_tensor = torch.tensor(text_lens, dtype=torch.long)
    

    def apply_mask(mask_tensor, text_lens_tensor):
        batch_size, seq_len = mask_tensor.shape
        for i in range(batch_size):
            text_len = text_lens_tensor[i].item()
            mask_tensor[i, -text_len-5:-5] = 0
        return 1- mask_tensor

    mask_tensor = apply_mask(torch.ones_like(batch['input_ids']), text_lens_tensor)

    batch['context_mask'] = mask_tensor

    
    return batch


batch_size = 16
train_dataloader = DataLoader(train_ds, batch_size=batch_size, collate_fn=collate_fn, shuffle=False)


In [11]:
# test the collate_fn
a = next(iter(train_dataloader))
print(tokenizer.decode(a.context_mask[1] * a.input_ids[1]))
print(tokenizer.decode(a.input_ids[1]))

!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!

"I Am Curious: Yellow" is a risible and pretentious steaming pile. It doesn't matter what one's political views are because this film can hardly be taken seriously on any level. As for the claim that frontal male nudity is an automatic NC-17, that isn't true. I've seen R-rated films with male nudity. Granted, they only offer some fleeting views, but where are the R-rated films with gaping vulvas and flapping labia? Nowhere, because they don't exist. The same goes for those crappy cable shows: schlongs swing!!!!!
<|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|begin_of_text|><|start_header_id|>system<|end_header_id|>

You are a chatbot for sentimate analysis. You can help users with their questions via concise responses of POSITIVE, NEGATIVE OR NEUTURAL with a brief explanation.<|eot_id|><|start_header_id|>user<|end_header_id|>

"I Am Curious: Ye

# Init model

In [12]:
mask_gen_model = MaskGeneratingModel(hidden_size=4096, mlp_hidden_dim=1024, mlp_bottleneck_dim=768, mlp_num_blocks=2)
mask_gen_model.to(device)

MaskGeneratingModel(
  (explain_map): MLP(
    (input_layer): Linear(in_features=4096, out_features=1024, bias=True)
    (layers): ModuleList(
      (0-1): 2 x Sequential(
        (0): Linear(in_features=1024, out_features=768, bias=True)
        (1): Dropout(p=0.5, inplace=False)
        (2): ReLU()
        (3): Linear(in_features=768, out_features=768, bias=True)
        (4): Dropout(p=0.5, inplace=False)
        (5): ReLU()
        (6): Linear(in_features=768, out_features=1024, bias=True)
        (7): Dropout(p=0.5, inplace=False)
      )
    )
    (output_layer): Linear(in_features=1024, out_features=1, bias=True)
  )
)

In [13]:
from tqdm import tqdm

# Set pad_token_id if it is not set
if tokenizer.pad_token_id is None:
    tokenizer.pad_token_id = tokenizer.eos_token_id

pad_token_id = tokenizer.pad_token_id

terminators = [
    tokenizer.eos_token_id,
    tokenizer.convert_tokens_to_ids("<|eot_id|>")
]

optimizer = torch.optim.Adam(mask_gen_model.parameters(), lr=1e-5)

for epoch in range(10):
    pbar = tqdm(train_dataloader)
    for idx, data in enumerate(pbar):
        input_ids = data['input_ids'].to(device)
        attention_mask = data['attention_mask'].to(device)
        context_mask = data['context_mask'].to(device)
        # get generated texts
        gen_outputs = model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_new_tokens=128,
            eos_token_id=terminators,
            pad_token_id=tokenizer.pad_token_id,
            do_sample=True,
            temperature=0.6,
            top_p=0.9,
            return_dict_in_generate=True,
            output_scores=True,
        )
        gen_tokens = gen_outputs.sequences
        pad_length = gen_tokens.size(1) - input_ids.size(1)
        # get the attention mask for the generated tokens, and also mask the padding tokens
        gen_attention_mask = F.pad(attention_mask, (0, pad_length), mode='constant', value=1)
        gen_attention_mask = gen_attention_mask * (gen_tokens != pad_token_id).long()
        # get the response mask, which is the mask for the generated tokens (the user inputs are masked with 0)
        response_mask = gen_attention_mask.clone()
        response_mask[:, :-gen_tokens.size(1)] = 0 # TODO: 有问题。
        # Get the last hidden state for the prompt sequence
        with torch.no_grad():
            prompt_outputs = model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True, return_dict=True)
            last_hidden_state = prompt_outputs.hidden_states[-1]
            last_hidden_state = last_hidden_state.float()
        
        mask_logits = mask_gen_model(last_hidden_state)

        
        # mask = mask_gen_model.generate_mask(mask_logits=mask_logits, context_mask=context_mask)
        perturbed_input_ids, masked_attention_mask, user_input_mask = mask_gen_model.sample_one_batch(input_ids=gen_tokens, 
                                                                attention_mask=gen_attention_mask, 
                                                                mask_logits=mask_logits, 
                                                                context_mask=context_mask)
        
        # Get the logits of the perturbed input concatenated with the response
        with torch.no_grad():
            outputs = model(input_ids=perturbed_input_ids, attention_mask=masked_attention_mask, return_dict=True)
            logits = outputs.logits.float()
        
        with torch.no_grad():
            outputs_gt = model(input_ids=perturbed_input_ids, attention_mask=masked_attention_mask, return_dict=True)
            logits_gt = outputs_gt.logits.float()
        
        # Compute the similarity between the logits and the response labels
        sim = mask_gen_model.compute_similarity(logits, gen_tokens, response_mask)
        sim_gt = mask_gen_model.compute_similarity(logits_gt, gen_tokens, response_mask)

        # PPO
        
        reward = torch.exp(sim - sim_gt)
        mask_prob = torch.sigmoid(mask_logits)
        bce_loss = nn.BCELoss(reduction='none')
        reward_loss = reward.unsqueeze(-1) * bce_loss(context_mask * mask_prob, context_mask * user_input_mask)
        reward_loss = reward_loss.mean()
        mask_loss = (mask_prob * context_mask).sum(dim=-1) / context_mask.sum(dim=-1)
        mask_loss = mask_loss.mean()

        loss = reward_loss + mask_loss
        pbar.set_description(f"Epoch {epoch+1}, Step {idx+1}: Loss = {loss.item():.4f}, " 
                             f"Reward Loss = {reward_loss.item():.4f}, "
                            #  f"Regret Loss = {loss_dict['regret_loss'].item():.4f}, "
                             f"Mask Loss = {mask_loss.item():.4f} "
                            #  f"alt_mask_loss = {loss_dict['alt_mask_loss'].item():.4f} "
                            #  f"mask_mean = {loss_dict['mask_mean'].item():.4f} "
                            #  f"prob_mean = {loss_dict['prob_mean'].item():.4f} "
                             )
        # Train the model
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # pbar.set_description(f"Epoch {epoch+1}, Step {idx+1}: Loss = {loss.item():.4f}, " 
        #                      f"Reward Loss = {reward_loss.item():.4f}, "
        #                     #  f"Regret Loss = {loss_dict['regret_loss'].item():.4f}, "
        #                      f"Mask Loss = {mask_loss.item():.4f} "
        #                     #  f"alt_mask_loss = {loss_dict['alt_mask_loss'].item():.4f} "
        #                     #  f"mask_mean = {loss_dict['mask_mean'].item():.4f} "
        #                     #  f"prob_mean = {loss_dict['prob_mean'].item():.4f} "
        #                      )
        if idx % 10 == 0:
            print()

        break
    break


        # scheduler.step()
        # pbar.set_description(f"Epoch {epoch+1}, Step {idx+1}: Loss = {loss_dict['loss'].item():.4f}, " 
        #                      f"Reward Loss = {loss_dict['reward_loss'].item():.4f}, "
        #                     #  f"Regret Loss = {loss_dict['regret_loss'].item():.4f}, "
        #                      f"Mask Loss = {loss_dict['mask_loss'].item():.4f} "
        #                     #  f"alt_mask_loss = {loss_dict['alt_mask_loss'].item():.4f} "
        #                      f"mask_mean = {loss_dict['mask_mean'].item():.4f} "
        #                      f"prob_mean = {loss_dict['prob_mean'].item():.4f} "
        #                      )
        # if idx % 10 == 0:
        #     print()
        # if (idx) % 10 == 0:
            
        #     torch.save(mask_gen_model.state_dict(), f'text_mask_gen_model/mask_gen_model_{epoch}_{idx}.pth') 


Epoch 1, Step 1: Loss = 0.8586, Reward Loss = 0.4148, Mask Loss = 0.4438 :   0%|          | 0/1563 [00:03<?, ?it/s]







In [14]:
response_mask[0]

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       device='cuda:0')

In [15]:
gen_attention_mask[0]

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       device='cuda:0')

In [16]:
idx = 0
masked_attention_mask[idx]

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 0., 1., 1., 1., 0., 0., 0., 1., 1., 1., 1., 1., 1.,
        1., 0., 0., 0., 1., 0., 0., 0., 1., 0., 1., 1., 0., 1., 1., 0., 0., 1.,
        1., 1., 0., 0., 1., 0., 1., 0., 0., 1., 0., 1., 0., 0., 0., 0., 0., 1.,
        0., 1., 1., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 1., 1., 0., 1., 1.,
        0., 0., 1., 1., 0., 0., 0., 1., 1., 0., 0., 0., 1., 0., 0., 0., 1., 0.,
        1., 1., 0., 1., 0., 1., 0., 1., 1., 1., 1., 1., 0., 1., 1., 1., 1., 1.,
        1., 1., 1., 0., 0., 1., 0., 0., 1., 1., 0., 0., 0., 1., 0., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 

In [17]:
input_ids[idx]

tensor([128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009,
        128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009,
        128009, 128000, 128006,   9125, 128007,    271,   2675,    527,    264,
          6369,   6465,    369,   3288,   3509,   6492,     13,   1472,    649,
          1520,   3932,    449,    872,   4860,   4669,  64694,  14847,    315,
         27592,  45450,     11,  85165,  24093,   2794,   8014,   1406,  51785,
           449,    264,  10015,  16540,     13, 128009, 128006,    882, 128007,
           271,     40,  49959,    358,   6912,  19058,  43752,  30237,  35771,
           505,    856,   2835,   3637,   1606,    315,    682,    279,  26654,
           430,  23712,    433,    994,    433,    574,   1176,   6004,    304,
           220,   5162,     22,     13,    358,   1101,   6755,    430,    520,
          1176,    433,    574,  31589,    555,    549,    815,     13,  35869,
           422,    433,   3596,   6818, 

In [18]:
tokenizer.decode(perturbed_input_ids[idx] * masked_attention_mask[idx].long())

'!!!!!!!!!!!!!!!!!!!<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nYou are a chatbot for sentimate analysis. You can help users with their questions via concise responses of POSITIVE, NEGATIVE OR NEUTURAL with a brief explanation.!<|start_header_id|>user<|end_header_id|>!!! I AM CURIOUS-YELLOW from!!! because!!! controversy! surrounded it! it was!! in 196!! I! heard!! first! was!!!!! customs! it ever!!!!!! therefore!!! of films! "cont!!" I!!! see this!!!br!!!The! is centered! a! Swedish! student named Lena who wants! learn everything she can about life. In!! wants!! her attent!!!aki!<|start_header_id|>assistant<|end_header_id|>\n\nPOSITIVE\n\nThe tone of your text is positive because you\'re expressing your interest in a film that was considered controversial when it was first released, and you\'re eager to learn more about it. Your curiosity and enthusiasm for the film are evident, and you\'re not expressing any negative opinions or criticisms.!!!!!!!!!!!!!!!!!!!!'

In [19]:
tokenizer.pad_token_id

128009

In [6]:
import numpy as np

def normalize_importance(importance_tensor):
    """Normalize the importance tensor to the range [0, 1]."""
    min_val = np.min(importance_tensor)
    max_val = np.max(importance_tensor)
    normalized = (importance_tensor - min_val) / (max_val - min_val)
    return normalized

def importance_to_bg_color(importance):
    """Map normalized importance value to an ANSI background color code using a custom gradient."""
    if importance < 0.5:
        # Interpolate between red (1, 0, 0) and black (0, 0, 0)
        color_rgb = np.array([255 * (1 - 2 * importance), 0, 0])
    else:
        # Interpolate between black (0, 0, 0) and green (0, 1, 0)
        color_rgb = np.array([0, 255 * (2 * (importance - 0.5)), 0])
    
    color_rgb = color_rgb.astype(int)
    
    # Convert RGB to 256-color mode
    color_code = 16 + 36 * (color_rgb[0] // 51) + 6 * (color_rgb[1] // 51) + (color_rgb[2] // 51)
    return f'\033[48;5;{color_code}m'

def print_colored_string(token_tensor, importance_tensor):
    normalized_importance = normalize_importance(importance_tensor)
    colored_string = ""
    for token, importance in zip(token_tensor, normalized_importance):
        bg_color = importance_to_bg_color(importance)
        colored_string += f"{bg_color}{token}\033[0m "
    print(colored_string)

# Example usage
token_tensor = ["The", "quick", "brown", "fox", "jumps", "over", "the", "lazy", "dog"]
importance_tensor = np.array([0.1, 0.3, 0.5, 0.7, 0.9, 0.2, 0.4, 0.6, 0.8])

print_colored_string(token_tensor, importance_tensor)

[48;5;196mThe[0m [48;5;88mquick[0m [48;5;16mbrown[0m [48;5;28mfox[0m [48;5;46mjumps[0m [48;5;124mover[0m [48;5;52mthe[0m [48;5;22mlazy[0m [48;5;34mdog[0m 
