Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[cherry-pick][Doc][Train] Improve torch, lightning quickstarts and migration guides + fix torch restoration example #41843

Merged
merged 2 commits into from
Dec 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions doc/source/train/distributed-tensorflow-keras.rst
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
.. _train-tensorflow-overview:

Get Started with TensorFlow and Keras
=====================================
Get Started with Distributed Training using TensorFlow/Keras
============================================================

Ray Train's `TensorFlow <https://www.tensorflow.org/>`__ integration enables you
to scale your TensorFlow and Keras training functions to many machines and GPUs.
Expand Down
4 changes: 2 additions & 2 deletions doc/source/train/distributed-xgboost-lightgbm.rst
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
.. _train-gbdt-guide:

Get Started with XGBoost and LightGBM
=====================================
Get Started with Distributed Training using XGBoost and LightGBM
================================================================

Ray Train has built-in support for XGBoost and LightGBM.

Expand Down
14 changes: 8 additions & 6 deletions doc/source/train/doc_code/checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ def train_func(config):
result = trainer.fit()
# __pytorch_save_end__


# __pytorch_restore_start__
import os
import tempfile
Expand Down Expand Up @@ -90,12 +89,19 @@ def train_func(config):
optimizer = Adam(model.parameters(), lr=3e-4)
criterion = nn.MSELoss()

# Wrap the model in DDP and move it to GPU.
model = ray.train.torch.prepare_model(model)

# ====== Resume training state from the checkpoint. ======
start_epoch = 0
checkpoint = train.get_checkpoint()
if checkpoint:
with checkpoint.as_directory() as checkpoint_dir:
model.load_state_dict(torch.load(os.path.join(checkpoint_dir, "model.pt")))
model_state_dict = torch.load(
os.path.join(checkpoint_dir, "model.pt"),
# map_location=..., # Load onto a different device if needed.
)
model.module.load_state_dict(model_state_dict)
optimizer.load_state_dict(
torch.load(os.path.join(checkpoint_dir, "optimizer.pt"))
)
Expand All @@ -104,9 +110,6 @@ def train_func(config):
)
# ========================================================

# Wrap the model in DDP
model = ray.train.torch.prepare_model(model)

for epoch in range(start_epoch, config["num_epochs"]):
y = model.forward(X)
loss = criterion(y, Y)
Expand Down Expand Up @@ -162,7 +165,6 @@ def train_func(config):
)
# __pytorch_restore_end__


# __checkpoint_from_single_worker_start__
import tempfile

Expand Down
91 changes: 62 additions & 29 deletions doc/source/train/getting-started-pytorch-lightning.rst
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
.. _train-pytorch-lightning:

Get Started with PyTorch Lightning
==================================
Get Started with Distributed Training using PyTorch Lightning
=============================================================

This tutorial walks through the process of converting an existing PyTorch Lightning script to use Ray Train.

Expand Down Expand Up @@ -57,7 +57,9 @@ Compare a PyTorch Lightning training script with and without Ray Train.
def __init__(self):
super(ImageClassifier, self).__init__()
self.model = resnet18(num_classes=10)
self.model.conv1 = torch.nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
self.model.conv1 = torch.nn.Conv2d(
1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
)
self.criterion = torch.nn.CrossEntropyLoss()

def forward(self, x):
Expand All @@ -84,29 +86,32 @@ Compare a PyTorch Lightning training script with and without Ray Train.
trainer.fit(model, train_dataloaders=train_dataloader)



.. tab-item:: PyTorch Lightning + Ray Train

.. code-block:: python
:emphasize-lines: 8-10, 34, 43, 48-50, 52, 53, 55-60
:emphasize-lines: 11-12, 38, 52-57, 59, 63, 66-73

import os
import tempfile

import torch
from torch.utils.data import DataLoader
from torchvision.models import resnet18
from torchvision.datasets import FashionMNIST
from torchvision.transforms import ToTensor, Normalize, Compose
from torch.utils.data import DataLoader
import lightning.pytorch as pl

from ray.train.torch import TorchTrainer
from ray.train import ScalingConfig
import ray.train.lightning
from ray.train.torch import TorchTrainer

# Model, Loss, Optimizer
class ImageClassifier(pl.LightningModule):
def __init__(self):
super(ImageClassifier, self).__init__()
self.model = resnet18(num_classes=10)
self.model.conv1 = torch.nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
self.model.conv1 = torch.nn.Conv2d(
1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
)
self.criterion = torch.nn.CrossEntropyLoss()

def forward(self, x):
Expand All @@ -124,10 +129,10 @@ Compare a PyTorch Lightning training script with and without Ray Train.


def train_func(config):

# Data
transform = Compose([ToTensor(), Normalize((0.5,), (0.5,))])
train_data = FashionMNIST(root='./data', train=True, download=True, transform=transform)
data_dir = os.path.join(tempfile.gettempdir(), "data")
train_data = FashionMNIST(root=data_dir, train=True, download=True, transform=transform)
train_dataloader = DataLoader(train_data, batch_size=128, shuffle=True)

# Training
Expand All @@ -140,16 +145,34 @@ Compare a PyTorch Lightning training script with and without Ray Train.
strategy=ray.train.lightning.RayDDPStrategy(),
plugins=[ray.train.lightning.RayLightningEnvironment()],
callbacks=[ray.train.lightning.RayTrainReportCallback()],
# [1a] Optionally, disable the default checkpointing behavior
# in favor of the `RayTrainReportCallback` above.
enable_checkpointing=False,
)
trainer = ray.train.lightning.prepare_trainer(trainer)
trainer.fit(model, train_dataloaders=train_dataloader)

# [2] Configure scaling and resource requirements.
scaling_config = ScalingConfig(num_workers=2, use_gpu=True)
scaling_config = ray.train.ScalingConfig(num_workers=2, use_gpu=True)

# [3] Launch distributed training job.
trainer = TorchTrainer(train_func, scaling_config=scaling_config)
result = trainer.fit()
trainer = TorchTrainer(
train_func,
scaling_config=scaling_config,
# [3a] If running in a multi-node cluster, this is where you
# should configure the run's persistent storage.
# run_config=ray.train.RunConfig(storage_path="s3://..."),
)
result: ray.train.Result = trainer.fit()

# [4] Load the trained model.
with result.checkpoint.as_directory() as checkpoint_dir:
model = ImageClassifier.load_from_checkpoint(
os.path.join(
checkpoint_dir,
ray.train.lightning.RayTrainReportCallback.CHECKPOINT_NAME,
),
)


Set up a training function
Expand Down Expand Up @@ -183,7 +206,7 @@ make a few changes to your Lightning Trainer definition.
datamodule = MyLightningDataModule(...)

trainer = pl.Trainer(
- devices=[0,1,2,3],
- devices=[0, 1, 2, 3],
- strategy=DDPStrategy(),
- plugins=[LightningEnvironment()],
+ devices="auto",
Expand Down Expand Up @@ -429,7 +452,7 @@ control over their native Lightning code.

config_builder = LightningConfigBuilder()
# [1] Collect model configs
config_builder.module(cls=MNISTClassifier, lr=1e-3, feature_dim=128)
config_builder.module(cls=MyLightningModule, lr=1e-3, feature_dim=128)

# [2] Collect checkpointing configs
config_builder.checkpointing(monitor="val_accuracy", mode="max", save_top_k=3)
Expand All @@ -439,11 +462,10 @@ control over their native Lightning code.
max_epochs=10,
accelerator="gpu",
log_every_n_steps=100,
logger=CSVLogger("./logs"),
)

# [4] Build datasets on the head node
datamodule = MNISTDataModule(batch_size=32)
datamodule = MyLightningDataModule(batch_size=32)
config_builder.fit_params(datamodule=datamodule)

# [5] Execute the internal training function in a black box
Expand All @@ -458,7 +480,11 @@ control over their native Lightning code.
),
)
)
ray_trainer.fit()
result = ray_trainer.fit()

# [6] Load the trained model from an opaque Lightning-specific checkpoint.
lightning_checkpoint = result.checkpoint
model = lightning_checkpoint.get_model(MyLightningModule)



Expand All @@ -469,8 +495,11 @@ control over their native Lightning code.
.. testcode::
:skipif: True

import os

import lightning.pytorch as pl
from ray.air import CheckpointConfig, RunConfig

import ray.train
from ray.train.torch import TorchTrainer
from ray.train.lightning import (
RayDDPStrategy,
Expand All @@ -481,18 +510,15 @@ control over their native Lightning code.

def train_func(config):
# [1] Create a Lightning model
model = MNISTClassifier(lr=1e-3, feature_dim=128)
model = MyLightningModule(lr=1e-3, feature_dim=128)

# [2] Report Checkpoint with callback
ckpt_report_callback = RayTrainReportCallback()

# [3] Create a Lighting Trainer
datamodule = MNISTDataModule(batch_size=32)

trainer = pl.Trainer(
max_epochs=10,
log_every_n_steps=100,
logger=CSVLogger("./logs"),
# New configurations below
devices="auto",
accelerator="auto",
Expand All @@ -505,19 +531,26 @@ control over their native Lightning code.
trainer = prepare_trainer(trainer)

# [4] Build your datasets on each worker
datamodule = MNISTDataModule(batch_size=32)
datamodule = MyLightningDataModule(batch_size=32)
trainer.fit(model, datamodule=datamodule)

# [5] Explicitly define and run the training function
ray_trainer = TorchTrainer(
train_func,
scaling_config=ScalingConfig(num_workers=4, use_gpu=True),
run_config=RunConfig(
checkpoint_config=CheckpointConfig(
scaling_config=ray.train.ScalingConfig(num_workers=4, use_gpu=True),
run_config=ray.train.RunConfig(
checkpoint_config=ray.train.CheckpointConfig(
num_to_keep=3,
checkpoint_score_attribute="val_accuracy",
checkpoint_score_order="max",
),
)
)
ray_trainer.fit()
result = ray_trainer.fit()

# [6] Load the trained model from a simplified checkpoint interface.
checkpoint: ray.train.Checkpoint = result.checkpoint
with checkpoint.as_directory() as checkpoint_dir:
print("Checkpoint contents:", os.listdir(checkpoint_dir))
checkpoint_path = os.path.join(checkpoint_dir, "checkpoint.ckpt")
model = MyLightningModule.load_from_checkpoint(checkpoint_path)
Loading
Loading