In [50]:
from llmexp.llm.smollm import LLMWrapper
from accelerate import Accelerator
import torch

# checkpoint = "meta-llama/Meta-Llama-3-8B-Instruct"
checkpoint = "meta-llama/Llama-3.2-1B-Instruct"
# checkpoint = "HuggingFaceTB/SmolLM-1.7B-Instruct"
# saved_mab_model = "checkpoints/mab_model_100.pth"
saved_mab_model = "checkpoints/mab_model_100.pth"


accelerator = Accelerator()
device = accelerator.device


llm = LLMWrapper(checkpoint, device=device)
tokenizer = llm.tokenizer

In [51]:
from llmexp.utils.data_utils import DataCollatorHotpotQA
from llmexp.utils.data_utils import LLMDataset, create_dataloader

instruction = "Answer the question based on the context provided."

dataloader = create_dataloader('hotpot_qa', tokenizer, max_length=2048, batch_size=1, instruction=instruction, split="test")
example = next(iter(dataloader)).to(device)

In [52]:
print(tokenizer.decode(example['input_ids'][0] * example['attention_mask'][0]))

<|begin_of_text|><|start_header_id|>system<|end_header_id|>

Answer the question based on the context provided.<|eot_id|><|start_header_id|>user<|end_header_id|>

Context Sentences:

- *Albreda:* "Albreda is a historic settlement in the Gambia on the north bank of the Gambia River, variously described as a 'trading post' or a'slave fort'.  It is located near Jufureh in the North Bank Division and an arch stands on the beach connecting the two places.  As of 2008, it has an estimated population of 1,776."
- *Hardley Flood:* "Hardley Flood is a Site of Special Scientific Interest on the north bank of the River Chet northeast of Loddon in Norfolk, part-managed by the Norfolk Wildlife Trust.  It is an area of shallow lagoons and reedbeds acting as a spillway for the River Chet.  Tidal muds attract a range of wading birds and the undisturbed reedbeds support nesting wildfowl and other fenland birds, including nationally important breeding populations of shoveller, pochard and gadwall.  Hard

In [53]:
gen_output = llm.generate(example['input_ids'], example['attention_mask'])
print(gen_output)
print(tokenizer.decode(gen_output['input_ids'][0]))

{'input_ids': tensor([[128000, 128006,   9125,  ...,    295,     13, 128009]],
       device='cuda:0'), 'attention_mask': tensor([[1, 1, 1,  ..., 1, 1, 0]], device='cuda:0')}
<|begin_of_text|><|start_header_id|>system<|end_header_id|>

Answer the question based on the context provided.<|eot_id|><|start_header_id|>user<|end_header_id|>

Context Sentences:

- *Albreda:* "Albreda is a historic settlement in the Gambia on the north bank of the Gambia River, variously described as a 'trading post' or a'slave fort'.  It is located near Jufureh in the North Bank Division and an arch stands on the beach connecting the two places.  As of 2008, it has an estimated population of 1,776."
- *Hardley Flood:* "Hardley Flood is a Site of Special Scientific Interest on the north bank of the River Chet northeast of Loddon in Norfolk, part-managed by the Norfolk Wildlife Trust.  It is an area of shallow lagoons and reedbeds acting as a spillway for the River Chet.  Tidal muds attract a range of wading bi

In [54]:
from llmexp.explainer.mab_model import MABModel
mab_model = MABModel.load_with_base_model(torch.load(saved_mab_model), llm, hidden_size=1024)
mab_model.to(device)
print()

  mab_model = MABModel.load_with_base_model(torch.load(saved_mab_model), llm, hidden_size=1024)





In [55]:
# Get gen_start_index 
gen_start_index = example['input_ids'].shape[1]
print(gen_start_index)

1799


In [56]:
input_ids = gen_output['input_ids']
attention_mask = gen_output['attention_mask']
# logits, values = mab_model.get_logits_value(input_ids, attention_mask)
logits_list = mab_model.inference(input_ids, attention_mask, gen_start_index)


# mab_values = torch.softmax(logits, dim=-1)

context_mask = example['context_mask']
mab_values_list = []
context_mask_list = []
for logits in logits_list:
    # Get the size difference
    # pad_size = logits.size(1) - context_mask.size(1)
    pad_size = logits_list[-1].size(1) - context_mask.size(1)
    # Right pad context_mask with False values
    padded_context_mask = torch.nn.functional.pad(context_mask, (0, pad_size), value=0)
    pad_size = padded_context_mask.size(1) - logits.size(1)
    padded_logits = torch.nn.functional.pad(logits, (0, pad_size), value=float('-inf'))

    masked_logits = padded_logits.masked_fill(~padded_context_mask.bool(), float('-inf'))
    mab_values = torch.softmax(masked_logits, dim=-1)

    # masked_logits = logits.masked_fill(~padded_context_mask.bool(),0)
    # mab_values = masked_logits
    # mab_values = torch.softmax(masked_logits, dim=-1)

    mab_values_list.append(mab_values)
    context_mask_list.append(padded_context_mask)

In [57]:
padded_logits.size(1)

1818

In [58]:
def visualize_tokens_with_values(input_ids, mab_values, context_mask, tokenizer):
    # Decode tokens one by one to preserve alignment
    tokens = []
    for i in range(input_ids.shape[1]):
        token = tokenizer.decode(input_ids[0, i:i+1])
        tokens.append(token)
    
    # Normalize MAB values to [0,1] for color intensity first
    mab_values = mab_values * context_mask
    # Create a mask for non-zero values
    non_zero_mask = mab_values[0] != 0
    # normalized_values = (mab_values[0] - mab_values[0].min()) / (mab_values[0].max() - mab_values[0].min())
    normalized_values = torch.zeros_like(mab_values[0])
    # Only normalize non-zero values
    if non_zero_mask.any():  # Check if there are any non-zero values
        non_zero_values = mab_values[0][non_zero_mask]
        normalized_non_zero = (non_zero_values - non_zero_values.min()) / (non_zero_values.max() - non_zero_values.min())
        normalized_values[non_zero_mask] = normalized_non_zero
    
    # Pad normalized_values with a zero at the end
    padded_normalized_values = torch.cat([normalized_values, torch.zeros(1, device=mab_values.device)], dim=0)
    # Pad original mab_values with the last actual value
    padded_mab_values = torch.cat([mab_values[0], mab_values[0][-1:]], dim=0)
    
    # Generate HTML with colored text and values
    html_output = "<div style='font-family: monospace; line-height: 2; background-color: white; padding: 10px;'>"
    for token, value, orig_value in zip(tokens, padded_normalized_values, padded_mab_values):
        # Use a gradient from white to green
        intensity = float(value)
        green_color = int(intensity * 200)  # Control the maximum intensity
        html_output += f'<span style="color: black; background-color: rgba(0, {green_color}, 0, 0.3); padding: 0.2em; margin: 0.1em; border-radius: 3px;" title="MAB: {orig_value:.3f}, Norm: {value:.3f}">{token}</span>'
    html_output += "</div>"
    
    # Print the values
    # print("Token\tNormalized Value\tOriginal MAB Value")
    # print("-" * 50)
    # for token, value, orig_value in zip(tokens, padded_normalized_values, padded_mab_values):
    #     print(f"{token}\t{value:.3f}\t\t{orig_value:.3f}")
    
    from IPython.display import HTML
    return HTML(html_output)

In [59]:
idx = 11
# mab_values = mab_values_list[idx]
mab_values = sum(mab_values_list) / len(mab_values_list)
visualization = visualize_tokens_with_values(input_ids, mab_values, context_mask_list[idx], tokenizer)
display(visualization)

In [35]:
mab_values_list

[tensor([[0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0'),
 tensor([[0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0'),
 tensor([[0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0'),
 tensor([[0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0'),
 tensor([[0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0'),
 tensor([[0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0'),
 tensor([[0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0'),
 tensor([[0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0'),
 tensor([[0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0'),
 tensor([[0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0'),
 tensor([[0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0'),
 tensor([[0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0'),
 tensor([[0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0'),
 tensor([[0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0'),
 tensor([[0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0'),
 tensor([[0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0'),
 tensor([[0., 0., 0.,  ..., 0., 0., 0.]]