In [1]:
import os 
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

from llmexp.llm.smollm import LLMWrapper
from accelerate import Accelerator
import wandb

wandb.init(project="mab-hotpotqa")

# checkpoint = "meta-llama/Meta-Llama-3-8B-Instruct"
checkpoint = "meta-llama/Llama-3.2-1B-Instruct"
# checkpoint = "HuggingFaceTB/SmolLM-1.7B-Instruct"


accelerator = Accelerator()
device = accelerator.device


llm = LLMWrapper(checkpoint, device=device)
tokenizer = llm.tokenizer

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mpd90506[0m ([33mpd90506-nd[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [2]:
from llmexp.utils.data_utils import DataCollatorHotpotQA
instruction = "Answer the question based on the context provided."
# data_collator = DataCollatorHotpotQA(tokenizer, max_length=512, instruction=instruction)

In [3]:
from llmexp.utils.data_utils import LLMDataset, create_dataloader
from llmexp.explainer.mab_model import MABModel
from llmexp.trainer.mab_trainer import MABTrainer

mab_model = MABModel(llm, hidden_size=1024)

config = {
    'lr': 1e-4,
    'num_pulls': 3,
    'num_epochs': 1,
    'batch_size': 16,
    'clip_epsilon': 0.2,
    'lambda_entropy': 0.0,
    'minibatch_size': 16,
    'ppo_epochs': 2,
    'topk': 0.2, # number of topk tokens to select, if topk < 1, then topk is the ratio of the number of tokens to select
    'save_interval': 20,
}

dataloader = create_dataloader('hotpot_qa', tokenizer, max_length=2048, batch_size=config['batch_size'], instruction=instruction, split="train")
mab_trainer = MABTrainer(mab_model, llm, tokenizer, config, device)



In [24]:
mab_trainer.train(dataloader, gamma=0.6)


Training MAB: 0it [00:00, ?it/s]We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)
Training MAB: 104it [2:21:22, 82.54s/it]

{'input_ids': tensor([[128000, 128006,   9125,  ..., 128006,  78191, 128007]]), 'attention_mask': tensor([[1, 1, 1,  ..., 1, 1, 1]]), 'context_mask': tensor([[0, 0, 0,  ..., 0, 0, 0]])}

In [13]:
example = next(iter(dataloader))
example
tokens = tokenizer.decode(example['input_ids'][0])
print(tokens)

<|begin_of_text|><|start_header_id|>system<|end_header_id|>

Answer the question based on the context provided.<|eot_id|><|start_header_id|>user<|end_header_id|>

Context Sentences:

- *Fat Lever:* "Lafayette "Fat" Lever ( ; born August 18, 1960) is an American retired professional basketball player born in Pine Bluff, Arkansas who played in the National Basketball Association.  He is currently the director of player development for the Sacramento Kings of the NBA.  Lever also serves as the color analyst for the Kings radio broadcasts."
- *Chris Mullin (basketball):* "Christopher Paul Mullin (born July 30, 1963) is an American retired professional basketball player and current head coach of the St. John's Red Storm.  He previously served as special advisor for the Sacramento Kings and general manager of the Golden State Warriors.  He is a two-time Olympic Gold medalist and a two-time Naismith Memorial Basketball Hall of Fame inductee (in 2010 as a member of the 1992 United States men's

In [None]:
# The generated outputs 
example = example.to(device)
gen_output = llm.generate(**example, max_new_tokens=50)
print(gen_output)

{'input_ids': tensor([[128000, 128006,   9125,  ...,    258,     13, 128009]],
       device='cuda:0'), 'attention_mask': tensor([[1, 1, 1,  ..., 1, 1, 0]], device='cuda:0')}


In [21]:
tokens = tokenizer.decode(gen_output['input_ids'][0])
print(tokens)

<|begin_of_text|><|start_header_id|>system<|end_header_id|>

Answer the question based on the context provided.<|eot_id|><|start_header_id|>user<|end_header_id|>

Context Sentences:

- *Fat Lever:* "Lafayette "Fat" Lever ( ; born August 18, 1960) is an American retired professional basketball player born in Pine Bluff, Arkansas who played in the National Basketball Association.  He is currently the director of player development for the Sacramento Kings of the NBA.  Lever also serves as the color analyst for the Kings radio broadcasts."
- *Chris Mullin (basketball):* "Christopher Paul Mullin (born July 30, 1963) is an American retired professional basketball player and current head coach of the St. John's Red Storm.  He previously served as special advisor for the Sacramento Kings and general manager of the Golden State Warriors.  He is a two-time Olympic Gold medalist and a two-time Naismith Memorial Basketball Hall of Fame inductee (in 2010 as a member of the 1992 United States men's