## 11. `train_loop_per_worker`  

This function defines the **per-worker training logic** that Ray Train executes on each distributed worker.  

Each worker builds its own model, optimizer, and dataloaders; resumes automatically from the most recent Ray-managed checkpoint (if available); and then trains and validates the model across epochs.  

Key behaviors to note:

- **Checkpoints** are first written to a fast **temporary local directory** on each worker, then safely persisted to the run’s configured `storage_path` by `train.report()`—ensuring reliability and retry support even under transient node failures.  
- **Metrics** (train and validation loss) are automatically collected and stored by Ray Train—no need for manual file writes or JSON logging.  
- **Fault tolerance** is fully handled by Ray Train’s checkpointing and retry mechanism via `RunConfig` and `FailureConfig`.  
- **Final accuracy** is computed using `torchmetrics.MulticlassAccuracy`, which performs synchronized, **distributed accuracy aggregation** across all workers, ensuring a correct global metric instead of rank-0-only evaluation.  

This design keeps the training loop clean, fault-tolerant, and fully aligned with Ray Train’s built-in distributed orchestration.


In [None]:
# 11. Define Ray Train train_loop_per_worker (tempdir checkpoints + Ray-managed metrics)
def train_loop_per_worker(config):
    import tempfile

    rank = get_context().get_world_rank()

    # === Model ===
    net = resnet18(num_classes=101)
    model = prepare_model(net)

    # === Optimizer / Loss ===
    optimizer = optim.Adam(model.parameters(), lr=config["lr"])
    criterion = nn.CrossEntropyLoss()

    # === Resume from Checkpoint ===
    start_epoch = 0
    ckpt = get_checkpoint()
    if ckpt:
        with ckpt.as_directory() as ckpt_dir:
            # Map to CPU is fine; prepare_model will handle device placement.
            model.load_state_dict(torch.load(os.path.join(ckpt_dir, "model.pt"), map_location="cpu"))
            opt_path = os.path.join(ckpt_dir, "optimizer.pt")
            if os.path.exists(opt_path):
                optimizer.load_state_dict(torch.load(opt_path, map_location="cpu"))
            meta_path = os.path.join(ckpt_dir, "meta.pt")
            if os.path.exists(meta_path):
                # Continue from the next epoch after the saved one
                start_epoch = int(torch.load(meta_path).get("epoch", -1)) + 1
        if rank == 0:
            print(f"[Rank {rank}] Resumed from checkpoint at epoch {start_epoch}")

    # === DataLoaders ===
    train_loader = build_dataloader(
        "/mnt/cluster_storage/food101_lite/train.parquet", config["batch_size"], shuffle=True
    )
    val_loader = build_dataloader(
        "/mnt/cluster_storage/food101_lite/val.parquet", config["batch_size"], shuffle=False
    )

    # === Training Loop ===
    for epoch in range(start_epoch, config["epochs"]):
        # Required when using DistributedSampler
        if hasattr(train_loader, "sampler") and hasattr(train_loader.sampler, "set_epoch"):
            train_loader.sampler.set_epoch(epoch)

        model.train()
        train_loss_total, train_batches = 0.0, 0
        for xb, yb in train_loader:
            optimizer.zero_grad()
            loss = criterion(model(xb), yb)
            loss.backward()
            optimizer.step()
            train_loss_total += loss.item()
            train_batches += 1
        train_loss = train_loss_total / max(train_batches, 1)

        # === Validation Loop ===
        model.eval()
        val_loss_total, val_batches = 0.0, 0
        with torch.no_grad():
            for val_xb, val_yb in val_loader:
                val_loss_total += criterion(model(val_xb), val_yb).item()
                val_batches += 1
        val_loss = val_loss_total / max(val_batches, 1)

        metrics = {"epoch": epoch, "train_loss": train_loss, "val_loss": val_loss}
        if rank == 0:
            print(metrics)

        # ---- Save checkpoint to fast local temp dir; Ray persists it via report() ----
        if rank == 0:
            with tempfile.TemporaryDirectory() as tmpdir:
                torch.save(model.state_dict(), os.path.join(tmpdir, "model.pt"))
                torch.save(optimizer.state_dict(), os.path.join(tmpdir, "optimizer.pt"))
                torch.save({"epoch": epoch}, os.path.join(tmpdir, "meta.pt"))
                ckpt_out = Checkpoint.from_directory(tmpdir)
                train.report(metrics, checkpoint=ckpt_out)
        else:
            # Non-zero ranks report metrics only (no checkpoint attachment)
            train.report(metrics)

    # === Final validation accuracy (distributed via TorchMetrics) ===
    from torchmetrics.classification import MulticlassAccuracy

    model.eval()
    device = next(model.parameters()).device
    # Sync across DDP workers when computing the final value
    acc_metric = MulticlassAccuracy(
        num_classes=101, average="micro", sync_on_compute=True
    ).to(device)

    with torch.no_grad():
        for xb, yb in val_loader:
            logits = model(xb)
            preds = torch.argmax(logits, dim=1)
            acc_metric.update(preds, yb)

    dist_val_acc = acc_metric.compute().item()
    if rank == 0:
        print(f"Val Accuracy (distributed): {dist_val_acc:.2%}")