# JIT Checkpointing with TransformersTrainer

This notebook demonstrates how to enable **Just-In-Time (JIT) checkpointing** for distributed training using `TransformersTrainer` on Red Hat OpenShift AI.

## Overview

In this example, we fine-tune the **Qwen 2.5 1.5B Instruct** model on the **Stanford Alpaca** instruction-following dataset. The training runs on 2 GPU nodes with **JIT checkpointing enabled**, which automatically saves training state when pods receive a SIGTERM signal (e.g., during preemption, scaling, or graceful shutdown).

### What You'll Learn

| Feature | Description |
|---------|-------------|
| **JIT Checkpointing** | Automatically save training state on SIGTERM signal (preemption-safe) |
| **Periodic Checkpointing** | Configure regular checkpoint saves using `PeriodicCheckpointConfig` |
| **Auto-Resume** | Automatically resume training from the latest checkpoint on restart |
| **PVC-Based Storage** | Save checkpoints to shared PersistentVolumeClaim for durability |

### Why JIT Checkpointing?

In cloud and Kubernetes environments, training pods can be preempted or terminated for various reasons:
- **Spot/Preemptible instances** - Cost-effective but can be reclaimed
- **Kueue preemption** - Higher-priority workloads may preempt lower-priority jobs
- **Node maintenance** - Cluster upgrades or node drains
- **Resource pressure** - Pods evicted due to memory or resource limits

JIT checkpointing ensures that when a pod receives SIGTERM:
1. Training pauses safely after the current optimizer step
2. Model state, optimizer state, and training progress are saved
3. When the job restarts, training automatically resumes from the checkpoint

### Model Details

**Qwen 2.5 1.5B Instruct** is a compact instruction-tuned language model from the Qwen family:
- **Parameters:** 1.5 billion
- **Context Length:** 32K tokens
- **Languages:** Multilingual with strong English and Chinese support
- **Use Case:** Ideal for instruction-following, chat, and text generation tasks
- **Why this model?** Small enough to train quickly for demonstration, yet powerful enough for real-world tasks

### Dataset Details

We use the **Stanford Alpaca** dataset (`tatsu-lab/alpaca`), a widely-used instruction-following dataset:

| Property | Value |
|----------|-------|
| **Source** | Stanford University |
| **Size** | 52,000 instruction-response pairs (we use 500 samples for this demo) |
| **Format** | Instruction, optional input, and response |
| **Use Case** | Instruction-tuning language models |

### Prerequisites

Before running this notebook, ensure you have:

1. **OpenShift AI Cluster** with Kubeflow Trainer v2 enabled
2. **Workbench** running Python 3.12+ with GPU access
3. **Shared PVC** named `shared` with **ReadWriteMany (RWX)** access mode
   - **Recommended Size:** 20Gi (for model weights, dataset, and checkpoints)
   - **Mount Path:** `/opt/app-root/src/shared` in the workbench
   - See the [README](./README.md) for detailed PVC setup instructions

## Setup and Imports

Install the Kubeflow SDK and required packages.

In [None]:
!pip install kubeflow --no-cache-dir --index-url https://console.redhat.com/api/pypi/public-rhai/rhoai/3.3/cuda12.9-ubi9/simple/
!python3 -m pip install datasets transformers accelerate huggingface_hub

In [None]:
import os

import kubeflow
import torch
from datasets import load_dataset
from kubeflow.common.types import KubernetesBackendConfig
from kubeflow.trainer import TrainerClient
from kubeflow.trainer.rhai import TransformersTrainer
from kubeflow.trainer.rhai.transformers import PeriodicCheckpointConfig
from kubernetes import client as k8s
from transformers import AutoModelForCausalLM, AutoTokenizer

print(f"Kubeflow SDK version: {kubeflow.__version__}")
print("‚úÖ All imports successful")

## Configuration

Configure authentication and paths.

### Environment Variables

The following environment variables are required for API authentication:
- `OPENSHIFT_API_URL` - Your OpenShift API server URL (e.g., `https://api.cluster.example.com:6443`)
- `NOTEBOOK_USER_TOKEN` - Authentication token for API access

**Note:** These are typically auto-set in OpenShift AI workbenches. If not set, uncomment and fill in the values in the cell below.

In [None]:
# ============================================================================
# AUTHENTICATION CONFIGURATION
# ============================================================================
# These values are typically auto-set in OpenShift AI workbenches.
# If not set, uncomment and fill in your values below:

# api_server = "https://api.your-cluster.example.com:6443"
# token = "sha256~your-token-here"

# Try to get from environment variables first
api_server = os.getenv("OPENSHIFT_API_URL")
token = os.getenv("NOTEBOOK_USER_TOKEN")

if not api_server or not token:
    raise RuntimeError(
        "OPENSHIFT_API_URL and NOTEBOOK_USER_TOKEN environment variables are required.\n"
        "Either set them in your environment, or uncomment and fill in the values above."
    )

# Configure Kubernetes client
configuration = k8s.Configuration()
configuration.host = api_server
configuration.verify_ssl = False  # Set to True if using trusted certificates
configuration.api_key = {"authorization": f"Bearer {token}"}

# ============================================================================
# PVC MOUNT PATHS
# ============================================================================
# Workbench (notebook) mount path
#
# NOTEBOOK_SHARED_PATH is where *your workbench* sees the PVC.
# This depends on what PVC you attached when you created the workbench.
# On OpenShift AI, the default mount convention is:
#   /opt/app-root/src/<pvc-name>
#
# If your PVC is not named "shared", set PVC_NAME accordingly.
PVC_NAME = "shared"
NOTEBOOK_SHARED_PATH = f"/opt/app-root/src/{PVC_NAME}"

# Training pods mount path
#
# SDK_MOUNT_PATH is a fixed path used by the Kubeflow SDK when you set:
#   TransformersTrainer(output_dir="pvc://<pvc-name>/<path>")
# The SDK mounts the PVC at this location inside the training pods.
# (This comes from the SDK constant CHECKPOINT_MOUNT_PATH.)
SDK_MOUNT_PATH = "/mnt/kubeflow-checkpoints"

# Quick sanity check to help users discover the right workbench mount
if not os.path.exists(NOTEBOOK_SHARED_PATH):
    print(
        "‚ö†Ô∏è  Expected workbench PVC mount not found at: "
        f"{NOTEBOOK_SHARED_PATH}\n"
        "If your PVC has a different name or mount, update PVC_NAME/NOTEBOOK_SHARED_PATH.\n"
        "Tip: in a workbench, PVCs are typically under /opt/app-root/src/."
    )

# Model Configuration
MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"

# Paths for notebook operations only (downloading model/data, reading checkpoints)
# Note: Training pods use SDK mount convention (/mnt/kubeflow-checkpoints/...)
#       which is handled automatically by the pvc:// URI in TransformersTrainer
NOTEBOOK_MODEL_PATH = f"{NOTEBOOK_SHARED_PATH}/models/qwen2.5-1.5b-instruct"
NOTEBOOK_DATA_PATH = f"{NOTEBOOK_SHARED_PATH}/data/alpaca_processed"
NOTEBOOK_CHECKPOINTS_PATH = f"{NOTEBOOK_SHARED_PATH}/checkpoints/jit-checkpointing"

print(f"API Server: {api_server}")
print(f"Model: {MODEL_NAME}")
print(f"PVC name: {PVC_NAME}")
print(f"Workbench PVC mount: {NOTEBOOK_SHARED_PATH}")
print(f"Training pod PVC mount (SDK): {SDK_MOUNT_PATH}")
print(f"Notebook Model Path: {NOTEBOOK_MODEL_PATH}")
print(f"Notebook Data Path: {NOTEBOOK_DATA_PATH}")
print(f"Notebook Checkpoints Path: {NOTEBOOK_CHECKPOINTS_PATH}")

## Download Model and Dataset to Shared PVC

Before submitting the training job, we pre-download the model and dataset to the shared PVC. This ensures:
- **Offline Training:** Training pods don't need internet access during training
- **Faster Startup:** No download delays when training pods start
- **Consistency:** All nodes use the same model weights and data

In [None]:
# Download model to PVC
if os.path.exists(NOTEBOOK_MODEL_PATH) and os.listdir(NOTEBOOK_MODEL_PATH):
    print(f"‚úÖ Model already exists at {NOTEBOOK_MODEL_PATH}")
else:
    print(f"üîÑ Downloading model {MODEL_NAME} to {NOTEBOOK_MODEL_PATH}...")
    os.makedirs(NOTEBOOK_MODEL_PATH, exist_ok=True)

    # Use fast tokenizer for compatibility
    tokenizer = AutoTokenizer.from_pretrained(
        MODEL_NAME, use_fast=True, trust_remote_code=True
    )
    tokenizer.save_pretrained(NOTEBOOK_MODEL_PATH)

    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        torch_dtype=torch.bfloat16,
        trust_remote_code=True,
    )
    model.save_pretrained(NOTEBOOK_MODEL_PATH, safe_serialization=True)
    print(f"‚úÖ Model saved to {NOTEBOOK_MODEL_PATH}")
    print(f"üìÅ Files: {os.listdir(NOTEBOOK_MODEL_PATH)}")

In [None]:
# Download and prepare dataset
if os.path.exists(NOTEBOOK_DATA_PATH) and os.listdir(NOTEBOOK_DATA_PATH):
    print(f"‚úÖ Dataset already exists at {NOTEBOOK_DATA_PATH}")
else:
    print("üîÑ Downloading and processing Alpaca dataset...")
    os.makedirs(NOTEBOOK_DATA_PATH, exist_ok=True)

    # Load subset of Alpaca dataset
    dataset = load_dataset("tatsu-lab/alpaca", split="train[:500]")

    # Load tokenizer for preprocessing
    tokenizer = AutoTokenizer.from_pretrained(
        NOTEBOOK_MODEL_PATH, use_fast=True, trust_remote_code=True
    )
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    def format_instruction(example):
        if example.get("input"):
            text = f"### Instruction:\n{example['instruction']}\n\n### Input:\n{example['input']}\n\n### Response:\n{example['output']}"
        else:
            text = f"### Instruction:\n{example['instruction']}\n\n### Response:\n{example['output']}"
        return {"text": text}

    dataset = dataset.map(format_instruction, remove_columns=dataset.column_names)

    def tokenize_function(examples):
        tokenized = tokenizer(
            examples["text"],
            padding="max_length",
            truncation=True,
            max_length=512,
        )
        tokenized["labels"] = tokenized["input_ids"].copy()
        return tokenized

    tokenized_dataset = dataset.map(
        tokenize_function, batched=True, remove_columns=["text"]
    )
    tokenized_dataset.save_to_disk(NOTEBOOK_DATA_PATH)
    print(f"‚úÖ Dataset saved to {NOTEBOOK_DATA_PATH}")

print("\n‚úÖ Model and dataset ready on PVC!")

## Define the Training Function

The training function runs inside each training pod as a distributed PyTorch process. TransformersTrainer serializes this function and executes it via `torchrun` on each node.

### Training Configuration

| Parameter | Value | Description |
|-----------|-------|-------------|
| `num_train_epochs` | 5 | Multiple epochs to allow time for testing pause/resume |
| `per_device_train_batch_size` | 2 | Samples per GPU per step |
| `gradient_accumulation_steps` | 4 | Effective batch size = 2 √ó 4 √ó 2 nodes = 16 |
| `learning_rate` | 2e-5 | Standard fine-tuning rate |
| `save_steps` | 20 | Checkpoint every 20 steps |
| `bf16` | True | Use bfloat16 mixed precision |

### Key Points

- **Supported Trainers:** Use `transformers.Trainer` or `trl.SFTTrainer` - both are auto-instrumented
- **No Manual Setup:** JIT checkpointing and progress tracking callbacks are injected automatically
- **Local Files Only:** Model and data are loaded from the mounted PVC (no network access needed)

In [None]:
def train_func():
    """SFT training function using HuggingFace Trainer.

    TransformersTrainer automatically:
    - Injects JIT checkpoint handler for SIGTERM (preemption-safe)
    - Injects KubeflowProgressCallback for real-time metrics
    - Auto-resumes from the latest checkpoint when available
    """
    import os

    import torch
    from datasets import load_from_disk
    from transformers import (
        AutoModelForCausalLM,
        DataCollatorForLanguageModeling,
        PreTrainedTokenizerFast,
        Trainer,
        TrainingArguments,
    )

    rank = int(os.environ.get("RANK", 0))
    local_rank = int(os.environ.get("LOCAL_RANK", 0))

    # Model/data are on the shared PVC mounted at /mnt/kubeflow-checkpoints via SDK's pvc:// URI
    model_path = "/mnt/kubeflow-checkpoints/models/qwen2.5-1.5b-instruct"
    data_path = "/mnt/kubeflow-checkpoints/data/alpaca_processed"

    print(f"üöÄ Starting training on rank {rank}")

    if torch.cuda.is_available():
        torch.cuda.set_device(local_rank)
        print(f"üîß GPU: {torch.cuda.get_device_name(local_rank)}")

    # Load tokenizer directly from tokenizer.json file
    # This bypasses AutoTokenizer's hub validation that fails with local paths
    print(f"üì• Loading tokenizer from: {model_path}")
    tokenizer_file = os.path.join(model_path, "tokenizer.json")
    tokenizer_config_file = os.path.join(model_path, "tokenizer_config.json")

    # Load tokenizer config to get special tokens
    import json

    with open(tokenizer_config_file) as f:
        tokenizer_config = json.load(f)

    tokenizer = PreTrainedTokenizerFast(
        tokenizer_file=tokenizer_file,
        eos_token=tokenizer_config.get("eos_token", "<|endoftext|>"),
        pad_token=tokenizer_config.get("pad_token"),
        bos_token=tokenizer_config.get("bos_token"),
        unk_token=tokenizer_config.get("unk_token"),
    )
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    # Load model directly from local path
    print(f"üì• Loading model from: {model_path}")
    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        torch_dtype=torch.bfloat16,
        device_map={"": local_rank},
        local_files_only=True,
        trust_remote_code=True,
    )

    # Load dataset
    print(f"üì• Loading dataset from: {data_path}")
    tokenized_dataset = load_from_disk(data_path)

    data_collator = DataCollatorForLanguageModeling(
        tokenizer=tokenizer,
        mlm=False,
    )

    # TransformersTrainer automatically configures:
    # - output_dir: Set from the pvc:// URI (mounted at /mnt/kubeflow-checkpoints)
    # - save_strategy, save_steps, save_total_limit: Set from PeriodicCheckpointConfig
    training_args = TrainingArguments(
        output_dir="/tmp/output",  # Placeholder - SDK overrides this
        num_train_epochs=5,
        per_device_train_batch_size=2,
        gradient_accumulation_steps=4,
        learning_rate=2e-5,
        logging_steps=5,
        report_to="none",
        bf16=True,
        ddp_find_unused_parameters=False,
    )

    # Trainer - TransformersTrainer injects:
    # - JIT checkpoint handler for SIGTERM
    # - KubeflowProgressCallback for real-time metrics
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_dataset,
        data_collator=data_collator,
    )

    print(f"üíæ Trainer output_dir: {trainer.args.output_dir}")

    # Train - auto-resumes from latest checkpoint if available
    trainer.train()

    # Save final model (only on rank 0)
    if rank == 0:
        final_path = os.path.join(trainer.args.output_dir, "final")
        os.makedirs(final_path, exist_ok=True)
        trainer.save_model(final_path)
        tokenizer.save_pretrained(final_path)
        print(f"‚úÖ Final model saved to {final_path}")

    print(f"‚úÖ Training complete on rank {rank}")


print("‚úÖ Training function defined")

## Create the Trainer Client

Initialize the TrainerClient with authentication configuration.

In [None]:
# Create client with authentication
api_client = k8s.ApiClient(configuration)

backend_config = KubernetesBackendConfig(
    client_configuration=api_client.configuration,
)

client = TrainerClient(backend_config)
print("‚úÖ TrainerClient created")

# Get the torch-distributed runtime
runtime = client.backend.get_runtime("torch-distributed")
print(f"‚úÖ Using runtime: {runtime.name}")

## Submit the Training Job with JIT Checkpointing

Now we create and submit the distributed training job with JIT checkpointing enabled.

### Job Configuration

| Parameter | Value | Description |
|-----------|-------|-------------|
| `num_nodes` | 2 | Number of GPU nodes for distributed training |
| `nvidia.com/gpu` | 1 | GPUs per node |
| `cpu` | 4 | CPU cores per node |
| `memory` | 16Gi | Memory per node |

### JIT Checkpointing Configuration

| Parameter | Value | Description |
|-----------|-------|-------------|
| `enable_jit_checkpoint` | `True` | Save checkpoint on SIGTERM (preemption) |
| `periodic_checkpoint_config` | See below | Configure periodic checkpoint saves |
| `output_dir` | `pvc://shared/...` | PVC path for checkpoint storage |

### How JIT Checkpointing Works

When `enable_jit_checkpoint=True`:

1. **SIGTERM Handler Registered:** TransformersTrainer registers a signal handler for SIGTERM
2. **Safe Checkpoint on Signal:** When SIGTERM is received, training pauses after the current optimizer step
3. **Async Checkpoint Save:** Model state is saved asynchronously using CUDA streams (if available)
4. **Sentinel File:** A marker file ensures incomplete checkpoints are detected and cleaned up
5. **Auto-Resume:** On restart, training automatically resumes from the latest valid checkpoint

### Periodic Checkpointing

In addition to JIT checkpointing, you can configure periodic saves:

```python
PeriodicCheckpointConfig(
    save_strategy="steps",  # or "epoch"
    save_steps=20,           # Save every 20 steps
    save_total_limit=2,      # Keep only 2 most recent checkpoints
)
```

> **Note:** Periodic checkpointing blocks GPU training during the save operation. Avoid checkpointing too frequently (e.g., every step) as this can significantly increase total training time and waste GPU cycles.

In [None]:
# Configure periodic checkpointing - SDK injects these into TrainingArguments
checkpoint_config = PeriodicCheckpointConfig(
    save_strategy="steps",
    save_steps=20,
    save_total_limit=2,
)

# Create TransformersTrainer with JIT checkpointing enabled
# The output_dir="pvc://shared/checkpoints/jit-checkpointing" tells the SDK to:
# - Mount PVC "shared" at /mnt/kubeflow-checkpoints on all training pods
# - Set TrainingArguments.output_dir to /mnt/kubeflow-checkpoints/checkpoints/jit-checkpointing
# - Enable JIT checkpointing (saves state on SIGTERM)
# - Auto-resume from latest checkpoint on restart
trainer = TransformersTrainer(
    func=train_func,
    num_nodes=2,
    resources_per_node={
        "nvidia.com/gpu": 1,
        "cpu": "4",
        "memory": "16Gi",
    },
    # Make the training function cleaner: set offline mode at the pod level
    env={
        "HF_HUB_OFFLINE": "1",
        "TRANSFORMERS_OFFLINE": "1",
    },
    # JIT Checkpointing: Save checkpoint on SIGTERM (preemption)
    enable_jit_checkpoint=True,
    # Periodic Checkpointing: Save checkpoint every save_steps
    periodic_checkpoint_config=checkpoint_config,
    # PVC path for checkpoints - SDK handles mounting automatically
    output_dir="pvc://shared/checkpoints/jit-checkpointing",
)

print("‚úÖ TransformersTrainer configured with:")
print("   - JIT Checkpointing: ENABLED (saves on SIGTERM)")
print("   - Periodic Checkpointing: Every 20 steps, keep 2 most recent")
print("   - Auto-Resume: ENABLED (resumes from latest checkpoint)")

# Submit the training job
job_name = client.train(
    trainer=trainer,
    runtime=runtime,
)
print(f"\n‚úÖ Training job submitted: {job_name}")
print(f"üíæ Checkpoints will appear in workbench at: {NOTEBOOK_CHECKPOINTS_PATH}")

## Follow Job Logs

Let's fetch our job logs to make sure training is going as expected. The logs will stream in real-time as the training progresses.

In [None]:
# Stream logs (press Ctrl+C to stop if you want to continue with other cells)
for logline in client.get_job_logs(job_name, follow=True):
    print(logline, end="")

## Get Job Status

Check the final status of the training job after completion.

In [None]:
# Check job status
job = client.get_job(job_name)
print("Final TrainJob Status:")
print(f"   Name: {job.name}")
print(f"   Status: {job.status}")
print(f"   Created: {job.creation_timestamp}")
print(f"   Nodes: {job.num_nodes}")
print(f"   Runtime: {job.runtime.name}")

if job.steps:
    print("   Steps:")
    for step in job.steps:
        print(f"     - {step.name}: {step.status}")
    print()

## Verify Checkpoints

After training completes (or after a preemption/restart), you can verify the checkpoints saved on the PVC.

### Checkpoint Structure

The training function saves checkpoints with this structure:
```
/opt/app-root/src/shared/checkpoints/jit-checkpointing/
‚îú‚îÄ‚îÄ checkpoint-<step>/  # Intermediate checkpoints (saved every save_steps)
‚îú‚îÄ‚îÄ checkpoint-<N>/    # Checkpoint at final step (N = last step number)
‚îî‚îÄ‚îÄ final/             # Final merged model ready for inference
```

### JIT Checkpoint Behavior

If training was interrupted by SIGTERM:
- A checkpoint is saved at the last completed optimizer step
- Incomplete checkpoints (with sentinel files) are automatically cleaned up on resume
- Training resumes from the most recent valid checkpoint

In [None]:
# List checkpoints on PVC
import os

if os.path.exists(NOTEBOOK_CHECKPOINTS_PATH):
    print(f"üìÇ Checkpoints at {NOTEBOOK_CHECKPOINTS_PATH}:")
    for item in sorted(os.listdir(NOTEBOOK_CHECKPOINTS_PATH)):
        item_path = os.path.join(NOTEBOOK_CHECKPOINTS_PATH, item)
        if os.path.isdir(item_path):
            files = os.listdir(item_path)
            print(f"   üìÅ {item}/ ({len(files)} files)")
else:
    print(f"‚ö†Ô∏è Checkpoint directory not found: {NOTEBOOK_CHECKPOINTS_PATH}")
    print("   This is expected if training hasn't completed yet.")

## Test the Trained Model (Optional)

After training completes, you can load the fine-tuned model from the checkpoint saved on the shared PVC.

In [None]:
def find_most_recent_checkpoint(output_dir):
    """Find the most recently created checkpoint directory."""
    if not os.path.exists(output_dir):
        raise FileNotFoundError(f"Output directory not found: {output_dir}")

    checkpoint_dirs = [
        os.path.join(output_dir, d)
        for d in os.listdir(output_dir)
        if os.path.isdir(os.path.join(output_dir, d))
        and (d.startswith("checkpoint-") or d == "final")
    ]

    if not checkpoint_dirs:
        raise FileNotFoundError(f"No checkpoints found in {output_dir}")

    # Prefer 'final' if it exists
    final_path = os.path.join(output_dir, "final")
    if os.path.exists(final_path):
        return final_path

    return max(checkpoint_dirs, key=os.path.getctime)


print("‚úÖ Checkpoint utility defined")

In [None]:
# Find and load the trained model
final_checkpoint = find_most_recent_checkpoint(NOTEBOOK_CHECKPOINTS_PATH)
print(f"üìÇ Loading checkpoint from: {final_checkpoint}")

trained_tokenizer = AutoTokenizer.from_pretrained(
    final_checkpoint, trust_remote_code=True
)
trained_model = AutoModelForCausalLM.from_pretrained(
    final_checkpoint,
    torch_dtype=torch.bfloat16,
    device_map="cuda:0",
    trust_remote_code=True,
)

print("‚úÖ Model loaded successfully")
print(f"üìä Model parameters: {trained_model.num_parameters():,}")

# Test the model
test_prompt = "### Instruction:\nExplain what machine learning is in one sentence.\n\n### Response:"

print("\nüìù Testing model with prompt:")
print(test_prompt)
print("\nü§ñ Model response:")

inputs = trained_tokenizer(test_prompt, return_tensors="pt").to(trained_model.device)

# Remove token_type_ids if present (not used by some models like Qwen)
if "token_type_ids" in inputs:
    del inputs["token_type_ids"]

with torch.no_grad():
    outputs = trained_model.generate(
        **inputs,
        max_new_tokens=100,
        do_sample=True,
        temperature=0.7,
        top_p=0.95,
    )

response = trained_tokenizer.decode(outputs[0], skip_special_tokens=True)
print(response.replace(test_prompt, "").strip())

print("\n‚úÖ Model test completed!")

## Cleanup

Delete the training job and free resources.

**Note:** To fully clean up, you may also want to delete the downloaded model, dataset, and checkpoints from the PVC:
```bash
rm -rf /opt/app-root/src/shared/models/qwen2.5-1.5b-instruct
rm -rf /opt/app-root/src/shared/data/alpaca_processed
rm -rf /opt/app-root/src/shared/checkpoints/jit-checkpointing
```

In [None]:
# Delete the training job
client.delete_job(name=job_name)
print(f"‚úÖ Job {job_name} deleted")

In [None]:
import gc

# Clear CUDA cache
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    torch.cuda.ipc_collect()

gc.collect()
print("‚úÖ Resources freed, CUDA cache cleared")

## Summary

Congratulations! You've successfully completed a distributed fine-tuning job with JIT checkpointing on OpenShift AI.

### What You Accomplished

| Step | Description |
|------|-------------|
| ‚úÖ Model Download | Downloaded Qwen 2.5 1.5B Instruct to shared PVC |
| ‚úÖ Dataset Preparation | Processed Stanford Alpaca dataset for instruction-tuning |
| ‚úÖ Distributed Training | Ran 2-node distributed training with PyTorch DDP |
| ‚úÖ JIT Checkpointing | Enabled automatic checkpoint saves on SIGTERM |
| ‚úÖ Periodic Checkpointing | Configured regular checkpoint saves every 20 steps |
| ‚úÖ Auto-Resume | Training can resume from latest checkpoint on restart |
| ‚úÖ Model Testing | Loaded and tested the fine-tuned model |

### Key Takeaways

1. **JIT Checkpointing** makes training preemption-safe:
   - Enable with `enable_jit_checkpoint=True`
   - Automatically saves state when pod receives SIGTERM
   - Uses CUDA streams for async checkpoint saves

2. **Periodic Checkpointing** provides regular saves:
   - Configure with `PeriodicCheckpointConfig`
   - Control save frequency (`save_strategy`, `save_steps`)
   - Limit disk usage (`save_total_limit`)

3. **Auto-Resume** minimizes training loss:
   - Training automatically resumes from latest valid checkpoint
   - Incomplete checkpoints are detected and cleaned up
   - No manual intervention required

4. **PVC Storage** ensures durability:
   - Use `output_dir="pvc://<pvc-name>/..."` for automatic mounting
   - Checkpoints persist across pod restarts
   - All nodes can access the same checkpoint storage

### When to Use JIT Checkpointing

| Scenario | Recommendation |
|----------|----------------|
| Spot/Preemptible instances | **Enable** - Instances can be reclaimed anytime |
| Kueue-managed workloads | **Enable** - Higher-priority jobs may preempt |
| Long-running training | **Enable** - Protect against interruptions |
| Short training runs | Optional - May add small overhead |

### TransformersTrainer Checkpointing Reference

| Parameter | Description | Default |
|-----------|-------------|----------|
| `enable_jit_checkpoint` | Save checkpoint on SIGTERM | `True` |
| `periodic_checkpoint_config` | Configure periodic saves | `None` |
| `output_dir` | PVC path for checkpoints (`pvc://...`) | Required |

### Resources

- [Kubeflow Trainer Documentation](https://www.kubeflow.org/docs/components/trainer/)
- [HuggingFace Transformers](https://huggingface.co/docs/transformers/)
- [Stanford Alpaca Dataset](https://huggingface.co/datasets/tatsu-lab/alpaca)