In [29]:
from llmexp.llm.smollm import LLMWrapper
from accelerate import Accelerator
import torch

# checkpoint = "meta-llama/Meta-Llama-3-8B-Instruct"
checkpoint = "meta-llama/Llama-3.2-1B-Instruct"
# checkpoint = "HuggingFaceTB/SmolLM-1.7B-Instruct"
# saved_mab_model = "checkpoints/mab_model_100.pth"
saved_mab_model = "checkpoints/mab_model_20.pth"


accelerator = Accelerator()
device = accelerator.device


llm = LLMWrapper(checkpoint, device=device)
tokenizer = llm.tokenizer

In [30]:
# instruction = "Analyze the sentiment of the following sentence. Be brief."
instruction = "Analyze the sentiment of the following sentence and respond with only one word: 'positive,' 'negative,' or 'neutral,' based on the overall tone and meaning of the sentence. Do not provide any additional explanation."
# user_input = "I am extremely disappointed with the quality; it broke after just one day."
user_input = "The service at this restaurant was fantastic, and the staff were so friendly."
# user_input = "This is a good book."

content = [
            {"role": "system", 
            "content": instruction
            },

            {"role": "sentence", 
            "content": user_input
            }
        ]
template = tokenizer.apply_chat_template(content, tokenize=False, add_generation_prompt=True)
# print(template)

# The generated outputs 
gen_output = llm.generate_from_texts(template)
print(gen_output)

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


<|begin_of_text|><|begin_of_text|><|start_header_id|>system<|end_header_id|>

Cutting Knowledge Date: December 2023
Today Date: 26 Jul 2024

Analyze the sentiment of the following sentence and respond with only one word: 'positive,' 'negative,' or 'neutral,' based on the overall tone and meaning of the sentence. Do not provide any additional explanation.<|eot_id|><|start_header_id|>sentence<|end_header_id|>

The service at this restaurant was fantastic, and the staff were so friendly.<|eot_id|><|start_header_id|>assistant<|end_header_id|>

positive<|eot_id|>


In [31]:
from llmexp.utils.data_utils import DataCollator
data_collator = DataCollator(tokenizer, max_length=512)

example = {
    'sentence': user_input,
    'label': 1
}

example = data_collator([example]).to(device)
print(example)


{'input_ids': tensor([[128000, 128006,   9125, 128007,    271,   2127,  56956,    279,  27065,
            315,    279,   2768,  11914,    323,   6013,    449,   1193,    832,
           3492,     25,    364,  31587,   2965,    364,  43324,   2965,    477,
            364,  60668,   2965,   3196,    389,    279,   8244,  16630,    323,
           7438,    315,    279,  11914,     13,   3234,    539,   3493,    904,
           5217,  16540,     13, 128009, 128006,  52989, 128007,    271,    791,
           2532,    520,    420,  10960,    574,  14964,     11,    323,    279,
           5687,   1051,    779,  11919,     13, 128009, 128006,  78191, 128007]],
       device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1]],
       device='cuda:0'), 'context_mask': tensor(

In [32]:
gen_output = llm.generate(example['input_ids'], example['attention_mask'])
print(gen_output)

{'input_ids': tensor([[128000, 128006,   9125, 128007,    271,   2127,  56956,    279,  27065,
            315,    279,   2768,  11914,    323,   6013,    449,   1193,    832,
           3492,     25,    364,  31587,   2965,    364,  43324,   2965,    477,
            364,  60668,   2965,   3196,    389,    279,   8244,  16630,    323,
           7438,    315,    279,  11914,     13,   3234,    539,   3493,    904,
           5217,  16540,     13, 128009, 128006,  52989, 128007,    271,    791,
           2532,    520,    420,  10960,    574,  14964,     11,    323,    279,
           5687,   1051,    779,  11919,     13, 128009, 128006,  78191, 128007,
            271,  31587, 128009]], device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1,
         1, 1, 0]], devi

In [33]:
print(tokenizer.decode(gen_output['input_ids'][0]))

<|begin_of_text|><|start_header_id|>system<|end_header_id|>

Analyze the sentiment of the following sentence and respond with only one word: 'positive,' 'negative,' or 'neutral,' based on the overall tone and meaning of the sentence. Do not provide any additional explanation.<|eot_id|><|start_header_id|>sentence<|end_header_id|>

The service at this restaurant was fantastic, and the staff were so friendly.<|eot_id|><|start_header_id|>assistant<|end_header_id|>

positive<|eot_id|>


In [34]:
from llmexp.trainer.mab_trainer import randomly_cut_and_pad_generations
cut_and_pad_gen_output = randomly_cut_and_pad_generations(example, gen_output, tokenizer)
print(cut_and_pad_gen_output)


{'input_ids': tensor([[128000, 128006,   9125, 128007,    271,   2127,  56956,    279,  27065,
            315,    279,   2768,  11914,    323,   6013,    449,   1193,    832,
           3492,     25,    364,  31587,   2965,    364,  43324,   2965,    477,
            364,  60668,   2965,   3196,    389,    279,   8244,  16630,    323,
           7438,    315,    279,  11914,     13,   3234,    539,   3493,    904,
           5217,  16540,     13, 128009, 128006,  52989, 128007,    271,    791,
           2532,    520,    420,  10960,    574,  14964,     11,    323,    279,
           5687,   1051,    779,  11919,     13, 128009, 128006,  78191, 128007,
            271,  31587]], device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1,
         1, 1]], device='cuda:0'

In [35]:
print(tokenizer.decode(cut_and_pad_gen_output['input_ids'][0]))

<|begin_of_text|><|start_header_id|>system<|end_header_id|>

Analyze the sentiment of the following sentence and respond with only one word: 'positive,' 'negative,' or 'neutral,' based on the overall tone and meaning of the sentence. Do not provide any additional explanation.<|eot_id|><|start_header_id|>sentence<|end_header_id|>

The service at this restaurant was fantastic, and the staff were so friendly.<|eot_id|><|start_header_id|>assistant<|end_header_id|>

positive


In [36]:
from llmexp.explainer.mab_model import MABModel
mab_model = MABModel.load_with_base_model(torch.load(saved_mab_model), llm, hidden_size=1024)
mab_model.to(device)
print()

  mab_model = MABModel.load_with_base_model(torch.load(saved_mab_model), llm, hidden_size=1024)





In [37]:
input_ids = cut_and_pad_gen_output['input_ids']
attention_mask = cut_and_pad_gen_output['attention_mask']
dist, values = mab_model.get_dist_value(input_ids, attention_mask)

mab_values = torch.sigmoid(dist.logits)

context_mask = cut_and_pad_gen_output['context_mask']
# mab_values = dist.logits

In [38]:
dist.logits

tensor([[-0.0744, -0.1821, -0.2556, -0.2011, -0.1284, -0.2351, -0.2945, -0.1696,
         -0.2390, -0.2496, -0.1985, -0.2542, -0.2859, -0.3161, -0.2718, -0.2600,
         -0.2452, -0.2136, -0.2030, -0.2643, -0.2442, -0.3878, -0.4335, -0.2538,
         -0.3299, -0.3185, -0.3568, -0.2413, -0.3102, -0.2797, -0.3317, -0.2492,
         -0.2292, -0.1889, -0.2909, -0.1962, -0.2822, -0.3447, -0.3071, -0.3962,
         -0.3263, -0.2755, -0.2171, -0.2550, -0.2422, -0.2194, -0.3607, -0.3115,
         -0.1606, -0.1651, -0.2648, -0.2795, -0.2874, -0.3480, -0.3514, -0.2623,
         -0.2598, -0.3774, -0.4406, -0.4785, -0.3845, -0.3791, -0.2542, -0.2363,
         -0.3274, -0.3021, -0.4036, -0.4414, -0.2714, -0.2433, -0.0986, -0.3808,
         -0.2739]], device='cuda:0', grad_fn=<SubBackward0>)

In [39]:
print(mab_values)
print(mab_values.shape)

tensor([[0.4814, 0.4546, 0.4365, 0.4499, 0.4680, 0.4415, 0.4269, 0.4577, 0.4405,
         0.4379, 0.4505, 0.4368, 0.4290, 0.4216, 0.4325, 0.4354, 0.4390, 0.4468,
         0.4494, 0.4343, 0.4393, 0.4042, 0.3933, 0.4369, 0.4183, 0.4210, 0.4117,
         0.4400, 0.4231, 0.4305, 0.4178, 0.4380, 0.4429, 0.4529, 0.4278, 0.4511,
         0.4299, 0.4147, 0.4238, 0.4022, 0.4191, 0.4316, 0.4459, 0.4366, 0.4397,
         0.4454, 0.4108, 0.4227, 0.4599, 0.4588, 0.4342, 0.4306, 0.4286, 0.4139,
         0.4130, 0.4348, 0.4354, 0.4068, 0.3916, 0.3826, 0.4050, 0.4063, 0.4368,
         0.4412, 0.4189, 0.4251, 0.4004, 0.3914, 0.4326, 0.4395, 0.4754, 0.4059,
         0.4319]], device='cuda:0', grad_fn=<SigmoidBackward0>)
torch.Size([1, 73])


In [40]:
print(input_ids)
print(input_ids.shape)

tensor([[128000, 128006,   9125, 128007,    271,   2127,  56956,    279,  27065,
            315,    279,   2768,  11914,    323,   6013,    449,   1193,    832,
           3492,     25,    364,  31587,   2965,    364,  43324,   2965,    477,
            364,  60668,   2965,   3196,    389,    279,   8244,  16630,    323,
           7438,    315,    279,  11914,     13,   3234,    539,   3493,    904,
           5217,  16540,     13, 128009, 128006,  52989, 128007,    271,    791,
           2532,    520,    420,  10960,    574,  14964,     11,    323,    279,
           5687,   1051,    779,  11919,     13, 128009, 128006,  78191, 128007,
            271,  31587]], device='cuda:0')
torch.Size([1, 74])


In [41]:
def visualize_tokens_with_values(input_ids, mab_values, context_mask, tokenizer):
    # Decode tokens one by one to preserve alignment
    tokens = []
    for i in range(input_ids.shape[1]):
        token = tokenizer.decode(input_ids[0, i:i+1])
        tokens.append(token)
    
    # Normalize MAB values to [0,1] for color intensity first
    mab_values = mab_values * context_mask[:,:-1]
    # Create a mask for non-zero values
    non_zero_mask = mab_values[0] != 0
    # normalized_values = (mab_values[0] - mab_values[0].min()) / (mab_values[0].max() - mab_values[0].min())
    normalized_values = torch.zeros_like(mab_values[0])
    # Only normalize non-zero values
    if non_zero_mask.any():  # Check if there are any non-zero values
        non_zero_values = mab_values[0][non_zero_mask]
        normalized_non_zero = (non_zero_values - non_zero_values.min()) / (non_zero_values.max() - non_zero_values.min())
        normalized_values[non_zero_mask] = normalized_non_zero
    
    # Pad normalized_values with a zero at the end
    padded_normalized_values = torch.cat([normalized_values, torch.zeros(1, device=mab_values.device)], dim=0)
    # Pad original mab_values with the last actual value
    padded_mab_values = torch.cat([mab_values[0], mab_values[0][-1:]], dim=0)
    
    # Generate HTML with colored text and values
    html_output = "<div style='font-family: monospace; line-height: 2; background-color: white; padding: 10px;'>"
    for token, value, orig_value in zip(tokens, padded_normalized_values, padded_mab_values):
        # Use a gradient from white to green
        intensity = float(value)
        green_color = int(intensity * 200)  # Control the maximum intensity
        html_output += f'<span style="color: black; background-color: rgba(0, {green_color}, 0, 0.3); padding: 0.2em; margin: 0.1em; border-radius: 3px;" title="MAB: {orig_value:.3f}, Norm: {value:.3f}">{token}</span>'
    html_output += "</div>"
    
    # Print the values
    print("Token\tNormalized Value\tOriginal MAB Value")
    print("-" * 50)
    for token, value, orig_value in zip(tokens, padded_normalized_values, padded_mab_values):
        print(f"{token}\t{value:.3f}\t\t{orig_value:.3f}")
    
    from IPython.display import HTML
    return HTML(html_output)

In [42]:
# Usage:
visualization = visualize_tokens_with_values(input_ids, mab_values, context_mask, tokenizer)
display(visualization)

Token	Normalized Value	Original MAB Value
--------------------------------------------------
<|begin_of_text|>	0.000		0.000
<|start_header_id|>	0.000		0.000
system	0.000		0.000
<|end_header_id|>	0.000		0.000


	0.000		0.000
An	0.000		0.000
alyze	0.000		0.000
 the	0.000		0.000
 sentiment	0.000		0.000
 of	0.000		0.000
 the	0.000		0.000
 following	0.000		0.000
 sentence	0.000		0.000
 and	0.000		0.000
 respond	0.000		0.000
 with	0.000		0.000
 only	0.000		0.000
 one	0.000		0.000
 word	0.000		0.000
:	0.000		0.000
 '	0.000		0.000
positive	0.000		0.000
,'	0.000		0.000
 '	0.000		0.000
negative	0.000		0.000
,'	0.000		0.000
 or	0.000		0.000
 '	0.000		0.000
neutral	0.000		0.000
,'	0.000		0.000
 based	0.000		0.000
 on	0.000		0.000
 the	0.000		0.000
 overall	0.000		0.000
 tone	0.000		0.000
 and	0.000		0.000
 meaning	0.000		0.000
 of	0.000		0.000
 the	0.000		0.000
 sentence	0.000		0.000
.	0.000		0.000
 Do	0.000		0.000
 not	0.000		0.000
 provide	0.000		0.000
 any	0.000		0.000
 additional	0.000		0.000
