In [1]:
# import necessary packages
import sys, os
import re
import torch 
from transformers import (pipeline,
                          AutoTokenizer,
                          AutoModelForCausalLM,
                          DataCollatorWithPadding,
                          DataCollatorForLanguageModeling,
                          Llama4ForConditionalGeneration,
                          get_scheduler)
from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm.auto import tqdm

# config
sys.path.append('../')

  from .autonotebook import tqdm as notebook_tqdm


[2025-04-15 11:44:33,367] [INFO] [real_accelerator.py:222:get_accelerator] Setting ds_accelerator to cuda (auto detect)


/home/pkr/miniconda3/envs/rl/shared/python_compiler_compat/ld: cannot find -lcufile: No such file or directory
collect2: error: ld returned 1 exit status


In [2]:
# load llama4 tokenizer
model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token # set pad token to eos token

## 1. Setup
* Load data with the huggingface `load_dataset` function. We are using a local file, but a remote path to the huggingface hub can also be passed
* partition the data into training and testing sets
* print an example input/output pair

In [3]:
# load qa pairs as dataset object from local file
dataset = load_dataset('json', data_files='data/qa_pairs.json')

# split the data into train and test sets
dataset = dataset['train'].train_test_split(test_size=0.2)

# load the first example from the training dataset
example = dataset['train'][0]
print(example)

{'question': 'What is the name of the gene with the symbol of MCM9?', 'answer': 'minichromosome maintenance 9 homologous recombination repair factor'}


## 2. Preprocessing
* We will create a function to format our question/answer pairs to simulate a user/assistant interaction. 
    - We follow the json "role", "content" format provided on huggingface docs
    - After json formatting, we use the `apply_chat_template` method to add special characters indicating the start of user and assistant interactions
* Next, we apply the Llama4 tokenizer to our text data
    - Tokenizers convert words into integer indices, which is how LLMs understand language
    - We combine the `prompt` and `answer` to form a single string of text
    - The `prompt` tokens are not included in our label calculations, since we expect a prompt as user input.
    - We specify a `max_length` for our sequences to avoid them becoming too large

In [4]:
# define a formatting function to convert the dataset into a chat format
def format_chat(row):
    row_json_inp = [{"role": "user", "content": row["question"]}]
    row_json_out = [{"role": "assistant", "content": row["answer"]}]
    row["prompt"] = tokenizer.apply_chat_template(row_json_inp, tokenize=False)
    row["response"] = tokenizer.apply_chat_template(row_json_out, tokenize=False)
    return row

def preprocess_data(examples):
    # concat to get full text
    full_text = example["prompt"] + example["response"]

    # tokenize
    tokenized = tokenizer(full_text,
                          truncation=True,
                          max_length=512,
                          add_special_tokens=False
                          )
    
    # tokenize prompt and get length for loss masking
    prompt_tokenized = tokenizer(example["prompt"],
                          truncation=True,
                          max_length=512,
                          add_special_tokens=False
                          ) 
    prompt_length = len(prompt_tokenized['input_ids'])

    # copy input_ids to get labels, set to -100 for loss masking
    labels = tokenized['input_ids'].copy()
    labels[:prompt_length] = [-100] * prompt_length
    tokenized['labels'] = labels

    return tokenized

In [5]:
# run the functions and print examples
formatted_data = format_chat(example)
tokenized_data = preprocess_data(formatted_data)

print(formatted_data)
print(tokenized_data)

{'question': 'What is the name of the gene with the symbol of MCM9?', 'answer': 'minichromosome maintenance 9 homologous recombination repair factor', 'prompt': '<|begin_of_text|><|header_start|>user<|header_end|>\n\nWhat is the name of the gene with the symbol of MCM9?<|eot|>', 'response': '<|begin_of_text|><|header_start|>assistant<|header_end|>\n\nminichromosome maintenance 9 homologous recombination repair factor<|eot|>'}
{'input_ids': [200000, 200005, 1556, 200006, 368, 3668, 373, 290, 1327, 323, 290, 14561, 517, 290, 9985, 323, 376, 6877, 37, 43, 200008, 200000, 200005, 140680, 200006, 368, 1419, 665, 154899, 19277, 220, 37, 192161, 108592, 21769, 5437, 200008], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'labels': [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 200000, 200005, 140680, 200006, 368, 1419, 665, 154899,

### 2.1. Map Function
* We apply the formatting and tokenization to the entire dataset using the `.map` method, which applies dataset transformations in parallel

In [6]:
# apply the map function to format the dataset
formatted_dataset = dataset.map(format_chat, remove_columns=dataset['train'].column_names)
tokenized_dataset = formatted_dataset.map(preprocess_data, remove_columns=["prompt", "response"])

Map: 100%|██████████| 162031/162031 [00:26<00:00, 6100.13 examples/s]
Map: 100%|██████████| 40508/40508 [00:06<00:00, 6190.17 examples/s]
Map: 100%|██████████| 162031/162031 [00:36<00:00, 4458.00 examples/s]
Map: 100%|██████████| 40508/40508 [00:08<00:00, 4518.14 examples/s]


In [7]:
formatted_dataset['train'][0]

{'prompt': '<|begin_of_text|><|header_start|>user<|header_end|>\n\nWhat is the name of the gene with the symbol of MCM9?<|eot|>',
 'response': '<|begin_of_text|><|header_start|>assistant<|header_end|>\n\nminichromosome maintenance 9 homologous recombination repair factor<|eot|>'}

## 3. Creating DataLoaders
* We utilize pytorch dataloaders to feed multiple sentence to the model at a time in a "batch" format
    - We implement a collate function to ensure all sequences have the same length
    - We make one dataloader for our training data, and one for our test data
    - The shape of our inputs should now be `batch size` x `batch_max_seq_len`, as seen in the example batch shape  

In [8]:
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
batch_size = 2
train_dataloader = DataLoader(tokenized_dataset['train'],
                              collate_fn=data_collator,
                              batch_size=batch_size)

val_dataloader = DataLoader(tokenized_dataset['test'],
                            collate_fn=data_collator,
                            batch_size=batch_size)

In [9]:
batch = next(iter(train_dataloader))
print(batch['input_ids'].shape)

torch.Size([2, 37])


## 4. Example Forward Pass
* We give an example to the model and examine the output it produces.
    - Calling `model()` will perform a forward pass. 
        * The model takes token indices and predicts the probability for the next token in a sequence given all previous tokens.
        * We can view the loss and logits. 
            - **The loss is a measure of how far the model's predicted next tokens deviate from the ground truth. A lower loss value indicates less error.**
            - The logits are the raw predicted values for each token at each point in the sequence. We can convert our logits into probabilities by taking the softmax of the logits, and use this probability distribution to sample the next tokens in the sequence.
    - Calling `model.generate()` will produce a text output. 
        * Under the hood, the model converts predicted logits to tokens and maps these tokens back into words with the tokenizer. This produces human-readable text.

In [10]:
model = Llama4ForConditionalGeneration.from_pretrained(
    model_id,
    device_map="auto",
    torch_dtype=torch.bfloat16
)

Loading checkpoint shards: 100%|██████████| 50/50 [00:51<00:00,  1.03s/it]


In [11]:
print(batch['input_ids'].shape, batch['labels'].shape, batch['attention_mask'].shape)

torch.Size([2, 37]) torch.Size([2, 37]) torch.Size([2, 37])


In [12]:
# example forward pass (on a single example - there's a bug in llama4)
with torch.no_grad():
    outputs = model(**batch)
    loss = outputs.loss
    logits = outputs.logits
    print(loss), print(logits.shape)

tensor(7.2188, dtype=torch.bfloat16)
torch.Size([2, 37, 202048])


In [13]:
# make generation function
def generate_response(prompt):
    # tokenize and generate
    prompt_text = tokenizer.apply_chat_template([{"role": "user", "content": prompt}], tokenize=False)
    tokenized = tokenizer(prompt_text, return_tensors="pt")
    input_ids = tokenized['input_ids'].to(model.device)
    attention_mask = tokenized['attention_mask'].to(model.device)
    generated_ids = model.generate(input_ids,
                                   attention_mask=attention_mask,
                                   max_new_tokens=1024,
                                   pad_token_id=tokenizer.eos_token_id,
                                   temperature=0.1)
    full_output = tokenizer.decode(generated_ids[0], skip_special_tokens=True)

    # get answer tokens and remove "assistant" prefix
    prompt_length = len(tokenizer(prompt_text, add_special_tokens=False)['input_ids'])
    answer_ids = generated_ids[0][prompt_length:]
    answer_text = tokenizer.decode(answer_ids, skip_special_tokens=True).strip()
    cleaned_answer = re.sub(r"^assistant[\s:]+", "", answer_text, flags=re.IGNORECASE)

    return cleaned_answer

question = "What is the symbol of A-kinase anchoring protein 8 pseudogene 1?"
answer = generate_response(question)
print(f"Question: {question}")
print(f"Generated Answer: {answer}")
print(f"Expected Answer: AKAP8P1")

Question: What is the symbol of A-kinase anchoring protein 8 pseudogene 1?
Generated Answer: The symbol for A-kinase anchoring protein 8 pseudogene 1 is AKAP8P1.
Expected Answer: AKAP8P1
