In [1]:
import ray
from ray.experimental.client2.client import Client as Client2
if Client2.active_client is not None:
    Client2.active_client.disconnect()
client = Client2("http://localhost:8265", "torch_fashion_mnist5", runtime_env={"pip":["torch==1.13.0", "torchvision"]})

2023-11-20 21:39:26,009	INFO client.py:205 -- client2 channel torch_fashion_mnist5 connected!


In [2]:
# https://docs.ray.io/en/releases-2.7.0/train/examples/pytorch/torch_fashion_mnist_example.html#torch-fashion-mnist-ex
# This is an easy adaptation, because most if not all code are already in a function, we can easily move it to remote.

import os
from filelock import FileLock
from typing import Dict



def get_dataloaders(batch_size):
    
    from torch.utils.data import DataLoader
    from torchvision import datasets, transforms
    from torchvision.transforms import ToTensor, Normalize
    
    # Transform to normalize the input images
    transform = transforms.Compose([ToTensor(), Normalize((0.5,), (0.5,))])

    with FileLock(os.path.expanduser("~/data.lock")):
        # Download training data from open datasets.
        training_data = datasets.FashionMNIST(
            root="~/data",
            train=True,
            download=True,
            transform=transform,
        )

        # Download test data from open datasets.
        test_data = datasets.FashionMNIST(
            root="~/data",
            train=False,
            download=True,
            transform=transform,
        )

    # Create data loaders.
    train_dataloader = DataLoader(training_data, batch_size=batch_size)
    test_dataloader = DataLoader(test_data, batch_size=batch_size)

    return train_dataloader, test_dataloader


def train_func_per_worker(config: Dict):
    import torch
    from torch import nn
    from tqdm import tqdm

    import ray.train
    
    
    # Model Definition
    class NeuralNetwork(nn.Module):
        def __init__(self):
            super(NeuralNetwork, self).__init__()
            self.flatten = nn.Flatten()
            self.linear_relu_stack = nn.Sequential(
                nn.Linear(28 * 28, 512),
                nn.ReLU(),
                nn.Dropout(0.25),
                nn.Linear(512, 512),
                nn.ReLU(),
                nn.Dropout(0.25),
                nn.Linear(512, 10),
                nn.ReLU(),
            )

        def forward(self, x):
            x = self.flatten(x)
            logits = self.linear_relu_stack(x)
            return logits

    lr = config["lr"]
    epochs = config["epochs"]
    batch_size = config["batch_size_per_worker"]

    # Get dataloaders inside worker training function
    train_dataloader, test_dataloader = get_dataloaders(batch_size=batch_size)

    # [1] Prepare Dataloader for distributed training
    # Shard the datasets among workers and move batches to the correct device
    # =======================================================================
    train_dataloader = ray.train.torch.prepare_data_loader(train_dataloader)
    test_dataloader = ray.train.torch.prepare_data_loader(test_dataloader)

    model = NeuralNetwork()

    # [2] Prepare and wrap your model with DistributedDataParallel
    # Move the model the correct GPU/CPU device
    # ============================================================
    model = ray.train.torch.prepare_model(model)

    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)

    # Model training loop
    for epoch in range(epochs):
        model.train()
        for X, y in tqdm(train_dataloader, desc=f"Train Epoch {epoch}"):
            pred = model(X)
            loss = loss_fn(pred, y)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        model.eval()
        test_loss, num_correct, num_total = 0, 0, 0
        with torch.no_grad():
            for X, y in tqdm(test_dataloader, desc=f"Test Epoch {epoch}"):
                pred = model(X)
                loss = loss_fn(pred, y)

                test_loss += loss.item()
                num_total += y.shape[0]
                num_correct += (pred.argmax(1) == y).sum().item()

        test_loss /= len(test_dataloader)
        accuracy = num_correct / num_total

        # [3] Report metrics to Ray Train
        # ===============================
        ray.train.report(metrics={"loss": test_loss, "accuracy": accuracy})

@ray.remote
def train_fashion_mnist(num_workers=2, use_gpu=False):
    from ray.train import ScalingConfig
    from ray.train.torch import TorchTrainer

    global_batch_size = 32

    train_config = {
        "lr": 1e-3,
        "epochs": 10,
        "batch_size_per_worker": global_batch_size // num_workers,
    }

    # Configure computation resources
    scaling_config = ScalingConfig(num_workers=num_workers, use_gpu=use_gpu)

    # Initialize a Ray TorchTrainer
    trainer = TorchTrainer(
        train_loop_per_worker=train_func_per_worker,
        train_loop_config=train_config,
        scaling_config=scaling_config,
    )

    # [4] Start Distributed Training
    # Run `train_func_per_worker` on all workers
    # =============================================
    result = trainer.fit()
    print(f"Training result: {result}")
    return result


In [3]:
o = client(train_fashion_mnist).remote(num_workers=1, use_gpu=False)
result = client.get(o)

In [6]:
result

Result(
  metrics={'loss': 0.3557578268856667, 'accuracy': 0.8732},
  path='/Users/ruiyangwang/ray_results/TorchTrainer_2023-11-20_21-33-50/TorchTrainer_8d0b9_00000_0_2023-11-20_21-33-50',
  filesystem='local',
  checkpoint=None
)