Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions dvclive/lgbm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import dvclive


class DvcLiveCallback:
def __init__(self, model_file=None):
super().__init__()
self.model_file = model_file

def __call__(self, env):
for eval_result in env.evaluation_result_list:
metric = eval_result[1]
value = eval_result[2]
dvclive.log(metric, value)
if self.model_file:
env.model.save_model(self.model_file)
dvclive.next_step()
4 changes: 3 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,9 @@ def run(self):
mmcv = ["mmcv", "torch", "torchvision"]
tf = ["tensorflow"]
xgb = ["xgboost"]
lgbm = ["lightgbm"]

all_libs = mmcv + tf + xgb
all_libs = mmcv + tf + xgb + lgbm

tests_requires = [
"pylint==2.5.3",
Expand Down Expand Up @@ -71,6 +72,7 @@ def run(self):
"all": all_libs,
"tf": tf,
"xgb": xgb,
"lgbm": lgbm,
},
keywords="data-science metrics machine-learning developer-tools ai",
python_requires=">=3.6",
Expand Down
71 changes: 71 additions & 0 deletions tests/test_lgbm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import os

import lightgbm as lgbm
import numpy as np
import pandas as pd
import pytest
from funcy import first
from sklearn import datasets
from sklearn.model_selection import train_test_split

import dvclive
from dvclive.lgbm import DvcLiveCallback
from tests.test_main import read_logs

# pylint: disable=redefined-outer-name, unused-argument


@pytest.fixture
def model_params():
return {"objective": "multiclass", "n_estimators": 5, "seed": 0}


@pytest.fixture
def iris_data():
iris = datasets.load_iris()
x = pd.DataFrame(iris["data"], columns=iris["feature_names"])
y = iris["target"]
x_train, x_test, y_train, y_test = train_test_split(
x, y, test_size=0.33, random_state=42
)
return (x_train, y_train), (x_test, y_test)


def test_lgbm_integration(tmp_dir, model_params, iris_data):
dvclive.init("logs")
model = lgbm.LGBMClassifier()
model.set_params(**model_params)

model.fit(
iris_data[0][0],
iris_data[0][1],
eval_set=(iris_data[1][0], iris_data[1][1]),
eval_metric=["multi_logloss"],
callbacks=[DvcLiveCallback()],
)

assert os.path.exists("logs")

logs, _ = read_logs("logs")
assert len(logs) == 1
assert len(first(logs.values())) == 5


def test_lgbm_model_file(tmp_dir, model_params, iris_data):
dvclive.init("logs")
model = lgbm.LGBMClassifier()
model.set_params(**model_params)

model.fit(
iris_data[0][0],
iris_data[0][1],
eval_set=(iris_data[1][0], iris_data[1][1]),
eval_metric=["multi_logloss"],
callbacks=[DvcLiveCallback("lgbm_model")],
)

preds = model.predict(iris_data[1][0])
model2 = lgbm.Booster(model_file="lgbm_model")
preds2 = model2.predict(iris_data[1][0])
preds2 = np.argmax(preds2, axis=1)
assert np.sum(np.abs(preds2 - preds)) == 0