In [1]:
from llmexp.llm.smollm import LLMWrapper
from accelerate import Accelerator
import torch

# checkpoint = "meta-llama/Meta-Llama-3-8B-Instruct"
checkpoint = "meta-llama/Llama-3.2-1B-Instruct"
# checkpoint = "HuggingFaceTB/SmolLM-1.7B-Instruct"
saved_mab_model = "checkpoints/mab_model_100.pth"


accelerator = Accelerator()
device = accelerator.device


llm = LLMWrapper(checkpoint, device=device)
tokenizer = llm.tokenizer

In [9]:
# instruction = "Analyze the sentiment of the following sentence. Be brief."
instruction = "Analyze the sentiment of the following sentence and respond with only one word: 'positive,' 'negative,' or 'neutral,' based on the overall tone and meaning of the sentence. Do not provide any additional explanation."
# user_input = "I am extremely disappointed with the quality; it broke after just one day."
user_input = "The service at this restaurant was fantastic, and the staff were so friendly."

content = [
            {"role": "system", 
            "content": instruction
            },

            {"role": "sentence", 
            "content": user_input
            }
        ]
template = tokenizer.apply_chat_template(content, tokenize=False, add_generation_prompt=True)
# print(template)

# The generated outputs 
gen_output = llm.generate_from_texts(template)
print(gen_output)

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


<|begin_of_text|><|begin_of_text|><|start_header_id|>system<|end_header_id|>

Cutting Knowledge Date: December 2023
Today Date: 26 Jul 2024

Analyze the sentiment of the following sentence and respond with only one word: 'positive,' 'negative,' or 'neutral,' based on the overall tone and meaning of the sentence. Do not provide any additional explanation.<|eot_id|><|start_header_id|>sentence<|end_header_id|>

The service at this restaurant was fantastic, and the staff were so friendly.<|eot_id|><|start_header_id|>assistant<|end_header_id|>

positive<|eot_id|>


In [10]:
from llmexp.explainer.mab_model import MABModel
mab_model = MABModel.load_with_base_model(torch.load(saved_mab_model), llm, hidden_size=1024)
mab_model.to(device)
print()




  mab_model = MABModel.load_with_base_model(torch.load(saved_mab_model), llm, hidden_size=1024)


In [11]:
gen_inputs = tokenizer(gen_output, return_tensors="pt").to(device)
input_ids = gen_inputs['input_ids'][:, :-1]
attention_mask = gen_inputs['attention_mask'][:, :-1]
_, mab_values, _ = mab_model.get_dist_value(input_ids, attention_mask)



In [12]:
print(mab_values)
print(mab_values.shape)

tensor([[-0.0407, -0.0407, -0.0407, -0.0376, -0.0316, -0.0332, -0.0333, -0.0410,
         -0.0386, -0.0382, -0.0334, -0.0419, -0.0348, -0.0196, -0.0374, -0.0280,
         -0.0327, -0.0323, -0.0312, -0.0361, -0.0414, -0.0362, -0.0299, -0.0319,
         -0.0188, -0.0340, -0.0276, -0.0286, -0.0297, -0.0358, -0.0263, -0.0283,
         -0.0377, -0.0306, -0.0378, -0.0300, -0.0294, -0.0252, -0.0262, -0.0304,
         -0.0288, -0.0300, -0.0308, -0.0256, -0.0259, -0.0289, -0.0331, -0.0264,
         -0.0267, -0.0296, -0.0311, -0.0342, -0.0308, -0.0304, -0.0273, -0.0255,
         -0.0316, -0.0271, -0.0360, -0.0270, -0.0293, -0.0304, -0.0265, -0.0325,
         -0.0283, -0.0296, -0.0316, -0.0344, -0.0296, -0.0291, -0.0325, -0.0344,
         -0.0290, -0.0311, -0.0232, -0.0306, -0.0379, -0.0391, -0.0392, -0.0330,
         -0.0355, -0.0353, -0.0317, -0.0332, -0.0412, -0.0407, -0.0362, -0.0333,
         -0.0336, -0.0312, -0.0327, -0.0291, -0.0286, -0.0300, -0.0258]],
       device='cuda:0', grad_fn=<Sq

In [13]:
print(input_ids)
print(input_ids.shape)

tensor([[128000, 128000, 128000, 128006,   9125, 128007,    271,  38766,   1303,
          33025,   2696,     25,   6790,    220,   2366,     18,    198,  15724,
           2696,     25,    220,   1627,  10263,    220,   2366,     19,    271,
           2127,  56956,    279,  27065,    315,    279,   2768,  11914,    323,
           6013,    449,   1193,    832,   3492,     25,    364,  31587,   2965,
            364,  43324,   2965,    477,    364,  60668,   2965,   3196,    389,
            279,   8244,  16630,    323,   7438,    315,    279,  11914,     13,
           3234,    539,   3493,    904,   5217,  16540,     13, 128009, 128006,
          52989, 128007,    271,    791,   2532,    520,    420,  10960,    574,
          14964,     11,    323,    279,   5687,   1051,    779,  11919,     13,
         128009, 128006,  78191, 128007,    271,  31587]], device='cuda:0')
torch.Size([1, 96])


In [14]:
def visualize_tokens_with_values(input_ids, mab_values, tokenizer):
    # Decode tokens one by one to preserve alignment
    tokens = []
    for i in range(input_ids.shape[1]):
        token = tokenizer.decode(input_ids[0, i:i+1])
        tokens.append(token)
    
    # Normalize MAB values to [0,1] for color intensity first
    normalized_values = (mab_values[0] - mab_values[0].min()) / (mab_values[0].max() - mab_values[0].min())
    
    # Pad normalized_values with a zero at the end
    padded_normalized_values = torch.cat([normalized_values, torch.zeros(1, device=mab_values.device)], dim=0)
    # Pad original mab_values with the last actual value
    padded_mab_values = torch.cat([mab_values[0], mab_values[0][-1:]], dim=0)
    
    # Generate HTML with colored text and values
    html_output = "<div style='font-family: monospace; line-height: 2; background-color: white; padding: 10px;'>"
    for token, value, orig_value in zip(tokens, padded_normalized_values, padded_mab_values):
        # Use a gradient from white to green
        intensity = float(value)
        green_color = int(intensity * 200)  # Control the maximum intensity
        html_output += f'<span style="color: black; background-color: rgba(0, {green_color}, 0, 0.3); padding: 0.2em; margin: 0.1em; border-radius: 3px;" title="MAB: {orig_value:.3f}, Norm: {value:.3f}">{token}</span>'
    html_output += "</div>"
    
    # Print the values
    print("Token\tNormalized Value\tOriginal MAB Value")
    print("-" * 50)
    for token, value, orig_value in zip(tokens, padded_normalized_values, padded_mab_values):
        print(f"{token}\t{value:.3f}\t\t{orig_value:.3f}")
    
    from IPython.display import HTML
    return HTML(html_output)

In [15]:
# Usage:
visualization = visualize_tokens_with_values(input_ids, mab_values, tokenizer)
display(visualization)

Token	Normalized Value	Original MAB Value
--------------------------------------------------
<|begin_of_text|>	0.050		-0.041
<|begin_of_text|>	0.050		-0.041
<|begin_of_text|>	0.050		-0.041
<|start_header_id|>	0.185		-0.038
system	0.445		-0.032
<|end_header_id|>	0.375		-0.033


	0.372		-0.033
Cut	0.039		-0.041
ting	0.144		-0.039
 Knowledge	0.162		-0.038
 Date	0.366		-0.033
:	0.000		-0.042
 December	0.308		-0.035
 	0.967		-0.020
202	0.194		-0.037
3	0.602		-0.028

	0.399		-0.033
Today	0.416		-0.032
 Date	0.461		-0.031
:	0.250		-0.036
 	0.019		-0.041
26	0.248		-0.036
 Jul	0.518		-0.030
 	0.432		-0.032
202	1.000		-0.019
4	0.342		-0.034


	0.619		-0.028
An	0.576		-0.029
alyze	0.528		-0.030
 the	0.264		-0.036
 sentiment	0.674		-0.026
 of	0.590		-0.028
 the	0.180		-0.038
 following	0.491		-0.031
 sentence	0.178		-0.038
 and	0.516		-0.030
 respond	0.540		-0.029
 with	0.722		-0.025
 only	0.680		-0.026
 one	0.499		-0.030
 word	0.567		-0.029
:	0.517		-0.030
 '	0.483		-0.031
positive	0.707		-0.026
