## 3. Overview of the training loop in Ray Train

Let's see how this data-parallel training loop will look like with Ray Train and PyTorch.

In [7]:
def train_loop_ray_train(config: dict):  # pass in hyperparameters in config
    loss_function = CrossEntropyLoss()
    
    # New: Use Ray Train to wrap the original PyTorch model
    model = load_model_ray_train()

    # Initialize Adam optimizer
    optimizer = Adam(model.parameters(), lr=1e-5)
    
    # New: Calculate the batch size for each worker (batch size / num workers)
    batch_size = config["global_batch_size"] // ray.train.get_context().get_world_size()
    
    # New: Use Ray Train to wrap the data loader as a distributed sampler
    data_loader = build_data_loader_ray_train(batch_size=batch_size) 
    
    acc = torchmetrics.Accuracy(task="multiclass", num_classes=10).to(model.device)

    for epoch in range(config["num_epochs"]):
        # Ensure data is on the correct device for each epoch
        data_loader.sampler.set_epoch(epoch)

        # New: images, labels are now sharded,
        # Gradients are accumulated across the workers
        for images, labels in data_loader:
            outputs = model(images)
            loss = loss_function(outputs, labels)
            optimizer.zero_grad()
            loss.backward() 
            optimizer.step()
            acc(outputs, labels)

        # accuracy is now aggregated across the workers
        accuracy = acc.compute()

        # Use Ray Train to report metrics
        metrics = print_metrics_ray_train(loss, accuracy, epoch)

        # Use Ray Train to save checkpoint and metrics
        save_checkpoint_and_metrics_ray_train(model, metrics)
        acc.reset() 

### Configure scale and GPUs
Outside of our training function, we create a `ScalingConfig` object to configure:

- `num_workers`: The number of distributed training worker processes.
- `use_gpu`: Whether each worker should use a GPU (or CPU).


See [docs on configuring scale and GPUs](https://docs.ray.io/en/latest/train/user-guides/using-gpus.html) for more details.

In [8]:
from ray.train import ScalingConfig

scaling_config = ScalingConfig(num_workers=2, use_gpu=True)

Here is a high-level architecture of how Ray Train works:

<img src="https://docs.ray.io/en/latest/_images/overview.png" width=600>

#### Key points
- The scaling config specifies the number of training workers.
- A trainer actor process is launched that oversees the training workers.