# Distributed training with Ray Train, PyTorch and Hugging Face
© 2025, Anyscale. All Rights Reserved

💻 **Launch Locally**: You can run this notebook locally.

🚀 **Launch on Cloud**: Think about running this notebook on a Ray Cluster (Click [here](http://console.anyscale.com/register) to easily start a Ray cluster on Anyscale)


This notebook demonstrates how to perform distributed training of a BERT model for sequence classification using Ray Train, PyTorch, and Hugging Face libraries. The goal is to classify Yelp reviews into categories by leveraging the power of distributed computing, which allows you to train large models efficiently across multiple CPUs or GPUs.

The notebook starts by importing all the necessary libraries, including PyTorch for deep learning, Hugging Face Transformers for model and tokenizer utilities, and Ray Train for distributed training. It then sets up the evaluation metric (accuracy) and defines a function to compute this metric during model evaluation.

A key part of the notebook is the training function, which is executed by each worker in the distributed setup. This function handles loading the Yelp review dataset, tokenizing the text data, preparing data loaders for batching, and setting up the BERT model for training. The function is designed to automatically use the best available hardware, whether that's a CPU, GPU, or Apple Silicon's MPS.

The main training function, `train_bert`, configures the distributed environment using Ray, sets up the training parameters, and launches the training process across multiple workers. This approach allows you to scale up your training easily, making it suitable for both local machines and cloud platforms. After training, Ray is properly shut down to free up resources.

Overall, this notebook provides a practical introduction to distributed deep learning with modern Python tools, making it easier for machine learning engineers to train large models on big datasets efficiently.

### Outline
<div class="alert alert-block alert-info">
<ol>
    <li>Architecture Diagram
    <li>Library Imports
        <ul>
            <li>Importing PyTorch, Hugging Face Transformers, Ray Train, and other dependencies
        </ul>
    <li>Metrics Setup
        <ul>
            <li>Defining accuracy as the evaluation metric
            <li>Function to compute metrics during evaluation
        </ul>
    <li>Training Function Per Worker
        <ul>
            <li>Data loading and preprocessing (tokenization)
            <li>Preparing data loaders for batching
            <li>Model initialization (BERT for sequence classification)
            <li>Device selection (CPU, GPU, or MPS)
            <li>Training and evaluation loop
        </ul>
    <li>Main Training Function
        <ul>
            <li>Setting up distributed training configuration with Ray
            <li>Scaling configuration for CPUs/GPUs
            <li>Initializing and running the Ray TorchTrainer
        </ul>
    <li>Running the Training
        <ul>
            <li>Executing the main training function with a specified number of workers
        </ul>
    <li>Shutdown Ray Cluster
    <li>Summary
</ol>
</div>


## 1. Architecture

![Architecture Diagram](https://lz-public-demo.s3.us-east-1.amazonaws.com/anyscale101/01_examples/04_Ray_Train_architecture.svg?sanitize=true)

### 2. Library Imports

In [1]:
# Import necessary libraries

import os
from typing import Dict # For type hinting

import torch # PyTorch for tensor operations
from torch import nn # PyTorch for deep learning
from torch.utils.data import DataLoader # DataLoader for batching and shuffling data
from tqdm import tqdm

import numpy as np
import evaluate
from datasets import load_dataset # To load datasets from Hugging Face
import transformers # Transformers library for model and tokenizer
from transformers import (
    Trainer, # 
    TrainingArguments,
    AutoTokenizer, # Tokenizer for Hugging Face models
    AutoModelForSequenceClassification, # Model for sequence classification
)

import ray.train # Ray Train for distributed training
from ray.train import ScalingConfig
from ray.train.torch import TorchTrainer # Trainer for PyTorch
from ray.train.torch import TorchConfig # Configuration for PyTorch training
from ray.runtime_env import RuntimeEnv # Runtime environment for Ray tasks


## 3. Metrics Setup
We will use accuracy as our evaluation metric. The compute_metrics function will calculate the accuracy of our model’s predictions.

In [2]:
# Metrics
metric = evaluate.load("accuracy") # Load accuracy metric from Hugging Face evaluate library

# Function to compute metrics
# This function takes the evaluation predictions and computes the accuracy metric
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

## 4. Training function per worker
This function will be executed by each worker during training. It handles data loading, tokenization, model initialization, and the training loop. This will automatically select GPU, MPS (on Apple Silicon), or CPU.

### Tokenizer
Tokenizer function is used to convert text into input IDs and attention masks.

Padding and truncation are applied to ensure uniform input size. This is essential for training models that require fixed-size inputs. The function is applied to the dataset using the map method. The map method applies the function to each example in the dataset. The batched=True argument allows processing multiple examples at once, which is more efficient.

The resulting dataset will have the tokenized inputs ready for training. This is a crucial step in preparing the dataset for model training. It ensures that the text data is converted into a format that the model can understand.

### Dataloaders
Dataloaders are used to load the dataset in batches for training and evaluation. This is essential for efficient training, especially with large datasets. The DataLoader will shuffle the training data and collate it into batches
The collate_fn is set to transformers.default_data_collator, which handles padding and batching automatically. The batch_size is set to the batch size per worker, which is defined in the config. This allows each worker to process a subset of the data in parallel. This is crucial for distributed training, where each worker processes a portion of the dataset.

In [3]:
def train_func_per_worker(config: Dict):
    
    # Datasets
    dataset = load_dataset("yelp_review_full")
    tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
    
    # Tokenization function
    def tokenize_function(examples):
        """    
        This function will tokenize the text data in the dataset
        It uses the tokenizer to convert text into input IDs and attention masks
        Padding and truncation are applied to ensure uniform input size
        This is essential for training models that require fixed-size inputs
        """
        return tokenizer(examples["text"], padding="max_length", truncation=True)

    lr = config["lr"]
    epochs = config["epochs"]
    batch_size = config["batch_size_per_worker"]

    # select a subset of the dataset for training and evaluation
    # In a real-world scenario, you would use the entire dataset
    SMALL_SIZE = 100
    # The map method applies the function to each example in the dataset
    # The batched=True argument allows processing multiple examples at once, which is more efficient
    # The resulting dataset will have the tokenized inputs ready for training
    # This is a crucial step in preparing the dataset for model training
    # It ensures that the text data is converted into a format that the model can understand
    train_dataset = dataset["train"].select(range(SMALL_SIZE)).map(tokenize_function, batched=True)
    eval_dataset = dataset["test"].select(range(SMALL_SIZE)).map(tokenize_function, batched=True)

    # Prepare dataloader for each worker
    # Dataloaders are used to load the dataset in batches for training and evaluation
    # The dataloaders dictionary will hold the training and evaluation dataloaders
    # This allows for easy access to the dataloaders during training and evaluation
    # The dataloaders will be used in the training loop to fetch batches of data for each worker
    dataloaders = {}
    dataloaders["train"] = torch.utils.data.DataLoader(
        train_dataset, 
        shuffle=True, 
        collate_fn=transformers.default_data_collator, 
        batch_size=batch_size
    )
    dataloaders["test"] = torch.utils.data.DataLoader(
        eval_dataset, 
        shuffle=True, 
        collate_fn=transformers.default_data_collator, 
        batch_size=batch_size
    )

    # Obtain GPU device automatically
    # device = ray.train.torch.get_device()
    
    # Alternatively, you can specify the device manually
    # Check if CUDA or MPS is available and set device accordingly
    # This is useful for running on different hardware configurations
    if torch.cuda.is_available():
        device = torch.device("cuda")
    elif torch.backends.mps.is_available():
        device = torch.device("mps") # For Apple Silicon Macs
    else:
        device = torch.device("cpu")

    # Prepare model and optimizer
    # Load a pre-trained BERT model for sequence classification
    # The model is initialized with the number of labels for classification
    model = AutoModelForSequenceClassification.from_pretrained(
        "bert-base-cased", num_labels=5
    )
    # The model is moved to the selected device (GPU, MPS, or CPU)
    model = model.to(device)
    
    # The optimizer is set to SGD with momentum
    # This is essential for training the model
    # The optimizer will update the model parameters during training
    # The learning rate and momentum are set based on the configuration
    optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)

    # Start training loops
    # The model will be trained for the specified number of epochs
    # The model will be trained using the training dataloader
    # The model will be evaluated using the evaluation dataloader
    # The training loop will iterate over the epochs and batches
    for epoch in range(epochs):
        # Each epoch has a training and validation phase
        for phase in ["train", "test"]:
            if phase == "train":
                model.train()  # Set model to training mode
            else:
                model.eval()  # Set model to evaluate mode

            # breakpoint()
            for batch in dataloaders[phase]: # Iterate over batches in the dataloader
                batch = {k: v.to(device) for k, v in batch.items()}

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward pass
                with torch.set_grad_enabled(phase == "train"):
                    # Get model outputs and calculate loss
                    # The model processes the input batch and returns outputs
                    # The outputs include the loss and logits
                    # The loss is calculated based on the model's predictions and the true labels
                    # The logits are the raw predictions from the model
                    # The loss is used to update the model parameters during training
                    outputs = model(**batch)
                    loss = outputs.loss

                    # backward + optimize only if in training phase
                    if phase == "train":
                        loss.backward() # Backpropagate the loss to compute gradients
                        # The optimizer updates the model parameters based on the computed gradients
                        optimizer.step()
                        print(f"train epoch:[{epoch}]\tloss:{loss:.6f}")

## 5. Main Training Function
The *train_bert* function sets up the distributed training environment using Ray and starts the training process. To enable training using GPU, we only need to make the following changes:

* Require an GPU for each worker in ScalingConfig
* Set backend to “nccl” in TorchConfig

This function is designed to train a BERT model using Ray Train. It sets up the training configuration, scaling, and starts the Ray cluster. The function initializes the Ray Train environment, configures the trainer, and starts the training process.
* It is intended to be run in a distributed setting with multiple workers, allowing for efficient training of large models on large datasets by leveraging Ray's distributed computing capabilities.
* The function uses the Ray Train library to manage distributed training and the TorchTrainer for PyTorch models.
* It supports both GPU and CPU training, making it flexible for different hardware configurations. 
* Additionally, it can be easily adapted for different models and datasets by changing the model and dataset loading parts. 
* This approach provides a scalable solution for training deep learning models in a distributed manner and can be used in various environments, including local machines and cloud platforms.
* It is a powerful tool for researchers and developers working with large-scale machine learning tasks, enabling efficient training on large datasets and easy integration into existing machine learning workflows with minimal changes.

In [4]:
# function to train BERT model using Ray Train
# This function sets up the training configuration, scaling, and starts the Ray cluster.
# It initializes the Ray Train environment, configures the trainer, and starts the training process.
def train_bert(num_workers=2):
    global_batch_size = 8 # This is the total batch size across all workers

    # Define the training configuration
    # This configuration includes the learning rate, number of epochs, and batch size per worker
    train_config = {
        "lr": 1e-3,  # Learning rate
        "epochs": 2,  # Reduced for faster testing
        "batch_size_per_worker": global_batch_size // num_workers,
    }

    # Configure computation resources
    # if using CPUs or MPS
    scaling_config = ScalingConfig(num_workers=num_workers, resources_per_worker={"CPU": 1,})
    
    # If using GPUs, you can specify resources_per_worker={"CPU": 1, "GPU": 1}
    # scaling_config = ScalingConfig(num_workers=num_workers, resources_per_worker={"CPU": 1, "GPU": 1})
    # Set backend to nccl in TorchConfig
    # torch_config = TorchConfig(backend = "nccl")
    
    # start your ray cluster
    ray.init() 
    
    # Initialize a Ray TorchTrainer
    trainer = TorchTrainer(
        train_loop_per_worker=train_func_per_worker,
        train_loop_config=train_config,
        # torch_config=torch_config, # Uncomment if using nccl backend
        scaling_config=scaling_config,
    )

    result = trainer.fit() # Start the training process
    print(f"Training result: {result}") # This will print the training result, which includes metrics like loss and accuracy

## 6. Start Training
Finally, we call the train_bert function to start the training process. You can adjust the number of workers to use.

In [5]:
# Run the training function with the specified number of workers
# You can adjust the number of workers based on your hardware configuration
train_bert(num_workers=2)

2025-07-11 10:16:10,519	INFO worker.py:1908 -- Started a local Ray instance. View the dashboard at [1m[32mhttp://127.0.0.1:8265 [39m[22m
2025-07-11 10:16:11,092	INFO tune.py:616 -- [output] This uses the legacy output and progress reporter, as Jupyter notebooks are not supported by the new engine, yet. For more information, please see https://github.com/ray-project/ray/issues/36949


== Status ==
Current time: 2025-07-11 10:16:11 (running for 00:00:00.11)
Using FIFO scheduling algorithm.
Logical resource usage: 3.0/16 CPUs, 0/0 GPUs
Result logdir: /tmp/ray/session_2025-07-11_10-16-09_200164_18044/artifacts/2025-07-11_10-16-11/TorchTrainer_2025-07-11_10-16-11/driver_artifacts
Number of trials: 1/1 (1 PENDING)


== Status ==
Current time: 2025-07-11 10:16:16 (running for 00:00:05.14)
Using FIFO scheduling algorithm.
Logical resource usage: 3.0/16 CPUs, 0/0 GPUs
Result logdir: /tmp/ray/session_2025-07-11_10-16-09_200164_18044/artifacts/2025-07-11_10-16-11/TorchTrainer_2025-07-11_10-16-11/driver_artifacts
Number of trials: 1/1 (1 RUNNING)




[36m(RayTrainWorker pid=41599)[0m Setting up process group for: env:// [rank=0, world_size=2]
[36m(TorchTrainer pid=41521)[0m Started distributed worker processes: 
[36m(TorchTrainer pid=41521)[0m - (node_id=0eca5bdc14957219701f50108487dbd39f13987d253f812c0d6b29a9, ip=127.0.0.1, pid=41599) world_rank=0, local_rank=0, node_rank=0
[36m(TorchTrainer pid=41521)[0m - (node_id=0eca5bdc14957219701f50108487dbd39f13987d253f812c0d6b29a9, ip=127.0.0.1, pid=41598) world_rank=1, local_rank=1, node_rank=0


== Status ==
Current time: 2025-07-11 10:16:21 (running for 00:00:10.22)
Using FIFO scheduling algorithm.
Logical resource usage: 3.0/16 CPUs, 0/0 GPUs
Result logdir: /tmp/ray/session_2025-07-11_10-16-09_200164_18044/artifacts/2025-07-11_10-16-11/TorchTrainer_2025-07-11_10-16-11/driver_artifacts
Number of trials: 1/1 (1 RUNNING)




Map: 100%|██████████| 100/100 [00:00<00:00, 5090.92 examples/s]
[36m(RayTrainWorker pid=41599)[0m Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight']
[36m(RayTrainWorker pid=41599)[0m You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


[36m(RayTrainWorker pid=41599)[0m train epoch:[0]	loss:1.764641
== Status ==
Current time: 2025-07-11 10:16:26 (running for 00:00:15.28)
Using FIFO scheduling algorithm.
Logical resource usage: 3.0/16 CPUs, 0/0 GPUs
Result logdir: /tmp/ray/session_2025-07-11_10-16-09_200164_18044/artifacts/2025-07-11_10-16-11/TorchTrainer_2025-07-11_10-16-11/driver_artifacts
Number of trials: 1/1 (1 RUNNING)


[36m(RayTrainWorker pid=41598)[0m train epoch:[0]	loss:1.949393[32m [repeated 27x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disable log deduplication, or see https://docs.ray.io/en/master/ray-observability/user-guides/configure-logging.html#log-deduplication for more options.)[0m
== Status ==
Current time: 2025-07-11 10:16:31 (running for 00:00:20.35)
Using FIFO scheduling algorithm.
Logical resource usage: 3.0/16 CPUs, 0/0 GPUs
Result logdir: /tmp/ray/session_2025-07-11_10-16-09_200164_18044/artifacts/2025-07-11_10-16-11/TorchTrainer_2025-07-11_10-16-11/dri

2025-07-11 10:16:49,397	INFO tune.py:1009 -- Wrote the latest version of all result files and experiment state to '/Users/maxpumperla/ray_results/TorchTrainer_2025-07-11_10-16-11' in 0.0033s.
2025-07-11 10:16:49,400	INFO tune.py:1041 -- Total run time: 38.31 seconds (38.29 seconds for the tuning loop).


Trial TorchTrainer_4dd7a_00000 completed. Last result: 
== Status ==
Current time: 2025-07-11 10:16:49 (running for 00:00:38.30)
Using FIFO scheduling algorithm.
Logical resource usage: 3.0/16 CPUs, 0/0 GPUs
Result logdir: /tmp/ray/session_2025-07-11_10-16-09_200164_18044/artifacts/2025-07-11_10-16-11/TorchTrainer_2025-07-11_10-16-11/driver_artifacts
Number of trials: 1/1 (1 TERMINATED)


Training result: Result(
  metrics={},
  path='/Users/maxpumperla/ray_results/TorchTrainer_2025-07-11_10-16-11/TorchTrainer_4dd7a_00000_0_2025-07-11_10-16-11',
  filesystem='local',
  checkpoint=None
)


### 7. Shutdown Ray Cluster

In [6]:
# Shutdown Ray after training is complete
ray.shutdown()

### 8. Summary
This notebook demonstrates how to use Ray Train, PyTorch, and Hugging Face Transformers to perform distributed training of a BERT model for sequence classification on the Yelp review dataset. It covers data loading, tokenization, model setup, and distributed training configuration, allowing you to efficiently train large models across multiple CPUs or GPUs. The notebook is designed to be accessible for machine learning engineers who want to learn scalable deep learning workflows using modern Python tools.