~~~
Copyright 2025 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
~~~

# Fine-tune MedGemma with Hugging Face

<table><tbody><tr>
  <td style="text-align: center">
    <a href="https://colab.research.google.com/github/google-health/medgemma/blob/main/notebooks/fine_tune_with_hugging_face.ipynb">
      <img alt="Google Colab logo" src="https://www.tensorflow.org/images/colab_logo_32px.png" width="32px"><br> Run in Google Colab
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://console.cloud.google.com/vertex-ai/colab/import/https:%2F%2Fraw.githubusercontent.com%2Fgoogle-health%2Fmedgemma%2Fmain%2Fnotebooks%2Ffine_tune_with_hugging_face.ipynb">
      <img alt="Google Cloud Colab Enterprise logo" src="https://lh3.googleusercontent.com/JmcxdQi-qOpctIvWKgPtrzZdJJK-J3sWE1RsfjZNwshCFgE_9fULcNpuXYTilIR2hjwN" width="32px"><br> Run in Colab Enterprise
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://github.com/google-health/medgemma/blob/main/notebooks/fine_tune_with_hugging_face.ipynb">
      <img alt="GitHub logo" src="https://github.githubassets.com/assets/GitHub-Mark-ea2971cee799.png" width="32px"><br> View on GitHub
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://huggingface.co/collections/google/medgemma-release-680aade845f90bec6a3f60c4">
      <img alt="Hugging Face logo" src="https://huggingface.co/front/assets/huggingface_logo-noborder.svg" width="32px"><br> View on Hugging Face
    </a>
  </td>
</tr></tbody></table>

This notebook demonstrates fine-tuning MedGemma on an image and text dataset for a vision task using Hugging Face libraries.

In this guide, you will use Hugging Face's [Transformer Reinforcement Learning (`TRL`)](https://github.com/huggingface/trl) library to train the model with Supervised Fine-Tuning (SFT), utilizing [Quantized Low-Rank Adaptation (QLoRA)](https://arxiv.org/abs/2305.14314) to reduce computational costs while maintaining high performance.


## Setup

To complete this tutorial, you'll need to have a runtime with sufficient resources to fine-tune the MedGemma model. **Note:** This guide requires a GPU that supports bfloat16 data type and has at least 40 GB of memory.

You can run this notebook in Google Colab using an A100 GPU:

1. In the upper-right of the Colab window, select **â–¾ (Additional connection options)**.
2. Select **Change runtime type**.
3. Under **Hardware accelerator**, select **A100 GPU**.

### Get access to MedGemma

Before you get started, make sure that you have access to MedGemma models on Hugging Face:

1. If you don't already have a Hugging Face account, you can create one for free by clicking [here](https://huggingface.co/join).
2. Head over to the [MedGemma model page](https://huggingface.co/google/medgemma-4b-it) and accept the usage conditions.

### Configure your HF token

Generate a Hugging Face `write` access token by going to [settings](https://huggingface.co/settings/tokens). **Note:** Make sure that the token has write access to push the fine-tuned model to Hugging Face Hub.

If you are using Google Colab, add your access token to the Colab Secrets manager to securely store it. If not, proceed to run the cell below to authenticate with Hugging Face.

1. Open your Google Colab notebook and click on the ðŸ”‘ Secrets tab in the left panel. <img src="https://storage.googleapis.com/generativeai-downloads/images/secrets.jpg" alt="The Secrets tab is found on the left panel." width=50%>
2. Create a new secret with the name `HF_TOKEN`.
3. Copy/paste your token key into the Value input box of `HF_TOKEN`.
4. Toggle the button on the left to allow notebook access to the secret.

In [1]:
import os
import sys

if "google.colab" in sys.modules and not os.environ.get("VERTEX_PRODUCT"):
    # Use secret if running in Google Colab
    from google.colab import userdata
    os.environ["HF_TOKEN"] = userdata.get("HF_TOKEN")
else:
    # Store Hugging Face data under `/content` if running in Colab Enterprise
    if os.environ.get("VERTEX_PRODUCT") == "COLAB_ENTERPRISE":
        os.environ["HF_HOME"] = "/content/hf"
    # Authenticate with Hugging Face
    from huggingface_hub import get_token
    if get_token() is None:
        from huggingface_hub import notebook_login
        notebook_login()

### Install dependencies

In [2]:
! pip install --upgrade --quiet bitsandbytes datasets evaluate peft tensorboard transformers trl

## Prepare fine-tuning dataset

This notebook uses the [NCT-CRC-HE-100K](https://zenodo.org/records/1214456) dataset which is licensed under [CC BY 4.0](https://creativecommons.org/licenses/by/4.0/legalcode), containing image patches from histological images of human colorectal cancer (CRC) and normal tissue, to fine-tune MedGemma to classify tissue types.

**Note:** The full NCT-CRC-HE-100K dataset contains 100K samples. By default this guide only uses a subset with 10,000 samples to keep the training example small, but you can adjust this number if you want to experiment.

**Dataset citation:** Kather, J. N., Halama, N., & Marx, A. (2018). 100,000 histological images of human colorectal cancer and healthy tissue (v0.1) [Data set]. Zenodo. https://doi.org/10.5281/zenodo.1214456

Download the dataset. This step may take around 15 minutes to complete.

In [3]:
# ! wget -nc -q "https://zenodo.org/records/1214456/files/NCT-CRC-HE-100K.zip"
# ! unzip -q NCT-CRC-HE-100K.zip
! curl -O "https://isic-archive.s3.amazonaws.com/challenges/2019/ISIC_2019_Training_Input.zip"


  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 9318M  100 9318M    0     0  14.9M      0  0:10:24  0:10:24 --:--:-- 15.2M
Archive:  ISIC_2019_Training_Input.zip
replace ISIC_2019_Training_Input/ATTRIBUTION.txt? [y]es, [n]o, [A]ll, [N]one, [r]ename: 

In [4]:
! curl -O "https://isic-archive.s3.amazonaws.com/challenges/2019/ISIC_2019_Training_Metadata.csv"
! curl -O "https://isic-archive.s3.amazonaws.com/challenges/2019/ISIC_2019_Training_GroundTruth.csv"

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 1185k  100 1185k    0     0   615k      0  0:00:01  0:00:01 --:--:--  615k
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 1261k  100 1261k    0     0   663k      0  0:00:01  0:00:01 --:--:--  663k


In [5]:
! rm -rf ISIC_2019_Training_Input
! unzip -q ISIC_2019_Training_Input.zip

In [6]:
train_size = 22265  # @param {type: "number"}
validation_size = 3066  # @param {type: "number"}

# Data loading and splitting is now handled in a separate cell below
# Removed: data = load_dataset("./ISIC_2019_Training_Input", split="train")
# Removed: data = data.train_test_split(train_size=train_size, test_size=validation_size, shuffle=True, seed=42,)
# Removed: data["validation"] = data.pop("test")
# Removed: data

In [7]:
import pandas as pd

# Load the ground truth CSV to inspect its structure
ground_truth_df = pd.read_csv("ISIC_2019_Training_GroundTruth.csv")
display(ground_truth_df.head())

# Get the label names from the columns, excluding the 'image' column
ISIC_TISSUE_CLASSES_RAW = ground_truth_df.columns.drop("image").tolist()
print(f"Detected ISIC Tissue Classes: {ISIC_TISSUE_CLASSES_RAW}")

Unnamed: 0,image,MEL,NV,BCC,AK,BKL,DF,VASC,SCC,UNK
0,ISIC_0000000,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1,ISIC_0000001,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2,ISIC_0000002,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3,ISIC_0000003,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
4,ISIC_0000004,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


Detected ISIC Tissue Classes: ['MEL', 'NV', 'BCC', 'AK', 'BKL', 'DF', 'VASC', 'SCC', 'UNK']


Now that we have inspected the ground truth data, we can update the `TISSUE_CLASSES` and then load the image dataset and merge it with the labels.

In [8]:
from datasets import Image, load_dataset, Dataset
import os

# Define the new TISSUE_CLASSES based on the ground truth CSV
TISSUE_CLASSES = [f"{chr(65+i)}: {cls_name}" for i, cls_name in enumerate(ISIC_TISSUE_CLASSES_RAW)]

options = "\n".join(TISSUE_CLASSES)


# Create a mapping from image ID to its integer label
def get_label_from_ground_truth(image_id: str) -> int:
    row = ground_truth_df[ground_truth_df['image'] == image_id]
    if not row.empty:
        # Find the column with value 1.0 (indicating the class)
        label_col = row.drop(columns='image').iloc[0]
        label_index = label_col[label_col == 1.0].index.tolist()
        if label_index:
            return ISIC_TISSUE_CLASSES_RAW.index(label_index[0])
    return -1 # Should not happen if all images have a label

# Load image paths and create a list of dictionaries with 'image_path' and 'label'
image_dir = "ISIC_2019_Training_Input"
image_data_list = []
for filename in os.listdir(image_dir):
    if filename.endswith((".jpg", ".jpeg", ".png")):
        image_id = os.path.splitext(filename)[0]
        label = get_label_from_ground_truth(image_id)
        if label != -1:
            image_data_list.append({"image_path": os.path.join(image_dir, filename), "label": label})

# Create a Dataset from the list and cast the image column
dataset_with_paths_and_labels = Dataset.from_list(image_data_list).cast_column("image_path", Image())

# Rename 'image_path' to 'image' to match original notebook structure
dataset_with_paths_and_labels = dataset_with_paths_and_labels.rename_column("image_path", "image")

# Create train and validation splits
data = dataset_with_paths_and_labels.train_test_split(
    train_size=train_size,
    test_size=validation_size,
    shuffle=True,
    seed=42,
)
# Use the test split as the validation set
data["validation"] = data.pop("test")

# Display dataset details
data

DatasetDict({
    train: Dataset({
        features: ['image', 'label'],
        num_rows: 22265
    })
    validation: Dataset({
        features: ['image', 'label'],
        num_rows: 3066
    })
})

In [9]:
print(TISSUE_CLASSES)
print(data['train'][0])

['A: MEL', 'B: NV', 'C: BCC', 'D: AK', 'E: BKL', 'F: DF', 'G: VASC', 'H: SCC', 'I: UNK']
{'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=600x450 at 0x7925DD86E600>, 'label': 1}


In [10]:
from typing import Any

# TISSUE_CLASSES is now defined in the previous cell based on ISIC data
# options and PROMPT are also defined in the previous cell

PROMPT = f"What is the most likely condition in the given dermascopic image?\n{options}"
print("using prompt:", PROMPT,"options",TISSUE_CLASSES)


def format_data(example: dict[str, Any]) -> dict[str, Any]:
    example["messages"] = [
        {
            "role": "user",
            "content": [
                {
                    "type": "image",
                },
                {
                    "type": "text",
                    "text": PROMPT,
                },
            ],
        },
        {
            "role": "assistant",
            "content": [
                {
                    "type": "text",
                    "text": TISSUE_CLASSES[example["label"]],
                },
            ],
        },
    ]
    return example

using prompt: What is the most likely condition in the given dermascopic image?
A: MEL
B: NV
C: BCC
D: AK
E: BKL
F: DF
G: VASC
H: SCC
I: UNK options ['A: MEL', 'B: NV', 'C: BCC', 'D: AK', 'E: BKL', 'F: DF', 'G: VASC', 'H: SCC', 'I: UNK']


For this classification task, create a multiple-choice question prompt and preprocess the data into a multimodal conversational format.

Apply the processing function on the dataset.

In [11]:
data = data.map(format_data)

# Display a processed data sample
data["train"][0]

Map:   0%|          | 0/22265 [00:00<?, ? examples/s]

Map:   0%|          | 0/3066 [00:00<?, ? examples/s]

{'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=600x450>,
 'label': 1,
 'messages': [{'content': [{'text': None, 'type': 'image'},
    {'text': 'What is the most likely condition in the given dermascopic image?\nA: MEL\nB: NV\nC: BCC\nD: AK\nE: BKL\nF: DF\nG: VASC\nH: SCC\nI: UNK',
     'type': 'text'}],
   'role': 'user'},
  {'content': [{'text': 'B: NV', 'type': 'text'}], 'role': 'assistant'}]}

## Fine-tune the model with LoRA

Traditional fine-tuning of large language models is resource-intensive because it requires adjusting billions of parameters. Parameter-Efficient Fine-Tuning (PEFT) addresses this by training a smaller number of parameters. A common PEFT technique is Low-Rank Adaptation (LoRA), which efficiently adapts large language models by training small, low-rank matrices that are added to the original model instead of updating the full-weight matrices. In QLoRA, the base model is quantized to 4-bit before its weights are frozen, then LoRA adapter layers are attached and trained.

This notebook demonstrates supervised fine-tuning MedGemma with QLoRA using the `SFTTrainer` from the Hugging Face `TRL` library.

### Load model from Hugging Face Hub

Initialize the quantization configuration and load the model.

In [20]:
import torch
from transformers import AutoProcessor, AutoModelForImageTextToText, BitsAndBytesConfig

model_id = "unsloth/medgemma-1.5-4b-it"

# T4 GPUs support float16 but not bfloat16. Setting to float16.
# If you have an A100/H100 GPU (compute capability >= 8.0), you can use torch.bfloat16.
model_kwargs = dict(
    attn_implementation="eager",
    torch_dtype=torch.float16, # Changed from torch.bfloat16
    device_map="auto",
)

model_kwargs["quantization_config"] = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=model_kwargs["torch_dtype"],
    bnb_4bit_quant_storage=model_kwargs["torch_dtype"],
)

model = AutoModelForImageTextToText.from_pretrained(model_id, **model_kwargs)
processor = AutoProcessor.from_pretrained(model_id)

# Use right padding to avoid issues during training
processor.tokenizer.padding_side = "right"
torch.cuda.empty_cache() # Added to clear memory after model loading

Loading weights:   0%|          | 0/883 [00:00<?, ?it/s]

### Set up for fine-tuning

Create a [`LoraConfig`](https://huggingface.co/docs/peft/package_reference/lora#peft.LoraConfig). It will be provided to the `SFTTrainer`, which supports built-in integration with the Hugging Face `PEFT` library.

In [13]:
from peft import LoraConfig

peft_config = LoraConfig(
    r=64,                      # Increased for complex clinical visual features
    lora_alpha=64,             # Set to r or 2*r for stability/learning speed
    lora_dropout=0.1,          # Slightly higher to prevent overfitting on clinical datasets
    bias="none",               # Standard for keeping base behavior frozen
    target_modules="all-linear", # Best practice for multimodal LLMs
    task_type="CAUSAL_LM",
    # modules_to_save removed to preserve original model behavior
)

Define a custom data collator that processes examples containing text and images and returns batches of data in the expected model input format.

In [14]:
from typing import Any


def collate_fn(examples: list[dict[str, Any]]):
    texts = []
    images = []
    for example in examples:
        images.append([example["image"].convert("RGB")])
        texts.append(processor.apply_chat_template(
            example["messages"], add_generation_prompt=False, tokenize=False
        ).strip())

    # Tokenize the texts and process the images
    batch = processor(text=texts, images=images, return_tensors="pt", padding=True)

    # The labels are the input_ids, with the padding and image tokens masked in
    # the loss computation
    labels = batch["input_ids"].clone()

    # Mask image tokens
    image_token_id = [
        processor.tokenizer.convert_tokens_to_ids(
            processor.tokenizer.special_tokens_map["boi_token"]
        )
    ]
    # Mask tokens that are not used in the loss computation
    labels[labels == processor.tokenizer.pad_token_id] = -100
    labels[labels == image_token_id] = -100
    labels[labels == 262144] = -100

    batch["labels"] = labels
    return batch

Configure training parameters in an [`SFTConfig`](https://huggingface.co/docs/trl/sft_trainer#trl.SFTConfig).

In [15]:
from trl import SFTConfig

num_train_epochs = 1  # @param {type: "number"}
learning_rate = 2e-4  # @param {type: "number"}

args = SFTConfig(
    output_dir="medgemma-4b-it-sft-lora-crc100k",            # Directory and Hub repository id to save the model to
    num_train_epochs=num_train_epochs,                       # Number of training epochs
    per_device_train_batch_size=2,                           # Batch size per device during training (further reduced)
    per_device_eval_batch_size=2,                            # Batch size per device during evaluation (further reduced)
    gradient_accumulation_steps=16,                          # Number of steps before performing a backward/update pass (increased)
    gradient_checkpointing=True,                             # Enable gradient checkpointing to reduce memory usage
    optim="adamw_torch_fused",                               # Use fused AdamW optimizer for better performance
    logging_steps=50,                                        # Number of steps between logs
    save_strategy="epoch",                                   # Save checkpoint every epoch
    eval_strategy="steps",                                   # Evaluate every `eval_steps`
    eval_steps=50,                                           # Number of steps between evaluations
    learning_rate=learning_rate,                             # Learning rate based on QLoRA paper
    fp16=True,                                               # Use float16 precision for T4 GPUs
    max_grad_norm=0.3,                                       # Max gradient norm based on QLoRA paper
    warmup_ratio=0.03,                                       # Warmup ratio based on QLoRA paper
    lr_scheduler_type="linear",                              # Use linear learning rate scheduler
    push_to_hub=True,                                        # Push model to Hub
    report_to="tensorboard",                                 # Report metrics to tensorboard
    gradient_checkpointing_kwargs={"use_reentrant": False},  # Set gradient checkpointing to non-reentrant to avoid issues
    dataset_kwargs={"skip_prepare_dataset": True},           # Skip default dataset preparation to preprocess manually
    remove_unused_columns = False,                           # Columns are unused for training but needed for data collator
    label_names=["labels"],                                  # Input keys that correspond to the labels
)


warmup_ratio is deprecated and will be removed in v5.2. Use `warmup_steps` instead.


### Fine-tune the model

Construct an [`SFTTrainer`](https://huggingface.co/docs/trl/sft_trainer) using the previously defined LoRA configuration, custom data collator, and training parameters.

In [21]:
from trl import SFTTrainer

trainer = SFTTrainer(
    model=model,
    args=args,
    train_dataset=data["train"],
    eval_dataset=data["validation"].shuffle().select(range(200)),  # Use subset of validation set for faster run
    peft_config=peft_config,
    processing_class=processor,
    data_collator=collate_fn,
)
torch.cuda.empty_cache() # Added to clear memory after trainer initialization

In [19]:
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
print("PYTORCH_CUDA_ALLOC_CONF set to expandable_segments:True")

PYTORCH_CUDA_ALLOC_CONF set to expandable_segments:True


Launch the fine-tuning process.

**Note**: This may take around 3 hours to run using the default configuration.

In [22]:
trainer.train()

OutOfMemoryError: CUDA out of memory. Tried to allocate 2.00 GiB. GPU 0 has a total capacity of 14.74 GiB of which 666.12 MiB is free. Process 148989 has 14.09 GiB memory in use. Of the allocated memory 11.90 GiB is allocated by PyTorch, and 2.05 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

Save the final model to Hugging Face Hub.

In [None]:
trainer.save_model()

Free up memory before proceeding to evaluate and test the fine-tuned model.

In [None]:
del model
del trainer
torch.cuda.empty_cache()

## Evaluate the fine-tuned model

### Prepare test dataset

The [CRC-VAL-HE-7K](https://zenodo.org/records/1214456) dataset contains image patches from patients with colorectal adenocarcinoma and does not overlap with NCT-CRC-HE-100K. It can be used as the test dataset to evaluate the fine-tuned MedGemma model.

**Note:** The full CRC-VAL-HE-7K dataset contains over 7K samples. By default this guide only uses a subset with 1,000 samples to keep the evaluation example small.

Download and prepare the test dataset.

In [None]:
! wget -nc -q "https://zenodo.org/records/1214456/files/CRC-VAL-HE-7K.zip"
! unzip -q CRC-VAL-HE-7K.zip

In [None]:
from typing import Any

from datasets import load_dataset


def format_test_data(example: dict[str, Any]) -> dict[str, Any]:
    example["messages"] = [
        {
            "role": "user",
            "content": [
                {
                    "type": "image",
                },
                {
                    "type": "text",
                    "text": PROMPT,
                },
            ],
        },
    ]
    return example


test_data = load_dataset("./CRC-VAL-HE-7K", split="train")
test_data = test_data.shuffle(seed=42).select(range(1000))
test_data = test_data.map(format_test_data)

### Set up for evaluation

Load the accuracy and F1 score metrics to evaluate the model's performance on the classfication task.

In [None]:
import evaluate

accuracy_metric = evaluate.load("accuracy")
f1_metric = evaluate.load("f1")

# Ground-truth labels
REFERENCES = test_data["label"]


def compute_metrics(predictions: list[int]) -> dict[str, float]:
    metrics = {}
    metrics.update(accuracy_metric.compute(
        predictions=predictions,
        references=REFERENCES,
    ))
    metrics.update(f1_metric.compute(
        predictions=predictions,
        references=REFERENCES,
        average="weighted",
    ))
    return metrics

Define a postprocessing function to convert responses to integer class labels before computing metrics.

In [None]:
from datasets import ClassLabel

# Rename the class names to the tissue classes, `X: tissue type`
test_data = test_data.cast_column(
    "label",
    ClassLabel(names=TISSUE_CLASSES)
)

LABEL_FEATURE = test_data.features["label"]
# Mapping to alternative label format, `(X) tissue type`
ALT_LABELS = dict([
    (label, f"({label.replace(': ', ') ')}") for label in TISSUE_CLASSES
])


def postprocess(prediction: list[dict[str, str]], do_full_match: bool=False) -> int:
    response_text = prediction[0]["generated_text"]
    if do_full_match:
        return LABEL_FEATURE.str2int(response_text)
    for label in TISSUE_CLASSES:
        # Search for `X: tissue type` or `(X) tissue type` in the response
        if label in response_text or ALT_LABELS[label] in response_text:
            return LABEL_FEATURE.str2int(label)
    return -1

### Compute baseline metrics on the pretrained model

Load the pretrained model using the `pipeline` API.

In [None]:
from transformers import pipeline

pt_pipe = pipeline(
    "image-text-to-text",
    model=model_id,
    torch_dtype=torch.float16,
)

# Set `do_sample = False` for deterministic responses
pt_pipe.model.generation_config.do_sample = False
pt_pipe.model.generation_config.pad_token_id = processor.tokenizer.eos_token_id

Run batch inference on the test dataset.

In [None]:
pt_outputs = pt_pipe(
    text=test_data["messages"],
    images=test_data["image"],
    max_new_tokens=40,
    batch_size=64,
    return_full_text=False,
)

pt_predictions = [postprocess(out) for out in pt_outputs]

Compute metrics.

In [None]:
pt_metrics = compute_metrics(pt_predictions)
print(f"Baseline metrics: {pt_metrics}")

### Compute metrics on the fine-tuned model

Load the base model with the fine-tuned LoRA adapter using the `pipeline` API.

In [None]:
ft_pipe = pipeline(
    "image-text-to-text",
    model=args.output_dir,
    processor=processor,
    torch_dtype=torch.float16,
)

# Set `do_sample = False` for deterministic responses
ft_pipe.model.generation_config.do_sample = False
ft_pipe.model.generation_config.pad_token_id = processor.tokenizer.eos_token_id
# Use left padding during inference
processor.tokenizer.padding_side = "left"

Run batch inference on the test dataset.

In [None]:
ft_outputs = ft_pipe(
    text=test_data["messages"],
    images=test_data["image"],
    max_new_tokens=20,
    batch_size=64,
    return_full_text=False,
)

ft_predictions = [postprocess(out, do_full_match=True) for out in ft_outputs]

Compute metrics.

In [None]:
ft_metrics = compute_metrics(ft_predictions)
print(f"Fine-tuned metrics: {ft_metrics}")

# Next steps

Explore the other [notebooks](https://github.com/google-health/medgemma/blob/main/notebooks) to learn what else you can do with the model.