diff --git a/src/dvclive/xgb.py b/src/dvclive/xgb.py index 0c158a92..daf30a2e 100644 --- a/src/dvclive/xgb.py +++ b/src/dvclive/xgb.py @@ -1,5 +1,6 @@ # ruff: noqa: ARG002 from typing import Optional +from warnings import warn from xgboost.callback import TrainingCallback @@ -8,18 +9,29 @@ class DVCLiveCallback(TrainingCallback): def __init__( - self, metric_data, model_file=None, live: Optional[Live] = None, **kwargs + self, + metric_data: Optional[str] = None, + model_file=None, + live: Optional[Live] = None, + **kwargs, ): super().__init__() + if metric_data is not None: + warn( + "`metric_data` is deprecated and will be removed", + category=DeprecationWarning, + stacklevel=2, + ) self._metric_data = metric_data self.model_file = model_file self.live = live if live is not None else Live(**kwargs) def after_iteration(self, model, epoch, evals_log): - for key, values in evals_log[self._metric_data].items(): - if values: - latest_metric = values[-1] - self.live.log_metric(key, latest_metric) + if self._metric_data: + evals_log = {"": evals_log[self._metric_data]} + for subdir, data in evals_log.items(): + for key, values in data.items(): + self.live.log_metric(f"{subdir}/{key}" if subdir else key, values[-1]) if self.model_file: model.save_model(self.model_file) self.live.next_step() diff --git a/tests/test_frameworks/test_xgboost.py b/tests/test_frameworks/test_xgboost.py index 178a02a2..43c8edee 100644 --- a/tests/test_frameworks/test_xgboost.py +++ b/tests/test_frameworks/test_xgboost.py @@ -1,8 +1,10 @@ import os +from contextlib import nullcontext import pytest from dvclive import Live +from dvclive.plots.metric import Metric from dvclive.utils import parse_metrics try: @@ -10,6 +12,7 @@ import pandas as pd import xgboost as xgb from sklearn import datasets + from sklearn.model_selection import train_test_split from dvclive.xgb import DVCLiveCallback except ImportError: @@ -29,24 +32,52 @@ def iris_data(): return xgb.DMatrix(x, y) -def test_xgb_integration(tmp_dir, train_params, iris_data, mocker): - callback = DVCLiveCallback("eval_data") +@pytest.fixture() +def iris_train_eval_data(): + iris = datasets.load_iris() + x_train, x_eval, y_train, y_eval = train_test_split( + iris.data, iris.target, random_state=0 + ) + return (xgb.DMatrix(x_train, y_train), xgb.DMatrix(x_eval, y_eval)) + + +@pytest.mark.parametrize( + ("metric_data", "subdirs", "context"), + [ + ( + "eval", + ("",), + pytest.warns(DeprecationWarning, match="`metric_data`.+deprecated"), + ), + (None, ("train", "eval"), nullcontext()), + ], +) +def test_xgb_integration( + tmp_dir, train_params, iris_train_eval_data, metric_data, subdirs, context, mocker +): + with context: + callback = DVCLiveCallback(metric_data) live = callback.live spy = mocker.spy(live, "end") + data_train, data_eval = iris_train_eval_data xgb.train( train_params, - iris_data, + data_train, callbacks=[callback], num_boost_round=5, - evals=[(iris_data, "eval_data")], + evals=[(data_train, "train"), (data_eval, "eval")], ) spy.assert_called_once() assert os.path.exists("dvclive") logs, _ = parse_metrics(callback.live) - assert len(logs) == 1 - assert len(list(logs.values())[0]) == 5 + assert len(logs) == len(subdirs) + assert list(map(len, logs.values())) == [5] * len(logs) + scalars = os.path.join(callback.live.plots_dir, Metric.subfolder) + assert all( + os.path.join(scalars, subdir, "mlogloss.tsv") in logs for subdir in subdirs + ) def test_xgb_model_file(tmp_dir, train_params, iris_data):