In [14]:
from llmexp.llm.smollm import LLMWrapper
from accelerate import Accelerator
import torch

# checkpoint = "meta-llama/Meta-Llama-3-8B-Instruct"
checkpoint = "meta-llama/Llama-3.2-1B-Instruct"
# checkpoint = "HuggingFaceTB/SmolLM-1.7B-Instruct"
# saved_mab_model = "checkpoints/mab_model_100.pth"
saved_mab_model = "checkpoints/mab_model_40.pth"


accelerator = Accelerator()
device = accelerator.device


llm = LLMWrapper(checkpoint, device=device)
tokenizer = llm.tokenizer

In [15]:
# instruction = "Analyze the sentiment of the following sentence. Be brief."
# instruction = "Analyze the sentiment of the following sentence and respond with only one word: 'positive', 'negative', or 'neutral', based on the overall tone and meaning of the sentence, if no enough information provided, respond with 'not clear' with an explanation."
instruction = "Analyze the sentiment of the following sentence and respond concisely."
# user_input = "I am extremely disappointed with the quality; it broke after just one day."
user_input = "The service at this restaurant was fantastic, and the staff were so friendly."
# user_input = "<|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|>"
# user_input = "I'm so happy with the product!"
# user_input = "The bright sunshine and gentle breeze made my afternoon truly delightful."
# user_input = "I felt deeply disappointed and frustrated after the meeting went completely off track."
# user_input = "Although the food at the restaurant was excellent, the service left much to be desired."


content = [
            {"role": "system", 
            "content": instruction
            },

            {"role": "sentence", 
            "content": user_input
            }
        ]
template = tokenizer.apply_chat_template(content, tokenize=False, add_generation_prompt=True)
# print(template)

# The generated outputs 
gen_output = llm.generate_from_texts(template)
print(template)
print(gen_output)

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


<|begin_of_text|><|start_header_id|>system<|end_header_id|>

Cutting Knowledge Date: December 2023
Today Date: 26 Jul 2024

Analyze the sentiment of the following sentence and respond concisely.<|eot_id|><|start_header_id|>sentence<|end_header_id|>

The service at this restaurant was fantastic, and the staff were so friendly.<|eot_id|><|start_header_id|>assistant<|end_header_id|>


<|begin_of_text|><|start_header_id|>system<|end_header_id|>

Cutting Knowledge Date: December 2023
Today Date: 26 Jul 2024

Analyze the sentiment of the following sentence and respond concisely.<|eot_id|><|start_header_id|>sentence<|end_header_id|>

The service at this restaurant was fantastic, and the staff were so friendly.<|eot_id|><|start_header_id|>assistant<|end_header_id|>

The sentiment of this sentence is overwhelmingly positive. The use of words like "fantastic" and "friendly" convey a strong sense of satisfaction and appreciation for the service.<|eot_id|>


In [16]:
from llmexp.utils.data_utils import DataCollator
data_collator = DataCollator(tokenizer, max_length=512, instruction=instruction)

example = {
    'sentence': user_input,
    'label': 1
}

example = data_collator([example]).to(device)
print(example)


{'input_ids': tensor([[128000, 128006,   9125, 128007,    271,  38766,   1303,  33025,   2696,
             25,   6790,    220,   2366,     18,    198,  15724,   2696,     25,
            220,   1627,  10263,    220,   2366,     19,    271,   2127,  56956,
            279,  27065,    315,    279,   2768,  11914,    323,   6013,   3613,
            285,    989,     13, 128009, 128006,  52989, 128007,    271,    791,
           2532,    520,    420,  10960,    574,  14964,     11,    323,    279,
           5687,   1051,    779,  11919,     13, 128009, 128006,  78191, 128007]],
       device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], device='cuda:0'), 'context_mask': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,

In [17]:
print(tokenizer.decode(example['input_ids'][0] * example['attention_mask'][0]))

<|begin_of_text|><|start_header_id|>system<|end_header_id|>

Cutting Knowledge Date: December 2023
Today Date: 26 Jul 2024

Analyze the sentiment of the following sentence and respond concisely.<|eot_id|><|start_header_id|>sentence<|end_header_id|>

The service at this restaurant was fantastic, and the staff were so friendly.<|eot_id|><|start_header_id|>assistant<|end_header_id|>


In [18]:
gen_output = llm.generate(example['input_ids'], example['attention_mask'])
print(gen_output)
print(tokenizer.decode(gen_output['input_ids'][0]))

{'input_ids': tensor([[128000, 128006,   9125, 128007,    271,  38766,   1303,  33025,   2696,
             25,   6790,    220,   2366,     18,    198,  15724,   2696,     25,
            220,   1627,  10263,    220,   2366,     19,    271,   2127,  56956,
            279,  27065,    315,    279,   2768,  11914,    323,   6013,   3613,
            285,    989,     13, 128009, 128006,  52989, 128007,    271,    791,
           2532,    520,    420,  10960,    574,  14964,     11,    323,    279,
           5687,   1051,    779,  11919,     13, 128009, 128006,  78191, 128007,
            271,    791,  27065,    315,    420,  11914,    374,  55734,   6928,
             13,    578,   1005,    315,   4339,   1093,    330,  61827,   5174,
              1,    323,    330,  82630,      1,  20599,    264,   1579,   2237,
            315,  24617,    323,  35996,    369,    279,   2532,     13, 128009]],
       device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,

In [19]:
from llmexp.explainer.mab_model import MABModel
mab_model = MABModel.load_with_base_model(torch.load(saved_mab_model), llm, hidden_size=1024)
mab_model.to(device)
print()




  mab_model = MABModel.load_with_base_model(torch.load(saved_mab_model), llm, hidden_size=1024)


In [20]:
# Get gen_start_index 
gen_start_index = example['input_ids'].shape[1]
print(gen_start_index)

63


In [21]:
input_ids = gen_output['input_ids']
attention_mask = gen_output['attention_mask']
# logits, values = mab_model.get_logits_value(input_ids, attention_mask)
logits_list = mab_model.inference(input_ids, attention_mask, gen_start_index)


# mab_values = torch.softmax(logits, dim=-1)

context_mask = example['context_mask']
mab_values_list = []
context_mask_list = []
for logits in logits_list:
    # Get the size difference
    pad_size = logits.size(1) - context_mask.size(1)
    # Right pad context_mask with False values
    padded_context_mask = torch.nn.functional.pad(context_mask, (0, pad_size), value=1)

    masked_logits = logits.masked_fill(~padded_context_mask.bool(), float('-inf'))
    mab_values = torch.softmax(masked_logits, dim=-1)
    mab_values_list.append(mab_values)
    context_mask_list.append(padded_context_mask)

In [22]:
mab_values_list

[tensor([[0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
          0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
          0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
          0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
          0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
          0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
          0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
          0.0000e+00, 0.0000e+00, 1.2048e-10, 1.4556e-09, 3.6372e-07, 3.2217e-10,
          1.2671e-07, 1.1194e-09, 9.6638e-02, 7.8493e-07, 2.0227e-04, 5.3159e-10,
          1.9092e-08, 9.3708e-09, 1.1803e-09, 9.0316e-01, 2.4025e-08, 0.0000e+00,
          0.0000e+00, 0.0000e+00, 0.0000e+00]], device='cuda:0'),
 tensor([[0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
          0.0000e+00, 0.0000e+00

In [23]:
def visualize_tokens_with_values(input_ids, mab_values, context_mask, tokenizer):
    # Decode tokens one by one to preserve alignment
    tokens = []
    for i in range(input_ids.shape[1]):
        token = tokenizer.decode(input_ids[0, i:i+1])
        tokens.append(token)
    
    # Normalize MAB values to [0,1] for color intensity first
    mab_values = mab_values * context_mask
    # Create a mask for non-zero values
    non_zero_mask = mab_values[0] != 0
    # normalized_values = (mab_values[0] - mab_values[0].min()) / (mab_values[0].max() - mab_values[0].min())
    normalized_values = torch.zeros_like(mab_values[0])
    # Only normalize non-zero values
    if non_zero_mask.any():  # Check if there are any non-zero values
        non_zero_values = mab_values[0][non_zero_mask]
        normalized_non_zero = (non_zero_values - non_zero_values.min()) / (non_zero_values.max() - non_zero_values.min())
        normalized_values[non_zero_mask] = normalized_non_zero
    
    # Pad normalized_values with a zero at the end
    padded_normalized_values = torch.cat([normalized_values, torch.zeros(1, device=mab_values.device)], dim=0)
    # Pad original mab_values with the last actual value
    padded_mab_values = torch.cat([mab_values[0], mab_values[0][-1:]], dim=0)
    
    # Generate HTML with colored text and values
    html_output = "<div style='font-family: monospace; line-height: 2; background-color: white; padding: 10px;'>"
    for token, value, orig_value in zip(tokens, padded_normalized_values, padded_mab_values):
        # Use a gradient from white to green
        intensity = float(value)
        green_color = int(intensity * 200)  # Control the maximum intensity
        html_output += f'<span style="color: black; background-color: rgba(0, {green_color}, 0, 0.3); padding: 0.2em; margin: 0.1em; border-radius: 3px;" title="MAB: {orig_value:.3f}, Norm: {value:.3f}">{token}</span>'
    html_output += "</div>"
    
    # Print the values
    print("Token\tNormalized Value\tOriginal MAB Value")
    print("-" * 50)
    for token, value, orig_value in zip(tokens, padded_normalized_values, padded_mab_values):
        print(f"{token}\t{value:.3f}\t\t{orig_value:.3f}")
    
    from IPython.display import HTML
    return HTML(html_output)

In [30]:
idx = 8
logits_list[idx]

tensor([[ 1.6821,  2.4370, -0.3381, -0.1458,  0.0945,  2.0942,  3.2790,  8.2897,
          4.9639,  5.0393,  1.1945,  0.6497,  1.9012,  2.7589,  3.4097,  4.4634,
          1.9652,  0.4637,  0.0984,  2.5790,  5.9221,  0.1619, -0.0758,  1.2160,
          4.0487, -0.9416, 12.7823, -1.1827,  4.5470,  7.0179, -0.7170, -0.4944,
         13.4562,  2.7673, 22.2034,  0.3594,  1.3680, 27.8104,  0.6752,  1.2474,
          0.4588,  4.8491,  2.9338,  2.5647, -0.4695,  3.4216, 11.8807,  1.0418,
         10.2537,  2.9749, 31.1248, 13.0493, 21.6345,  1.8357,  7.3487,  6.2445,
          3.0282, 34.5897,  7.6591, -0.1813, -0.8215, -0.2661,  3.4043, -0.2533,
         -0.8253,  0.9731, 27.5506, -0.0721,  2.5463, -0.2527, -0.1902]],
       device='cuda:0')

In [31]:
# Usage:
mab_values = mab_values_list[idx]
visualization = visualize_tokens_with_values(input_ids, mab_values, context_mask_list[idx], tokenizer)
display(visualization)

Token	Normalized Value	Original MAB Value
--------------------------------------------------
<|begin_of_text|>	0.000		0.000
<|start_header_id|>	0.000		0.000
system	0.000		0.000
<|end_header_id|>	0.000		0.000


	0.000		0.000
Cut	0.000		0.000
ting	0.000		0.000
 Knowledge	0.000		0.000
 Date	0.000		0.000
:	0.000		0.000
 December	0.000		0.000
 	0.000		0.000
202	0.000		0.000
3	0.000		0.000

	0.000		0.000
Today	0.000		0.000
 Date	0.000		0.000
:	0.000		0.000
 	0.000		0.000
26	0.000		0.000
 Jul	0.000		0.000
 	0.000		0.000
202	0.000		0.000
4	0.000		0.000


	0.000		0.000
An	0.000		0.000
alyze	0.000		0.000
 the	0.000		0.000
 sentiment	0.000		0.000
 of	0.000		0.000
 the	0.000		0.000
 following	0.000		0.000
 sentence	0.000		0.000
 and	0.000		0.000
 respond	0.000		0.000
 conc	0.000		0.000
is	0.000		0.000
ely	0.000		0.000
.	0.000		0.000
<|eot_id|>	0.000		0.000
<|start_header_id|>	0.000		0.000
sentence	0.000		0.000
<|end_header_id|>	0.000		0.000


	0.000		0.000
The	0.000		0.000
 service	0.000		0.000
 a

In [32]:
input_ids.shape

torch.Size([1, 99])