Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[train] simplify TorchTrainer docstring #38049

Merged
merged 8 commits into from
Aug 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/source/train/api/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ Ray Train Config

~ray.train.DataConfig

.. _train-loop-api:

Ray Train Loop
--------------
Expand Down
6 changes: 3 additions & 3 deletions python/ray/air/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,9 @@ class ScalingConfig:
worker can be overridden with the ``resources_per_worker``
argument.
resources_per_worker: If specified, the resources
defined in this Dict will be reserved for each worker. The
``CPU`` and ``GPU`` keys (case-sensitive) can be defined to
override the number of CPU/GPUs used by each worker.
defined in this Dict is reserved for each worker.
Define the ``"CPU"`` and ``"GPU"`` keys (case-sensitive) to
override the number of CPU or GPUs used by each worker.
placement_strategy: The placement strategy to use for the
placement group of the Ray actors. See :ref:`Placement Group
Strategies <pgroup-strategy>` for the possible options.
Expand Down
286 changes: 103 additions & 183 deletions python/ray/train/torch/torch_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,217 +16,123 @@
class TorchTrainer(DataParallelTrainer):
"""A Trainer for data parallel PyTorch training.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this the same training for all torch-based trainers? Is it appropriate to mention that Lightning and HF also follow this pattern?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but will make this change in a future PR when this becomes the standard way!


This Trainer runs the function ``train_loop_per_worker`` on multiple Ray
Actors. These actors already have the necessary torch process group
configured for distributed PyTorch training.
At a high level, this Trainer does the following:

The ``train_loop_per_worker`` function is expected to take in either 0 or 1
arguments:
1. Launches multiple workers as defined by the ``scaling_config``.
2. Sets up a distributed PyTorch environment
on these workers as defined by the ``torch_config``.
3. Ingests the input ``datasets`` based on the ``dataset_config``.
4. Runs the input ``train_loop_per_worker(train_loop_config)``
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a user-defined function?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah it is

on all workers.

.. testcode::

def train_loop_per_worker():
...

.. testcode::

from typing import Dict, Any
def train_loop_per_worker(config: Dict[str, Any]):
...

If ``train_loop_per_worker`` accepts an argument, then
``train_loop_config`` will be passed in as the argument. This is useful if you
want to tune the values in ``train_loop_config`` as hyperparameters.

If the ``datasets`` dict contains a training dataset (denoted by
the "train" key), then it will be split into multiple dataset
shards that can then be accessed by ``train.get_dataset_shard("train")`` inside
``train_loop_per_worker``. All the other datasets will not be split and
``train.get_dataset_shard(...)`` will return the the entire Dataset.

Inside the ``train_loop_per_worker`` function, you can use any of the
following methods:

.. testcode::

from ray import train

def train_loop_per_worker():
# Report intermediate results for callbacks or logging and
# checkpoint data.
train.report(...)

# Get the Dataset shard for the given key.
train.get_dataset_shard("my_dataset")

# Get dict of last saved checkpoint.
train.get_checkpoint()

# Get the total number of workers executing training.
train.get_context().get_world_size()

# Get the rank of this worker.
train.get_context().get_world_rank()

# Get the rank of the worker on the current node.
train.get_context().get_local_rank()

You can also use any of the Torch specific function utils,
such as :func:`ray.train.torch.get_device` and :func:`ray.train.torch.prepare_model`

.. testcode::

def train_loop_per_worker():
# Prepares model for distribted training by wrapping in
# `DistributedDataParallel` and moving to correct device.
train.torch.prepare_model(...)

# Configures the dataloader for distributed training by adding a
# `DistributedSampler`.
# You should NOT use this if you are doing
# `train.get_dataset_shard(...).iter_torch_batches(...)`
train.torch.prepare_data_loader(...)

# Get the current torch device.
train.torch.get_device()

Any returns from the ``train_loop_per_worker`` will be discarded and not
used or persisted anywhere.

To save a model to use for the ``TorchPredictor``, you must save it under the
"model" kwarg in ``Checkpoint`` passed to ``train.report()``.

.. note::
When you wrap the ``model`` with ``prepare_model``, the keys of its
``state_dict`` are prefixed by ``module.``. For example,
``layer1.0.bn1.bias`` becomes ``module.layer1.0.bn1.bias``.
However, when saving ``model`` through ``train.report()``
all ``module.`` prefixes are stripped.
As a result, when you load from a saved checkpoint, make sure that
you first load ``state_dict`` to the model
before calling ``prepare_model``.
Otherwise, you will run into errors like
``Error(s) in loading state_dict for DistributedDataParallel:
Missing key(s) in state_dict: "module.conv1.weight", ...``. See snippet below.

.. testcode::

from torchvision.models import resnet18
from ray.train import Checkpoint
import ray.train as train

def train_func():
...
model = resnet18()
model = train.torch.prepare_model(model)
for epoch in range(3):
...
ckpt = Checkpoint.from_dict({
"epoch": epoch,
"model": model.state_dict(),
# "model": model.module.state_dict(),
# ** The above two are equivalent **
})
train.report({"foo": "bar"}, ckpt)
For more details, see the
:ref:`Distributed PyTorch User Guides <train-pytorch-overview>`.

Example:

.. testcode::

import os
import tempfile

import torch
import torch.nn as nn
from torch import nn
from torch.nn.parallel import DistributedDataParallel

import ray
from ray import train
from ray.train import Checkpoint, CheckpointConfig, RunConfig, ScalingConfig
from ray.train.torch import TorchTrainer

# If using GPUs, set this to True.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we're setting it to True by default, does it make sense to make the comment: "If not using GPUs, set this to False."?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch.

use_gpu = False
# Number of processes to run training on.
num_workers = 4
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does everyone know what a worker is, or should we add a comment about what this is for?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a comment


# Define NN layers archicture, epochs, and number of workers
input_size = 1
layer_size = 32
output_size = 1
num_epochs = 20
num_workers = 3

# Define your network structure
# Define your network structure.
class NeuralNetwork(nn.Module):
def __init__(self):
super(NeuralNetwork, self).__init__()
self.layer1 = nn.Linear(input_size, layer_size)
self.layer1 = nn.Linear(1, 32)
self.relu = nn.ReLU()
self.layer2 = nn.Linear(layer_size, output_size)
self.layer2 = nn.Linear(32, 1)

def forward(self, input):
return self.layer2(self.relu(self.layer1(input)))

# Define your train worker loop
def train_loop_per_worker():
torch.manual_seed(42)
# Training loop.
def train_loop_per_worker(config):

# Read configurations.
lr = config["lr"]
batch_size = config["batch_size"]
num_epochs = config["num_epochs"]

# Fetch training set
dataset_shard = train.get_dataset_shard("train")
# Fetch training dataset.
train_dataset_shard = ray.train.get_dataset_shard("train")

# Instantiate and prepare model for training.
model = NeuralNetwork()
model = ray.train.torch.prepare_model(model)

# Loss function, optimizer, prepare model for training.
# This moves the data and prepares model for distributed
# execution
# Define loss and optimizer.
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(),
lr=0.01,
weight_decay=0.01)
model = train.torch.prepare_model(model)

# Iterate over epochs and batches
for epoch in range(num_epochs):
for batches in dataset_shard.iter_torch_batches(batch_size=32,
dtypes=torch.float):
optimizer = torch.optim.SGD(model.parameters(), lr=lr)

# Add batch or unsqueeze as an additional dimension [32, x]
inputs, labels = torch.unsqueeze(batches["x"], 1), batches["y"]
output = model(inputs)
# Create data loader.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does everyone know what a data loader is?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this is standard PyTorch terminology.

dataloader = train_dataset_shard.iter_torch_batches(
batch_size=batch_size, dtypes=torch.float
)

# Make output shape same as the as labels
loss = loss_fn(output.squeeze(), labels)
# Train multiple epochs.
for epoch in range(num_epochs):

# Zero out grads, do backward, and update optimizer
# Train epoch.
for batch in dataloader:
output = model(batch["input"])
loss = loss_fn(output, batch["label"])
optimizer.zero_grad()
loss.backward()
optimizer.step()

# Print what's happening with loss per 30 epochs
if epoch % 20 == 0:
print(f"epoch: {epoch}/{num_epochs}, loss: {loss:.3f}")

# Report and record metrics, checkpoint model at end of each
# epoch
train.report({"loss": loss.item(), "epoch": epoch},
checkpoint=Checkpoint.from_dict(
dict(epoch=epoch, model=model.state_dict()))
# Create checkpoint.
base_model = (model.module
if isinstance(model, DistributedDataParallel) else model)
checkpoint_dir = tempfile.mkdtemp()
torch.save(
{"model_state_dict": base_model.state_dict()},
os.path.join(checkpoint_dir, "model.pt"),
)
checkpoint = Checkpoint.from_directory(checkpoint_dir)

# Report metrics and checkpoint.
ray.train.report({"loss": loss.item()}, checkpoint=checkpoint)

train_dataset = ray.data.from_items(
[{"x": x, "y": 2 * x + 1} for x in range(2000)]
)

# Define scaling and run configs
scaling_config = ScalingConfig(num_workers=3, use_gpu=use_gpu)
# Define configurations.
train_loop_config = {"num_epochs": 20, "lr": 0.01, "batch_size": 32}
scaling_config = ScalingConfig(num_workers=num_workers, use_gpu=use_gpu)
run_config = RunConfig(checkpoint_config=CheckpointConfig(num_to_keep=1))

# Define datasets.
train_dataset = ray.data.from_items(
[{"input": [x], "label": [2 * x + 1]} for x in range(2000)]
)
datasets = {"train": train_dataset}

# Initialize the Trainer.
trainer = TorchTrainer(
train_loop_per_worker=train_loop_per_worker,
train_loop_config=train_loop_config,
scaling_config=scaling_config,
run_config=run_config,
datasets={"train": train_dataset})
datasets=datasets
)

# Train the model.
result = trainer.fit()

best_checkpoint_loss = result.metrics['loss']

# Assert loss is less 0.09
assert best_checkpoint_loss <= 0.09
# Inspect the results.
final_loss = result.metrics["loss"]

.. testoutput::
:hide:
Expand All @@ -235,24 +141,37 @@ def train_loop_per_worker():

Args:

train_loop_per_worker: The training function to execute.
This can either take in no arguments or a ``config`` dict.
train_loop_config: Configurations to pass into
``train_loop_per_worker`` if it accepts an argument.
torch_config: Configuration for setting up the PyTorch backend. If set to
None, use the default configuration. This replaces the ``backend_config``
arg of ``DataParallelTrainer``.
scaling_config: Configuration for how to scale data parallel training.
dataset_config: Configuration for dataset ingest.
run_config: Configuration for the execution of the training run.
datasets: Any Datasets to use for training. Use
the key "train" to denote which dataset is the training
dataset. If a ``preprocessor`` is provided and has not already been fit,
it will be fit on the training dataset. All datasets will be transformed
by the ``preprocessor`` if one is provided.
preprocessor: A ``ray.data.Preprocessor`` to preprocess the
provided datasets.
train_loop_per_worker: The training function to execute on each worker.
This function can either take in zero arguments or a single ``Dict``
argument which is set by defining ``train_loop_config``.
Within this function you can use any of the
:ref:`Ray Train Loop utilities <train-loop-api>`.
train_loop_config: A configuration ``Dict`` to pass in as an argument to
``train_loop_per_worker``.
This is typically used for specifying hyperparameters.
torch_config: The configuration for setting up the PyTorch Distributed backend.
If set to None, a default configuration will be used in which
GPU training uses NCCL and CPU training uses Gloo.
scaling_config: The configuration for how to scale data parallel training.
``num_workers`` determines how many Python processes are used for training,
and ``use_gpu`` determines whether or not each process should use GPUs.
matthewdeng marked this conversation as resolved.
Show resolved Hide resolved
See :class:`~ray.air.ScalingConfig` for more info.
run_config: The configuration for the execution of the training run.
matthewdeng marked this conversation as resolved.
Show resolved Hide resolved
See :class:`~ray.air.RunConfig` for more info.
datasets: The Ray Datasets to ingest for training.
Datasets are keyed by name (``{name: dataset}``).
Each dataset can be accessed from within the ``train_loop_per_worker``
by calling ``ray.train.get_dataset_shard(name)``.
Sharding and additional configuration can be done by
passing in a ``dataset_config``.
dataset_config: The configuration for ingesting the input ``datasets``.
By default:

- The ``"train"`` Dataset is split equally across workers.
- All other Datasets are **not** split.
resume_from_checkpoint: A checkpoint to resume training from.
This checkpoint can be accessed from within ``train_loop_per_worker``
by calling ``ray.train.get_checkpoint()``.
"""

def __init__(
Expand All @@ -262,11 +181,12 @@ def __init__(
train_loop_config: Optional[Dict] = None,
torch_config: Optional[TorchConfig] = None,
scaling_config: Optional[ScalingConfig] = None,
dataset_config: Optional[DataConfig] = None,
run_config: Optional[RunConfig] = None,
datasets: Optional[Dict[str, GenDataset]] = None,
preprocessor: Optional["Preprocessor"] = None,
dataset_config: Optional[DataConfig] = None,
resume_from_checkpoint: Optional[Checkpoint] = None,
# Deprecated.
preprocessor: Optional["Preprocessor"] = None,
):
if not torch_config:
torch_config = TorchConfig()
Expand Down
Loading