-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[train] simplify TorchTrainer docstring #38049
Changes from all commits
15c8380
8faf4ee
69d0fb4
e4e2520
e19c8d5
72c5962
ea53d55
980adc8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -59,6 +59,7 @@ Ray Train Config | |
|
||
~ray.train.DataConfig | ||
|
||
.. _train-loop-api: | ||
|
||
Ray Train Loop | ||
-------------- | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,217 +16,123 @@ | |
class TorchTrainer(DataParallelTrainer): | ||
"""A Trainer for data parallel PyTorch training. | ||
|
||
This Trainer runs the function ``train_loop_per_worker`` on multiple Ray | ||
Actors. These actors already have the necessary torch process group | ||
configured for distributed PyTorch training. | ||
At a high level, this Trainer does the following: | ||
|
||
The ``train_loop_per_worker`` function is expected to take in either 0 or 1 | ||
arguments: | ||
1. Launches multiple workers as defined by the ``scaling_config``. | ||
2. Sets up a distributed PyTorch environment | ||
on these workers as defined by the ``torch_config``. | ||
3. Ingests the input ``datasets`` based on the ``dataset_config``. | ||
4. Runs the input ``train_loop_per_worker(train_loop_config)`` | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this a user-defined function? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah it is |
||
on all workers. | ||
|
||
.. testcode:: | ||
|
||
def train_loop_per_worker(): | ||
... | ||
|
||
.. testcode:: | ||
|
||
from typing import Dict, Any | ||
def train_loop_per_worker(config: Dict[str, Any]): | ||
... | ||
|
||
If ``train_loop_per_worker`` accepts an argument, then | ||
``train_loop_config`` will be passed in as the argument. This is useful if you | ||
want to tune the values in ``train_loop_config`` as hyperparameters. | ||
|
||
If the ``datasets`` dict contains a training dataset (denoted by | ||
the "train" key), then it will be split into multiple dataset | ||
shards that can then be accessed by ``train.get_dataset_shard("train")`` inside | ||
``train_loop_per_worker``. All the other datasets will not be split and | ||
``train.get_dataset_shard(...)`` will return the the entire Dataset. | ||
|
||
Inside the ``train_loop_per_worker`` function, you can use any of the | ||
following methods: | ||
|
||
.. testcode:: | ||
|
||
from ray import train | ||
|
||
def train_loop_per_worker(): | ||
# Report intermediate results for callbacks or logging and | ||
# checkpoint data. | ||
train.report(...) | ||
|
||
# Get the Dataset shard for the given key. | ||
train.get_dataset_shard("my_dataset") | ||
|
||
# Get dict of last saved checkpoint. | ||
train.get_checkpoint() | ||
|
||
# Get the total number of workers executing training. | ||
train.get_context().get_world_size() | ||
|
||
# Get the rank of this worker. | ||
train.get_context().get_world_rank() | ||
|
||
# Get the rank of the worker on the current node. | ||
train.get_context().get_local_rank() | ||
|
||
You can also use any of the Torch specific function utils, | ||
such as :func:`ray.train.torch.get_device` and :func:`ray.train.torch.prepare_model` | ||
|
||
.. testcode:: | ||
|
||
def train_loop_per_worker(): | ||
# Prepares model for distribted training by wrapping in | ||
# `DistributedDataParallel` and moving to correct device. | ||
train.torch.prepare_model(...) | ||
|
||
# Configures the dataloader for distributed training by adding a | ||
# `DistributedSampler`. | ||
# You should NOT use this if you are doing | ||
# `train.get_dataset_shard(...).iter_torch_batches(...)` | ||
train.torch.prepare_data_loader(...) | ||
|
||
# Get the current torch device. | ||
train.torch.get_device() | ||
|
||
Any returns from the ``train_loop_per_worker`` will be discarded and not | ||
used or persisted anywhere. | ||
|
||
To save a model to use for the ``TorchPredictor``, you must save it under the | ||
"model" kwarg in ``Checkpoint`` passed to ``train.report()``. | ||
|
||
.. note:: | ||
When you wrap the ``model`` with ``prepare_model``, the keys of its | ||
``state_dict`` are prefixed by ``module.``. For example, | ||
``layer1.0.bn1.bias`` becomes ``module.layer1.0.bn1.bias``. | ||
However, when saving ``model`` through ``train.report()`` | ||
all ``module.`` prefixes are stripped. | ||
As a result, when you load from a saved checkpoint, make sure that | ||
you first load ``state_dict`` to the model | ||
before calling ``prepare_model``. | ||
Otherwise, you will run into errors like | ||
``Error(s) in loading state_dict for DistributedDataParallel: | ||
Missing key(s) in state_dict: "module.conv1.weight", ...``. See snippet below. | ||
|
||
.. testcode:: | ||
|
||
from torchvision.models import resnet18 | ||
from ray.train import Checkpoint | ||
import ray.train as train | ||
|
||
def train_func(): | ||
... | ||
model = resnet18() | ||
model = train.torch.prepare_model(model) | ||
for epoch in range(3): | ||
... | ||
ckpt = Checkpoint.from_dict({ | ||
"epoch": epoch, | ||
"model": model.state_dict(), | ||
# "model": model.module.state_dict(), | ||
# ** The above two are equivalent ** | ||
}) | ||
train.report({"foo": "bar"}, ckpt) | ||
For more details, see the | ||
:ref:`Distributed PyTorch User Guides <train-pytorch-overview>`. | ||
|
||
Example: | ||
|
||
.. testcode:: | ||
|
||
import os | ||
import tempfile | ||
|
||
import torch | ||
import torch.nn as nn | ||
from torch import nn | ||
from torch.nn.parallel import DistributedDataParallel | ||
|
||
import ray | ||
from ray import train | ||
from ray.train import Checkpoint, CheckpointConfig, RunConfig, ScalingConfig | ||
from ray.train.torch import TorchTrainer | ||
|
||
# If using GPUs, set this to True. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we're setting it to True by default, does it make sense to make the comment: "If not using GPUs, set this to False."? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good catch. |
||
use_gpu = False | ||
# Number of processes to run training on. | ||
num_workers = 4 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does everyone know what a worker is, or should we add a comment about what this is for? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added a comment |
||
|
||
# Define NN layers archicture, epochs, and number of workers | ||
input_size = 1 | ||
layer_size = 32 | ||
output_size = 1 | ||
num_epochs = 20 | ||
num_workers = 3 | ||
|
||
# Define your network structure | ||
# Define your network structure. | ||
class NeuralNetwork(nn.Module): | ||
def __init__(self): | ||
super(NeuralNetwork, self).__init__() | ||
self.layer1 = nn.Linear(input_size, layer_size) | ||
self.layer1 = nn.Linear(1, 32) | ||
self.relu = nn.ReLU() | ||
self.layer2 = nn.Linear(layer_size, output_size) | ||
self.layer2 = nn.Linear(32, 1) | ||
|
||
def forward(self, input): | ||
return self.layer2(self.relu(self.layer1(input))) | ||
|
||
# Define your train worker loop | ||
def train_loop_per_worker(): | ||
torch.manual_seed(42) | ||
# Training loop. | ||
def train_loop_per_worker(config): | ||
|
||
# Read configurations. | ||
lr = config["lr"] | ||
batch_size = config["batch_size"] | ||
num_epochs = config["num_epochs"] | ||
|
||
# Fetch training set | ||
dataset_shard = train.get_dataset_shard("train") | ||
# Fetch training dataset. | ||
train_dataset_shard = ray.train.get_dataset_shard("train") | ||
|
||
# Instantiate and prepare model for training. | ||
model = NeuralNetwork() | ||
model = ray.train.torch.prepare_model(model) | ||
|
||
# Loss function, optimizer, prepare model for training. | ||
# This moves the data and prepares model for distributed | ||
# execution | ||
# Define loss and optimizer. | ||
loss_fn = nn.MSELoss() | ||
optimizer = torch.optim.Adam(model.parameters(), | ||
lr=0.01, | ||
weight_decay=0.01) | ||
model = train.torch.prepare_model(model) | ||
|
||
# Iterate over epochs and batches | ||
for epoch in range(num_epochs): | ||
for batches in dataset_shard.iter_torch_batches(batch_size=32, | ||
dtypes=torch.float): | ||
optimizer = torch.optim.SGD(model.parameters(), lr=lr) | ||
|
||
# Add batch or unsqueeze as an additional dimension [32, x] | ||
inputs, labels = torch.unsqueeze(batches["x"], 1), batches["y"] | ||
output = model(inputs) | ||
# Create data loader. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does everyone know what a data loader is? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah this is standard PyTorch terminology. |
||
dataloader = train_dataset_shard.iter_torch_batches( | ||
batch_size=batch_size, dtypes=torch.float | ||
) | ||
|
||
# Make output shape same as the as labels | ||
loss = loss_fn(output.squeeze(), labels) | ||
# Train multiple epochs. | ||
for epoch in range(num_epochs): | ||
|
||
# Zero out grads, do backward, and update optimizer | ||
# Train epoch. | ||
for batch in dataloader: | ||
output = model(batch["input"]) | ||
loss = loss_fn(output, batch["label"]) | ||
optimizer.zero_grad() | ||
loss.backward() | ||
optimizer.step() | ||
|
||
# Print what's happening with loss per 30 epochs | ||
if epoch % 20 == 0: | ||
print(f"epoch: {epoch}/{num_epochs}, loss: {loss:.3f}") | ||
|
||
# Report and record metrics, checkpoint model at end of each | ||
# epoch | ||
train.report({"loss": loss.item(), "epoch": epoch}, | ||
checkpoint=Checkpoint.from_dict( | ||
dict(epoch=epoch, model=model.state_dict())) | ||
# Create checkpoint. | ||
base_model = (model.module | ||
if isinstance(model, DistributedDataParallel) else model) | ||
checkpoint_dir = tempfile.mkdtemp() | ||
torch.save( | ||
{"model_state_dict": base_model.state_dict()}, | ||
os.path.join(checkpoint_dir, "model.pt"), | ||
) | ||
checkpoint = Checkpoint.from_directory(checkpoint_dir) | ||
|
||
# Report metrics and checkpoint. | ||
ray.train.report({"loss": loss.item()}, checkpoint=checkpoint) | ||
|
||
train_dataset = ray.data.from_items( | ||
[{"x": x, "y": 2 * x + 1} for x in range(2000)] | ||
) | ||
|
||
# Define scaling and run configs | ||
scaling_config = ScalingConfig(num_workers=3, use_gpu=use_gpu) | ||
# Define configurations. | ||
train_loop_config = {"num_epochs": 20, "lr": 0.01, "batch_size": 32} | ||
scaling_config = ScalingConfig(num_workers=num_workers, use_gpu=use_gpu) | ||
run_config = RunConfig(checkpoint_config=CheckpointConfig(num_to_keep=1)) | ||
|
||
# Define datasets. | ||
train_dataset = ray.data.from_items( | ||
[{"input": [x], "label": [2 * x + 1]} for x in range(2000)] | ||
) | ||
datasets = {"train": train_dataset} | ||
|
||
# Initialize the Trainer. | ||
trainer = TorchTrainer( | ||
train_loop_per_worker=train_loop_per_worker, | ||
train_loop_config=train_loop_config, | ||
scaling_config=scaling_config, | ||
run_config=run_config, | ||
datasets={"train": train_dataset}) | ||
datasets=datasets | ||
) | ||
|
||
# Train the model. | ||
result = trainer.fit() | ||
|
||
best_checkpoint_loss = result.metrics['loss'] | ||
|
||
# Assert loss is less 0.09 | ||
assert best_checkpoint_loss <= 0.09 | ||
# Inspect the results. | ||
final_loss = result.metrics["loss"] | ||
|
||
.. testoutput:: | ||
:hide: | ||
|
@@ -235,24 +141,37 @@ def train_loop_per_worker(): | |
|
||
Args: | ||
|
||
train_loop_per_worker: The training function to execute. | ||
This can either take in no arguments or a ``config`` dict. | ||
train_loop_config: Configurations to pass into | ||
``train_loop_per_worker`` if it accepts an argument. | ||
torch_config: Configuration for setting up the PyTorch backend. If set to | ||
None, use the default configuration. This replaces the ``backend_config`` | ||
arg of ``DataParallelTrainer``. | ||
scaling_config: Configuration for how to scale data parallel training. | ||
dataset_config: Configuration for dataset ingest. | ||
run_config: Configuration for the execution of the training run. | ||
datasets: Any Datasets to use for training. Use | ||
the key "train" to denote which dataset is the training | ||
dataset. If a ``preprocessor`` is provided and has not already been fit, | ||
it will be fit on the training dataset. All datasets will be transformed | ||
by the ``preprocessor`` if one is provided. | ||
preprocessor: A ``ray.data.Preprocessor`` to preprocess the | ||
provided datasets. | ||
train_loop_per_worker: The training function to execute on each worker. | ||
This function can either take in zero arguments or a single ``Dict`` | ||
argument which is set by defining ``train_loop_config``. | ||
Within this function you can use any of the | ||
:ref:`Ray Train Loop utilities <train-loop-api>`. | ||
train_loop_config: A configuration ``Dict`` to pass in as an argument to | ||
``train_loop_per_worker``. | ||
This is typically used for specifying hyperparameters. | ||
torch_config: The configuration for setting up the PyTorch Distributed backend. | ||
If set to None, a default configuration will be used in which | ||
GPU training uses NCCL and CPU training uses Gloo. | ||
scaling_config: The configuration for how to scale data parallel training. | ||
``num_workers`` determines how many Python processes are used for training, | ||
and ``use_gpu`` determines whether or not each process should use GPUs. | ||
matthewdeng marked this conversation as resolved.
Show resolved
Hide resolved
|
||
See :class:`~ray.air.ScalingConfig` for more info. | ||
run_config: The configuration for the execution of the training run. | ||
matthewdeng marked this conversation as resolved.
Show resolved
Hide resolved
|
||
See :class:`~ray.air.RunConfig` for more info. | ||
datasets: The Ray Datasets to ingest for training. | ||
Datasets are keyed by name (``{name: dataset}``). | ||
Each dataset can be accessed from within the ``train_loop_per_worker`` | ||
by calling ``ray.train.get_dataset_shard(name)``. | ||
Sharding and additional configuration can be done by | ||
passing in a ``dataset_config``. | ||
dataset_config: The configuration for ingesting the input ``datasets``. | ||
By default: | ||
|
||
- The ``"train"`` Dataset is split equally across workers. | ||
- All other Datasets are **not** split. | ||
resume_from_checkpoint: A checkpoint to resume training from. | ||
This checkpoint can be accessed from within ``train_loop_per_worker`` | ||
by calling ``ray.train.get_checkpoint()``. | ||
""" | ||
|
||
def __init__( | ||
|
@@ -262,11 +181,12 @@ def __init__( | |
train_loop_config: Optional[Dict] = None, | ||
torch_config: Optional[TorchConfig] = None, | ||
scaling_config: Optional[ScalingConfig] = None, | ||
dataset_config: Optional[DataConfig] = None, | ||
run_config: Optional[RunConfig] = None, | ||
datasets: Optional[Dict[str, GenDataset]] = None, | ||
preprocessor: Optional["Preprocessor"] = None, | ||
dataset_config: Optional[DataConfig] = None, | ||
resume_from_checkpoint: Optional[Checkpoint] = None, | ||
# Deprecated. | ||
preprocessor: Optional["Preprocessor"] = None, | ||
): | ||
if not torch_config: | ||
torch_config = TorchConfig() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this the same training for all torch-based trainers? Is it appropriate to mention that Lightning and HF also follow this pattern?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, but will make this change in a future PR when this becomes the standard way!