## 11 · `train_loop_per_worker`  
This is the workhorse called by each Ray worker. Inside, build the model, optimiser, and loss, try to resume from any existing checkpoint, run the training and validation loops, log metrics, and (on rank 0) save new checkpoints and append results to a history file. At the end, compute a final validation accuracy for good measure.

In [None]:
# 11. Define Ray Train train_loop_per_worker
def train_loop_per_worker(config):

    # === Model ===
    net = resnet18(num_classes=101)
    model = prepare_model(net)

    # === Optimizer / Loss ===
    optimizer = optim.Adam(model.parameters(), lr=config["lr"])
    criterion = nn.CrossEntropyLoss()

    # === Resume from Checkpoint ===
    checkpoint = get_checkpoint()
    start_epoch = 0
    if checkpoint:
        with checkpoint.as_directory() as ckpt_dir:
            model.load_state_dict(torch.load(os.path.join(ckpt_dir, "model.pt")))
            optimizer.load_state_dict(torch.load(os.path.join(ckpt_dir, "optimizer.pt")))
            start_epoch = torch.load(os.path.join(ckpt_dir, "extra.pt"))["epoch"]
        print(f"[Rank {get_context().get_world_rank()}] Resumed from checkpoint at epoch {start_epoch}")

    # === DataLoaders ===
    train_loader = build_dataloader(
        "/mnt/cluster_storage/food101_lite/train.parquet", config["batch_size"], shuffle=True
    )
    val_loader = build_dataloader(
        "/mnt/cluster_storage/food101_lite/val.parquet", config["batch_size"], shuffle=False
    )

    # === Training Loop ===
    for epoch in range(start_epoch, config["epochs"]):
        train_loader.sampler.set_epoch(epoch)  # required when using DistributedSampler
        model.train()
        train_loss_total = 0.0
        train_batches = 0

        for xb, yb in train_loader:
            optimizer.zero_grad()
            loss = criterion(model(xb), yb)
            loss.backward()
            optimizer.step()
            train_loss_total += loss.item()
            train_batches += 1

        train_loss = train_loss_total / train_batches

        # === Validation Loop ===
        model.eval()
        val_loss_total = 0.0
        val_batches = 0
        with torch.no_grad():
            for val_xb, val_yb in val_loader:
                val_loss_total += criterion(model(val_xb), val_yb).item()
                val_batches += 1
        val_loss = val_loss_total / val_batches

        metrics = {"train_loss": train_loss, "val_loss": val_loss, "epoch": epoch}
        if train.get_context().get_world_rank() == 0:
            print(metrics)

        # === Save checkpoint only on rank 0 ===
        if get_context().get_world_rank() == 0:
            ckpt_dir = f"/mnt/cluster_storage/food101_lite/tmp_checkpoints/epoch_{epoch}_{uuid.uuid4().hex}"
            os.makedirs(ckpt_dir, exist_ok=True)
            torch.save(model.state_dict(), os.path.join(ckpt_dir, "model.pt"))
            torch.save(optimizer.state_dict(), os.path.join(ckpt_dir, "optimizer.pt"))
            torch.save({"epoch": epoch}, os.path.join(ckpt_dir, "extra.pt"))
            checkpoint = Checkpoint.from_directory(ckpt_dir)
        else:
            checkpoint = None

        # Append metrics to a file (only on rank 0)
        if train.get_context().get_world_rank() == 0:
            with open("/mnt/cluster_storage/food101_lite/results/history.csv", "a") as f:
                f.write(f"{epoch},{train_loss},{val_loss}\n")
        train.report(metrics, checkpoint=checkpoint)

    correct, total = 0, 0
    model.eval()
    for xb, yb in val_loader:
        xb, yb = xb.cuda(), yb.cuda()
        pred = model(xb).argmax(dim=1)
        correct += (pred == yb).sum().item()
        total += yb.size(0)
    accuracy = correct / total
    print(f"Val Accuracy: {accuracy:.2%}")