## 8. Define the Ray Train worker loop (Arrow-based, memory-efficient)  

Each Ray Train worker runs its own copy of `train_func`.  
Inside the loop, the worker pulls its **shard** of the train and validation datasets directly as **Arrow tables**. 

You then:  
1. **Materialize each shard** into a `pyarrow.Table` and drop any accidental index columns (like `__index_level_0__`)  
   that might have been added during Parquet serialization.  
2. **Convert Arrow → NumPy → XGBoost DMatrix** with explicit `feature_names`, ensuring consistent column order  
   across all workers and splits.  
3. **Optionally resume** from a prior checkpoint using `get_checkpoint()`.  
4. **Train the booster** with `xgb.train`, using the built-in `RayTrainReportCallback()` to automatically stream  
   per-round metrics and checkpoints back to Ray Train.  

This design keeps the data path fully distributed and avoids unnecessary copies or manual metric handling.  

In [None]:
INDEX_COLS = {"__index_level_0__"}  # extend if needed

def _arrow_table_from_shard(name: str) -> pa.Table:
    """Collect this worker's Ray Dataset shard into one pyarrow. Table and
    drop accidental index columns (e.g., from pandas Parquet)."""
    ds_iter = get_dataset_shard(name)
    arrow_refs = ds_iter.materialize().to_arrow_refs()
    tables = [ray.get(r) for r in arrow_refs]
    tbl = pa.concat_tables(tables, promote_options="none") if tables else pa.table({})
    # Drop index columns if present
    keep = [c for c in tbl.column_names if c not in INDEX_COLS]
    if len(keep) != len(tbl.column_names):
        tbl = tbl.select(keep)
    return tbl

def _dmat_from_arrow(table: pa.Table, feature_cols, label_col: str):
    """Build XGBoost DMatrix from pyarrow.Table with explicit feature_names."""
    X = np.column_stack([table[c].to_numpy(zero_copy_only=False) for c in feature_cols])
    y = table[label_col].to_numpy(zero_copy_only=False)
    return xgb.DMatrix(X, label=y, feature_names=feature_cols)

def train_func(config):
    label_col = config["label_column"]

    # Arrow tables 
    train_arrow = _arrow_table_from_shard("train")
    eval_arrow  = _arrow_table_from_shard("evaluation")

    # Use the SAME ordered feature list for both splits
    feature_cols = [c for c in train_arrow.column_names if c != label_col]

    dtrain = _dmat_from_arrow(train_arrow, feature_cols, label_col)
    deval  = _dmat_from_arrow(eval_arrow,  feature_cols, label_col)

    # -------- 2) Optional resume from checkpoint ------------------------------
    ckpt = get_checkpoint()
    if ckpt:
        with ckpt.as_directory() as d:
            model_path = os.path.join(d, RayTrainReportCallback.CHECKPOINT_NAME)
            booster = xgb.Booster()
            booster.load_model(model_path)
            print(f"[Rank {get_context().get_world_rank()}] Resumed from checkpoint")
    else:
        booster = None

    # -------- 3) Train with per-round reporting & checkpointing ---------------
    evals_result = {}
    xgb.train(
        params          = config["params"],
        dtrain          = dtrain,
        evals           = [(dtrain, "train"), (deval, "validation")],
        num_boost_round = config["num_boost_round"],
        xgb_model       = booster,
        evals_result    = evals_result,
        callbacks       = [RayTrainReportCallback()],
    )

### 9. Configure XGBoost and build the Trainer  
Next, define the XGBoost hyperparameters and wrap the `train_func` in an `XGBoostTrainer` for distributed execution.  
Each worker is assigned an entire CPU node (`resources_per_worker={"CPU": CPUS_PER_WORKER}`),  
allowing XGBoost to use all local cores efficiently through the `nthread` parameter.  

Key settings:  
- **`ScalingConfig`** — controls how many workers to launch and their CPU/GPU allocation.  
- **`CheckpointConfig`** — saves a checkpoint every 10 boosting rounds and scores each checkpoint by  
  validation log-loss (`validation-mlogloss`), retaining only the best model.  
- **`FailureConfig`** — automatically retries failed workers up to one time for fault tolerance.  

By passing the Ray Datasets directly into the trainer, Ray handles dataset sharding and distributed streaming automatically,  
so each worker trains on its own slice of the data without manual coordination.  

In [None]:
# 09. XGBoost config and Trainer (full-node CPU workers)

# Adjust this to your node size if different (e.g., 16, 32, etc.)
CPUS_PER_WORKER = 4

xgb_params = {
    "objective": "multi:softprob",
    "num_class": 7,
    "eval_metric": "mlogloss",
    "tree_method": "hist",
    "eta": 0.3,
    "max_depth": 8,
    "nthread": CPUS_PER_WORKER,  
}

trainer = XGBoostTrainer(
    train_func,
    scaling_config=ScalingConfig(
        num_workers=2,
        use_gpu=False,
        resources_per_worker={"CPU": CPUS_PER_WORKER},
    ),
    datasets={"train": train_ds, "evaluation": val_ds},
    train_loop_config={
        "label_column": "label",
        "params": xgb_params,
        "num_boost_round": 50,
    },
    run_config=RunConfig(
        name="covtype_xgb_cpu",
        storage_path="/mnt/cluster_storage/covtype/results",
        checkpoint_config=CheckpointConfig(
            checkpoint_frequency=10,
            num_to_keep=1,
            checkpoint_score_attribute="validation-mlogloss",  # score by val loss
            checkpoint_score_order="min",
        ),
        failure_config=FailureConfig(max_failures=1),
    ),
)

### 10. Start distributed training  
`trainer.fit()` blocks until all boosting rounds finish, or until Ray exhausts retries.  The result object contains the last reported metrics and the best checkpoint found so far. Print the final validation log-loss and keep a handle to the checkpoint for inference.

In [None]:
# 10. Fit the trainer (reports eval metrics every boosting round)
result = trainer.fit()
best_ckpt = result.checkpoint            # saved automatically by Trainer 

### 11. Evaluate the trained model  
Pull the XGBoost `Booster` back from the checkpoint, run predictions on the entire validation set, and compute overall accuracy. Converting the Ray Dataset to pandas keeps the example short. In production you stream batches instead of materializing the whole frame.


In [None]:
# 11. Retrieve Booster object from Ray checkpoint
booster = RayTrainReportCallback.get_model(best_ckpt)

# Convert Ray Dataset to pandas for quick local scoring
val_pd = val_ds.to_pandas()
dmatrix = xgb.DMatrix(val_pd[feature_columns])
pred_prob = booster.predict(dmatrix)
pred_labels = np.argmax(pred_prob, axis=1)

acc = accuracy_score(val_pd.label, pred_labels)
print(f"Validation accuracy: {acc:.3f}")