diff --git a/doc/source/train/api/api.rst b/doc/source/train/api/api.rst index f007b176155fa..d89261a02843d 100644 --- a/doc/source/train/api/api.rst +++ b/doc/source/train/api/api.rst @@ -59,6 +59,7 @@ Ray Train Config ~ray.train.DataConfig +.. _train-loop-api: Ray Train Loop -------------- diff --git a/python/ray/air/config.py b/python/ray/air/config.py index 8e8061c26177d..7169a1616910c 100644 --- a/python/ray/air/config.py +++ b/python/ray/air/config.py @@ -105,9 +105,9 @@ class ScalingConfig: worker can be overridden with the ``resources_per_worker`` argument. resources_per_worker: If specified, the resources - defined in this Dict will be reserved for each worker. The - ``CPU`` and ``GPU`` keys (case-sensitive) can be defined to - override the number of CPU/GPUs used by each worker. + defined in this Dict is reserved for each worker. + Define the ``"CPU"`` and ``"GPU"`` keys (case-sensitive) to + override the number of CPU or GPUs used by each worker. placement_strategy: The placement strategy to use for the placement group of the Ray actors. See :ref:`Placement Group Strategies ` for the possible options. diff --git a/python/ray/train/torch/torch_trainer.py b/python/ray/train/torch/torch_trainer.py index 1c0043483d3c3..7e7a957ff33e3 100644 --- a/python/ray/train/torch/torch_trainer.py +++ b/python/ray/train/torch/torch_trainer.py @@ -16,217 +16,123 @@ class TorchTrainer(DataParallelTrainer): """A Trainer for data parallel PyTorch training. - This Trainer runs the function ``train_loop_per_worker`` on multiple Ray - Actors. These actors already have the necessary torch process group - configured for distributed PyTorch training. + At a high level, this Trainer does the following: - The ``train_loop_per_worker`` function is expected to take in either 0 or 1 - arguments: + 1. Launches multiple workers as defined by the ``scaling_config``. + 2. Sets up a distributed PyTorch environment + on these workers as defined by the ``torch_config``. + 3. Ingests the input ``datasets`` based on the ``dataset_config``. + 4. Runs the input ``train_loop_per_worker(train_loop_config)`` + on all workers. - .. testcode:: - - def train_loop_per_worker(): - ... - - .. testcode:: - - from typing import Dict, Any - def train_loop_per_worker(config: Dict[str, Any]): - ... - - If ``train_loop_per_worker`` accepts an argument, then - ``train_loop_config`` will be passed in as the argument. This is useful if you - want to tune the values in ``train_loop_config`` as hyperparameters. - - If the ``datasets`` dict contains a training dataset (denoted by - the "train" key), then it will be split into multiple dataset - shards that can then be accessed by ``train.get_dataset_shard("train")`` inside - ``train_loop_per_worker``. All the other datasets will not be split and - ``train.get_dataset_shard(...)`` will return the the entire Dataset. - - Inside the ``train_loop_per_worker`` function, you can use any of the - following methods: - - .. testcode:: - - from ray import train - - def train_loop_per_worker(): - # Report intermediate results for callbacks or logging and - # checkpoint data. - train.report(...) - - # Get the Dataset shard for the given key. - train.get_dataset_shard("my_dataset") - - # Get dict of last saved checkpoint. - train.get_checkpoint() - - # Get the total number of workers executing training. - train.get_context().get_world_size() - - # Get the rank of this worker. - train.get_context().get_world_rank() - - # Get the rank of the worker on the current node. - train.get_context().get_local_rank() - - You can also use any of the Torch specific function utils, - such as :func:`ray.train.torch.get_device` and :func:`ray.train.torch.prepare_model` - - .. testcode:: - - def train_loop_per_worker(): - # Prepares model for distribted training by wrapping in - # `DistributedDataParallel` and moving to correct device. - train.torch.prepare_model(...) - - # Configures the dataloader for distributed training by adding a - # `DistributedSampler`. - # You should NOT use this if you are doing - # `train.get_dataset_shard(...).iter_torch_batches(...)` - train.torch.prepare_data_loader(...) - - # Get the current torch device. - train.torch.get_device() - - Any returns from the ``train_loop_per_worker`` will be discarded and not - used or persisted anywhere. - - To save a model to use for the ``TorchPredictor``, you must save it under the - "model" kwarg in ``Checkpoint`` passed to ``train.report()``. - - .. note:: - When you wrap the ``model`` with ``prepare_model``, the keys of its - ``state_dict`` are prefixed by ``module.``. For example, - ``layer1.0.bn1.bias`` becomes ``module.layer1.0.bn1.bias``. - However, when saving ``model`` through ``train.report()`` - all ``module.`` prefixes are stripped. - As a result, when you load from a saved checkpoint, make sure that - you first load ``state_dict`` to the model - before calling ``prepare_model``. - Otherwise, you will run into errors like - ``Error(s) in loading state_dict for DistributedDataParallel: - Missing key(s) in state_dict: "module.conv1.weight", ...``. See snippet below. - - .. testcode:: - - from torchvision.models import resnet18 - from ray.train import Checkpoint - import ray.train as train - - def train_func(): - ... - model = resnet18() - model = train.torch.prepare_model(model) - for epoch in range(3): - ... - ckpt = Checkpoint.from_dict({ - "epoch": epoch, - "model": model.state_dict(), - # "model": model.module.state_dict(), - # ** The above two are equivalent ** - }) - train.report({"foo": "bar"}, ckpt) + For more details, see the + :ref:`Distributed PyTorch User Guides `. Example: .. testcode:: + import os + import tempfile + import torch - import torch.nn as nn + from torch import nn + from torch.nn.parallel import DistributedDataParallel import ray - from ray import train from ray.train import Checkpoint, CheckpointConfig, RunConfig, ScalingConfig from ray.train.torch import TorchTrainer # If using GPUs, set this to True. use_gpu = False + # Number of processes to run training on. + num_workers = 4 - # Define NN layers archicture, epochs, and number of workers - input_size = 1 - layer_size = 32 - output_size = 1 - num_epochs = 20 - num_workers = 3 - - # Define your network structure + # Define your network structure. class NeuralNetwork(nn.Module): def __init__(self): super(NeuralNetwork, self).__init__() - self.layer1 = nn.Linear(input_size, layer_size) + self.layer1 = nn.Linear(1, 32) self.relu = nn.ReLU() - self.layer2 = nn.Linear(layer_size, output_size) + self.layer2 = nn.Linear(32, 1) def forward(self, input): return self.layer2(self.relu(self.layer1(input))) - # Define your train worker loop - def train_loop_per_worker(): - torch.manual_seed(42) + # Training loop. + def train_loop_per_worker(config): + + # Read configurations. + lr = config["lr"] + batch_size = config["batch_size"] + num_epochs = config["num_epochs"] - # Fetch training set - dataset_shard = train.get_dataset_shard("train") + # Fetch training dataset. + train_dataset_shard = ray.train.get_dataset_shard("train") + + # Instantiate and prepare model for training. model = NeuralNetwork() + model = ray.train.torch.prepare_model(model) - # Loss function, optimizer, prepare model for training. - # This moves the data and prepares model for distributed - # execution + # Define loss and optimizer. loss_fn = nn.MSELoss() - optimizer = torch.optim.Adam(model.parameters(), - lr=0.01, - weight_decay=0.01) - model = train.torch.prepare_model(model) - - # Iterate over epochs and batches - for epoch in range(num_epochs): - for batches in dataset_shard.iter_torch_batches(batch_size=32, - dtypes=torch.float): + optimizer = torch.optim.SGD(model.parameters(), lr=lr) - # Add batch or unsqueeze as an additional dimension [32, x] - inputs, labels = torch.unsqueeze(batches["x"], 1), batches["y"] - output = model(inputs) + # Create data loader. + dataloader = train_dataset_shard.iter_torch_batches( + batch_size=batch_size, dtypes=torch.float + ) - # Make output shape same as the as labels - loss = loss_fn(output.squeeze(), labels) + # Train multiple epochs. + for epoch in range(num_epochs): - # Zero out grads, do backward, and update optimizer + # Train epoch. + for batch in dataloader: + output = model(batch["input"]) + loss = loss_fn(output, batch["label"]) optimizer.zero_grad() loss.backward() optimizer.step() - # Print what's happening with loss per 30 epochs - if epoch % 20 == 0: - print(f"epoch: {epoch}/{num_epochs}, loss: {loss:.3f}") - - # Report and record metrics, checkpoint model at end of each - # epoch - train.report({"loss": loss.item(), "epoch": epoch}, - checkpoint=Checkpoint.from_dict( - dict(epoch=epoch, model=model.state_dict())) + # Create checkpoint. + base_model = (model.module + if isinstance(model, DistributedDataParallel) else model) + checkpoint_dir = tempfile.mkdtemp() + torch.save( + {"model_state_dict": base_model.state_dict()}, + os.path.join(checkpoint_dir, "model.pt"), ) + checkpoint = Checkpoint.from_directory(checkpoint_dir) + + # Report metrics and checkpoint. + ray.train.report({"loss": loss.item()}, checkpoint=checkpoint) - train_dataset = ray.data.from_items( - [{"x": x, "y": 2 * x + 1} for x in range(2000)] - ) - # Define scaling and run configs - scaling_config = ScalingConfig(num_workers=3, use_gpu=use_gpu) + # Define configurations. + train_loop_config = {"num_epochs": 20, "lr": 0.01, "batch_size": 32} + scaling_config = ScalingConfig(num_workers=num_workers, use_gpu=use_gpu) run_config = RunConfig(checkpoint_config=CheckpointConfig(num_to_keep=1)) + # Define datasets. + train_dataset = ray.data.from_items( + [{"input": [x], "label": [2 * x + 1]} for x in range(2000)] + ) + datasets = {"train": train_dataset} + + # Initialize the Trainer. trainer = TorchTrainer( train_loop_per_worker=train_loop_per_worker, + train_loop_config=train_loop_config, scaling_config=scaling_config, run_config=run_config, - datasets={"train": train_dataset}) + datasets=datasets + ) + # Train the model. result = trainer.fit() - best_checkpoint_loss = result.metrics['loss'] - - # Assert loss is less 0.09 - assert best_checkpoint_loss <= 0.09 + # Inspect the results. + final_loss = result.metrics["loss"] .. testoutput:: :hide: @@ -235,24 +141,37 @@ def train_loop_per_worker(): Args: - train_loop_per_worker: The training function to execute. - This can either take in no arguments or a ``config`` dict. - train_loop_config: Configurations to pass into - ``train_loop_per_worker`` if it accepts an argument. - torch_config: Configuration for setting up the PyTorch backend. If set to - None, use the default configuration. This replaces the ``backend_config`` - arg of ``DataParallelTrainer``. - scaling_config: Configuration for how to scale data parallel training. - dataset_config: Configuration for dataset ingest. - run_config: Configuration for the execution of the training run. - datasets: Any Datasets to use for training. Use - the key "train" to denote which dataset is the training - dataset. If a ``preprocessor`` is provided and has not already been fit, - it will be fit on the training dataset. All datasets will be transformed - by the ``preprocessor`` if one is provided. - preprocessor: A ``ray.data.Preprocessor`` to preprocess the - provided datasets. + train_loop_per_worker: The training function to execute on each worker. + This function can either take in zero arguments or a single ``Dict`` + argument which is set by defining ``train_loop_config``. + Within this function you can use any of the + :ref:`Ray Train Loop utilities `. + train_loop_config: A configuration ``Dict`` to pass in as an argument to + ``train_loop_per_worker``. + This is typically used for specifying hyperparameters. + torch_config: The configuration for setting up the PyTorch Distributed backend. + If set to None, a default configuration will be used in which + GPU training uses NCCL and CPU training uses Gloo. + scaling_config: The configuration for how to scale data parallel training. + ``num_workers`` determines how many Python processes are used for training, + and ``use_gpu`` determines whether or not each process should use GPUs. + See :class:`~ray.air.ScalingConfig` for more info. + run_config: The configuration for the execution of the training run. + See :class:`~ray.air.RunConfig` for more info. + datasets: The Ray Datasets to ingest for training. + Datasets are keyed by name (``{name: dataset}``). + Each dataset can be accessed from within the ``train_loop_per_worker`` + by calling ``ray.train.get_dataset_shard(name)``. + Sharding and additional configuration can be done by + passing in a ``dataset_config``. + dataset_config: The configuration for ingesting the input ``datasets``. + By default: + + - The ``"train"`` Dataset is split equally across workers. + - All other Datasets are **not** split. resume_from_checkpoint: A checkpoint to resume training from. + This checkpoint can be accessed from within ``train_loop_per_worker`` + by calling ``ray.train.get_checkpoint()``. """ def __init__( @@ -262,11 +181,12 @@ def __init__( train_loop_config: Optional[Dict] = None, torch_config: Optional[TorchConfig] = None, scaling_config: Optional[ScalingConfig] = None, - dataset_config: Optional[DataConfig] = None, run_config: Optional[RunConfig] = None, datasets: Optional[Dict[str, GenDataset]] = None, - preprocessor: Optional["Preprocessor"] = None, + dataset_config: Optional[DataConfig] = None, resume_from_checkpoint: Optional[Checkpoint] = None, + # Deprecated. + preprocessor: Optional["Preprocessor"] = None, ): if not torch_config: torch_config = TorchConfig()