# Finetuning a Gemma chat model for Q&A

In this tutorial we will finetune an instruction-tuned Gemma model to do Q&A on difficult medical questions. An instruction-tuned model is able to respond to text in a conversational manner, a format which is suitable for a wide variety of tasks. 

We will load a Q&A dataset and format it into a "user" and "assistant" chat format, where the user poses the questions and the model will answer these questions. We will then finetune the model to improve its answers.
To facilitate the finetuning of such large models, we [load and finetune the model with 4-bit quantization](https://huggingface.co/docs/bitsandbytes/main/en/fsdp_qlora).

This tutorial is loosely based on [this one](https://medium.com/the-ai-forum/instruction-fine-tuning-gemma-2b-on-medical-reasoning-and-convert-the-finetuned-model-into-gguf-844191f8d329). It is designed to work on an L4 GPU with 24GB VRAM with the instruction tuned variants of either [Gemma-2B](https://huggingface.co/google/gemma-2b-it) or [Gemma-2-9B](https://huggingface.co/google/gemma-2-9b-it).


## Config

The config below exposes some key variables that affect the amount of training / eval data, how the model behaves, and important knobs for finetuning.

GPU memory usage during finetuning is a prime concern for hobbyists. If you find yourself running out of memory, adjust some of these config variables.

In [1]:
# Model
# BASE_MODEL_ID = "google/gemma-2b-it"
BASE_MODEL_ID = "google/gemma-2-9b-it"

# Dataset
INPUT_LIMIT = 700  # lower this to limit GPU memory usage
TRAIN_SIZE = 1000
TEST_SIZE = 100
EVAL_SIZE = 50
SEED = 123

# Generation / eval
GENERATE_KWARGS = dict(
    do_sample=True,
    max_new_tokens=512,
    temperature=1e-3,
)
EVAL_BATCH_SIZE = 4

# Finetuning
TRAIN_MAX_LENGTH = 512  # lower this to limit GPU memory usage
TRAIN_NUM_EPOCHS = 1
TRAIN_BATCH_SIZE = 4  # lower this to limit GPU memory usage, while slowing down training
TRAIN_GRADIENT_ACCUMULATION_STEPS = 1  # equivalent to batch size, but without the computational speedup
TRAIN_LOGGING_STEPS = 10
EVAL_ACCUMULATION_STEPS = 4  # lower this to limit GPU memory usage

## Load Gemma instruct model

Notice that we are loading the model with 4-bit quantization. When finetuning Gemma-2, we need to use the `eager` attention implementation to stabilize training.

In [2]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID)

quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_storage=torch.bfloat16,
)
model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL_ID,
    attn_implementation="eager",
    quantization_config=quantization_config
)
print(model.device)

`low_cpu_mem_usage` was None, now set to True since model is quantized.


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

cuda:0


## Medical Q & A

The dataset we will finetune on is a processed version of [`medalpaca/medical_meadow_medqa`](https://huggingface.co/datasets/medalpaca/medical_meadow_medqa). We will skip discussing the data processing and just load and inspect the dataset.

In [3]:
from medqa_data import load_train_test_data

dataset = load_train_test_data(
    train_size=TRAIN_SIZE,
    test_size=TEST_SIZE,
    seed=SEED,
    input_limit=INPUT_LIMIT,
)
display(dataset)

def print_sample(sample: dict[str, str]):
    print("\n".join(f"\n# {k}\n{v}" for k, v in sample.items())[1:])

print_sample(dataset["test"][0])

DatasetDict({
    train: Dataset({
        features: ['input', 'output'],
        num_rows: 1000
    })
    test: Dataset({
        features: ['input', 'output'],
        num_rows: 100
    })
})

# input
Q:A 27-year-old previously healthy man presents to the clinic complaining of bloody diarrhea and abdominal pain. Sexual history reveals that he has sex with men and women and uses protection most of the time. He is febrile with all other vital signs within normal limits. Physical exam demonstrates tenderness to palpation of the right upper quadrant. Subsequent ultrasound shows a uniform cyst in the liver. In addition to draining the potential abscess and sending it for culture, appropriate medical therapy would involve which of the following?? 
{'A': 'Amphotericin', 'B': 'Nifurtimox', 'C': 'Supportive therapy', 'D': 'Sulfadiazine and pyrimethamine', 'E': 'Metronidazole and iodoquinol'},

# output
E: Metronidazole and iodoquinol


We can see that this is a multiple choice medical Q&A dataset. Each question comes with 5 possible answers to choose from.

Let's pose this into a user-assistant chat format for the model and additionally, ask it to format the answer in JSON. The `<start_of_turn>user|assistant\n` tags of the text to keep track of which parts correspond to the question vs. the model generated answer.

In [4]:
question = """
Q:A child is in the nursery one day after birth. A nurse notices a urine-like discharge being expressed through the umbilical stump. What two structures in the embryo are connected by the structure that failed to obliterate during the embryologic development of this child??
{'A': 'Pulmonary artery - aorta', 'B': 'Bladder - yolk sac', 'C': 'Bladder - small bowel', 'D': 'Liver - umbilical vein', 'E': 'Kidney - large bowel'},
Give your answer as a JSON dictionary with the "option" (a letter from A-E) and the  corresponding"option_text". No yapping.
""".strip()

chat = [{"role": "user", "content": question}]
input_ids = tokenizer.apply_chat_template(chat, add_generation_prompt=True, return_tensors="pt").to(model.device)
output_ids = model.generate(input_ids, **GENERATE_KWARGS)
input_ids = input_ids.to("cpu")
output_ids = output_ids.to("cpu")
print(tokenizer.decode(output_ids[0]))

<bos><start_of_turn>user
Q:A child is in the nursery one day after birth. A nurse notices a urine-like discharge being expressed through the umbilical stump. What two structures in the embryo are connected by the structure that failed to obliterate during the embryologic development of this child??
{'A': 'Pulmonary artery - aorta', 'B': 'Bladder - yolk sac', 'C': 'Bladder - small bowel', 'D': 'Liver - umbilical vein', 'E': 'Kidney - large bowel'},
Give your answer as a JSON dictionary with the "option" (a letter from A-E) and the  corresponding"option_text". No yapping.<end_of_turn>
<start_of_turn>model
{"option": "B", "option_text": "Bladder - yolk sac"}  
<end_of_turn><eos>


Here we are only interested in using the model for Q&A, so all our chats will consist of one turn from the user and assistant. However it is certainly possible to generate and finetune on multi-turn dialogue. For more on how to do that, see the [documentation](https://huggingface.co/docs/transformers/main/en/chat_templating).

Next we will reformat the entire dataset into this chat format:

In [5]:
from medqa_data import reformat_sample

dataset = dataset.map(reformat_sample)
train_data, test_data = dataset["train"], dataset["test"]

print_sample(test_data[0])

# input
Q: A 27-year-old previously healthy man presents to the clinic complaining of bloody diarrhea and abdominal pain. Sexual history reveals that he has sex with men and women and uses protection most of the time. He is febrile with all other vital signs within normal limits. Physical exam demonstrates tenderness to palpation of the right upper quadrant. Subsequent ultrasound shows a uniform cyst in the liver. In addition to draining the potential abscess and sending it for culture, appropriate medical therapy would involve which of the following?? 
{'A': 'Amphotericin', 'B': 'Nifurtimox', 'C': 'Supportive therapy', 'D': 'Sulfadiazine and pyrimethamine', 'E': 'Metronidazole and iodoquinol'}
Give your answer as a JSON dictionary with the "option" (a letter from A-E) and the  corresponding"option_text". No yapping.

# output
{"option": "E", "text": "Metronidazole and iodoquinol"}

# true_label
E


Now that we have our dataset, we are almost ready to start finetuning. Before we do that, let's evaluate Gemma on our test set. 

If you are running Gemma-2-9B, you will notice that the performance is actually fairly good at ~60%, i.e. much higher than random! 

On the other hand, Gemma-2B literally performs at ~20%.

In [6]:
import numpy as np
from medqa_data import create_predict

batch_predict = create_predict(tokenizer, model, "gemma", batch=True, generate_kwargs=GENERATE_KWARGS)
test_data = test_data.map(batch_predict, batched=True, batch_size=EVAL_BATCH_SIZE)

gemma_accuracy = (np.asarray(test_data["gemma_label"]) == np.asarray(test_data["true_label"])).mean()
print(f"Gemma accuracy: {round(gemma_accuracy * 100, 1)}%")



Map:   0%|          | 0/100 [00:00<?, ? examples/s]

Gemma accuracy: 58.0%


## Supervised finetuning (SFT)

Now we will finetune our model on the training data that we loaded. Since we are training the model to obey instructions, this is also a form of instruction tuning called [supervised finetuning (SFT)](https://huggingface.co/docs/trl/en/sft_trainer).

First, we need to have a function that applies the model's chat template to the input questions and the expected outputs. Note that the implementation below does this in batches.

In [7]:
def create_chat_for_finetuning(samples: dict) -> list[str]:
    chat_texts = tokenizer.apply_chat_template(
        [
            [
                {"role": "user", "content": input},
                {"role": "assistant", "content": output},
            ]
            for input, output in zip(samples["input"], samples["output"])
        ],
        tokenize=False,
    )
    return [text.removeprefix("<bos>") for text in chat_texts]

display(create_chat_for_finetuning(train_data[:3]))

['<start_of_turn>user\nQ: A 70-year-old man with a long-standing history of diabetes mellitus type 2 and hypertension presents with complaints of constant wrist and shoulder pain. Currently, the patient undergoes hemodialysis 2 to 3 times a week and is on the transplant list for a kidney. The patient denies any recent traumas. Which of the following proteins is likely to be increased in his plasma, causing the patient’s late complaints?? \n{\'A\': \'Ig light chains\', \'B\': \'Amyloid A (AA)\', \'C\': \'Amyloid precursor protein\', \'D\': \'Urine tests will only be diagnostic of end-stage kidney disease\', \'E\': \'β2-microglobulin\'}\nGive your answer as a JSON dictionary with the "option" (a letter from A-E) and the  corresponding"option_text". No yapping.<end_of_turn>\n<start_of_turn>model\n{"option": "E", "text": "\\u03b22-microglobulin"}<end_of_turn>\n',
 '<start_of_turn>user\nQ: A 60-year-old African-American female presents to your office complaining of dysuria, paresthesias, an

Next, we will train the model using [parameter efficient finetuning](https://huggingface.co/docs/peft/en/index). This is again because our available GPU memory is a very limited resource. We will be using a [quantized low-rank adaptor (QLoRA)](https://huggingface.co/docs/bitsandbytes/main/en/fsdp_qlora), which allows us to train on a much smaller set of parameters than would otherwise be needed if we were to do full finetuning.

The adaptor essentially sits on top of the base Gemma model, and we will be training on a model that contains both the base model and adaptor components (but only the adaptor's weights get updated).

In [8]:
# pip install peft
from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model

model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training(model)

lora_config = LoraConfig(
    lora_alpha=16,
    lora_dropout=0.1,
    r=64,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules="all-linear",
)

model = get_peft_model(model, lora_config)

trainable, total = model.get_nb_trainable_parameters()
print(f"Trainable: {trainable:,} | total: {total:,} | Percentage: {trainable/total*100:.4f}%")

Trainable: 216,072,192 | total: 9,457,778,176 | Percentage: 2.2846%


Finally, let's finetune! Note the various flags that need to be turned on before training starts, e.g. `model.gradient_checkpointing_enable()` or `tokenizer.padding_size = right`. Working with LLMs and neural networks in general often involves dealing with many of these technical details.

In [9]:
from trl import DataCollatorForCompletionOnlyLM, SFTConfig, SFTTrainer

tokenizer.padding_side = "right"
model.config.use_cache=False
torch.cuda.empty_cache()

collator = DataCollatorForCompletionOnlyLM(
    instruction_template="<start_of_turn>user\n",
    response_template="<start_of_turn>model\n",
    tokenizer=tokenizer,
    mlm=False,
)

trainer = SFTTrainer(
    model,
    args=SFTConfig(
        output_dir="/tmp/finetuned_gemma_2b",
        per_device_train_batch_size=TRAIN_BATCH_SIZE,
        gradient_accumulation_steps=TRAIN_GRADIENT_ACCUMULATION_STEPS,
        gradient_checkpointing=True,
        gradient_checkpointing_kwargs=dict(use_reentrant=False),
        max_seq_length=TRAIN_MAX_LENGTH,
        num_train_epochs=TRAIN_NUM_EPOCHS,
        save_strategy="epoch",
        logging_steps=TRAIN_LOGGING_STEPS,
        eval_steps=TRAIN_LOGGING_STEPS,
        eval_strategy="steps",
        eval_accumulation_steps=EVAL_ACCUMULATION_STEPS,
    ),
    data_collator=collator,
    eval_dataset=test_data.select(range(min(len(test_data), EVAL_SIZE))),
    formatting_func=create_chat_for_finetuning,
    peft_config=lora_config,
    train_dataset=train_data,
    tokenizer=tokenizer,
)

train_result = trainer.train()
display(train_result._asdict())

Map:   0%|          | 0/50 [00:00<?, ? examples/s]

Step,Training Loss,Validation Loss
10,0.2691,0.130881
20,0.1302,0.097202
30,0.0803,0.089198
40,0.0823,0.089355
50,0.0762,0.087239
60,0.0778,0.082603
70,0.1061,0.075994
80,0.0767,0.074501
90,0.0576,0.074349
100,0.084,0.072865


{'global_step': 250,
 'training_loss': 0.07853067517280579,
 'metrics': {'train_runtime': 1239.651,
  'train_samples_per_second': 0.807,
  'train_steps_per_second': 0.202,
  'total_flos': 1.17394608823296e+16,
  'train_loss': 0.07853067517280579,
  'epoch': 1.0}}

Awesome, we finished training! Now let's reset all the flags toggled for training.

In [10]:
model.config.use_cache=True
model.gradient_checkpointing_disable()
model.eval()
tokenizer.padding_side = "left"

## Merge and evaluate the finetuned model

After training the LoRA adapter has not yet been merged with the base Gemma model. This will make it run a lot slower. To merge the LoRA adapter, we will follow these steps:
- https://huggingface.co/docs/peft/main/en/package_reference/lora#peft.LoraModel.merge_and_unload
- https://discuss.huggingface.co/t/help-with-merging-lora-weights-back-into-base-model/40968/3

In [11]:
from peft import PeftModel

trainer.model.save_pretrained("models/lora_adapter")

base_model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL_ID,
    low_cpu_mem_usage=True,
    return_dict=True,
    torch_dtype=torch.float16,
)

model = PeftModel.from_pretrained(base_model, "models/lora_adapter").merge_and_unload()
model.save_pretrained("models/finetuned_model", safe_serialization=True)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Now let's clean up all our prior models from GPU memory, and load the merged model.

In [12]:
import gc

try: del model
except NameError: pass
try: del trainer
except NameError: pass
try: del base_model
except NameError: pass

gc.collect()
torch.cuda.empty_cache()

model = AutoModelForCausalLM.from_pretrained(
    "models/finetuned_model",
    quantization_config=quantization_config
)

`low_cpu_mem_usage` was None, now set to True since model is quantized.


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Finally, we will run the finetuned model on our test data once more and compare the performance.

If you ran the tutorial using Gemma-2B, the performance would've increased from ~19% to ~25%. For Gemma-2-9B, the increase will be from ~58% to ~62%. So, the improvement is higher on the Gemma-2B, but Gemma-2-9B is significantly more capable overall.

In [13]:
batch_predict = create_predict(tokenizer, model, "finetuned", batch=True, generate_kwargs=GENERATE_KWARGS)
test_data = test_data.map(batch_predict, batched=True, batch_size=EVAL_BATCH_SIZE)

finetuned_accuracy = (np.asarray(test_data["finetuned_label"]) == np.asarray(test_data["true_label"])).mean()
print(f"Gemma accuracy: {round(gemma_accuracy * 100, 1)}%")
print(f"Finetuned accuracy: {round(finetuned_accuracy * 100, 1)}%")

Map:   0%|          | 0/100 [00:00<?, ? examples/s]

Gemma accuracy: 58.0%
Finetuned accuracy: 62.0%
