In [1]:
# import necessary packages
import sys, os
import torch 
from transformers import (pipeline,
                          AutoTokenizer,
                          AutoModelForCausalLM,
                          DataCollatorWithPadding,
                          Llama4ForConditionalGeneration,
                          get_scheduler)
from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm.auto import tqdm

# config
sys.path.append('../')

  from .autonotebook import tqdm as notebook_tqdm


[2025-04-10 18:21:15,918] [INFO] [real_accelerator.py:222:get_accelerator] Setting ds_accelerator to cuda (auto detect)


/home/pkr/miniconda3/envs/rl/shared/python_compiler_compat/ld: cannot find -lcufile: No such file or directory
collect2: error: ld returned 1 exit status


In [2]:
# load llama4 tokenizer
model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token # set pad token to eos token

## 1. Setup
* Load data with the huggingface `load_dataset` function. We are using a local file, but a remote path to the huggingface hub can also be passed
* partition the data into training and testing sets
* print an example input/output pair

In [3]:
# load qa pairs as dataset object from local file
dataset = load_dataset('json', data_files='data/qa_pairs.json')

# split the data into train and test sets
dataset = dataset['train'].train_test_split(test_size=0.2)

# load the first example from the training dataset
example = dataset['train'][0]
print(example)

{'question': 'What is the ENSEMBL id of dynein axonemal assembly factor 11?', 'answer': 'ENSG00000129295'}


## 2. Preprocessing
* We will create a function to format our question/answer pairs to simulate a user/assistant interaction. 
    - We follow the json "role", "content" format provided on huggingface docs
    - After json formatting, we use the `apply_chat_template` method to add special characters indicating the start of user and assistant interactions
* Next, we apply the Llama4 tokenizer to our text data
    - Tokenizers convert words into integer indices, which is how LLMs understand language
    - We specify a `max_length` for our sequences to avoid them becoming too large
    - For sequences below the `max_length`, we apply padding to make all sentence sizes even.

In [4]:
# define a formatting function to convert the dataset into a chat format
def format_chat(row):
    row_json_inp = [{"role": "user", "content": row["question"]}]
    row_json_out = [{"role": "assistant", "content": row["answer"]}]
    row["input"] = tokenizer.apply_chat_template(row_json_inp, tokenize=False)
    row["target"] = tokenizer.apply_chat_template(row_json_out, tokenize=False)
    return row

def preprocess_data(examples):
    inp = examples["input"]
    out = examples["target"]
    tokenized_data = tokenizer(text=inp, 
                               text_target=out,
                               padding='max_length', 
                               truncation=True, 
                               max_length=512)
    return tokenized_data

In [5]:
# run the functions and print examples
formatted_data = format_chat(example)
tokenized_data = preprocess_data(formatted_data)

print(formatted_data)
print(tokenized_data)

{'question': 'What is the ENSEMBL id of dynein axonemal assembly factor 11?', 'answer': 'ENSG00000129295', 'input': '<|begin_of_text|><|header_start|>user<|header_end|>\n\nWhat is the ENSEMBL id of dynein axonemal assembly factor 11?<|eot|>', 'target': '<|begin_of_text|><|header_start|>assistant<|header_end|>\n\nENSG00000129295<|eot|>'}
{'input_ids': [200000, 200000, 200005, 1556, 200006, 368, 3668, 373, 290, 8940, 1546, 13346, 56, 1182, 323, 14764, 588, 259, 4810, 261, 347, 278, 19543, 5437, 220, 825, 43, 200008, 200008, 200008, 200008, 200008, 200008, 200008, 200008, 200008, 200008, 200008, 200008, 200008, 200008, 200008, 200008, 200008, 200008, 200008, 200008, 200008, 200008, 200008, 200008, 200008, 200008, 200008, 200008, 200008, 200008, 200008, 200008, 200008, 200008, 200008, 200008, 200008, 200008, 200008, 200008, 200008, 200008, 200008, 200008, 200008, 200008, 200008, 200008, 200008, 200008, 200008, 200008, 200008, 200008, 200008, 200008, 200008, 200008, 200008, 200008, 200008, 

### 2.1. Map Function
* We apply the formatting and tokenization to the entire dataset using the `.map` method, which applies dataset transformations in parallel

In [6]:
# apply the map function to format the dataset
formatted_dataset = dataset.map(format_chat, remove_columns=dataset['train'].column_names)
tokenized_dataset = formatted_dataset.map(preprocess_data, batched=True, remove_columns=formatted_dataset['train'].column_names)
tokenized_dataset = tokenized_dataset.with_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])

Map: 100%|██████████| 162031/162031 [00:26<00:00, 6058.03 examples/s]
Map: 100%|██████████| 40508/40508 [00:06<00:00, 6073.72 examples/s]
Map: 100%|██████████| 162031/162031 [01:01<00:00, 2654.44 examples/s]
Map: 100%|██████████| 40508/40508 [00:13<00:00, 3074.08 examples/s]


In [7]:
tokenized_dataset['train'][0]['input_ids'].shape

torch.Size([512])

## 3. Creating DataLoaders
* We utilize pytorch dataloaders to feed multiple sentence to the model at a time in a "batch" format
    - We implement a collate function to ensure all sequences have the same length
    - We make one dataloader for our training data, and one for our test data
    - The shape of our inputs should now be `batch size` x `max_seq_len`, as seen in the example batch shape  

In [14]:
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
batch_size = 2
train_dataloader = DataLoader(tokenized_dataset['train'],
                            batch_size=batch_size)

val_dataloader = DataLoader(tokenized_dataset['test'],
                            batch_size=batch_size)

In [15]:
batch = next(iter(train_dataloader))
print(batch['input_ids'].shape)

torch.Size([2, 512])


## 4. Example Forward Pass
* We give an example to the model and examine the output it produces.
    - Calling `model()` will perform a forward pass. 
        * The model takes token indices and predicts the probability for the next token in a sequence given all previous tokens.
        * We can view the loss and logits. 
            - **The loss is a measure of how far the model's predicted next tokens deviate from the ground truth. A lower loss value indicates less error.**
            - The logits are the raw predicted values for each token at each point in the sequence. We can convert our logits into probabilities by taking the softmax of the logits, and use this probability distribution to sample the next tokens in the sequence.
    - Calling `model.generate()` will produce a text output.
        * Under the hood, the model converts predicted logits to tokens and maps these tokens back into words with the tokenizer. This produces human-readable text.

In [10]:
model = Llama4ForConditionalGeneration.from_pretrained(
    model_id,
    device_map="auto",
    torch_dtype=torch.bfloat16
)

Loading checkpoint shards: 100%|██████████| 50/50 [00:51<00:00,  1.04s/it]


In [16]:
# example forward pass (on a single example - there's a bug in llama4)
with torch.no_grad():
    outputs = model(input_ids=batch['input_ids'], 
                    attention_mask=batch['attention_mask'], 
                    labels=batch['labels'])
    loss = outputs.loss
    logits = outputs.logits
    print(loss), print(logits.shape)

tensor(25.3750, dtype=torch.bfloat16)
torch.Size([2, 512, 202048])


In [19]:
# example generate (just for a single question here)
with torch.no_grad():
    outputs = model.generate(input_ids=batch['input_ids'].to('cuda'), 
                             attention_mask=batch['attention_mask'].to('cuda'), 
                             max_length=1024, 
                             do_sample=True, 
                             top_k=50, 
                             top_p=0.95, 
                             temperature=1.0)
    print(tokenizer.decode(outputs[0], skip_special_tokens=True))

user

What is the ENSEMBL id of dynein axonemal assembly factor 11?assistant

I don't have the specific information on the ENSEMBL ID for dynein axonemal assembly factor 11. For the most accurate and up-to-date information, I recommend checking directly with the ENSEMBL database or other reliable genetic databases. They should have the most current details on gene identifiers, including ENSEMBL IDs for various genes and proteins.
