diff --git a/dvclive/lgbm.py b/dvclive/lgbm.py new file mode 100644 index 00000000..a49b55c0 --- /dev/null +++ b/dvclive/lgbm.py @@ -0,0 +1,16 @@ +import dvclive + + +class DvcLiveCallback: + def __init__(self, model_file=None): + super().__init__() + self.model_file = model_file + + def __call__(self, env): + for eval_result in env.evaluation_result_list: + metric = eval_result[1] + value = eval_result[2] + dvclive.log(metric, value) + if self.model_file: + env.model.save_model(self.model_file) + dvclive.next_step() diff --git a/setup.py b/setup.py index 7cfd159b..56a0a8a0 100644 --- a/setup.py +++ b/setup.py @@ -39,8 +39,9 @@ def run(self): mmcv = ["mmcv", "torch", "torchvision"] tf = ["tensorflow"] xgb = ["xgboost"] +lgbm = ["lightgbm"] -all_libs = mmcv + tf + xgb +all_libs = mmcv + tf + xgb + lgbm tests_requires = [ "pylint==2.5.3", @@ -71,6 +72,7 @@ def run(self): "all": all_libs, "tf": tf, "xgb": xgb, + "lgbm": lgbm, }, keywords="data-science metrics machine-learning developer-tools ai", python_requires=">=3.6", diff --git a/tests/test_lgbm.py b/tests/test_lgbm.py new file mode 100644 index 00000000..8195f5fc --- /dev/null +++ b/tests/test_lgbm.py @@ -0,0 +1,71 @@ +import os + +import lightgbm as lgbm +import numpy as np +import pandas as pd +import pytest +from funcy import first +from sklearn import datasets +from sklearn.model_selection import train_test_split + +import dvclive +from dvclive.lgbm import DvcLiveCallback +from tests.test_main import read_logs + +# pylint: disable=redefined-outer-name, unused-argument + + +@pytest.fixture +def model_params(): + return {"objective": "multiclass", "n_estimators": 5, "seed": 0} + + +@pytest.fixture +def iris_data(): + iris = datasets.load_iris() + x = pd.DataFrame(iris["data"], columns=iris["feature_names"]) + y = iris["target"] + x_train, x_test, y_train, y_test = train_test_split( + x, y, test_size=0.33, random_state=42 + ) + return (x_train, y_train), (x_test, y_test) + + +def test_lgbm_integration(tmp_dir, model_params, iris_data): + dvclive.init("logs") + model = lgbm.LGBMClassifier() + model.set_params(**model_params) + + model.fit( + iris_data[0][0], + iris_data[0][1], + eval_set=(iris_data[1][0], iris_data[1][1]), + eval_metric=["multi_logloss"], + callbacks=[DvcLiveCallback()], + ) + + assert os.path.exists("logs") + + logs, _ = read_logs("logs") + assert len(logs) == 1 + assert len(first(logs.values())) == 5 + + +def test_lgbm_model_file(tmp_dir, model_params, iris_data): + dvclive.init("logs") + model = lgbm.LGBMClassifier() + model.set_params(**model_params) + + model.fit( + iris_data[0][0], + iris_data[0][1], + eval_set=(iris_data[1][0], iris_data[1][1]), + eval_metric=["multi_logloss"], + callbacks=[DvcLiveCallback("lgbm_model")], + ) + + preds = model.predict(iris_data[1][0]) + model2 = lgbm.Booster(model_file="lgbm_model") + preds2 = model2.predict(iris_data[1][0]) + preds2 = np.argmax(preds2, axis=1) + assert np.sum(np.abs(preds2 - preds)) == 0