Skip to content

Commit

Permalink
Update test for higher python versions
Browse files Browse the repository at this point in the history
  • Loading branch information
thieu1995 committed Feb 23, 2024
1 parent b776794 commit 9d667cd
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 26 deletions.
36 changes: 17 additions & 19 deletions tests/test_comparisons/test_sklearn_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@
from permetrics import ClassificationMetric


def is_close_enough(x1, x2, eps=1e-5):
if abs(x1 - x2) <= eps:
return True
return False


@pytest.fixture(scope="module") # scope: Call only 1 time at the beginning
def data():
y_true1 = np.array([0, 1, 0, 0, 1, 0, 0, 1, 1, 0])
Expand Down Expand Up @@ -44,64 +50,56 @@ def test_AS(data):
(y_true1, y_pred1), (y_true2, y_pred2), (y_true3, y_pred3), cm1, cm2, cm3 = data
res11 = cm1.PS(average="micro")
res12 = accuracy_score(y_true1, y_pred1)
assert res11 == res12
assert is_close_enough(res11, res12)

# res21 = cm2.PS(average="micro")
# res22 = accuracy_score(y_true2, y_pred2) # ValueError: Classification metrics can't handle a mix of multiclass and continuous-multioutput targets
# assert res21 == res22
# assert is_close_enough(res21, res22)

# res31 = cm3.PS(average="micro")
# res32 = accuracy_score(y_true3, y_pred3) # ValueError: Classification metrics can't handle a mix of multiclass and continuous-multioutput targets
# assert res31 == res32

# avg_paras = [None, "macro", "micro", "weighted"]
# outs = (dict, float, float, float)
#
# for idx, avg in enumerate(avg_paras):
# for cm in data:
# res = cm.PS(average=avg)
# assert isinstance(res, outs[idx])
# assert is_close_enough(res31, res32)


def test_F1S(data):
(y_true1, y_pred1), (y_true2, y_pred2), (y_true3, y_pred3), cm1, cm2, cm3 = data
res11 = cm1.F1S(average="micro")
res12 = f1_score(y_true1, y_pred1, average="micro")
assert res11 == res12
assert is_close_enough(res11, res12)

res11 = cm1.F1S(average="macro")
res12 = f1_score(y_true1, y_pred1, average="macro")
assert res11 == res12
assert is_close_enough(res11, res12)


def test_FBS(data):
(y_true1, y_pred1), (y_true2, y_pred2), (y_true3, y_pred3), cm1, cm2, cm3 = data
res11 = cm1.FBS(average="micro", beta=1.5)
res12 = fbeta_score(y_true1, y_pred1, average="micro", beta=1.5)
assert res11 == res12
assert is_close_enough(res11, res12)

res11 = cm1.FBS(average="macro", beta=2.0)
res12 = fbeta_score(y_true1, y_pred1, average="macro", beta=2.0)
assert res11 == res12
assert is_close_enough(res11, res12)


def test_PS(data):
(y_true1, y_pred1), (y_true2, y_pred2), (y_true3, y_pred3), cm1, cm2, cm3 = data
res11 = cm1.PS(average="micro")
res12 = precision_score(y_true1, y_pred1, average="micro")
assert res11 == res12
assert is_close_enough(res11, res12)

res11 = cm1.PS(average="macro")
res12 = precision_score(y_true1, y_pred1, average="macro")
assert res11 == res12
assert is_close_enough(res11, res12)


def test_RS(data):
(y_true1, y_pred1), (y_true2, y_pred2), (y_true3, y_pred3), cm1, cm2, cm3 = data
res11 = cm1.RS(average="micro")
res12 = recall_score(y_true1, y_pred1, average="micro")
assert res11 == res12
assert is_close_enough(res11, res12)

res11 = cm1.RS(average="macro")
res12 = recall_score(y_true1, y_pred1, average="macro")
assert res11 == res12
assert is_close_enough(res11, res12)
20 changes: 13 additions & 7 deletions tests/test_comparisons/test_sklearn_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@
from permetrics import RegressionMetric


def is_close_enough(x1, x2, eps=1e-5):
if abs(x1 - x2) <= eps:
return True
return False


@pytest.fixture(scope="module") # scope: Call only 1 time at the beginning
def data():
y_true = np.array([3, -0.5, 2, 7, 5, 3, 4, -3, 10])
Expand All @@ -24,46 +30,46 @@ def test_EVS(data):
y_true, y_pred, rm = data
res11 = rm.EVS()
res12 = explained_variance_score(y_true, y_pred)
assert res11 == res12
assert is_close_enough(res11, res12)


def test_ME(data):
y_true, y_pred, rm = data
res11 = rm.ME()
res12 = max_error(y_true, y_pred)
assert res11 == res12
assert is_close_enough(res11, res12)


def test_MAE(data):
y_true, y_pred, rm = data
res11 = rm.MAE()
res12 = mean_absolute_error(y_true, y_pred)
assert res11 == res12
assert is_close_enough(res11, res12)


def test_MSE(data):
y_true, y_pred, rm = data
res11 = rm.MSE()
res12 = mean_squared_error(y_true, y_pred)
assert res11 == res12
assert is_close_enough(res11, res12)


def test_MedAE(data):
y_true, y_pred, rm = data
res11 = rm.MedAE()
res12 = median_absolute_error(y_true, y_pred)
assert res11 == res12
assert is_close_enough(res11, res12)


def test_R2(data):
y_true, y_pred, rm = data
res11 = rm.R2()
res12 = r2_score(y_true, y_pred)
assert res11 == res12
assert is_close_enough(res11, res12)


def test_MAPE(data):
y_true, y_pred, rm = data
res11 = rm.MAPE()
res12 = mean_absolute_percentage_error(y_true, y_pred)
assert res11 == res12
assert is_close_enough(res11, res12)

0 comments on commit 9d667cd

Please sign in to comment.