Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 17 additions & 5 deletions src/dvclive/xgb.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# ruff: noqa: ARG002
from typing import Optional
from warnings import warn

from xgboost.callback import TrainingCallback

Expand All @@ -8,18 +9,29 @@

class DVCLiveCallback(TrainingCallback):
def __init__(
self, metric_data, model_file=None, live: Optional[Live] = None, **kwargs
self,
metric_data: Optional[str] = None,
model_file=None,
live: Optional[Live] = None,
**kwargs,
):
super().__init__()
if metric_data is not None:
warn(
"`metric_data` is deprecated and will be removed",
category=DeprecationWarning,
stacklevel=2,
)
self._metric_data = metric_data
self.model_file = model_file
self.live = live if live is not None else Live(**kwargs)

def after_iteration(self, model, epoch, evals_log):
for key, values in evals_log[self._metric_data].items():
if values:
latest_metric = values[-1]
self.live.log_metric(key, latest_metric)
if self._metric_data:
evals_log = {"": evals_log[self._metric_data]}
for subdir, data in evals_log.items():
for key, values in data.items():
self.live.log_metric(f"{subdir}/{key}" if subdir else key, values[-1])
if self.model_file:
model.save_model(self.model_file)
self.live.next_step()
Expand Down
43 changes: 37 additions & 6 deletions tests/test_frameworks/test_xgboost.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
import os
from contextlib import nullcontext

import pytest

from dvclive import Live
from dvclive.plots.metric import Metric
from dvclive.utils import parse_metrics

try:
import numpy as np
import pandas as pd
import xgboost as xgb
from sklearn import datasets
from sklearn.model_selection import train_test_split

from dvclive.xgb import DVCLiveCallback
except ImportError:
Expand All @@ -29,24 +32,52 @@ def iris_data():
return xgb.DMatrix(x, y)


def test_xgb_integration(tmp_dir, train_params, iris_data, mocker):
callback = DVCLiveCallback("eval_data")
@pytest.fixture()
def iris_train_eval_data():
iris = datasets.load_iris()
x_train, x_eval, y_train, y_eval = train_test_split(
iris.data, iris.target, random_state=0
)
return (xgb.DMatrix(x_train, y_train), xgb.DMatrix(x_eval, y_eval))


@pytest.mark.parametrize(
("metric_data", "subdirs", "context"),
[
(
"eval",
("",),
pytest.warns(DeprecationWarning, match="`metric_data`.+deprecated"),
),
(None, ("train", "eval"), nullcontext()),
],
)
def test_xgb_integration(
tmp_dir, train_params, iris_train_eval_data, metric_data, subdirs, context, mocker
):
with context:
callback = DVCLiveCallback(metric_data)
live = callback.live
spy = mocker.spy(live, "end")
data_train, data_eval = iris_train_eval_data
xgb.train(
train_params,
iris_data,
data_train,
callbacks=[callback],
num_boost_round=5,
evals=[(iris_data, "eval_data")],
evals=[(data_train, "train"), (data_eval, "eval")],
)
spy.assert_called_once()

assert os.path.exists("dvclive")

logs, _ = parse_metrics(callback.live)
assert len(logs) == 1
assert len(list(logs.values())[0]) == 5
assert len(logs) == len(subdirs)
assert list(map(len, logs.values())) == [5] * len(logs)
scalars = os.path.join(callback.live.plots_dir, Metric.subfolder)
assert all(
os.path.join(scalars, subdir, "mlogloss.tsv") in logs for subdir in subdirs
)


def test_xgb_model_file(tmp_dir, train_params, iris_data):
Expand Down