In [1]:
from llmexp.llm.smollm import LLMWrapper
from accelerate import Accelerator
import torch

# checkpoint = "meta-llama/Meta-Llama-3-8B-Instruct"
checkpoint = "meta-llama/Llama-3.2-1B-Instruct"
# checkpoint = "HuggingFaceTB/SmolLM-1.7B-Instruct"
# saved_mab_model = "checkpoints/mab_model_100.pth"
saved_mab_model = "checkpoints/mab_model_260.pth"


accelerator = Accelerator()
device = accelerator.device


llm = LLMWrapper(checkpoint, device=device)
tokenizer = llm.tokenizer

In [2]:
from llmexp.utils.data_utils import LLMDataset, create_dataloader
instruction = "Analyze the sentiment of the following sentence and respond concisely."

In [3]:
# instruction = "Analyze the sentiment of the following sentence. Be brief."
# instruction = "Analyze the sentiment of the following sentence and respond with only one word: 'positive', 'negative', or 'neutral', based on the overall tone and meaning of the sentence, if no enough information provided, respond with 'not clear' with an explanation."
instruction = "Analyze the sentiment of the following sentence and respond concisely."
# user_input = "I am extremely disappointed with the quality; it broke after just one day."
user_input = "The service at this restaurant was fantastic, and the staff were so friendly."
# user_input = "<|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|>"
# user_input = "I'm so happy with the product!"
# user_input = "The bright sunshine and gentle breeze made my afternoon truly delightful."
# user_input = "I felt deeply disappointed and frustrated after the meeting went completely off track."
# user_input = "Although the food at the restaurant was excellent, the service left much to be desired."


content = [
            {"role": "system", 
            "content": instruction
            },

            {"role": "sentence", 
            "content": user_input
            }
        ]
template = tokenizer.apply_chat_template(content, tokenize=False, add_generation_prompt=True)
# print(template)

# The generated outputs 
gen_output = llm.generate_from_texts(template)
print(template)
print(gen_output)

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


<|begin_of_text|><|start_header_id|>system<|end_header_id|>

Cutting Knowledge Date: December 2023
Today Date: 26 Jul 2024

Analyze the sentiment of the following sentence and respond concisely.<|eot_id|><|start_header_id|>sentence<|end_header_id|>

The service at this restaurant was fantastic, and the staff were so friendly.<|eot_id|><|start_header_id|>assistant<|end_header_id|>


<|begin_of_text|><|start_header_id|>system<|end_header_id|>

Cutting Knowledge Date: December 2023
Today Date: 26 Jul 2024

Analyze the sentiment of the following sentence and respond concisely.<|eot_id|><|start_header_id|>sentence<|end_header_id|>

The service at this restaurant was fantastic, and the staff were so friendly.<|eot_id|><|start_header_id|>assistant<|end_header_id|>

The sentiment of this sentence is overwhelmingly positive, indicating a high level of satisfaction with the service and staff at the restaurant.<|eot_id|>


In [4]:
from llmexp.utils.data_utils import DataCollatorSST2
data_collator = DataCollatorSST2(tokenizer, max_length=512, instruction=instruction)

example = {
    'sentence': user_input,
    'label': 1
}

example = data_collator([example]).to(device)
print(example)


{'input_ids': tensor([[128000, 128006,   9125, 128007,    271,  38766,   1303,  33025,   2696,
             25,   6790,    220,   2366,     18,    198,  15724,   2696,     25,
            220,   1627,  10263,    220,   2366,     19,    271,   2127,  56956,
            279,  27065,    315,    279,   2768,  11914,    323,   6013,   3613,
            285,    989,     13, 128009, 128006,  52989, 128007,    271,    791,
           2532,    520,    420,  10960,    574,  14964,     11,    323,    279,
           5687,   1051,    779,  11919,     13, 128009, 128006,  78191, 128007]],
       device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], device='cuda:0'), 'context_mask': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,

In [5]:
print(tokenizer.decode(example['input_ids'][0] * example['attention_mask'][0]))

<|begin_of_text|><|start_header_id|>system<|end_header_id|>

Cutting Knowledge Date: December 2023
Today Date: 26 Jul 2024

Analyze the sentiment of the following sentence and respond concisely.<|eot_id|><|start_header_id|>sentence<|end_header_id|>

The service at this restaurant was fantastic, and the staff were so friendly.<|eot_id|><|start_header_id|>assistant<|end_header_id|>


In [6]:
gen_output = llm.generate(example['input_ids'], example['attention_mask'])
print(gen_output)
print(tokenizer.decode(gen_output['input_ids'][0]))

{'input_ids': tensor([[128000, 128006,   9125, 128007,    271,  38766,   1303,  33025,   2696,
             25,   6790,    220,   2366,     18,    198,  15724,   2696,     25,
            220,   1627,  10263,    220,   2366,     19,    271,   2127,  56956,
            279,  27065,    315,    279,   2768,  11914,    323,   6013,   3613,
            285,    989,     13, 128009, 128006,  52989, 128007,    271,    791,
           2532,    520,    420,  10960,    574,  14964,     11,    323,    279,
           5687,   1051,    779,  11919,     13, 128009, 128006,  78191, 128007,
            271,    791,  27065,    315,    420,  11914,    374,  55734,   6928,
             13,    578,   1005,    315,   4339,   1093,    330,  61827,   5174,
              1,    323,    330,  82630,      1,  20599,    264,   1579,   2237,
            315,  24617,    323,  35996,    369,    279,   2532,     13, 128009]],
       device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,

In [7]:
from llmexp.explainer.mab_model import MABModel
mab_model = MABModel.load_with_base_model(torch.load(saved_mab_model), llm, hidden_size=1024)
mab_model.to(device)
print()




  mab_model = MABModel.load_with_base_model(torch.load(saved_mab_model), llm, hidden_size=1024)


In [8]:
# Get gen_start_index 
gen_start_index = example['input_ids'].shape[1]
print(gen_start_index)

63


In [9]:
input_ids = gen_output['input_ids']
attention_mask = gen_output['attention_mask']
# logits, values = mab_model.get_logits_value(input_ids, attention_mask)
logits_list = mab_model.inference(input_ids, attention_mask, gen_start_index)


# mab_values = torch.softmax(logits, dim=-1)

context_mask = example['context_mask']
mab_values_list = []
context_mask_list = []
for logits in logits_list:
    # Get the size difference
    pad_size = logits.size(1) - context_mask.size(1)
    # Right pad context_mask with False values
    padded_context_mask = torch.nn.functional.pad(context_mask, (0, pad_size), value=0)

    masked_logits = logits.masked_fill(~padded_context_mask.bool(), float('-inf'))
    mab_values = torch.softmax(masked_logits, dim=-1)

    # masked_logits = logits.masked_fill(~padded_context_mask.bool(), 0)
    # mab_values = masked_logits

    mab_values_list.append(mab_values)
    context_mask_list.append(padded_context_mask)

We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)


In [10]:
mab_values_list

[tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1209,
          0.0040, 0.0104, 0.0079, 0.0016, 0.0378, 0.0119, 0.0148, 0.0139, 0.0156,
          0.0545, 0.0591, 0.6277, 0.0134, 0.0065, 0.0000, 0.0000, 0.0000, 0.0000]],
        device='cuda:0'),
 tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.00

In [11]:
def visualize_tokens_with_values(input_ids, mab_values, context_mask, tokenizer):
    # Decode tokens one by one to preserve alignment
    tokens = []
    for i in range(input_ids.shape[1]):
        token = tokenizer.decode(input_ids[0, i:i+1])
        tokens.append(token)
    
    # Normalize MAB values to [0,1] for color intensity first
    mab_values = mab_values * context_mask
    # Create a mask for non-zero values
    non_zero_mask = mab_values[0] != 0
    # normalized_values = (mab_values[0] - mab_values[0].min()) / (mab_values[0].max() - mab_values[0].min())
    normalized_values = torch.zeros_like(mab_values[0])
    # Only normalize non-zero values
    if non_zero_mask.any():  # Check if there are any non-zero values
        non_zero_values = mab_values[0][non_zero_mask]
        normalized_non_zero = (non_zero_values - non_zero_values.min()) / (non_zero_values.max() - non_zero_values.min())
        normalized_values[non_zero_mask] = normalized_non_zero
    
    # Pad normalized_values with a zero at the end
    padded_normalized_values = torch.cat([normalized_values, torch.zeros(1, device=mab_values.device)], dim=0)
    # Pad original mab_values with the last actual value
    padded_mab_values = torch.cat([mab_values[0], mab_values[0][-1:]], dim=0)
    
    # Generate HTML with colored text and values
    html_output = "<div style='font-family: monospace; line-height: 2; background-color: white; padding: 10px;'>"
    for token, value, orig_value in zip(tokens, padded_normalized_values, padded_mab_values):
        # Use a gradient from white to green
        intensity = float(value)
        green_color = int(intensity * 200)  # Control the maximum intensity
        html_output += f'<span style="color: black; background-color: rgba(0, {green_color}, 0, 0.3); padding: 0.2em; margin: 0.1em; border-radius: 3px;" title="MAB: {orig_value:.3f}, Norm: {value:.3f}">{token}</span>'
    html_output += "</div>"
    
    # Print the values
    # print("Token\tNormalized Value\tOriginal MAB Value")
    # print("-" * 50)
    # for token, value, orig_value in zip(tokens, padded_normalized_values, padded_mab_values):
    #     print(f"{token}\t{value:.3f}\t\t{orig_value:.3f}")
    
    from IPython.display import HTML
    return HTML(html_output)

In [12]:
idx = 2
logits_list[idx]

tensor([[-0.6448, -0.5844, -0.8927, -0.4281, -0.5870, -1.2761, -2.0542, -0.3719,
          0.5974, -0.7093, -0.5944, -1.0439, -0.5448, -0.6329, -1.2310, -0.9018,
          0.8960, -0.1912, -0.8569, -0.9632, -1.2564, -1.3394, -1.4138, -0.5905,
         -0.6077, -1.1437,  0.1992, -0.2699,  1.3168,  1.4095,  1.2071,  1.4351,
          0.4318,  1.2661,  1.0487, -0.9086, -1.2527,  1.0464,  0.6675, -0.0579,
         -1.7395,  0.2434, -0.3059, -0.3846,  3.4305, -0.2630,  0.5190, -0.3730,
         -0.0306,  1.1745,  2.5467,  0.0143,  0.0378, -0.2256, -0.8574, -0.1272,
         -0.1543,  0.1630, -0.4114, -0.2560, -1.1741, -0.4889,  0.0646,  1.2893,
          3.2827]], device='cuda:0')

In [13]:
# Usage:
for idx in range(len(logits_list)):
    mab_values = mab_values_list[idx]
    visualization = visualize_tokens_with_values(input_ids, mab_values, context_mask_list[idx], tokenizer)
    display(visualization)

In [14]:
input_ids.shape

torch.Size([1, 99])