## 5. Reporting checkpoints and metrics

To monitor progress, we can continue to print/log metrics as before. This time we chose to only do so for the first worker.

In [24]:
def print_metrics_ray_train(
    loss: torch.Tensor, accuracy: torch.Tensor, epoch: int
) -> None:
    metrics = {"loss": loss.item(), "accuracy": accuracy.item(), "epoch": epoch}
    if ray.train.get_context().get_world_rank() == 0:
        print(metrics)
    return metrics

we will report intermediate metrics and checkpoints using the `ray.train.report` utility function.

In [23]:
def save_checkpoint_and_metrics_ray_train(
    model: torch.nn.Module, metrics: dict[str, float]
) -> None:
    with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
        torch.save(
            model.module.state_dict(),  # note the .module to unwrap the DistributedDataParallel
            os.path.join(temp_checkpoint_dir, "model.pt"),
        )
        ray.train.report(  # use ray.train.report to save the metrics and checkpoint
            metrics,  # train.report will only save worker rank 0's metrics
            checkpoint=ray.train.Checkpoint.from_directory(temp_checkpoint_dir),
        )

Here is the lifecycle of a checkpoint from being created using a local path to being uploaded to persistent storage.

<img src="https://docs.ray.io/en/latest/_images/checkpoint_lifecycle.png" width=800>


Given it is the same model across all workers, we can instead only build the checkpoint on worker of rank 0. Note that we will still need to call `ray.train.report` on all workers to ensure that the training loop is synchronized.

In [22]:
import tempfile
import ray

def save_checkpoint_and_metrics_ray_train(model: torch.nn.Module, metrics: dict[str, float]):
    with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
        checkpoint = None
        # On the first worker
        if ray.train.get_context().get_world_rank() == 0:
            # Save PyTorch checkpoints locally
            torch.save(model.module.state_dict(), os.path.join(temp_checkpoint_dir, "model.pt"))
            # Convert to Ray Train checkpoint
            checkpoint = ray.train.Checkpoint.from_directory(temp_checkpoint_dir)

        # Report checkpoint
        ray.train.report(
            metrics,
            checkpoint=checkpoint,
        )

For an in-depth guide on saving checkpoints and metrics, see the [docs](https://docs.ray.io/en/latest/train/user-guides/checkpoints.html).