# Distributed Training with TransformersTrainer

This notebook demonstrates how to use `TransformersTrainer` from the Kubeflow SDK to run **distributed fine-tuning** of a HuggingFace model on **Red Hat OpenShift AI**.

## Overview

In this example, we fine-tune **DistilBERT** on the **IMDB sentiment classification** dataset using **2 nodes**.

| Feature | Description |
| --- | --- |
| **Distributed training** | Run the same `transformers.Trainer` code across multiple nodes without manual DDP setup |
| **Progress tracking** | View training progress in the OpenShift AI Dashboard (Training Jobs) |
| **Checkpointing (optional)** | Persist checkpoints to a shared RWX PVC via `output_dir="pvc://..."` |

### Model details

- **Model**: `distilbert-base-uncased`
- **Task**: binary sentiment classification

### Dataset details

- **Dataset**: `stanfordnlp/imdb` (train split; we use a 1,000-sample subset)

### What you will learn

- How to define a `train_func()` that uses `transformers.Trainer`
- How to configure and submit a multi-node training job with `TransformersTrainer`
- Where to monitor training progress in the OpenShift AI Dashboard

### Prerequisites

- OpenShift AI (RHOAI) 3.2+ with Kubeflow Trainer v2 enabled
- A workbench with Python and access to submit TrainJobs
- A shared PVC named `shared` with **ReadWriteMany (RWX)** access mode
  - **Suggested size**: 20Gi (model weights + dataset + checkpoints)
  - **Workbench mount**: `/opt/app-root/src/shared`
  - See `README.md` in this folder for step-by-step PVC setup instructions

## Setup and Imports

Install the Kubeflow SDK and required packages.

In [None]:
!python3 -m pip install datasets transformers accelerate huggingface_hub
!python3 -m pip install --force-reinstall --no-cache-dir -U "kubeflow @ git+https://github.com/opendatahub-io/kubeflow-sdk.git@v0.2.1+rhai0"
!python3 -m pip install --force-reinstall --no-cache-dir -U ipykernel

In [None]:
import os

import kubeflow
import torch
from kubeflow.common.types import KubernetesBackendConfig
from kubeflow.trainer import TrainerClient
from kubeflow.trainer.rhai import TransformersTrainer
from kubernetes import client as k8s

print(f"Kubeflow SDK version: {kubeflow.__version__}")
print(f"Torch version: {torch.__version__}")
print("‚úÖ All imports successful")

## Configuration

This notebook needs:

- **Authentication** to talk to the OpenShift/Kubernetes API
- **A shared RWX PVC** (model + data + checkpoints)
- **Distributed training settings** (`num_nodes`, `resources_per_node`)

### Environment variables

The following environment variables are required for API authentication:

- `OPENSHIFT_API_URL` ‚Äî your cluster API URL (e.g. `https://api.cluster.example.com:6443`)
- `NOTEBOOK_USER_TOKEN` ‚Äî an access token for API calls

In OpenShift AI workbenches, these are often auto-set.

If they are not set in your environment, uncomment and populate the values in the next cell.

In [None]:
# ============================================================================
# AUTHENTICATION
# ============================================================================
# If your workbench does not auto-populate these env vars, uncomment and fill them in:
#
# api_server = "https://api.your-cluster.example.com:6443"
# token = "sha256~your-token-here"

api_server = os.getenv("OPENSHIFT_API_URL")
token = os.getenv("NOTEBOOK_USER_TOKEN")

if not api_server or not token:
    raise RuntimeError(
        "OPENSHIFT_API_URL and NOTEBOOK_USER_TOKEN must be set. "
        "Either set them in your environment or uncomment the values above."
    )

# Configure Kubernetes client
configuration = k8s.Configuration()
configuration.host = api_server
configuration.verify_ssl = False  # Set to True if using trusted certificates
configuration.api_key = {"authorization": f"Bearer {token}"}

# ============================================================================
# PVC MOUNT PATHS
# ============================================================================
# Workbench mount path
#
# NOTEBOOK_SHARED_PATH is where *your workbench* sees the PVC.
# This depends on which PVC you attached when you created the workbench.
# OpenShift AI typically mounts PVCs under:
#   /opt/app-root/src/<pvc-name>
PVC_NAME = "shared"
NOTEBOOK_SHARED_PATH = f"/opt/app-root/src/{PVC_NAME}"

# Training pod mount path
#
# SDK_MOUNT_PATH is a fixed path used by the Kubeflow SDK when you set:
#   TransformersTrainer(output_dir="pvc://<pvc-name>/<path>")
# The SDK mounts that PVC at this location inside the training pods.
SDK_MOUNT_PATH = "/mnt/kubeflow-checkpoints"

if not os.path.exists(NOTEBOOK_SHARED_PATH):
    print(
        "‚ö†Ô∏è  Expected workbench PVC mount not found at: "
        f"{NOTEBOOK_SHARED_PATH}\n"
        "If your PVC has a different name/mount, update PVC_NAME/NOTEBOOK_SHARED_PATH.\n"
        "Tip: in a workbench, PVCs are typically under /opt/app-root/src/."
    )

# ============================================================================
# MODEL + DATASET
# ============================================================================
MODEL_NAME = "distilbert-base-uncased"

# Use the canonical Hub repo ID to avoid brittle revision lookups
DATASET_NAME = "stanfordnlp/imdb"
DATASET_REVISION = "main"

MODEL_PATH = f"{NOTEBOOK_SHARED_PATH}/models/{MODEL_NAME}"
DATA_PATH = f"{NOTEBOOK_SHARED_PATH}/data/imdb_train_1000"
CHECKPOINTS_PATH = f"{NOTEBOOK_SHARED_PATH}/checkpoints/transformer-trainer"

TRAINING_MODEL_PATH = f"{SDK_MOUNT_PATH}/models/{MODEL_NAME}"
TRAINING_DATA_PATH = f"{SDK_MOUNT_PATH}/data/imdb_train_1000"
TRAINING_CHECKPOINTS_PATH = f"{SDK_MOUNT_PATH}/checkpoints/transformer-trainer"

# ============================================================================
# DISTRIBUTED TRAINING
# ============================================================================
NUM_NODES = 2
GPUS_PER_NODE = 1

print(f"API Server: {api_server}")
print(f"PVC name: {PVC_NAME}")
print(f"Workbench PVC mount: {NOTEBOOK_SHARED_PATH}")
print(f"Training pod PVC mount (SDK): {SDK_MOUNT_PATH}")
print(f"Model: {MODEL_NAME}")
print(f"Dataset: {DATASET_NAME}")
print(f"Nodes: {NUM_NODES}")
print(f"GPUs per node: {GPUS_PER_NODE}")

## Define the training function

The training function runs inside each training pod as a distributed PyTorch process. `TransformersTrainer` serializes this function and executes it via `torchrun` on each node.

### Key points

- **All imports must be inside the function** ‚Äî the function is serialized and executed in training pods
- **Use `transformers.Trainer` or `trl.SFTTrainer`** ‚Äî both are automatically instrumented
- **PyTorch env vars are set for you** ‚Äî `RANK`, `WORLD_SIZE`, `LOCAL_RANK` are set automatically

In [None]:
def train_func():
    """Distributed training function for IMDB sentiment classification."""
    import os

    import torch
    from datasets import load_from_disk
    from transformers import (
        AutoModelForSequenceClassification,
        AutoTokenizer,
        Trainer,
        TrainingArguments,
    )

    rank = int(os.environ.get("RANK", 0))
    world_size = int(os.environ.get("WORLD_SIZE", 1))
    local_rank = int(os.environ.get("LOCAL_RANK", 0))

    # The SDK mounts the PVC at a fixed path inside training pods
    model_path = "/mnt/kubeflow-checkpoints/models/distilbert-base-uncased"
    data_path = "/mnt/kubeflow-checkpoints/data/imdb_train_1000"

    print(f"üöÄ Starting training on rank {rank}/{world_size}")

    if torch.cuda.is_available():
        torch.cuda.set_device(local_rank)
        print(f"üîß GPU: {torch.cuda.get_device_name(local_rank)}")

    # Load model + tokenizer from the PVC (downloaded by the workbench)
    tokenizer = AutoTokenizer.from_pretrained(model_path, local_files_only=True)
    model = AutoModelForSequenceClassification.from_pretrained(
        model_path,
        num_labels=2,
        local_files_only=True,
    )

    # Load dataset from the PVC (saved by the workbench)
    dataset = load_from_disk(data_path)

    def tokenize_function(examples):
        return tokenizer(
            examples["text"],
            padding="max_length",
            truncation=True,
            max_length=256,
        )

    tokenized_dataset = dataset.map(
        tokenize_function,
        batched=True,
        remove_columns=["text"],
    )

    # TransformersTrainer will patch output_dir + checkpoint settings when output_dir="pvc://..." is set
    training_args = TrainingArguments(
        output_dir="/tmp/output",  # placeholder; overridden by the SDK
        num_train_epochs=1,
        per_device_train_batch_size=8,
        learning_rate=2e-5,
        logging_steps=10,
        report_to="none",
        ddp_find_unused_parameters=False,
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_dataset,
    )

    print(f"üíæ Trainer output_dir: {trainer.args.output_dir}")

    trainer.train()

    # Save final model on rank 0
    if rank == 0:
        final_path = os.path.join(trainer.args.output_dir, "final")
        os.makedirs(final_path, exist_ok=True)
        trainer.save_model(final_path)
        tokenizer.save_pretrained(final_path)
        print(f"‚úÖ Final model saved to {final_path}")

    print(f"‚úÖ Training complete on rank {rank}")

## Download model and dataset to the shared PVC (recommended)

To make training more reliable, we download the model and dataset to the shared PVC from the workbench first.

This avoids repeated downloads inside training pods and lets the training job run without direct internet access.

In [None]:
from datasets import load_dataset
from transformers import AutoModelForSequenceClassification, AutoTokenizer

# Download model to PVC
if os.path.exists(MODEL_PATH) and os.listdir(MODEL_PATH):
    print(f"‚úÖ Model already exists at {MODEL_PATH}")
else:
    print(f"üîÑ Downloading model {MODEL_NAME} to {MODEL_PATH}...")
    os.makedirs(MODEL_PATH, exist_ok=True)

    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
    tokenizer.save_pretrained(MODEL_PATH)

    model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=2)
    model.save_pretrained(MODEL_PATH, safe_serialization=True)

    print(f"‚úÖ Model saved to {MODEL_PATH}")

# Download dataset subset to PVC
if os.path.exists(DATA_PATH) and os.listdir(DATA_PATH):
    print(f"‚úÖ Dataset already exists at {DATA_PATH}")
else:
    print(f"üîÑ Downloading dataset {DATASET_NAME} (train[:1000]) to {DATA_PATH}...")
    os.makedirs(DATA_PATH, exist_ok=True)

    dataset = load_dataset(
        DATASET_NAME, split="train[:1000]", revision=DATASET_REVISION
    )
    dataset.save_to_disk(DATA_PATH)

    print(f"‚úÖ Dataset saved to {DATA_PATH}")

print("\n‚úÖ Model and dataset ready on PVC!")

## Configure and submit the training job

Create a `TransformersTrainer` and submit the job using the `TrainerClient`.

We also set `output_dir="pvc://..."` so checkpoints are written to the shared PVC and can be inspected from the workbench.

In [None]:
from kubeflow.trainer.rhai.transformers import PeriodicCheckpointConfig

checkpoint_config = PeriodicCheckpointConfig(
    save_strategy="steps",
    save_steps=25,
    save_total_limit=2,
)

trainer = TransformersTrainer(
    func=train_func,
    num_nodes=NUM_NODES,
    resources_per_node={
        "nvidia.com/gpu": GPUS_PER_NODE,
        "cpu": "4",
        "memory": "16Gi",
    },
    # Keep train_func focused; set offline mode at the pod level
    env={
        "HF_HUB_OFFLINE": "1",
        "TRANSFORMERS_OFFLINE": "1",
    },
    # Persist checkpoints on the shared PVC
    output_dir=f"pvc://{PVC_NAME}/checkpoints/transformer-trainer",
    periodic_checkpoint_config=checkpoint_config,
    enable_jit_checkpoint=True,
)

print("‚úÖ TransformersTrainer configured")

In [None]:
# Create a TrainerClient using the explicit API server + token
api_client = k8s.ApiClient(configuration)
backend_config = KubernetesBackendConfig(
    client_configuration=api_client.configuration,
)
client = TrainerClient(backend_config)

runtime = client.backend.get_runtime("torch-distributed")
print(f"‚úÖ Using runtime: {runtime.name}")

In [None]:
# Submit the training job
JOB_NAME = client.train(trainer=trainer, runtime=runtime)
print(f"Job submitted: {JOB_NAME}")

## Monitor the training job

Navigate to **Training Jobs** in the OpenShift AI Dashboard to see training progress and job status.

In [None]:
# Check job status
job = client.get_job(name=JOB_NAME)
print(f"Job: {job.name}")
print(f"Status: {job.status}")

In [None]:
# Wait for job to complete
import time

print("Waiting for job to complete...")
while True:
    job = client.get_job(name=JOB_NAME)
    print(f"Status: {job.status}")
    if job.status in ["Complete", "Failed"]:
        break
    time.sleep(15)

print(f"Job finished: {job.status}")

## Cleanup

Delete the training job to free cluster resources.

**Optional:** If you want to clean up all artifacts from the PVC (model, dataset, checkpoints), you can delete these directories from the workbench:

```bash
rm -rf /opt/app-root/src/shared/models/distilbert-base-uncased
rm -rf /opt/app-root/src/shared/data/imdb_train_1000
rm -rf /opt/app-root/src/shared/checkpoints/transformer-trainer
```

In [None]:
# Delete the training job
client.delete_job(name=JOB_NAME)
print(f"Job {JOB_NAME} deleted")

## Summary

Congratulations! You've successfully run a distributed fine-tuning job with `TransformersTrainer`.

### What you accomplished

| Step | Description |
| --- | --- |
| ‚úÖ Model + dataset staging | Downloaded DistilBERT + IMDB subset to a shared PVC from the workbench |
| ‚úÖ Distributed training | Ran 2-node training with PyTorch distributed via `torchrun` |
| ‚úÖ Monitoring | Viewed progress in the OpenShift AI Dashboard (Training Jobs) |

### Why this pattern works well on OpenShift AI

- **Repeatable**: training pods load model/data from shared storage
- **Scalable**: change `num_nodes` / `resources_per_node` without changing training code
- **Observable**: progress tracking is enabled by default for `TransformersTrainer`

### How progress tracking works

When you use `TransformersTrainer` (default `enable_progression_tracking=True`):

1. **Automatic instrumentation**: a `KubeflowProgressCallback` is injected into your Hugging Face `Trainer`
2. **Metrics endpoint**: a lightweight HTTP server exposes progress metrics during training
3. **Dashboard integration**: the OpenShift AI Dashboard polls and displays progress

### Next steps

- Increase `NUM_NODES` or adjust `resources_per_node` for larger workloads
- Swap in a different model/dataset (as long as you use `transformers.Trainer` or `trl.SFTTrainer`)