## 4. Migrating the model and dataset to Ray Train

Use the `ray.train.torch.prepare_model()` utility function to:

- Automatically move your model to the correct device.
- Wrap the model in pytorch's `DistributedDataParallel`.

To learn more about the `prepare_model()` function, see the [API reference](https://docs.ray.io/en/latest/train/api/doc/ray.train.torch.prepare_model.html#ray-train-torch-prepare-model).

In [9]:
import torch

def load_model_ray_train() -> torch.nn.Module:
    model = build_resnet18()
    # Instead of model = model.to("cuda")
    model = ray.train.torch.prepare_model(model) 
    return model

Use the `ray.train.torch.prepare_data_loader()` utility function, to:

- Automatically moves the batches to the right device.
- Wrap the data loader with pytorch's `DistributedSampler`.

To learn more about the `prepare_data_loader()` function, see the [API reference](https://docs.ray.io/en/latest/train/api/doc/ray.train.torch.prepare_data_loader.html#ray-train-torch-prepare-data-loader).

In [10]:
def build_data_loader_ray_train(batch_size: int) -> DataLoader:
    transform = Compose([ToTensor(), Normalize((0.5,), (0.5,))])
    train_data = MNIST(root="./data", train=True, download=True, transform=transform)
    train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, drop_last=True)

    # Add DistributedSampler to the DataLoader
    train_loader = ray.train.torch.prepare_data_loader(train_loader)
    return train_loader

<div class="alert alert-block alert-warning">
<b> Note</b> that this step isn’t necessary if you are integrating your Ray Train implementaiton with Ray Data.
</div>