In [17]:
from llmexp.llm.smollm import LLMWrapper
from accelerate import Accelerator
import torch

# checkpoint = "meta-llama/Meta-Llama-3-8B-Instruct"
checkpoint = "meta-llama/Llama-3.2-1B-Instruct"
# checkpoint = "HuggingFaceTB/SmolLM-1.7B-Instruct"
saved_mab_model = "checkpoints/mab_model_100.pth"


accelerator = Accelerator()
device = accelerator.device


llm = LLMWrapper(checkpoint, device=device)
tokenizer = llm.tokenizer

In [18]:
# instruction = "Analyze the sentiment of the following sentence. Be brief."
instruction = "Analyze the sentiment of the following sentence and respond with only one word: 'positive,' 'negative,' or 'neutral,' based on the overall tone and meaning of the sentence. Do not provide any additional explanation."
user_input = "I am extremely disappointed with the quality; it broke after just one day."
# user_input = "The service at this restaurant was fantastic, and the staff were so friendly."
# user_input = "I like this movie!"

content = [
            {"role": "system", 
            "content": instruction
            },

            {"role": "sentence", 
            "content": user_input
            }
        ]
template = tokenizer.apply_chat_template(content, tokenize=False, add_generation_prompt=True)
# print(template)

# The generated outputs 
gen_output = llm.generate_from_texts(template)
print(gen_output)

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


<|begin_of_text|><|begin_of_text|><|start_header_id|>system<|end_header_id|>

Cutting Knowledge Date: December 2023
Today Date: 26 Jul 2024

Analyze the sentiment of the following sentence and respond with only one word: 'positive,' 'negative,' or 'neutral,' based on the overall tone and meaning of the sentence. Do not provide any additional explanation.<|eot_id|><|start_header_id|>sentence<|end_header_id|>

I am extremely disappointed with the quality; it broke after just one day.<|eot_id|><|start_header_id|>assistant<|end_header_id|>

negative<|eot_id|>


In [19]:
from llmexp.explainer.mab_model import MABModel
mab_model = MABModel.load_with_base_model(torch.load(saved_mab_model), llm, hidden_size=1024)
mab_model.to(device)
print()




  mab_model = MABModel.load_with_base_model(torch.load(saved_mab_model), llm, hidden_size=1024)


In [20]:
gen_inputs = tokenizer(gen_output, return_tensors="pt").to(device)
input_ids = gen_inputs['input_ids'][:, :-1]
attention_mask = gen_inputs['attention_mask'][:, :-1]
dist, values = mab_model.get_dist_value(input_ids, attention_mask)

mab_values = torch.sigmoid(dist.logits)
# mab_values = dist.logits

In [21]:
print(mab_values)
print(mab_values.shape)

tensor([[0.6241, 0.6241, 0.6241, 0.4767, 0.4974, 0.5088, 0.8086, 0.9814, 0.9606,
         0.8457, 0.4847, 0.9275, 0.5896, 0.9958, 0.6286, 0.9986, 0.3847, 0.9835,
         0.8830, 0.5874, 0.5796, 0.4676, 0.5621, 0.8221, 0.9580, 0.7357, 0.4140,
         0.9460, 0.9999, 0.9991, 0.9633, 0.9999, 0.9996, 0.9998, 0.9999, 0.9997,
         0.9996, 0.9992, 0.9871, 0.7676, 0.8924, 0.9508, 0.7476, 0.9995, 0.9996,
         0.6379, 0.9997, 0.9995, 0.9981, 0.6966, 0.9999, 0.9999, 0.9998, 0.9998,
         0.9995, 0.9993, 0.9992, 0.9988, 0.9980, 0.9999, 0.9997, 0.9999, 0.9988,
         0.9989, 0.3619, 0.9951, 0.9852, 0.9972, 0.9996, 0.9719, 0.4999, 0.5444,
         0.9990, 0.4925, 0.5819, 0.9450, 0.9841, 0.9668, 0.9989, 0.9993, 0.9929,
         0.9982, 0.9960, 0.9980, 0.9979, 0.9947, 0.9977, 0.9800, 0.9955, 0.9553,
         0.5065, 0.5140, 0.4332, 0.5873, 0.5140]], device='cuda:0',
       grad_fn=<SigmoidBackward0>)
torch.Size([1, 95])


In [22]:
print(input_ids)
print(input_ids.shape)

tensor([[128000, 128000, 128000, 128006,   9125, 128007,    271,  38766,   1303,
          33025,   2696,     25,   6790,    220,   2366,     18,    198,  15724,
           2696,     25,    220,   1627,  10263,    220,   2366,     19,    271,
           2127,  56956,    279,  27065,    315,    279,   2768,  11914,    323,
           6013,    449,   1193,    832,   3492,     25,    364,  31587,   2965,
            364,  43324,   2965,    477,    364,  60668,   2965,   3196,    389,
            279,   8244,  16630,    323,   7438,    315,    279,  11914,     13,
           3234,    539,   3493,    904,   5217,  16540,     13, 128009, 128006,
          52989, 128007,    271,     40,   1097,   9193,  25406,    449,    279,
           4367,     26,    433,  14760,   1306,   1120,    832,   1938,     13,
         128009, 128006,  78191, 128007,    271,  43324]], device='cuda:0')
torch.Size([1, 96])


In [23]:
def visualize_tokens_with_values(input_ids, mab_values, tokenizer):
    # Decode tokens one by one to preserve alignment
    tokens = []
    for i in range(input_ids.shape[1]):
        token = tokenizer.decode(input_ids[0, i:i+1])
        tokens.append(token)
    
    # Normalize MAB values to [0,1] for color intensity first
    normalized_values = (mab_values[0] - mab_values[0].min()) / (mab_values[0].max() - mab_values[0].min())
    
    # Pad normalized_values with a zero at the end
    padded_normalized_values = torch.cat([normalized_values, torch.zeros(1, device=mab_values.device)], dim=0)
    # Pad original mab_values with the last actual value
    padded_mab_values = torch.cat([mab_values[0], mab_values[0][-1:]], dim=0)
    
    # Generate HTML with colored text and values
    html_output = "<div style='font-family: monospace; line-height: 2; background-color: white; padding: 10px;'>"
    for token, value, orig_value in zip(tokens, padded_normalized_values, padded_mab_values):
        # Use a gradient from white to green
        intensity = float(value)
        green_color = int(intensity * 200)  # Control the maximum intensity
        html_output += f'<span style="color: black; background-color: rgba(0, {green_color}, 0, 0.3); padding: 0.2em; margin: 0.1em; border-radius: 3px;" title="MAB: {orig_value:.3f}, Norm: {value:.3f}">{token}</span>'
    html_output += "</div>"
    
    # Print the values
    print("Token\tNormalized Value\tOriginal MAB Value")
    print("-" * 50)
    for token, value, orig_value in zip(tokens, padded_normalized_values, padded_mab_values):
        print(f"{token}\t{value:.3f}\t\t{orig_value:.3f}")
    
    from IPython.display import HTML
    return HTML(html_output)

In [24]:
# Usage:
visualization = visualize_tokens_with_values(input_ids, mab_values, tokenizer)
display(visualization)

Token	Normalized Value	Original MAB Value
--------------------------------------------------
<|begin_of_text|>	0.411		0.624
<|begin_of_text|>	0.411		0.624
<|begin_of_text|>	0.411		0.624
<|start_header_id|>	0.180		0.477
system	0.212		0.497
<|end_header_id|>	0.230		0.509


	0.700		0.809
Cut	0.971		0.981
ting	0.938		0.961
 Knowledge	0.758		0.846
 Date	0.192		0.485
:	0.886		0.927
 December	0.357		0.590
 	0.994		0.996
202	0.418		0.629
3	0.998		0.999

	0.036		0.385
Today	0.974		0.983
 Date	0.817		0.883
:	0.353		0.587
 	0.341		0.580
26	0.166		0.468
 Jul	0.314		0.562
 	0.721		0.822
202	0.934		0.958
4	0.586		0.736


	0.082		0.414
An	0.915		0.946
alyze	1.000		1.000
 the	0.999		0.999
 sentiment	0.943		0.963
 of	1.000		1.000
 the	0.999		1.000
 following	1.000		1.000
 sentence	1.000		1.000
 and	1.000		1.000
 respond	0.999		1.000
 with	0.999		0.999
 only	0.980		0.987
 one	0.636		0.768
 word	0.831		0.892
:	0.923		0.951
 '	0.604		0.748
positive	0.999		1.000
,'	0.999		1.000
 '	0.432		0.638
negative	1.0