Skip to content

Commit

Permalink
Merge pull request #6104 from maniteja123/issue5715
Browse files Browse the repository at this point in the history
[MRG+1] Enable pandas input to log_loss
  • Loading branch information
agramfort committed Dec 31, 2015
2 parents f8d7d5f + 0d597c0 commit 4afbcce
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 0 deletions.
1 change: 1 addition & 0 deletions sklearn/metrics/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -1550,6 +1550,7 @@ def log_loss(y_true, y_pred, eps=1e-15, normalize=True, sample_weight=None):
if T.shape[1] == 1:
T = np.append(1 - T, T, axis=1)

y_pred = check_array(y_pred, ensure_2d=False)
# Clipping
Y = np.clip(y_pred, eps, 1 - eps)

Expand Down
19 changes: 19 additions & 0 deletions sklearn/metrics/tests/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from sklearn.utils.testing import assert_warns_message
from sklearn.utils.testing import assert_not_equal
from sklearn.utils.testing import ignore_warnings
from sklearn.utils.mocking import MockDataFrame

from sklearn.metrics import accuracy_score
from sklearn.metrics import average_precision_score
Expand Down Expand Up @@ -53,6 +54,7 @@
###############################################################################
# Utilities for testing


def make_prediction(dataset=None, binary=False):
"""Make some classification predictions on a toy dataset using a SVC
Expand Down Expand Up @@ -1275,6 +1277,23 @@ def test_log_loss():
assert_almost_equal(loss, 1.0383217, decimal=6)


def test_log_loss_pandas_input():
# case when input is a pandas series and dataframe gh-5715
y_tr = np.array(["ham", "spam", "spam", "ham"])
y_pr = np.array([[0.2, 0.7], [0.6, 0.5], [0.4, 0.1], [0.7, 0.2]])
types = [(MockDataFrame, MockDataFrame)]
try:
from pandas import Series, DataFrame
types.append((Series, DataFrame))
except ImportError:
pass
for TrueInputType, PredInputType in types:
# y_pred dataframe, y_true series
y_true, y_pred = TrueInputType(y_tr), PredInputType(y_pr)
loss = log_loss(y_true, y_pred)
assert_almost_equal(loss, 1.0383217, decimal=6)


def test_brier_score_loss():
# Check brier_score_loss function
y_true = np.array([0, 1, 1, 0, 1, 1])
Expand Down

0 comments on commit 4afbcce

Please sign in to comment.