In [None]:
!pip install -q -U transformers trl[vllm] datasets tensorflow fastai gensim wandb
# Tested with python 3.12 and !pip install transformers==4.57.3 trl[vllm]==4.4.1 torch==2.8.0+cu128

## Dataset processing

This notebook uses the [MedQA](https://arxiv.org/abs/2009.13081) dataset, a multiple-choice question dataset derived from medical licensing exams in the US, China, and Taiwan, designed to assess medical knowledge and clinical reasoning skills.

Load the data using the Hugging Face `datasets` library. Then, create train and validation splits. We subsample the dev split for faster evaluation times.

**Dataset citation:** Jin, D., Pan, E., Oufattole, N., Weng, W. H., Fang, H., & Szolovits, P. (2021). What disease does this patient have? a large-scale open domain question answering dataset from medical exams. Applied Sciences, 11(14), 6421.

In [None]:
import datasets

def process_medqa(data):
    prompt_template = f"""Answer the given question. Think step by step.
    You can directly provide the answer (A single letter), without further additions. E.g. "Final Answer: (A)".
    Question: [QUESTION]
    [OPTIONS]
    """
    return data.map(lambda x: {
                        'prompt': [
                            {'role': 'system', 'content': 'SYSTEM INSTRUCTION: think silently if needed.'},
                            {'role': 'user', 'content': prompt_template.replace('[QUESTION]', x['data']['Question']).replace(
                                '[OPTIONS]', f"(A) {x['data']['Options']['A']} (B) {x['data']['Options']['B']} (C) {x['data']['Options']['C']} (D) {x['data']['Options']['D']}")}
                        ],
                        'answer': x['data']['Correct Option']
                    })

medqa_dataset = datasets.load_dataset("openlifescienceai/medqa")
train_dataset = process_medqa(medqa_dataset["train"])
val_dataset = process_medqa(medqa_dataset["dev"])

In [None]:
train_dataset['prompt'][0]

[{'content': 'SYSTEM INSTRUCTION: think silently if needed.',
  'role': 'system'},
 {'content': 'Answer the given question. Think step by step.\n  You can directly provide the answer (A single letter), without further additions. E.g. "Final Answer: (A)".\n  Question: A 23-year-old pregnant woman at 22 weeks gestation presents with burning upon urination. She states it started 1 day ago and has been worsening despite drinking more water and taking cranberry extract. She otherwise feels well and is followed by a doctor for her pregnancy. Her temperature is 97.7°F (36.5°C), blood pressure is 122/77 mmHg, pulse is 80/min, respirations are 19/min, and oxygen saturation is 98% on room air. Physical exam is notable for an absence of costovertebral angle tenderness and a gravid uterus. Which of the following is the best treatment for this patient?\n  (A) Ampicillin (B) Ceftriaxone (C) Doxycycline (D) Nitrofurantoin\n  ',
  'role': 'user'}]

## Post-train the model with LoRA via GRPO on MedQA

Traditional fine-tuning of large language models is resource-intensive because it requires adjusting billions of parameters. Parameter-Efficient Fine-Tuning (PEFT) addresses this by training a smaller number of parameters. A common PEFT technique is *Low-Rank Adaptation (LoRA)*, which efficiently adapts large language models by training small, low-rank matrices that are added to the original model instead of updating the full-weight matrices.

*GRPO (Group Relative Policy Optimization)* is a reinforcement learning (RL) algorithm that aims to improve efficiency and reduce training costs by eliminating the need for a separate value function. Instead, GRPO uses group-based advantage estimation and incorporates KL divergence into the loss function for better stability.

This notebook demonstrates RL training MedGemma (with verifiable rewards) with LoRA.

First, define the reward function to check when the model's answer letter matches the correct answer letter (i.e. 'A', 'B', 'C', or 'D').

In [None]:
import re

def extract_xml_answer(answer: str) -> str:
    """Extract the answer letter from an XML answer string."""
    if not isinstance(answer, str):
        return None
    if not answer:
        return None

    final_answers = [
        r'The final answer is\s\(([A-J])\)',
        r'The final answer is\s\**\(([A-J])\)\**',
        r'The final answer is\s\$\\boxed{([A-J])}\$',
        r'Final Answer:\(([A-J])\)',
        r'Final Answer:\s\(([A-J])\)',
        r'Final Answer:\s\(?([A-J])',
        r'Final Answer:\s*\**\(([A-J])\)\**',
        r'\**Final Answer:\**\s\(([A-J])\)',
    ]
    for final_type in final_answers:
        match = re.search(final_type, answer)
        if match:
            answer_letter = match.group(1)
            return f'{answer_letter}'

    return None

def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
    """Reward function to check when the model's answer letter matches the correct answer letter (i.e. 'A', 'B', ...)"""
    responses = [completion[0]['content'] for completion in completions]
    extracted_responses = [extract_xml_answer(r) for r in responses]
    # print(f"-----Question:\n{q}\nAnswer:\n{answer[0]}\nResponse:\n{responses[0]}\nExtracted:\n{extracted_responses[0]}")
    # print([(r,a, r == a) for r, a in zip(extracted_responses, answer)])
    return [1.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]

Next, configure training with the `GRPOConfig`.

In [None]:
import torch
from peft import LoraConfig
from trl import GRPOConfig, GRPOTrainer

ckpt = "google/medgemma-1.5-4b-it"
output_dir="./tuned_medgemma4b",

training_args = GRPOConfig(
    output_dir=output_dir,
    eval_on_start=False,                     # Run an evaluation at the very beginning of training.
    learning_rate=5e-6,                      # The initial learning rate for the AdamW optimizer.
    per_device_train_batch_size=3,
    gradient_accumulation_steps=4,           # Accumulate gradients for this many steps to simulate a larger batch size (per_device_train_batch_size * gradient_accumulation_steps).
    num_generations=4,                       # Number of completions to generate per prompt for GRPO's preference learning.
    max_prompt_length=512,                   # Maximum token length for input prompts.
    max_completion_length=1024,              # Maximum token length for the model's generated completions.
    max_steps=1700,
    logging_steps=20,
    save_steps=100,
    eval_strategy="steps",
    eval_steps=100,
    report_to="tensorboard",
    use_vllm=True,                           # Use the vLLM library for significantly faster inference during generation.
    vllm_mode="colocate",                    # vLLM deployment mode; 'colocate' runs vLLM on the same GPU(s) as the trainer.
    vllm_gpu_memory_utilization=.30,         # Fraction of GPU memory that vLLM is allowed to use.
    bf16=True,                               # Enable bfloat16 mixed precision training to save memory and speed up training.
    gradient_checkpointing=True,             # Save memory by trading compute (avoids storing all intermediate activations).
    gradient_checkpointing_kwargs={
        "use_reentrant": False               # Use a more efficient implementation of gradient checkpointing.
    },
    model_init_kwargs={
        "device_map": "auto",
        "dtype": torch.bfloat16,             # Set model parameter data type to bfloat16.
        "attn_implementation": "eager"       # Gemma 3 recommends using the 'eager' attention implementation.
    },
    push_to_hub=True
)

lora_config = LoraConfig(
    task_type="CAUSAL_LM",
    r=64,
    lora_alpha=64,
    target_modules="all-linear",
)

Train the model.

Note that this will take a long time to run (~11 hrs total on A100 40GB GPU).

In [None]:
trainer = GRPOTrainer(
    model=ckpt,
    reward_funcs=[correctness_reward_func],
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset.select(range(100)), # Use a very small subset for validation
    peft_config=lora_config,
)
trainer.train()
trainer.save_model(output_dir=training_args.output_dir)

In [None]:
# Change the relevant paths to store training results in Google Drive
if "google.colab" in sys.modules and not os.environ.get("VERTEX_PRODUCT"):
    ! cp -r ./tuned_medgemma4b/ /content/drive/MyDrive/trl_colab_storage/

In [None]:
# Visualize training curves
! pip install -q tensorboard
%load_ext tensorboard
%tensorboard --logdir /content/tuned_medgemma4b/ --port 6007

## Model evaluation: Effect of RL-tuning

**Important: Before you continue, you may need to restart the runtime due to the VRAM limitation on Colab kernels.**

The following cells compute and print the accuracy of the baseline and fine-tuned models on the test dataset to assess the effect of RL-tuning.

We also load and process the test split using the same logic as before. These functions are repeated below for convenience.

In [None]:
# Reinstantiate environment variables
import os
import sys

if "google.colab" in sys.modules and not os.environ.get("VERTEX_PRODUCT"):
    # Use secret if running in Google Colab
    from google.colab import userdata
    os.environ["HF_TOKEN"] = userdata.get("HF_TOKEN")
else:
    # Store Hugging Face data under `/content` if running in Colab Enterprise
    if os.environ.get("VERTEX_PRODUCT") == "COLAB_ENTERPRISE":
        os.environ["HF_HOME"] = "/content/hf"
    # Authenticate with Hugging Face
    from huggingface_hub import get_token
    if get_token() is None:
        from huggingface_hub import notebook_login
        notebook_login()

In [None]:
import datasets
import re

def extract_xml_answer(answer: str):
    """Extract the answer letter from an XML answer string."""
    if not isinstance(answer, str):
        return None
    if not answer:
        return None

    final_answers = [
        r'The final answer is\s\(([A-J])\)',
        r'The final answer is\s\**\(([A-J])\)\**',
        r'The final answer is\s\$\\boxed{([A-J])}\$',
        r'Final Answer:\(([A-J])\)',
        r'Final Answer:\s\(([A-J])\)',
        r'Final Answer:\s\(?([A-J])',
        r'Final Answer:\s*\**\(([A-J])\)\**',
        r'\**Final Answer:\**\s\(([A-J])\)',
    ]
    for final_type in final_answers:
        match = re.search(final_type, answer)
        if match:
            answer_letter = match.group(1)
            return f'{answer_letter}'

    return None


def process_medqa(data):
    prompt_template = f"""Answer the given question. Think step by step.
    You can directly provide the answer (A single letter), without further additions. E.g. "Final Answer: (A)".
    Question: [QUESTION]
    [OPTIONS]
    """
    return data.map(lambda x: {
                        'prompt': [
                            {'role': 'system', 'content': 'SYSTEM INSTRUCTION: think silently if needed.'},
                            {'role': 'user', 'content': prompt_template.replace('[QUESTION]', x['data']['Question']).replace(
                                '[OPTIONS]', f"(A) {x['data']['Options']['A']} (B) {x['data']['Options']['B']} (C) {x['data']['Options']['C']} (D) {x['data']['Options']['D']}")}
                        ],
                        'answer': x['data']['Correct Option']
                    })

medqa_dataset = datasets.load_dataset("openlifescienceai/medqa")
test_dataset = process_medqa(medqa_dataset["test"])

Define a method to run batch inference on the test dataset.

In [None]:
import torch
import pandas as pd
from tqdm.auto import tqdm

def run_inference_batched(test_dataset, model, processor, batch_size=4, device="cuda", verbose=True):
    """
    Runs inference on a processed test dataset using batching for efficiency.

    Args:
        test_dataset: A dataset where each item has 'prompt' (chat history) and 'answer' (ground truth).
        model: The loaded PEFT model for inference.
        processor: The processor for tokenizing the input.
        batch_size (int): The number of samples to process at once. Adjust based on VRAM.
        device (str): The device to run inference on ('cuda' or 'cpu').
        verbose (bool): Whether to print progress and sample outputs.

    Returns:
        A list of dictionaries, with each dictionary containing the prompt,
        ground truth answer, and the model's generated answer.
    """
    results = []

    # Create an iterator for the batches
    num_samples = len(test_dataset)

    # Use tqdm for a progress bar if verbose is True
    batch_iterator = range(0, num_samples, batch_size)
    if verbose:
        print(f"Starting batched inference on {num_samples} samples with batch size {batch_size}...")
        batch_iterator = tqdm(batch_iterator, desc="Batch Inference")

    for i in batch_iterator:
        # 1. Prepare the current batch
        batch_data = test_dataset[i : i + batch_size]
        batch_prompts = batch_data['prompt']
        batch_ground_truths = batch_data['answer']

        # 2. Tokenize the entire batch at once with left-padding
        inputs = processor.tokenizer.apply_chat_template(
            batch_prompts,
            add_generation_prompt=True,
            tokenize=True,
            return_dict=True,
            return_tensors="pt",
            truncation=True,
            padding=True,  # Pad sequences to the length of the longest in the batch
            max_length=1024,  # Set a fixed maximum length
        ).to(device)

        # 3. Generate responses for the entire batch in one go
        outputs = model.generate(
            **inputs,
            max_new_tokens=1024,
            do_sample=False, # Greedy decoding for deterministic output appropriate for logical reasoning"
        )

        # 4. Decode the generated part of the output
        # This is more robust than decoding the whole sequence and stripping the prompt
        input_token_length = inputs['input_ids'].shape[1]
        generated_tokens = outputs[:, input_token_length:]
        model_generated_answers = processor.batch_decode(generated_tokens, skip_special_tokens=True)

        # 5. Store the results for the current batch
        for j in range(len(batch_prompts)):
            results.append({
                'prompt': batch_prompts[j],
                'ground_truth': batch_ground_truths[j],
                'model_answer': model_generated_answers[j].strip() # Use .strip() for clean output
            })

    # Optional: print a few examples from the final results
    if verbose:
        print("\n--- Sample of Batched Inference Results ---")
        for res in results[:3]: # Print first 3 results
            print(f"Ground Truth: {res['ground_truth']}")
            print(f"Model Answer: {res['model_answer']}\n")

    return results

### Evaluate baseline performance

This cell calculates the baseline model's accuracy on the test data. This baseline serves as a benchmark to measure the fine-tuned model's performance improvement.

In [None]:
from transformers import Gemma3Processor, AutoModelForCausalLM

# Load model and processor
print("Loading model and processor...")
ckpt = "google/medgemma-1.5-4b-it"
model = AutoModelForCausalLM.from_pretrained(
    ckpt,
    dtype=torch.bfloat16,
    device_map="auto",
)
processor = Gemma3Processor.from_pretrained(ckpt)

# Run inference on the test dataset
inference_results = run_inference_batched(
    test_dataset=test_dataset,
    model=model,
    processor=processor,
    batch_size=32,
)
results_df = pd.DataFrame(inference_results)
results_df['model_pred'] = results_df['model_answer'].apply(extract_xml_answer)
results_df['correct'] = results_df['ground_truth'] == results_df['model_pred']
print('Baseline Accuracy', results_df['correct'].mean())
results_df.to_csv('baseline_test_results.csv') # Save baseline results

del model # To free up VRAM
del processor
torch.cuda.empty_cache()

Baseline Accuracy 0.14072327044025157


As expected, we observe a low accuracy of 14% with only 1k output tokens. Let us look at an example output.

In [None]:
print(results_df['model_answer'].values[0])

<unused94>thought
The user wants me to identify the antibiotic that blocks cell wall synthesis based on the clinical presentation and lab findings.

1.  **Analyze the clinical presentation:** A young, sexually active male presents with symptoms suggestive of urethritis (fever, dysuria) and septic arthritis (knee pain, inflammation). This points towards a potential sexually transmitted infection (STI) that can cause disseminated infection, like *Neisseria gonorrhoeae*.
2.  **Analyze the lab findings:** The joint fluid culture shows a bacterium that does not ferment maltose and has no polysaccharide capsule.
    *   *Neisseria gonorrhoeae* is the most common cause of septic arthritis in young, sexually active adults.
    *   *Neisseria gonorrhoeae* typically ferments glucose but *not* maltose.
    *   *Neisseria gonorrhoeae* typically has a polysaccharide capsule, although some strains may lack it (non-encapsulated strains).
    *   However, the description "does not ferment maltose and 

For this case (and many others), the 1k output limit is too low, and the model's response gets cut off before it is able to respond with the final answer.

### Evaluate tuned model performance

This cell calculates the fine-tuned model's accuracy on the test data. Comparing this with the baseline score shows the improvement from fine-tuning.

In [None]:
from peft import AutoPeftModelForCausalLM
from transformers import Gemma3Processor
import torch
import pandas as pd

# Define model and processor information (make sure the paths are right!)
model_path = "/content/tuned_medgemma4b/checkpoint-1700"

# Load Model and Processor
print("Loading model and processor...")
model = AutoPeftModelForCausalLM.from_pretrained(
    model_path,
    dtype=torch.bfloat16,
    device_map="auto",
)
processor = Gemma3Processor.from_pretrained("google/medgemma-1.5-4b-it")

# Run inference on the entire processed dataset
inference_results = run_inference_batched(
    test_dataset=test_dataset,
    model=model,
    processor=processor,
    batch_size=64,
)
results_df = pd.DataFrame(inference_results)
results_df['model_pred'] = results_df['model_answer'].apply(extract_xml_answer)
results_df['correct'] = results_df['ground_truth'] == results_df['model_pred']
print('GRPO-tuned Accuracy', results_df['correct'].mean())
results_df.to_csv('trained_results.csv') # Save trained_results

GRPO-tuned Accuracy 0.7051886792452831


Reproduced MedQA **1k output** test accuracy after 1700 steps:

---



| Model    | Pre-RL Tuning | Post-RL Tuning |
| -------- | ------- | ------- |
| medgemma-1.5-4b-it | 0.141 | 0.705 |

Observations:
- Accuracy recovery: The RL-tuning with GRPO improved the model's accuracy from a baseline of 14.1% to 70.5%.
- Fixing the token limit issue: The baseline model often failed because it generated long reasoning chains that were cut off by the 1,000-token limit before it could state the final answer. The fine-tuned model learned to be more concise and provide the answer within the limit.
- High efficiency: Achieving this level of performance (70.5% accuracy) on the MedQA dataset with a relatively small 4B parameter model demonstrates the effectiveness of using LoRA and GRPO for specialized domain adaptation.



### Additional optimizations
Note that this notebook is meant to be a starting point. There are numerous optimizations that are not covered under this Colab, including [deepspeed](https://huggingface.co/docs/trl/main/en/deepspeed_integration), [parallelization on multiple nodes](https://huggingface.co/docs/trl/main/en/grpo_trainer#grpo-at-scale-train-a-70b-model-on-multiple-nodes), and more.

We recommend checking out [GRPO Trainer](https://huggingface.co/docs/trl/main/en/grpo_trainer) for further details.

## Next steps

Explore the other [notebooks](https://github.com/google-health/medgemma/blob/main/notebooks) to learn what else you can do with the model.