# Introduction to Ray Tune

This template provides a hands-on introduction to **Ray Tune** — a scalable hyperparameter tuning library built on [Ray](https://docs.ray.io/en/latest/). You will learn what Ray Tune is, why it matters, and how to use its core APIs to efficiently search for the best hyperparameters for your models.

In the first half, we'll walk through the core Ray Tune workflow — from a baseline training loop to a fully tuned model with smart search and early stopping. The second half covers production concerns and advanced patterns you can explore as your workloads grow.

**Here is the roadmap for this template:**

**Core tutorial:**
- **Part 1:** Baseline PyTorch Training
- **Part 2:** Your First Tune Experiment
- **Part 3:** Smarter Search and Early Stopping

**Advanced topics:**
- **Part 4:** Checkpointing and Fault Tolerance
- **Part 5:** Integrating with Ray Train
- **Summary and Next Steps**

In [1]:
import os
import tempfile
from typing import Any

import numpy as np
import torch
from torchvision.datasets import MNIST
from torchvision.transforms import Compose, ToTensor, Normalize
from torchvision.models import resnet18
from torch.utils.data import DataLoader
from torch.optim import Adam
from torch.nn import CrossEntropyLoss

import ray
from ray import tune
from ray.tune import Checkpoint, Stopper
from ray.tune.search.optuna import OptunaSearch
from ray.tune.schedulers import ASHAScheduler

### Note on Storage

Throughout this tutorial, we use `/mnt/cluster_storage` to represent a shared storage location. In a multi-node cluster, Ray workers on different nodes cannot access the head node's local file system. Use a [shared storage solution](https://docs.anyscale.com/configuration/storage#shared) accessible from every node.

## Part 1: Baseline PyTorch Training

We begin with a standard PyTorch training loop to establish a baseline. Our running example throughout this template is:

- **Objective**: Classify handwritten digits (0-9)
- **Model**: ResNet18 adapted for single-channel MNIST images
- **Evaluation Metric**: CrossEntropy Loss
- **Dataset**: MNIST (60,000 training images, 28x28 grayscale)

In [2]:
# Helper to build a DataLoader for MNIST.
def build_data_loader(batch_size: int) -> DataLoader:
    transform = Compose([ToTensor(), Normalize((0.5,), (0.5,))])
    train_data = MNIST(root="./data", train=True, download=True, transform=transform)
    return DataLoader(train_data, batch_size=batch_size, shuffle=True, drop_last=True)

Here is our baseline training function with predefined hyperparameters:

In [3]:
# Baseline PyTorch training loop with hardcoded hyperparameters.
def train_loop_torch(num_epochs: int = 2, batch_size: int = 128, lr: float = 1e-3):
    criterion = CrossEntropyLoss()

    model = resnet18()
    model.conv1 = torch.nn.Conv2d(
        1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
    )
    model.to("cuda")

    data_loader = build_data_loader(batch_size)
    optimizer = Adam(model.parameters(), lr=lr)

    for epoch in range(num_epochs):
        for images, labels in data_loader:
            images, labels = images.to("cuda"), labels.to("cuda")
            outputs = model(images)
            loss = criterion(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print(f"Epoch {epoch}, Loss: {loss:.4f}")

We schedule this on a GPU worker node using `@ray.remote` (GPU-intensive work shouldn't run directly on the head node):

In [4]:
# Initialize Ray (or connect to an existing cluster).
ray.init(ignore_reinit_error=True)

2026-02-14 00:39:19,206	INFO worker.py:1821 -- Connecting to existing Ray cluster at address: 10.0.27.62:6379...
2026-02-14 00:39:19,217	INFO worker.py:1998 -- Connected to Ray cluster. View the dashboard at [1m[32mhttps://session-q1yvvlhf7ul4se8w63az3crcti.i.anyscaleuserdata.com [39m[22m
2026-02-14 00:39:19,219	INFO packaging.py:463 -- Pushing file package 'gcs://_ray_pkg_6439ff843f0aaf49436206e9a35b9df50a23f55d.zip' (0.03MiB) to Ray cluster...
2026-02-14 00:39:19,220	INFO packaging.py:476 -- Successfully pushed file package 'gcs://_ray_pkg_6439ff843f0aaf49436206e9a35b9df50a23f55d.zip'.


0,1
Python version:,3.12.12
Ray version:,2.53.0
Dashboard:,http://session-q1yvvlhf7ul4se8w63az3crcti.i.anyscaleuserdata.com


In [5]:
# Run the baseline training on a GPU worker node.
@ray.remote(num_gpus=1, resources={"accelerator_type:T4": 0.001})
def run_baseline():
    train_loop_torch(num_epochs=2)

ray.get(run_baseline.remote())

  0%|          | 0.00/9.91M [00:00<?, ?B/s][0m 
  1%|          | 98.3k/9.91M [00:00<00:11, 847kB/s]
  3%|▎         | 295k/9.91M [00:00<00:07, 1.35MB/s]
 13%|█▎        | 1.25M/9.91M [00:00<00:01, 4.47MB/s]
 51%|█████     | 5.01M/9.91M [00:00<00:00, 15.5MB/s]
100%|██████████| 9.91M/9.91M [00:00<00:00, 18.8MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 527kB/s]
  0%|          | 0.00/1.65M [00:00<?, ?B/s][0m 
  6%|▌         | 98.3k/1.65M [00:00<00:01, 897kB/s]
 12%|█▏        | 197k/1.65M [00:00<00:01, 896kB/s] 
 22%|██▏       | 360k/1.65M [00:00<00:01, 1.17MB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 3.75MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 17.2MB/s]


[36m(run_baseline pid=3982, ip=10.0.25.52)[0m Epoch 0, Loss: 0.0602
[36m(run_baseline pid=3982, ip=10.0.25.52)[0m Epoch 1, Loss: 0.0687


**Can we do better?** The model has several hyperparameters — learning rate, batch size, number of epochs — that we chose somewhat arbitrarily. Tuning them systematically could improve performance, but searching over combinations is expensive and slow when done sequentially.

This is exactly what [Ray Tune](https://docs.ray.io/en/latest/tune/) solves — it's a distributed hyperparameter tuning library that runs many trials in parallel across your cluster:

| Challenge | **Ray Tune Solution** |
| --- | --- |
| **Scale tuning** | Distributes trials across cluster CPUs/GPUs for massive parallelism. |
| **Sophisticated search** | Wraps complex algorithms (Bayesian optimization, PBT) and runs them distributed — no custom parallelization code needed. |
| **Early stopping** | Schedulers such as **ASHA** and **PBT** terminate underperformers early, freeing resources for promising trials. |
| **Ecosystem integration** | Plugs into Optuna, HyperOpt, Ax, and experiment tracking tools. |
| **Fault tolerance** | Trials checkpoint automatically; experiments can be resumed end-to-end. |

Let's apply Ray Tune to our model.

## Part 2: Your First Tune Experiment

We'll tune our ResNet18/MNIST model in four steps: **define** a training function, **configure** a Tuner, **run** the experiment, and **inspect** the results.

### Step 1: Define the training function

A Tune training function accepts a `config` dictionary containing hyperparameters and reports metrics back to Tune using `tune.report()`:

In [6]:
# Tune-compatible PyTorch training function.
def train_pytorch(config):
    criterion = CrossEntropyLoss()

    model = resnet18()
    model.conv1 = torch.nn.Conv2d(
        1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
    )
    model.to("cuda")

    optimizer = Adam(model.parameters(), lr=config["lr"])
    data_loader = build_data_loader(config["batch_size"])

    for epoch in range(config["num_epochs"]):
        for images, labels in data_loader:
            images, labels = images.to("cuda"), labels.to("cuda")
            outputs = model(images)
            loss = criterion(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # Report metrics at the end of each epoch
        tune.report({"loss": loss.item()})

### Step 2: Configure the Tuner

Each trial runs as a separate process. Use `tune.with_resources` to specify what each trial needs — here, 1 GPU per trial:

In [7]:
# Allocate 1 T4 GPU per trial and search over learning rate.
tuner = tune.Tuner(
    trainable=tune.with_resources(
        train_pytorch, {"gpu": 1, "accelerator_type:T4": 0.001}
    ),
    param_space={
        "num_epochs": 3,
        "batch_size": 128,
        "lr": tune.loguniform(1e-4, 1e-1),
    },
    tune_config=tune.TuneConfig(
        mode="min",
        metric="loss",
        num_samples=4,
    ),
)

A few things to note about the configuration:

- **`param_space`** defines the hyperparameters to explore. `tune.loguniform(1e-4, 1e-1)` samples the learning rate on a log scale. Ray Tune provides other primitives such as `tune.choice`, `tune.uniform`, and `tune.randint` — see the [full Search Space API reference](https://docs.ray.io/en/latest/tune/api/search_space.html).
- **`tune_config`** tells Tune which metric to optimize (`"loss"`), whether to minimize or maximize (`"min"`), and how many trials to run (`num_samples=4`).
- By default, Tune uses random search (`BasicVariantGenerator`) to pick hyperparameters. We'll plug in a smarter search algorithm in Part 3.

Here's how the Tuner, search algorithm, and scheduler fit together:

<img src="https://docs.ray.io/en/latest/_images/tune_flow.png" width="800" />

To learn more about these concepts, visit the [Ray Tune Key Concepts documentation](https://docs.ray.io/en/latest/tune/key-concepts.html).

### Step 3: Run the Tuner

In [8]:
results = tuner.fit()

0,1
Current time:,2026-02-14 00:41:20
Running for:,00:01:12.19
Memory:,3.6/31.0 GiB

Trial name,status,loc,lr,iter,total time (s),loss
train_pytorch_b5f8b_00000,TERMINATED,10.0.25.52:4164,0.00638812,3,65.0279,0.0291291
train_pytorch_b5f8b_00001,TERMINATED,10.0.38.17:3841,0.000255295,3,66.1842,0.065753
train_pytorch_b5f8b_00002,TERMINATED,10.0.4.27:3834,0.000761471,3,65.2487,0.0418063
train_pytorch_b5f8b_00003,TERMINATED,10.0.49.157:3812,0.00133194,3,65.4621,0.00562479


  0%|          | 0.00/9.91M [00:00<?, ?B/s])[0m 
  1%|          | 65.5k/9.91M [00:00<00:16, 586kB/s]
  2%|▏         | 229k/9.91M [00:00<00:08, 1.09MB/s]
 10%|▉         | 983k/9.91M [00:00<00:02, 3.63MB/s]
 40%|███▉      | 3.93M/9.91M [00:00<00:00, 12.5MB/s]
100%|██████████| 9.91M/9.91M [00:00<00:00, 19.1MB/s]
100%|██████████| 9.91M/9.91M [00:00<00:00, 20.1MB/s]
100%|██████████| 9.91M/9.91M [00:00<00:00, 19.5MB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 3.54MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 513kB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 531kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 4.97MB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 5.02MB/s]
2026-02-14 00:41:20,136	INFO tune.py:1009 -- Wrote the latest version of all result files and experiment state to '/home/ray/ray_results/train_pytorch_2026-02-14_00-40-07' in 0.0052s.
2026-02-14 00:41:20,141	INFO tune.py:1041 -- Total run time: 72.56 seconds (72.18 seconds for the tuning loop).


### Step 4: Inspect the results

In [9]:
# Print the best configuration and loss.
best_result = results.get_best_result()
print(f"Best config: {best_result.config}")
print(f"Best loss: {best_result.metrics['loss']:.4f}")

Best config: {'num_epochs': 3, 'batch_size': 128, 'lr': 0.001331935044635197}
Best loss: 0.0056


In [10]:
# View results for all trials.
results.get_dataframe()[["loss", "config/lr"]]

Unnamed: 0,loss,config/lr
0,0.029129,0.006388
1,0.065753,0.000255
2,0.041806,0.000761
3,0.005625,0.001332


### Recap

To summarize, a `tune.Tuner` accepts:
- **`trainable`** — a training function (or class) to be tuned
- **`param_space`** — a dictionary defining the hyperparameter search space
- **`tune_config`** — configuration for the metric to optimize (`metric`, `mode`) and how many trials to run (`num_samples`)

`tuner.fit()` runs multiple trials in parallel, each with a different set of hyperparameters, and returns a `ResultGrid` from which you can retrieve the best configuration.

## Part 3: Smarter Search and Early Stopping

Now that we've seen the basic Tune workflow, let's make it smarter — with a better search algorithm and a scheduler that stops underperformers early.

In Part 2, Tune used random search and ran every trial to completion. We can do better on both fronts:

- **Search algorithm** — Instead of random search, use **Optuna** for Bayesian optimization. It learns from previous trial results to make smarter choices about which hyperparameters to try next.
- **Scheduler** — Instead of the default `FIFOScheduler` (which runs all trials to completion), use the **ASHAScheduler** (Asynchronous Successive Halving). It terminates underperforming trials early, freeing resources for more promising ones.

You can combine both in a single Tuner:

In [11]:
# Combine Optuna search with ASHA early stopping.
tuner = tune.Tuner(
    trainable=tune.with_resources(
        train_pytorch, {"gpu": 1, "accelerator_type:T4": 0.001}
    ),
    param_space={
        "num_epochs": 8,
        "batch_size": 128,
        "lr": tune.loguniform(1e-4, 1e-1),
    },
    tune_config=tune.TuneConfig(
        metric="loss",
        mode="min",
        num_samples=4,
        search_alg=OptunaSearch(),
        scheduler=ASHAScheduler(
            max_t=10,        # Max training iterations
            grace_period=2,  # Min iterations before stopping is allowed
        ),
    ),
)

In [12]:
results = tuner.fit()

0,1
Current time:,2026-02-14 00:44:22
Running for:,00:03:01.88
Memory:,3.6/31.0 GiB

Trial name,status,loc,lr,iter,total time (s),loss
train_pytorch_8094469c,TERMINATED,10.0.4.27:4018,0.000226563,8,169.598,0.00182185
train_pytorch_91241386,TERMINATED,10.0.49.157:4005,0.00011626,8,169.973,0.0026228
train_pytorch_2b2e90b7,TERMINATED,10.0.25.52:4350,0.000388738,8,167.573,0.0412388
train_pytorch_c16154ff,TERMINATED,10.0.38.17:4048,0.00412688,2,45.9304,0.0231441


  0%|          | 0.00/9.91M [00:00<?, ?B/s][0m 
  1%|          | 98.3k/9.91M [00:00<00:11, 876kB/s]
  4%|▍         | 393k/9.91M [00:00<00:04, 1.93MB/s]
 15%|█▌        | 1.51M/9.91M [00:00<00:01, 5.62MB/s]
 33%|███▎      | 3.31M/9.91M [00:00<00:00, 9.66MB/s]
100%|██████████| 9.91M/9.91M [00:00<00:00, 17.3MB/s]
  0%|          | 0.00/28.9k [00:00<?, ?B/s][0m 
100%|██████████| 28.9k/28.9k [00:00<00:00, 534kB/s]
  0%|          | 0.00/1.65M [00:00<?, ?B/s][0m 
  4%|▍         | 65.5k/1.65M [00:00<00:02, 603kB/s]
 10%|▉         | 164k/1.65M [00:00<00:01, 779kB/s] 
 16%|█▌        | 262k/1.65M [00:00<00:01, 835kB/s]
 30%|██▉       | 492k/1.65M [00:00<00:00, 1.34MB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 3.02MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 20.7MB/s]
  0%|          | 0.00/9.91M [00:00<?, ?B/s]7)[0m 
 11%|█         | 1.11M/9.91M [00:00<00:02, 3.97MB/s][32m [repeated 3x across cluster][0m
  0%|          | 0.00/28.9k [00:00<?, ?B/s]7)[0m 
100%|██████████| 28.9k/28.9k [00

Notice from the output that some trials were terminated early before reaching epoch 8 — ASHA stopped them because they weren't competitive.

In [13]:
print(f"Best config: {results.get_best_result().config}")
results.get_dataframe()[["loss", "training_iteration", "config/lr"]].sort_values("loss")

Best config: {'num_epochs': 8, 'batch_size': 128, 'lr': 0.00022656264464934007}


Unnamed: 0,loss,training_iteration,config/lr
0,0.001822,8,0.000227
1,0.002623,8,0.000116
3,0.023144,2,0.004127
2,0.041239,8,0.000389


Ray Tune integrates with many search libraries and schedulers:

| **Library** | **Search Algorithm** | **Best For** |
|------------|---------------------|--------------|
| Built-in | `BasicVariantGenerator` | Simple random/grid search |
| Optuna | `OptunaSearch` | Bayesian optimization with pruning |
| HyperOpt | `HyperOptSearch` | Tree-structured Parzen Estimators |
| Ax | `AxSearch` | Bayesian optimization |

See the full list in the [Search Algorithm API docs](https://docs.ray.io/en/latest/tune/api/suggestion.html) and [Scheduler API docs](https://docs.ray.io/en/latest/tune/api/schedulers.html).

---

You've seen the core Ray Tune workflow — from a baseline training loop to smart search with Optuna and early stopping with ASHA. The following sections cover production concerns and advanced patterns.

## Part 4: Checkpointing and Fault Tolerance

For production-grade experiments, you need persistent storage, checkpointing, and fault tolerance.

### Persistent Storage

On a distributed cluster, Ray Tune needs a persistent storage location accessible from all nodes to save checkpoints and experiment state. Configure it via `tune.RunConfig(storage_path="/mnt/cluster_storage")`.

<img src="https://docs.ray.io/en/latest/_images/checkpoint_lifecycle.png" alt="Checkpoint Lifecycle" width="700"/>

The checkpoint lifecycle: saved locally, then uploaded to persistent storage via `tune.report()`.

### Checkpointing Trials

To make trials resumable, save model state as a `Checkpoint` inside `tune.report()`. Here is the pattern for PyTorch:

In [14]:
# Training function with checkpointing for fault tolerance.
def train_pytorch_with_checkpoints(config):
    model = resnet18()
    model.conv1 = torch.nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    model.to("cuda")
    optimizer = Adam(model.parameters(), lr=config["lr"])
    criterion = CrossEntropyLoss()
    data_loader = build_data_loader(config["batch_size"])
    start_epoch = 0

    # Resume from checkpoint if available
    checkpoint = tune.get_checkpoint()
    if checkpoint:
        with checkpoint.as_directory() as ckpt_dir:
            state = torch.load(os.path.join(ckpt_dir, "model.pt"), weights_only=False)
            model.load_state_dict(state["model"])
            optimizer.load_state_dict(state["optimizer"])
            start_epoch = state["epoch"] + 1

    for epoch in range(start_epoch, config["num_epochs"]):
        for images, labels in data_loader:
            images, labels = images.to("cuda"), labels.to("cuda")
            loss = criterion(model(images), labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # Save checkpoint with each metric report
        with tempfile.TemporaryDirectory() as tmp_dir:
            torch.save(
                {"model": model.state_dict(), "optimizer": optimizer.state_dict(), "epoch": epoch},
                os.path.join(tmp_dir, "model.pt"),
            )
            tune.report(
                {"loss": loss.item()},
                checkpoint=Checkpoint.from_directory(tmp_dir),
            )

### Fault Tolerance

Ray Tune provides two mechanisms for handling failures:

**1. Automatic trial retries** — Configure `FailureConfig` to retry failed trials automatically. For example, `tune.FailureConfig(max_failures=3)` retries each trial up to 3 times.

**2. Experiment recovery** — If the entire experiment fails (e.g., driver crash), you can resume it with `tune.Tuner.restore(path=..., trainable=..., restart_errored=True)`. This picks up where the experiment left off, skipping completed trials and restarting errored ones.

Here is a complete example combining checkpointing and fault tolerance:

In [15]:
# Combine checkpointing with fault tolerance and persistent storage.
tuner = tune.Tuner(
    trainable=tune.with_resources(
        train_pytorch_with_checkpoints, {"gpu": 1, "accelerator_type:T4": 0.001}
    ),
    param_space={
        "num_epochs": 2,
        "batch_size": 128,
        "lr": tune.loguniform(1e-4, 1e-1),
    },
    tune_config=tune.TuneConfig(
        metric="loss",
        mode="min",
        num_samples=4,
    ),
    run_config=tune.RunConfig(
        storage_path="/mnt/cluster_storage",
        name="resnet18_fault_tolerant",
        failure_config=tune.FailureConfig(max_failures=2),
    ),
)

In [16]:
results = tuner.fit()

0,1
Current time:,2026-02-14 00:45:14
Running for:,00:00:52.11
Memory:,3.6/31.0 GiB

Trial name,status,loc,lr,iter,total time (s),loss
train_pytorch_with_checkpoints_4df54_00000,TERMINATED,10.0.25.52:4693,0.000349043,2,46.395,0.016914
train_pytorch_with_checkpoints_4df54_00001,TERMINATED,10.0.4.27:4387,0.0199833,2,46.7967,0.139688
train_pytorch_with_checkpoints_4df54_00002,TERMINATED,10.0.49.157:4367,0.000149897,2,47.0556,0.00468507
train_pytorch_with_checkpoints_4df54_00003,TERMINATED,10.0.38.17:4385,0.000154253,2,46.6599,0.0382939


  0%|          | 0.00/9.91M [00:00<?, ?B/s]93, ip=10.0.25.52)[0m 
  1%|          | 98.3k/9.91M [00:00<00:10, 911kB/s]0.0.25.52)[0m 
100%|██████████| 9.91M/9.91M [00:00<00:00, 20.2MB/s].0.25.52)[0m 
100%|██████████| 9.91M/9.91M [00:00<00:00, 20.2MB/s].0.38.17)[0m 
100%|██████████| 9.91M/9.91M [00:00<00:00, 20.2MB/s].0.49.157)[0m 
100%|██████████| 28.9k/28.9k [00:00<00:00, 500kB/s]0.0.4.27)[0m 
100%|██████████| 28.9k/28.9k [00:00<00:00, 516kB/s]0.0.49.157)[0m 
100%|██████████| 1.65M/1.65M [00:00<00:00, 4.75MB/s].0.4.27)[0m 
100%|██████████| 1.65M/1.65M [00:00<00:00, 4.73MB/s].0.38.17)[0m 
100%|██████████| 1.65M/1.65M [00:00<00:00, 5.05MB/s].0.49.157)[0m 
[36m(train_pytorch_with_checkpoints pid=4693, ip=10.0.25.52)[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/mnt/cluster_storage/resnet18_fault_tolerant/train_pytorch_with_checkpoints_4df54_00000_0_lr=0.0003_2026-02-14_00-44-22/checkpoint_000000)
100%|██████████| 4.54k/4.54k [00:00<00:00, 18.9MB/s][

In [17]:
# Inspect the fault-tolerant experiment results.
best_result = results.get_best_result()
print(f"Best config: {best_result.config}")
print(f"Best loss: {best_result.metrics['loss']:.4f}")
print(f"Best checkpoint: {best_result.checkpoint}")

Best config: {'num_epochs': 2, 'batch_size': 128, 'lr': 0.00014989705601956802}
Best loss: 0.0047
Best checkpoint: Checkpoint(filesystem=local, path=/mnt/cluster_storage/resnet18_fault_tolerant/train_pytorch_with_checkpoints_4df54_00002_2_lr=0.0001_2026-02-14_00-44-22/checkpoint_000001)


### Stopping Criteria

Beyond ASHA, Ray Tune offers additional ways to stop trials and experiments:

**Metric-based stopping** — Define a custom `Stopper` to stop individual trials or the entire experiment based on metric thresholds:

```python
class CustomStopper(Stopper):
    def __init__(self):
        self.should_stop = False

    def __call__(self, trial_id: str, result: dict) -> bool:
        if result["loss"] > 1.0 and result["training_iteration"] > 5:
            return True  # Stop this underperforming trial
        if result["loss"] <= 0.05:
            self.should_stop = True  # Found a great result
        return False

    def stop_all(self) -> bool:
        return self.should_stop
```

Pass the stopper via `RunConfig(stop=CustomStopper())`.

**Time-based stopping** — Stop trials after a certain duration with `RunConfig(stop={"time_total_s": 120})`, or cap the full experiment time with `TuneConfig(time_budget_s=600.0)`.

### Resource Management

When running many concurrent trials, OOM errors can occur. Mitigate this by:
- **Setting memory resources:** `tune.with_resources(trainable, {"gpu": 1, "memory": 6 * 1024**3})`
- **Limiting concurrency:** `tune.TuneConfig(max_concurrent_trials=4)`

## Part 5: Integrating with Ray Train

By default, each Ray Tune trial runs as a single-worker process — training happens on one machine with one GPU. For large models that require distributed (multi-GPU or multi-node) training, you need the Ray Train integration. Wrapping a Ray Train `Trainer` inside a Tune trial lets each trial run a full distributed training job — giving you distributed hyperparameter search *and* distributed training at the same time.

To set this up, wrap your Ray Train `Trainer` creation in a driver function that Tune calls with different hyperparameter configurations:

In [18]:
def train_loop_per_worker(config):
    """Adapted from train_pytorch for Ray Train"""
    import ray.train

    criterion = CrossEntropyLoss()
    model = resnet18()
    model.conv1 = torch.nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    model = ray.train.torch.prepare_model(model)  # added for distributed

    optimizer = Adam(model.parameters(), lr=config["lr"])
    data_loader = build_data_loader(config["batch_size"])

    for epoch in range(config["num_epochs"]):
        for images, labels in data_loader:
            images, labels = images.to("cuda"), labels.to("cuda")
            loss = criterion(model(images), labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        ray.train.report({"loss": loss.item()})  # changed from tune.report

In [19]:
# Define a driver function that creates a Ray Train Trainer per Tune trial.
from ray.train.torch import TorchTrainer
from ray.tune.integration.ray_train import TuneReportCallback

def train_driver_fn(config):
    trainer = TorchTrainer(
        train_loop_per_worker=train_loop_per_worker,
        train_loop_config=config["train_loop_config"],
        # Use 2 gpus per trial
        scaling_config=ray.train.ScalingConfig(num_workers=2, use_gpu=True),
        run_config=ray.train.RunConfig(
            name=f"train-trial_id={tune.get_context().get_trial_id()}",
            callbacks=[TuneReportCallback()],
        ),
    )
    trainer.fit()

In [20]:
# Launch multi-GPU Tune trials with Ray Train.
tuner = tune.Tuner(
    train_driver_fn,
    param_space={"train_loop_config": {"lr": tune.loguniform(1e-4, 1e-1), "batch_size": 128, "num_epochs": 3}},
    tune_config=tune.TuneConfig(num_samples=2, max_concurrent_trials=2),
)
results = tuner.fit()

0,1
Current time:,2026-02-14 00:46:54
Running for:,00:01:39.00
Memory:,3.8/31.0 GiB

Trial name,status,loc,train_loop_config/lr,iter,total time (s),loss
train_driver_fn_6d54d_00000,TERMINATED,10.0.4.27:4542,0.00375019,3,90.6543,0.0641301
train_driver_fn_6d54d_00001,TERMINATED,10.0.38.17:4537,0.000293321,3,90.7852,0.000826244


[36m(TrainController pid=4603, ip=10.0.4.27)[0m [State Transition] INITIALIZING -> SCHEDULING.
[36m(TrainController pid=4603, ip=10.0.4.27)[0m Attempting to start training worker group of size 2 with the following resources: [{'GPU': 1}] * 2
[36m(RayTrainWorker pid=4854, ip=10.0.25.52)[0m Setting up process group for: env:// [rank=0, world_size=2]
[36m(TrainController pid=4598, ip=10.0.38.17)[0m [State Transition] INITIALIZING -> SCHEDULING.
[36m(TrainController pid=4598, ip=10.0.38.17)[0m Attempting to start training worker group of size 2 with the following resources: [{'GPU': 1}] * 2
[36m(TrainController pid=4603, ip=10.0.4.27)[0m Started training worker group of size 2: 
[36m(TrainController pid=4603, ip=10.0.4.27)[0m - (ip=10.0.25.52, pid=4854) world_rank=0, local_rank=0, node_rank=0
[36m(TrainController pid=4603, ip=10.0.4.27)[0m - (ip=10.0.49.157, pid=4627) world_rank=1, local_rank=0, node_rank=1
[36m(TrainController pid=4603, ip=10.0.4.27)[0m [State Transition

Key details:
- **`TuneReportCallback`** propagates metrics reported by Ray Train workers back to Tune, so the `Tuner` can track and compare trial results.
- **`tune.get_context().get_trial_id()`** ensures each Train run gets a unique name tied to the Tune trial, which is required for proper fault tolerance.
- **`max_concurrent_trials`** limits how many Train runs compete for cluster resources at once. Set this based on your GPU budget (e.g., `total_gpus // gpus_per_trial`).

See the [Ray Train + Tune guide](https://docs.ray.io/en/latest/train/user-guides/hyperparameter-optimization.html) for full details.

## Summary and Next Steps

In this template, you learned:

- **What** Ray Tune is — a scalable, distributed hyperparameter tuning library
- **Why** to use it — parallel trial execution, smart search algorithms, early stopping, fault tolerance, and ecosystem integration
- **How** to use it — defining trainable functions with `tune.report()`, configuring `tune.Tuner` with search spaces and `TuneConfig`, running experiments with `tuner.fit()`, and retrieving best results
- **Core concepts** — resources (`tune.with_resources`), search algorithms (random, Optuna), schedulers (FIFO, ASHA)
- **Production features** — checkpointing, persistent storage, fault tolerance, experiment recovery

### Next Steps

1. **[Ray Tune User Guide](https://docs.ray.io/en/latest/tune/getting-started.html)** — Complete guide to Ray Tune
2. **[Search Algorithm Reference](https://docs.ray.io/en/latest/tune/api/suggestion.html)** — All supported search algorithms
3. **[Scheduler Reference](https://docs.ray.io/en/latest/tune/api/schedulers.html)** — All supported schedulers including ASHA and PBT
4. **[Ray Train + Tune Integration](https://docs.ray.io/en/latest/train/user-guides/hyperparameter-optimization.html)** — Combining distributed training with HPO
5. **[Tune Examples Gallery](https://docs.ray.io/en/latest/tune/examples/index.html)** — End-to-end examples with popular frameworks