# ECE284 SP25 Project

## Part 1 Install and import libraries

In [None]:
!pip install datasets
!pip install bitsandbytes

In [10]:
# Python built-in libraries
import os

# Hugging face libraries
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    TrainingArguments,
    Trainer,
    DataCollatorForLanguageModeling
)
from datasets import load_dataset
from peft import (
    LoraConfig,
    get_peft_model,
    prepare_model_for_kbit_training
)

# Pytorch libraries
import torch

# Other libraries

## Part 2 Set global parameters

In [13]:
# Baseline model
model_name = "Qwen/Qwen2.5-1.5B-Instruct"

# Datasets
dataset_name = "GBaker/MedQA-USMLE-4-options"

# Output dir
output_path = "output_models"
if not os.path.exists(output_path):
    os.mkdir(output_path)

## Part 3 Load model

In [None]:
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Config bits and bytes quantization
bnb_config = BitsAndBytesConfig(
    load_in_4bit = True,
    bnb_4bit_quant_type = "nf4",
    bnb_4bit_compute_dtype = torch.float16,
)

# Load model
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config = bnb_config,
    device_map = "auto"
)
model = prepare_model_for_kbit_training(model)

In [None]:
# Config lora
lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
                    "up_proj","down_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)
model = get_peft_model(model, lora_config)

In [None]:
# This cell is for test
model.print_trainable_parameters()

trainable params: 13,762,560 || all params: 1,557,476,864 || trainable%: 0.8836


## Part 4 Load dataset

In [14]:
# Define dataset preprocess function
def preprocess(data):
    '''
        Preprocess dataset for training and validation.
        Each data is a dict.
    '''

    # Pick question, options and answers from data
    question = data["question"]
    answer = data["answer_idx"]

    options = [key + ". " + val for key, val in data["options"].items()]
    options = "\n".join(options)

    # Concatenate information
    instruction = "Please answering the following question "        \
                    "by selecting the correct answer.\n\n"          \
                    f"Question:\n {question}\n\n"                   \
                    f"Options: {options}\n\n"                       \
                    "Provide only the letter of the correct answer."

    # Add prompt format
    instruction_formatted = "<|im_start|>user\n"                    \
                            f"{instruction} <|im_end|>\n"           \
                            "<|im_start|>assistant\n"               \
                            f"{answer} <|im_end|>\n"

    return {"text": instruction_formatted}


# Define dataset tokenization function
def tokenize(data):
    '''
        Tokenize dataset.
    '''
    # The longest input sequence length is 4424
    return tokenizer(data["text"], truncation=True, padding="max_length", max_length=5000)

In [22]:
# Load dataset
dataset = load_dataset(dataset_name)

# Preprocess dataset
column_names = dataset["train"].column_names
dataset = dataset.map(preprocess, remove_columns=column_names)

train_dataset = dataset["train"]
test_dataset = dataset["test"]

# Tokenize dataset
train_dataset = train_dataset.map(tokenize, batched=True, remove_columns=["text"])
test_dataset = test_dataset.map(tokenize, batched=True, remove_columns=["text"])

Map:   0%|          | 0/10178 [00:00<?, ? examples/s]

Map:   0%|          | 0/1273 [00:00<?, ? examples/s]

In [23]:
# This cell is for test
print(train_dataset[0])

{'input_ids': [151644, 872, 198, 5501, 35764, 279, 2701, 3405, 553, 26301, 279, 4396, 4226, 382, 14582, 510, 362, 220, 17, 18, 4666, 6284, 20280, 5220, 518, 220, 17, 17, 5555, 12743, 367, 18404, 448, 19675, 5193, 4335, 2554, 13, 2932, 5302, 432, 3855, 220, 16, 1899, 4134, 323, 702, 1012, 92305, 8818, 16163, 803, 3015, 323, 4633, 69537, 15357, 8649, 13, 2932, 5937, 11074, 1632, 323, 374, 8110, 553, 264, 10668, 369, 1059, 19636, 13, 6252, 9315, 374, 220, 24, 22, 13, 22, 58472, 320, 18, 21, 13, 20, 30937, 701, 6543, 7262, 374, 220, 16, 17, 17, 14, 22, 22, 9465, 39, 70, 11, 27235, 374, 220, 23, 15, 44173, 11, 32415, 804, 525, 220, 16, 24, 44173, 11, 323, 23552, 49743, 374, 220, 24, 23, 4, 389, 3054, 3720, 13, 27379, 7006, 374, 27190, 369, 458, 19265, 315, 2783, 1975, 665, 41643, 9210, 8376, 28568, 323, 264, 89554, 84556, 13, 15920, 315, 279, 2701, 374, 279, 1850, 6380, 369, 419, 8720, 1939, 3798, 25, 362, 13, 53687, 292, 60497, 198, 33, 13, 356, 823, 376, 685, 87, 603, 198, 34, 13, 3155, 8

## Part 5 Config training arguments

## Part 6 Train

## Part 7 Evaluate model