(hpu_bert_training)=
# BERT Model Training with Intel Gaudi

<a id="try-anyscale-quickstart-intel_gaudi-bert" href="https://console.anyscale.com/register/ha?render_flow=ray&utm_source=ray_docs&utm_medium=docs&utm_campaign=intel_gaudi-bert">
    <img src="../../../_static/img/run-on-anyscale.svg" alt="try-anyscale-quickstart">
</a>
<br></br>

In this notebook, we will train a BERT model for sequence classification using the Yelp review full dataset. We will use the `transformers` and `datasets` libraries from Hugging Face, along with `ray.train` for distributed training.

[Intel Gaudi AI Processors (HPUs)](https://habana.ai) are AI hardware accelerators designed by Intel Habana Labs. For more information, see [Gaudi Architecture](https://docs.habana.ai/en/latest/Gaudi_Overview/index.html) and [Gaudi Developer Docs](https://developer.habana.ai/).

## Configuration

A node with Gaudi/Gaudi2 installed is required to run this example. Both Gaudi and Gaudi2 have 8 HPUs. We will use 2 workers to train the model, each using 1 HPU.

We recommend using a prebuilt container to run these examples. To run a container, you need Docker. See [Install Docker Engine](https://docs.docker.com/engine/install/) for installation instructions.

Next, follow [Run Using Containers](https://docs.habana.ai/en/latest/Installation_Guide/Bare_Metal_Fresh_OS.html?highlight=installer#run-using-containers) to install the Gaudi drivers and container runtime.

Next, start the Gaudi container:
```bash
docker pull vault.habana.ai/gaudi-docker/1.22.1/ubuntu24.04/habanalabs/pytorch-installer-2.7.1:latest
docker run -it --runtime=habana -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice --net=host --ipc=host vault.habana.ai/gaudi-docker/1.22.1/ubuntu24.04/habanalabs/pytorch-installer-2.7.1:latest
```

Inside the container, install the following dependencies to run this notebook.
```bash
pip install ray[train] notebook transformers datasets evaluate scikit-learn
```

In [None]:
# Import necessary libraries

import os
from typing import Dict

import torch
from torch import nn
from torch.utils.data import DataLoader
from tqdm import tqdm

import numpy as np
import evaluate
from datasets import load_dataset
import transformers
from transformers import (
    Trainer,
    TrainingArguments,
    AutoTokenizer,
    AutoModelForSequenceClassification,
)

import ray.train
from ray.train import ScalingConfig
from ray.train.torch import TorchTrainer
from ray.train.torch import TorchConfig
from ray.runtime_env import RuntimeEnv

import habana_frameworks.torch.core as htcore

## Metrics Setup

We will use accuracy as our evaluation metric. The `compute_metrics` function will calculate the accuracy of our model's predictions.

In [None]:
# Metrics
metric = evaluate.load("accuracy")

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

## Training Function

This function will be executed by each worker during training. It handles data loading, tokenization, model initialization, and the training loop. Compared to a training function for GPU, no changes are needed to port to HPU. Internally, Ray Train does these things:

* Detect HPU and set the device.

* Initializes the habana PyTorch backend.

* Initializes the habana distributed backend.

In [None]:
def train_func_per_worker(config: Dict):
    
    # Datasets
    dataset = load_dataset("yelp_review_full")
    tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
    
    def tokenize_function(examples):
        return tokenizer(examples["text"], padding="max_length", truncation=True)

    lr = config["lr"]
    epochs = config["epochs"]
    batch_size = config["batch_size_per_worker"]

    train_dataset = dataset["train"].select(range(1000)).map(tokenize_function, batched=True)
    eval_dataset = dataset["test"].select(range(1000)).map(tokenize_function, batched=True)

    # Prepare dataloader for each worker
    dataloaders = {}
    dataloaders["train"] = torch.utils.data.DataLoader(
        train_dataset, 
        shuffle=True, 
        collate_fn=transformers.default_data_collator, 
        batch_size=batch_size
    )
    dataloaders["test"] = torch.utils.data.DataLoader(
        eval_dataset, 
        shuffle=True, 
        collate_fn=transformers.default_data_collator, 
        batch_size=batch_size
    )

    # Obtain HPU device automatically
    device = ray.train.torch.get_device()

    # Prepare model and optimizer
    model = AutoModelForSequenceClassification.from_pretrained(
        "bert-base-cased", num_labels=5
    )
    model = model.to(device)
    
    optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)

    # Start training loops
    for epoch in range(epochs):
        # Each epoch has a training and validation phase
        for phase in ["train", "test"]:
            if phase == "train":
                model.train()  # Set model to training mode
            else:
                model.eval()  # Set model to evaluate mode

            # breakpoint()
            for batch  in dataloaders[phase]:
                batch = {k: v.to(device) for k, v in batch.items()}

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                with torch.set_grad_enabled(phase == "train"):
                    # Get model outputs and calculate loss
                    
                    outputs = model(**batch)
                    loss = outputs.loss

                    # backward + optimize only if in training phase
                    if phase == "train":
                        loss.backward()
                        optimizer.step()
                        print(f"train epoch:[{epoch}]\tloss:{loss:.6f}")

## Main Training Function

The `train_bert` function sets up the distributed training environment using Ray and starts the training process. To enable training using HPU, we only need to make the following changes:
* Require an HPU for each worker in ScalingConfig
* Set backend to "hccl" in TorchConfig

In [None]:
def train_bert(num_workers=2):
    global_batch_size = 8

    train_config = {
        "lr": 1e-3,
        "epochs": 10,
        "batch_size_per_worker": global_batch_size // num_workers,
    }

    # Configure computation resources
    # In ScalingConfig, require an HPU for each worker
    scaling_config = ScalingConfig(num_workers=num_workers, resources_per_worker={"CPU": 1, "HPU": 1})
    # Set backend to hccl in TorchConfig
    torch_config = TorchConfig(backend = "hccl")
    
    # Start your ray cluster
    # Workaround https://github.com/ray-project/ray/issues/45302 by explictly setting HPU resource
    ray.init(resources={"HPU": 8})
    
    # Initialize a Ray TorchTrainer
    trainer = TorchTrainer(
        train_loop_per_worker=train_func_per_worker,
        train_loop_config=train_config,
        torch_config=torch_config,
        scaling_config=scaling_config,
    )

    result = trainer.fit()
    print(f"Training result: {result}")

## Start Training

Finally, we call the `train_bert` function to start the training process. You can adjust the number of workers to use.

In [None]:
train_bert(num_workers=2)

## Possible outputs

``` text
2025-11-17 22:15:24,256	INFO worker.py:2012 -- Started a local Ray instance.
/usr/local/lib/python3.12/dist-packages/ray/_private/worker.py:2051: FutureWarning: Tip: In future versions of Ray, Ray will no longer override accelerator visible devices env var if num_gpus=0 or num_gpus=None (default). To enable this behavior and turn off this error message, set RAY_ACCEL_ENV_VAR_OVERRIDE_ON_ZERO=0
  warnings.warn(
(TrainController pid=87725) Calling add_step_closure function does not have any effect. It's lazy mode only functionality. (warning logged once)
(TrainController pid=87725) Calling mark_step function does not have any effect. It's lazy mode only functionality. (warning logged once)
(TrainController pid=87725) Calling iter_mark_step function does not have any effect. It's lazy mode only functionality. (warning logged once)
(TrainController pid=87725) Attempting to start training worker group of size 2 with the following resources: [{'CPU': 1, 'HPU': 1}] * 2
(RayTrainWorker pid=88179) Calling add_step_closure function does not have any effect. It's lazy mode only functionality. (warning logged once)
(RayTrainWorker pid=88179) Calling mark_step function does not have any effect. It's lazy mode only functionality. (warning logged once)
(RayTrainWorker pid=88179) Calling iter_mark_step function does not have any effect. It's lazy mode only functionality. (warning logged once)
(RayTrainWorker pid=88178) Calling add_step_closure function does not have any effect. It's lazy mode only functionality. (warning logged once)
(RayTrainWorker pid=88178) Calling mark_step function does not have any effect. It's lazy mode only functionality. (warning logged once)
(RayTrainWorker pid=88178) Calling iter_mark_step function does not have any effect. It's lazy mode only functionality. (warning logged once)
(RayTrainWorker pid=88178) Setting up process group for: env:// [rank=0, world_size=2]
(TrainController pid=87725) Started training worker group of size 2: 
(TrainController pid=87725) - (ip=100.83.67.100, pid=88178) world_rank=0, local_rank=0, node_rank=0
(TrainController pid=87725) - (ip=100.83.67.100, pid=88179) world_rank=1, local_rank=1, node_rank=0
(RayTrainWorker pid=88179)               0 COPY_FREE_VARS           1
(RayTrainWorker pid=88179) 
(RayTrainWorker pid=88179)   7           2 RESUME                   0
(RayTrainWorker pid=88179) 
(RayTrainWorker pid=88179)   8           4 PUSH_NULL
(RayTrainWorker pid=88179)               6 LOAD_DEREF               1 (tokenizer)
(RayTrainWorker pid=88179)               8 LOAD_FAST                0 (examples)
(RayTrainWorker pid=88179)              10 LOAD_CONST               1 ('text')
(RayTrainWorker pid=88179)              12 BINARY_SUBSCR
(RayTrainWorker pid=88179)              16 LOAD_CONST               2 ('max_length')
(RayTrainWorker pid=88179)              18 LOAD_CONST               3 (True)
(RayTrainWorker pid=88179)              20 KW_NAMES                 4 (('padding', 'truncation'))
(RayTrainWorker pid=88179)              22 CALL                     3
(RayTrainWorker pid=88179)              30 RETURN_VALUE
(RayTrainWorker pid=88179)               0 COPY_FREE_VARS           1
(RayTrainWorker pid=88179) 
(RayTrainWorker pid=88179)   7           2 RESUME                   0
(RayTrainWorker pid=88179) 
(RayTrainWorker pid=88179)   8           4 PUSH_NULL
(RayTrainWorker pid=88179)               6 LOAD_DEREF               1 (tokenizer)
(RayTrainWorker pid=88179)               8 LOAD_FAST                0 (examples)
(RayTrainWorker pid=88179)              10 LOAD_CONST               1 ('text')
(RayTrainWorker pid=88179)              12 BINARY_SUBSCR
(RayTrainWorker pid=88179)              16 LOAD_CONST               2 ('max_length')
(RayTrainWorker pid=88179)              18 LOAD_CONST               3 (True)
(RayTrainWorker pid=88179)              20 KW_NAMES                 4 (('padding', 'truncation'))
(RayTrainWorker pid=88179)              22 CALL                     3
(RayTrainWorker pid=88179)              30 RETURN_VALUE
Map:   0%|          | 0/1000 [00:00<?, ? examples/s]
(RayTrainWorker pid=88178) 
(RayTrainWorker pid=88178) 
(RayTrainWorker pid=88178) 
(RayTrainWorker pid=88178) 
Map: 100%|██████████| 1000/1000 [00:00<00:00, 4032.52 examples/s]
(RayTrainWorker pid=88179) Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight']
(RayTrainWorker pid=88179) You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
(pid=gcs_server) [2025-11-17 22:15:52,392 E 77855 77855] (gcs_server) gcs_server.cc:302: Failed to establish connection to the event+metrics exporter agent. Events and metrics will not be exported. Exporter agent status: RpcError: Running out of retries to initialize the metrics agent. rpc_code: 14
(RayTrainWorker pid=88178) ============================= HPU PT BRIDGE CONFIGURATION ON RANK = 0 ============= 
(RayTrainWorker pid=88178)  PT_HPU_LAZY_MODE = 0
(RayTrainWorker pid=88178)  PT_HPU_RECIPE_CACHE_CONFIG = ,false,1024,false
(RayTrainWorker pid=88178)  PT_HPU_MAX_COMPOUND_OP_SIZE = 9223372036854775807
(RayTrainWorker pid=88178)  PT_HPU_LAZY_ACC_PAR_MODE = 1
(RayTrainWorker pid=88178)  PT_HPU_ENABLE_REFINE_DYNAMIC_SHAPES = 0
(RayTrainWorker pid=88178)  PT_HPU_EAGER_PIPELINE_ENABLE = 1
(RayTrainWorker pid=88178)  PT_HPU_EAGER_COLLECTIVE_PIPELINE_ENABLE = 1
(RayTrainWorker pid=88178)  PT_HPU_ENABLE_LAZY_COLLECTIVES = 0
(RayTrainWorker pid=88178) ---------------------------: System Configuration :---------------------------
(RayTrainWorker pid=88178) Num CPU Cores : 160
(RayTrainWorker pid=88178) CPU RAM       : 1007 GB
(RayTrainWorker pid=88178) ------------------------------------------------------------------------------
(RayTrainWorker pid=88179) train epoch:[0]	loss:2.253497
(RayTrainWorker pid=88179) train epoch:[0]	loss:1.718906
(raylet) [2025-11-17 22:15:54,166 E 78148 78148] (raylet) main.cc:975: Failed to establish connection to the metrics exporter agent. Metrics will not be exported. Exporter agent status: RpcError: Running out of retries to initialize the metrics agent. rpc_code: 14
(RayTrainWorker pid=88178)               0 COPY_FREE_VARS           1 [repeated 2x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disable log deduplication, or see https://docs.ray.io/en/master/ray-observability/user-guides/configure-logging.html#log-deduplication for more options.)
(RayTrainWorker pid=88178)   7           2 RESUME                   0 [repeated 2x across cluster]
(RayTrainWorker pid=88178)   8           4 PUSH_NULL [repeated 2x across cluster]
(RayTrainWorker pid=88178)               6 LOAD_DEREF               1 (tokenizer) [repeated 2x across cluster]
(RayTrainWorker pid=88178)               8 LOAD_FAST                0 (examples) [repeated 2x across cluster]
(RayTrainWorker pid=88178)              10 LOAD_CONST               1 ('text') [repeated 2x across cluster]
(RayTrainWorker pid=88178)              12 BINARY_SUBSCR [repeated 2x across cluster]
(RayTrainWorker pid=88178)              16 LOAD_CONST               2 ('max_length') [repeated 2x across cluster]
(RayTrainWorker pid=88178)              18 LOAD_CONST               3 (True) [repeated 2x across cluster]
(RayTrainWorker pid=88178)              20 KW_NAMES                 4 (('padding', 'truncation')) [repeated 2x across cluster]
(RayTrainWorker pid=88178)              22 CALL                     3 [repeated 2x across cluster]
(RayTrainWorker pid=88178)              30 RETURN_VALUE [repeated 2x across cluster]
Map:   0%|          | 0/1000 [00:00<?, ? examples/s]
Map: 100%|██████████| 1000/1000 [00:00<00:00, 4351.62 examples/s]
(RayTrainWorker pid=88178) Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight']
(RayTrainWorker pid=88178) You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
(pid=78279) [2025-11-17 22:15:58,103 E 78279 78722] core_worker_process.cc:825: Failed to establish connection to the metrics exporter agent. Metrics will not be exported. Exporter agent status: RpcError: Running out of retries to initialize the metrics agent. rpc_code: 14
(RayTrainWorker pid=88179) train epoch:[0]	loss:1.546598 [repeated 146x across cluster]
[2025-11-17 22:15:59,457 E 77711 78275] core_worker_process.cc:825: Failed to establish connection to the metrics exporter agent. Metrics will not be exported. Exporter agent status: RpcError: Running out of retries to initialize the metrics agent. rpc_code: 14
(TrainController pid=87725) [2025-11-17 22:16:02,243 E 87725 87764] core_worker_process.cc:825: Failed to establish connection to the metrics exporter agent. Metrics will not be exported. Exporter agent status: RpcError: Running out of retries to initialize the metrics agent. rpc_code: 14 [repeated 157x across cluster]
(RayTrainWorker pid=88179) train epoch:[0]	loss:1.599180 [repeated 151x across cluster]
(bundle_reservation_check_func pid=87994) [2025-11-17 22:16:08,707 E 87994 88113] core_worker_process.cc:825: Failed to establish connection to the metrics exporter agent. Metrics will not be exported. Exporter agent status: RpcError: Running out of retries to initialize the metrics agent. rpc_code: 14
(RayTrainWorker pid=88179) train epoch:[0]	loss:1.554767 [repeated 145x across cluster]
(RayTrainWorker pid=88178) [2025-11-17 22:16:11,577 E 88178 88342] core_worker_process.cc:825: Failed to establish connection to the metrics exporter agent. Metrics will not be exported. Exporter agent status: RpcError: Running out of retries to initialize the metrics agent. rpc_code: 14 [repeated 3x across cluster]
(RayTrainWorker pid=88179) train epoch:[1]	loss:1.579086 [repeated 58x across cluster]
(RayTrainWorker pid=88179) train epoch:[1]	loss:1.228644 [repeated 136x across cluster]
(RayTrainWorker pid=88179) train epoch:[1]	loss:1.458077 [repeated 158x across cluster]
(RayTrainWorker pid=88178) train epoch:[1]	loss:1.093389 [repeated 155x across cluster]
(RayTrainWorker pid=88179) train epoch:[2]	loss:1.557385 [repeated 75x across cluster]
(RayTrainWorker pid=88179) train epoch:[2]	loss:1.406001 [repeated 111x across cluster]
(RayTrainWorker pid=88179) train epoch:[2]	loss:1.225072 [repeated 146x across cluster]
(RayTrainWorker pid=88178) train epoch:[2]	loss:1.274225 [repeated 147x across cluster]
(RayTrainWorker pid=88179) train epoch:[3]	loss:1.091613 [repeated 93x across cluster]
(RayTrainWorker pid=88179) train epoch:[3]	loss:1.352153 [repeated 100x across cluster]
(RayTrainWorker pid=88179) train epoch:[3]	loss:1.706026 [repeated 156x across cluster]
(RayTrainWorker pid=88178) train epoch:[3]	loss:2.390724 [repeated 149x across cluster]
(RayTrainWorker pid=88179) train epoch:[4]	loss:0.666302 [repeated 106x across cluster]
(RayTrainWorker pid=88179) train epoch:[4]	loss:1.550435 [repeated 100x across cluster]
(RayTrainWorker pid=88179) train epoch:[4]	loss:0.650674 [repeated 162x across cluster]
(RayTrainWorker pid=88178) train epoch:[4]	loss:2.413051 [repeated 116x across cluster]
(RayTrainWorker pid=88179) train epoch:[5]	loss:1.099013 [repeated 140x across cluster]
(RayTrainWorker pid=88179) train epoch:[5]	loss:1.775257 [repeated 103x across cluster]
(RayTrainWorker pid=88179) train epoch:[5]	loss:0.800103 [repeated 159x across cluster]
(RayTrainWorker pid=88179) train epoch:[6]	loss:0.799364 [repeated 103x across cluster]
(RayTrainWorker pid=88179) train epoch:[6]	loss:0.831640 [repeated 160x across cluster]
(RayTrainWorker pid=88179) train epoch:[6]	loss:0.791100 [repeated 101x across cluster]
(RayTrainWorker pid=88178) train epoch:[6]	loss:1.881995 [repeated 152x across cluster]
(RayTrainWorker pid=88179) train epoch:[7]	loss:0.614887 [repeated 102x across cluster]
(RayTrainWorker pid=88179) train epoch:[7]	loss:0.568946 [repeated 145x across cluster]
(RayTrainWorker pid=88179) train epoch:[7]	loss:1.241300 [repeated 108x across cluster]
(RayTrainWorker pid=88178) train epoch:[7]	loss:1.521225 [repeated 126x across cluster]
(RayTrainWorker pid=88179) train epoch:[8]	loss:1.929299 [repeated 110x across cluster]
(RayTrainWorker pid=88179) train epoch:[8]	loss:0.744064 [repeated 162x across cluster]
(RayTrainWorker pid=88179) train epoch:[8]	loss:0.533718 [repeated 114x across cluster]
(RayTrainWorker pid=88178) train epoch:[8]	loss:1.589755 [repeated 118x across cluster]
(RayTrainWorker pid=88179) train epoch:[9]	loss:1.681320 [repeated 123x across cluster]
(RayTrainWorker pid=88179) train epoch:[9]	loss:0.482950 [repeated 151x across cluster]
(RayTrainWorker pid=88179) train epoch:[9]	loss:0.122691 [repeated 125x across cluster]
(RayTrainWorker pid=88178) train epoch:[9]	loss:1.464374 [repeated 101x across cluster]
(RayTrainWorker pid=88178) train epoch:[9]	loss:1.761037 [repeated 80x across cluster]
(RayTrainWorker pid=88178) train epoch:[9]	loss:1.560819 [repeated 77x across cluster]
Training result: Result(metrics=None, checkpoint=None, error=None, path='/root/ray_results/ray_train_run-2025-11-17_22-15-29', metrics_dataframe=None, best_checkpoints=[], _storage_filesystem=<pyarrow._fs.LocalFileSystem object at 0x7f792c895330>)
```