# Fine-Tune Llama 3.2-3B with MLX Distributed and LoRA

This Notebook will show how to fine-tune Llama 3.2-3B on multiple GPUs and `mlx-lm`.

The MLX runtime uses `mlx[cuda]` package to run distributed training on GPUs.

MLX Distributed: https://ml-explore.github.io/mlx/build/html/usage/distributed.html

MLX LM: https://github.com/ml-explore/mlx-lm

## Install the Kubeflow SDK

You need to install the Kubeflow SDK to interact with Kubeflow Trainer APIs:

In [None]:
# !pip install -U kubeflow

## Create Script to Fine-Tune Llama 3.2

We will use `mlx-lm` library to fine-tune Llama 3.2.

`mlx-lm` is a Python package for generating text and fine-tuning LLMs with MLX.

We will perform LoRA (Low-Rank Adaptation) fine-tuning to reduce number of trainable parameters and optimize GPU resources.

In [1]:
def fine_tune_llama(hf_token: str, num_samples: str, batch_size: str):
    import types
    import os
    import mlx.core as mx
    from mlx_lm.lora import train_model, CONFIG_DEFAULTS
    from mlx_lm.tuner.datasets import load_dataset
    from mlx_lm.utils import load
    from mlx_lm.generate import generate

    os.environ["HF_TOKEN"] = hf_token

    # Set parameters for the mlx-lm.
    args = types.SimpleNamespace()
    args.model = "meta-llama/Llama-3.2-3B-Instruct"
    args.data = "mlx-community/WikiSQL"
    args.train = True
    # Configure LoRA settings to reduce number of trainable params.
    args.lora_parameters = {
        "rank": 8,
        "dropout": 0.05,
        "scale": 20.0,
    }

    args.iters = int(num_samples)
    args.batch_size = int(batch_size)

    # Set defaults for other required parameters
    for k, v in CONFIG_DEFAULTS.items():
        if not hasattr(args, k):
            setattr(args, k, v)

    model, tokenizer = load(args.model)
    train_set, valid_set, test_set = load_dataset(args, tokenizer)

    # Start the Llama distributed fine-tuning.
    train_model(args, model, train_set, valid_set)

    # Evaluate the fine-tuned adapter.
    dist = mx.distributed.init(strict=True, backend="mpi")
    if dist.rank() == 0:
        print("=" * 100)
        print(f"Training is complete. Adapters saved to: {args.adapter_path}")
        print("Evaluate the fine-tuned LoRA adapter for Llama 3.2")

        finetuned_model, finetuned_tokenizer = load(
            args.model, adapter_path=args.adapter_path
        )

        # Generate response using the fine-tuned adapter.
        sample_prompt = "What is SQL?"

        print(f"Prompt: {sample_prompt}")
        print("Response:")

        response = generate(
            model=finetuned_model,
            tokenizer=finetuned_tokenizer,
            prompt=sample_prompt,
            max_tokens=1000,
            verbose=False,
        )

        print(response)

## Get the MLX Runtime

You can list the available Kubeflow Trainer runtimes with the `list_runtimes()` API.

The name of the MLX runtime is `mlx-distributed`.

In [2]:
from kubeflow.trainer import TrainerClient, CustomTrainer

for r in TrainerClient().list_runtimes():
    if r.name == "mlx-distributed":
        print(f"Name: {r.name}, Framework: {r.trainer.framework}, Trainer Type: {r.trainer.trainer_type.value}\n")
        mlx_runtime = r

Name: mlx-distributed, Framework: mlx, Trainer Type: CustomTrainer



## Get the Runtime Packages

You can see the available Python packages and GPUs with the `get_runtime_packages()` API.

The API shows available GPUs with CUDA driver on the single training node.

In [7]:
TrainerClient().get_runtime_packages(mlx_runtime)

Python: 3.10.12 (main, Aug 15 2025, 14:32:43) [GCC 11.4.0]
Package                Version
---------------------- -----------
aiohappyeyeballs       2.6.1
aiohttp                3.13.1
aiosignal              1.4.0
async-timeout          5.0.1
attrs                  25.4.0
certifi                2025.10.5
charset-normalizer     3.4.4
datasets               4.0.0
dill                   0.3.8
filelock               3.20.0
frozenlist             1.8.0
fsspec                 2025.3.0
hf-xet                 1.1.10
huggingface-hub        0.35.3
idna                   3.11
Jinja2                 3.1.6
MarkupSafe             3.0.3
mlx                    0.28.0
mlx-cuda               0.28.0
mlx-data               0.1.0
mlx-lm                 0.26.3
multidict              6.7.0
multiprocess           0.70.16
numpy                  2.2.6
nvidia-cublas-cu12     12.9.1.4
nvidia-cuda-nvrtc-cu12 12.9.86
nvidia-cudnn-cu12      9.14.0.64
packaging              25.0
pandas                 2.3.3
pip       

## Create TrainJob with MLX Distributed

Use the `train()` API to create distributed TrainJob on **4 GPUs**. Every MPI training node uses 1 GPU.

**Note** Update the HF Token.

In [8]:
# HF_TOKEN = "hf_add_your_token"

In [9]:
job_id = TrainerClient().train(
    trainer=CustomTrainer(
        func=fine_tune_llama,
        func_args={
            "hf_token": HF_TOKEN,
            "num_samples": "100",
            # Batch size must be divisible by the number of GPUs. (8 / 4 = 2) per training node.
            "batch_size": "8",
        },
        num_nodes=4,  # Fine-Tune Llama3.2 on 4 GPUs.
        resources_per_node={
            "gpu": 1
        },
    ),
    runtime=mlx_runtime,
)

In [10]:
# Train API generates a random TrainJob id.
job_id

'ja1595e2b2a3'

## Check the TrainJob Info

Use the `list_jobs()` and `get_job()` APIs to get information about created TrainJob and its steps.

In [11]:
for job in TrainerClient().list_jobs():
    print(f"TrainJob: {job.name}, Status: {job.status}, Created at: {job.creation_timestamp}")

TrainJob: ja1595e2b2a3, Status: Created, Created at: 2025-10-18 00:32:46+00:00


In [13]:
# We execute mpirun command on node-0, which functions as the MPI Launcher node.
for c in TrainerClient().get_job(name=job_id).steps:
    print(f"Step: {c.name}, Status: {c.status}, Devices: {c.device} x {c.device_count}\n")

Step: node-0, Status: Running, Devices: gpu x 1

Step: node-1, Status: Running, Devices: gpu x 1

Step: node-2, Status: Running, Devices: gpu x 1

Step: node-3, Status: Running, Devices: gpu x 1



## Get the TrainJob Logs

Use the `get_job_logs()` API to retrieve the TrainJob logs.

The fine-tuning runs on 4 training nodes.

In [14]:
for logline in TrainerClient().get_job_logs(job_id, follow=True):
    print(logline)

Fetching 11 files: 100%|██████████| 11/11 [00:08<00:00,  1.34it/s]
Fetching 11 files: 100%|██████████| 11/11 [00:08<00:00,  1.30it/s]
Fetching 11 files: 100%|██████████| 11/11 [00:08<00:00,  1.25it/s]
Fetching 11 files: 100%|██████████| 11/11 [00:08<00:00,  1.23it/s]
Loading Hugging Face dataset mlx-community/WikiSQL.
Loading Hugging Face dataset mlx-community/WikiSQL.
Loading Hugging Face dataset mlx-community/WikiSQL.
Loading Hugging Face dataset mlx-community/WikiSQL.
Generating train split: 100%|██████████| 1000/1000 [00:00<00:00, 285307.39 examples/s]
Generating valid split: 100%|██████████| 100/100 [00:00<00:00, 62110.23 examples/s]
Generating test split: 100%|██████████| 100/100 [00:00<00:00, 90336.08 examples/s]
Generating train split: 100%|██████████| 1000/1000 [00:00<00:00, 304752.16 examples/s]
Generating valid split: 100%|██████████| 100/100 [00:00<00:00, 81049.35 examples/s]
Generating test split: 100%|██████████| 100/100 [00:00<00:00, 66798.92 examples/s]
Trainable parame

## Delete the TrainJob

When TrainJob is finished, you can delete the resource.

In [None]:
TrainerClient().delete_job(job_id)