# Fine-tuning DistilBERT for question answering

This guide describes fine-tuning DistilBERT model with Stanford Question Answering Dataset (SQuAD) for question-answering using Kubeflow Trainer.

This guide is adapted from HuggingFace question answering task recipe page: https://huggingface.co/docs/transformers/en/tasks/question_answering

Pretrained DistilBERT: https://huggingface.co/docs/transformers/en/model_doc/distilbert

SQuAD dataset: https://huggingface.co/datasets/rajpurkar/squad

# Install the Kubeflow SDK

You need to install the Kubeflow SDK to interact with Kubeflow Trainer APIs:

In [None]:
# !pip install git+https://github.com/kubeflow/sdk.git@main#subdirectory=python

Install dependencies

In [1]:
!pip install "cloudpathlib[all]" "transformers[torch]"


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.0.1[0m[39;49m -> [0m[32;49m25.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


# Define the HuggingFace training script

We need to wrap our training script into a function to create the Kubeflow TrainJob.

In [2]:
def train_distilbert(args):
    import os

    from cloudpathlib import CloudPath
    from datasets import load_dataset
    import torch
    from transformers import AutoTokenizer, DefaultDataCollator, AutoModelForQuestionAnswering, TrainingArguments, Trainer

    import torch.distributed as dist

    # Initialize distributed environment
    _, backend = ("cuda", "nccl") if torch.cuda.is_available() else ("cpu", "gloo")
    dist.init_process_group(backend=backend)

    local_rank = int(os.getenv("LOCAL_RANK", 0))
    print(
        "Distributed Training with WORLD_SIZE: {}, RANK: {}, LOCAL_RANK: {}.".format(
            dist.get_world_size(),
            dist.get_rank(),
            local_rank,
        )
    )

    # Download the dataset and tokenizer
    squad = load_dataset("squad", split="train[:100]")    

    squad = squad.train_test_split(test_size=0.2, shuffle=False)
    
    tokenizer = AutoTokenizer.from_pretrained(f'distilbert/{args["MODEL_NAME"]}')
    
    # Define the preprocessing function
    def preprocess_function(examples):
        questions = [q.strip() for q in examples["question"]]
        inputs = tokenizer(
            questions,
            examples["context"],
            max_length=384,
            truncation="only_second",
            return_offsets_mapping=True,
            padding="max_length",
        )
    
        offset_mapping = inputs.pop("offset_mapping")
        answers = examples["answers"]
        start_positions = []
        end_positions = []
    
        for i, offset in enumerate(offset_mapping):
            answer = answers[i]
            start_char = answer["answer_start"][0]
            end_char = answer["answer_start"][0] + len(answer["text"][0])
            sequence_ids = inputs.sequence_ids(i)
    
            # Find the start and end of the context
            idx = 0
            while sequence_ids[idx] != 1:
                idx += 1
            context_start = idx
            while sequence_ids[idx] == 1:
                idx += 1
            context_end = idx - 1
    
            # If the answer is not fully inside the context, label it (0, 0)
            if offset[context_start][0] > end_char or offset[context_end][1] < start_char:
                start_positions.append(0)
                end_positions.append(0)
            else:
                # Otherwise it's the start and end token positions
                idx = context_start
                while idx <= context_end and offset[idx][0] <= start_char:
                    idx += 1
                start_positions.append(idx - 1)
    
                idx = context_end
                while idx >= context_start and offset[idx][1] >= end_char:
                    idx -= 1
                end_positions.append(idx + 1)
    
        inputs["start_positions"] = start_positions
        inputs["end_positions"] = end_positions
        return inputs
        
    # Apply the preprocessing function to the dataset
    tokenized_squad = squad.map(preprocess_function, batched=True, remove_columns=squad["train"].column_names)
        
    # Create a batch of examples using DefaultDataCollator
    data_collator = DefaultDataCollator()

    # Load the model
    model = AutoModelForQuestionAnswering.from_pretrained(f'distilbert/{args["MODEL_NAME"]}')

    # Define training hyperparameters
    training_args = TrainingArguments(
        output_dir=args["MODEL_NAME"],
        eval_strategy="epoch",
        learning_rate=2e-5,
        per_device_train_batch_size=1,
        per_device_eval_batch_size=1,
        num_train_epochs=1,
        weight_decay=0.01,
        push_to_hub=False,
    )
    
    # Prepare trainer with configuration
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_squad["train"],
        eval_dataset=tokenized_squad["test"],
        processing_class=tokenizer,
        data_collator=data_collator,
    )
    
    trainer.train()

    # Upload the fine-tuned model
    if args.get("BUCKET", None):
        (CloudPath(args["BUCKET"]) / args["MODEL_NAME"]).upload_from(args["MODEL_NAME"])

In [3]:
from kubeflow.trainer import TrainerClient, CustomTrainer

for r in TrainerClient().list_runtimes():
    print(f"Name: {r.name}, Framework: {r.trainer.framework.value}, Trainer Type: {r.trainer.trainer_type.value}")
    print(f"Entrypoint: {r.trainer.entrypoint[:3]}")

Name: deepspeed-distributed, Framework: deepspeed, Trainer Type: CustomTrainer
Entrypoint: ['mpirun', '--hostfile', '/etc/mpi/hostfile']
Name: mlx-distributed, Framework: mlx, Trainer Type: CustomTrainer
Entrypoint: ['mpirun', '--hostfile', '/etc/mpi/hostfile']
Name: mpi-distributed, Framework: torch, Trainer Type: CustomTrainer
Entrypoint: ['torchrun']
Name: torch-distributed, Framework: torch, Trainer Type: CustomTrainer
Entrypoint: ['torchrun']


In [4]:
# To upload to object storage (S3, GCS or Azure Blob Storage), set the bucket with protocol, e.g., "s3://my-bucket/folder"
BUCKET = None

MODEL_NAME = "distilbert-base-uncased"
args = {
    "BUCKET": BUCKET,
    "MODEL_NAME": MODEL_NAME,
}

job_id = TrainerClient().train(
    trainer=CustomTrainer(
        func=train_distilbert,
        func_args=args,
        num_nodes=1,
        packages_to_install=["datasets", "transformers[torch]", "cloudpathlib[all]"],
        resources_per_node={
            "cpu": "2",
            "memory": "12Gi",
            # Uncomment this to distribute the TrainJob using GPU nodes
            # "nvidia.com/gpu": 1,
        },
    ),
)

In [5]:
# Train API generates a random TrainJob id.
job_id

'rafd89de924b'

# Check the TrainJob details

Use `list_jobs()` and `get_job()` APIs to get details about the created TrainJob and its steps.

In [6]:
for job in TrainerClient().list_jobs():
    print(f"TrainJob: {job.name}, Status: {job.status}, Created at: {job.creation_timestamp}")

TrainJob: rafd89de924b, Status: Unknown, Created at: 2025-04-29 01:22:14+00:00


In [7]:
# TODO (andreyvelich): Use wait_for_job_status API from TrainerClient() when it is implemented.
import time

def wait_for_job_running():
    for _ in range(100):
        trainjob = TrainerClient().get_job(name=job_id)
        for c in trainjob.steps:
            if c.status == "Running":
                return
        print("Waiting for TrainJob running status. Sleep for 5 seconds")
        time.sleep(5)

wait_for_job_running()

Waiting for TrainJob running status. Sleep for 5 seconds


In [8]:
# We execute mpirun command on node-0, which functions as the MPI Launcher node.
for c in TrainerClient().get_job(name=job_id).steps:
    print(f"Step: {c.name}, Status: {c.status}, Devices: {c.device} x {c.device_count}")

Step: node-0, Status: Running, Devices: cpu x 2


# Show the TrainJob logs

Use `get_job_logs()` API to retrieve the TrainJob logs.

In [9]:
_ = TrainerClient().get_job_logs(name=job_id, follow=True)

[node-0]: W0429 01:22:31.907000 1 site-packages/torch/distributed/run.py:793] 
[node-0]: W0429 01:22:31.907000 1 site-packages/torch/distributed/run.py:793] *****************************************
[node-0]: W0429 01:22:31.907000 1 site-packages/torch/distributed/run.py:793] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
[node-0]: W0429 01:22:31.907000 1 site-packages/torch/distributed/run.py:793] *****************************************
[node-0]: Distributed Training with WORLD_SIZE: 2, RANK: 1, LOCAL_RANK: 1.
[node-0]: Distributed Training with WORLD_SIZE: 2, RANK: 0, LOCAL_RANK: 0.
Generating train split: 100%|██████████| 87599/87599 [00:00<00:00, 627416.25 examples/s]
Generating validation split: 100%|██████████| 10570/10570 [00:00<00:00, 1041555.11 examples/s]
Map: 100%|██████████| 80/80 [00:00<00:00, 1361.10 examples/

# Inference

Download the model and run inference on some examples.

In [10]:
from cloudpathlib import CloudPath
from transformers import pipeline

if BUCKET:
    (CloudPath(BUCKET) / MODEL_NAME).download_to(MODEL_NAME)

    question = "How many programming languages does BLOOM support?"
    context = "BLOOM has 176 billion parameters and can generate text in 46 languages natural languages and 13 programming languages."

    question_answerer = pipeline("question-answering", model=f"./{MODEL_NAME}/checkpoint-375")
    question_answerer(question=question, context=context)

# Clean up

To delete the TrainJob you can use the `delete_job()` API and pass the generated `job_id`.

In [11]:
# _ = TrainerClient().delete_job(job_id)