## 8 · Custom per-worker training loop  
Ray Train launches one copy of `train_func` on every worker (16 CPUs in your case).  
Inside the loop you:  
1. Pull the local shard of both the training and validation Ray datasets.  
2. Convert each pandas shard into an XGBoost `DMatrix` (efficient Certificate Signing Request (CSR) format).  
3. Resume from an existing checkpoint if Ray passed one in with `get_checkpoint()`.  
4. Call `xgb.train`, handing it a `RayTrainReportCallback` so that **every boosting round automatically reports metrics**.  

In [None]:
# 08. Custom Ray Train loop for XGBoost (CPU)

def train_func(config):
    """Per-worker training loop executed by Ray Train."""

    # --------------------------------------------------------
    # 1. Pull this worker’s data shard from Ray Datasets
    # --------------------------------------------------------
    label_col   = config["label_column"]
    train_df    = get_dataset_shard("train").materialize().to_pandas()
    eval_df     = get_dataset_shard("evaluation").materialize().to_pandas()
    feature_cols = [c for c in train_df.columns if c != label_col]

    # Convert pandas → DMatrix (fast CSR format used by XGBoost)
    dtrain = xgb.DMatrix(train_df[feature_cols], label=train_df[label_col])
    deval  = xgb.DMatrix(eval_df[feature_cols],  label=eval_df[label_col])

    # --------------------------------------------------------
    # 2. Train booster — RayTrainReportCallback handles:
    #       • per-round ray.train.report(...)
    #       • checkpoint upload to Ray storage
    # --------------------------------------------------------

    # Optional resume from checkpoint (Ray sets this automatically if resuming)
    ckpt = get_checkpoint()
    if ckpt:
        with ckpt.as_directory() as d:
            model_path = os.path.join(d, RayTrainReportCallback.CHECKPOINT_NAME)
            booster = xgb.Booster()
            booster.load_model(model_path)
            print(f"[Rank {get_context().get_world_rank()}] Resumed from checkpoint")
    else:
        booster = None
    
    evals_result = {}  # <- XGBoost fills this with per-iteration metrics

    xgb.train(
        params          = config["params"],
        dtrain          = dtrain,
        evals           = [(dtrain, "train"), (deval, "validation")],  # ← CHANGED label only
        num_boost_round = config["num_boost_round"],
        xgb_model       = booster,  # <- resumes if booster is not None
        evals_result    = evals_result,  # <- NEW: capture metrics per round
        callbacks=[
            RayTrainReportCallback()  # ← CHANGED: let it auto-collect metrics
        ],
    )
    # --------------------------------------------------------
    # 3. Rank-0 writes metrics JSON to the shared path
    # --------------------------------------------------------

    if get_context().get_world_rank() == 0:
        out_json_path = config["out_json_path"]

        # Optionally add a quick “best” summary for convenience
        v_hist = evals_result.get("validation", {}).get("mlogloss", [])
        best_idx = int(np.argmin(v_hist)) if len(v_hist) else None
        payload = {
            "evals_result": evals_result,
            "best": {
                "iteration": (best_idx + 1) if best_idx is not None else None,
                "validation-mlogloss": (float(v_hist[best_idx]) if best_idx is not None else None),
            },
        }

        os.makedirs(os.path.dirname(out_json_path), exist_ok=True)
        with open(out_json_path, "w") as f:
            json.dump(payload, f)
        print(f"[Rank 0] Wrote metrics JSON → {out_json_path}")

### 9 · Configure XGBoost and build the Trainer  
Here you define all model hyper-parameters (objective, number of classes, CPU tree method, etc.) and wrap `train_func` inside an `XGBoostTrainer`.  
* `ScalingConfig(num_workers=16, use_gpu=False)` allocates one CPU per worker.  
* `CheckpointConfig(checkpoint_frequency=10, num_to_keep=3)` keeps the three most recent checkpoints.  
* `FailureConfig(max_failures=1)` tells Ray to retry training up to one time if a worker crashes.  
Because you pass the Ray Datasets directly, Ray takes care of sharding them evenly across workers.

In [None]:
# 09. XGBoost config + Trainer (uses train_func above)
xgb_params = {
    "objective": "multi:softprob",
    "num_class": 7,
    "eval_metric": "mlogloss",
    "tree_method": "hist",  # CPU histogram algorithm 
    "eta": 0.3,
    "max_depth": 8,
}

trainer = XGBoostTrainer(
    train_func,                
    scaling_config   = ScalingConfig(num_workers=16, use_gpu=False),
    datasets         = {"train": train_ds, "evaluation": val_ds},
    train_loop_config={
        "label_column": "label",
        "params": xgb_params,
        "num_boost_round": 50,  # Increase or decrease to adjust training iterations
        "out_json_path": "/mnt/cluster_storage/covtype/results/covtype_xgb_cpu/metrics.json",
    },
    run_config       = RunConfig(
        name="covtype_xgb_cpu",
        storage_path="/mnt/cluster_storage/covtype/results",
        checkpoint_config=CheckpointConfig(checkpoint_frequency=10, num_to_keep=1),
        failure_config=FailureConfig(max_failures=1),  # resume up to 3 times
    ),
)

### 10 · Start distributed training  
`trainer.fit()` blocks until all boosting rounds finish (or until Ray exhausts retries).  The result object contains the last reported metrics and the best checkpoint found so far. Print the final validation log-loss and keep a handle to the checkpoint for inference.

In [None]:
# 10. Fit the trainer (reports eval metrics every boosting round)
result = trainer.fit()
best_ckpt = result.checkpoint            # saved automatically by Trainer 

### 11 · Plot log-loss over boosting rounds  
During training you captured the full per-round evaluation history using XGBoost’s built-in `evals_result` and saved it to JSON. Reloading that JSON now gives you both training and validation log-loss values for each boosting round. Plotting these lists against their round index shows how the model converges: training loss decrease steadily, while validation loss follows, maintaining a small gap.

In [None]:
# 11. Plot evaluation history from saved JSON

with open("/mnt/cluster_storage/covtype/results/covtype_xgb_cpu/metrics.json") as f:
    payload = json.load(f)

hist = payload["evals_result"]
train = hist["train"]["mlogloss"]
val   = hist["validation"]["mlogloss"]

xs = np.arange(1, len(val) + 1)
plt.figure(figsize=(7,4))
plt.plot(xs, train, label="Train")
plt.plot(xs, val,   label="Val")
plt.xlabel("Boosting round"); plt.ylabel("Log-loss"); plt.title("XGBoost log-loss")
plt.grid(True); plt.legend(); plt.tight_layout(); plt.show()

best = payload["best"]["validation-mlogloss"]
print("Best validation log-loss:", best)