## Ray Train

Ray Train is ray's functionnality to run training jobs on the cluster. It is compatible with Pytorch, Huggingface, Tensorflow (although tensorflow support is wearing thin as tensorflow losses traction in the ML/DL community)... But not with scikit-learn, which is why we did not use it in the previous class.

Though it's very powerful for deep learning models mostly. Let's prepare our environment:

```shell
conda activate ray
```

Once our environment is activated we'll install the ray components we'll need for our demo:

```shell
pip install ray==2.36.0 -U "ray[train]" torch torchvision
```

The following example shows how you can use Ray Train to set up Multi-worker training with pytorch.

In [None]:
!pip install ray==2.36.0 -U "ray[train]" torch torchvision

To use Ray Train effectively, you need to understand four main concepts:

* Training function: A Python function that contains your model training logic.

* Worker: A process that runs the training function.

* Scaling configuration: A configuration of the number of workers and compute resources (for example, CPUs or GPUs).

* Trainer: A Python class that ties together the training function, workers, and scaling configuration to execute a distributed training job.

Now define your single-worker TensorFlow training function.

The ray documentation suggets that you call the traininf function `train_func`. This function should take care of :
- loading the model
- loading the dataset
- training the model
- saving checkpoints
- logging metrics

This framework works with pytorch, hugging face, tensorflow and Keras, XGBoost...

In [None]:
import os
import tempfile

import torch
from torch.nn import CrossEntropyLoss
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchvision.models import resnet18
from torchvision.datasets import FashionMNIST
from torchvision.transforms import ToTensor, Normalize, Compose

import ray.train.torch

def train_func():
    # Model, Loss, Optimizer
    model = resnet18(num_classes=10)
    model.conv1 = torch.nn.Conv2d(
        1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
    )
    # [1] Prepare model.
    model = ray.train.torch.prepare_model(model)
    # model.to("cuda")  # This is done by `prepare_model`
    criterion = CrossEntropyLoss()
    optimizer = Adam(model.parameters(), lr=0.001)

    # Data
    transform = Compose([ToTensor(), Normalize((0.5,), (0.5,))])
    data_dir = os.path.join(tempfile.gettempdir(), "data")
    train_data = FashionMNIST(root=data_dir, train=True, download=True, transform=transform)
    train_loader = DataLoader(train_data, batch_size=128, shuffle=True)
    # [2] Prepare dataloader.
    train_loader = ray.train.torch.prepare_data_loader(train_loader)

    # Training
    for epoch in range(10):
        if ray.train.get_context().get_world_size() > 1:
            train_loader.sampler.set_epoch(epoch)

        for images, labels in train_loader:
            # This is done by `prepare_data_loader`!
            # images, labels = images.to("cuda"), labels.to("cuda")
            outputs = model(images)
            loss = criterion(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # [3] Report metrics and checkpoint.
        metrics = {"loss": loss.item(), "epoch": epoch}
        with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
            torch.save(
                model.module.state_dict(),
                os.path.join(temp_checkpoint_dir, "model.pt")
            )
            ray.train.report(
                metrics,
                checkpoint=ray.train.Checkpoint.from_directory(temp_checkpoint_dir),
            )
        if ray.train.get_context().get_world_rank() == 0:
            print(metrics)

This training function can be executed with:

In [None]:
# [4] Configure scaling and resource requirements.
scaling_config = ray.train.ScalingConfig(num_workers=3, use_gpu=False)

# [5] Launch distributed training job.
trainer = ray.train.torch.TorchTrainer(
    train_func,
    scaling_config=scaling_config,
    # [5a] If running in a multi-node cluster, this is where you
    # should configure the run's persistent storage that is accessible
    # across all worker nodes.
    # run_config=ray.train.RunConfig(storage_path="s3://..."),
)
result = trainer.fit()

We may load the latest checkpoint using this :

In [None]:
# [6] Load the trained model.
with result.checkpoint.as_directory() as checkpoint_dir:
    model_state_dict = torch.load(os.path.join(checkpoint_dir, "model.pt"))
    model = resnet18(num_classes=10)
    model.conv1 = torch.nn.Conv2d(
        1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
    )
    model.load_state_dict(model_state_dict)

## Resources 📚📚

[Ray Train Pytorch](https://docs.ray.io/en/latest/train/getting-started-pytorch.html)