Skip to content

Commit

Permalink
ENH add sample_weight support to dummy classifier
Browse files Browse the repository at this point in the history
  • Loading branch information
arjoly committed Jul 18, 2014
1 parent 3f5e691 commit f9773ba
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 10 deletions.
6 changes: 4 additions & 2 deletions sklearn/dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Arnaud Joly <a.joly@ulg.ac.be>
# Maheshakya Wijewardena<maheshakya.10@cse.mrt.ac.lk>
# License: BSD 3 clause
from __future__ import division

import numpy as np

Expand Down Expand Up @@ -65,7 +66,7 @@ def __init__(self, strategy="stratified", random_state=None,
self.random_state = random_state
self.constant = constant

def fit(self, X, y):
def fit(self, X, y, sample_weight=None):
"""Fit the random classifier.
Parameters
Expand Down Expand Up @@ -111,7 +112,8 @@ def fit(self, X, y):
classes, y_k = np.unique(y[:, k], return_inverse=True)
self.classes_.append(classes)
self.n_classes_.append(classes.shape[0])
self.class_prior_.append(np.bincount(y_k) / float(y_k.shape[0]))
class_prior = np.bincount(y_k, weights=sample_weight)
self.class_prior_.append(class_prior / class_prior.sum())

# Checking in case of constant strategy if the constant provided
# by the user is in y.
Expand Down
25 changes: 17 additions & 8 deletions sklearn/tests/test_dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
import numpy as np

from sklearn.base import clone
from sklearn.externals.six.moves import xrange
from sklearn.utils.testing import (assert_array_equal,
assert_equal,
assert_almost_equal,
assert_raises)
from sklearn.utils.testing import assert_array_equal
from sklearn.utils.testing import assert_array_almost_equal
from sklearn.utils.testing import assert_equal
from sklearn.utils.testing import assert_almost_equal
from sklearn.utils.testing import assert_raises

from sklearn.dummy import DummyClassifier, DummyRegressor

Expand All @@ -29,7 +29,7 @@ def _check_predict_proba(clf, X, y):
proba = [proba]
log_proba = [log_proba]

for k in xrange(n_outputs):
for k in range(n_outputs):
assert_equal(proba[k].shape[0], n_samples)
assert_equal(proba[k].shape[1], len(np.unique(y[:, k])))
assert_array_equal(proba[k].sum(axis=1), np.ones(len(X)))
Expand Down Expand Up @@ -136,7 +136,7 @@ def test_stratified_strategy_multioutput():
X = [[0]] * 500
y_pred = clf.predict(X)

for k in xrange(y.shape[1]):
for k in range(y.shape[1]):
p = np.bincount(y_pred[:, k]) / float(len(X))
assert_almost_equal(p[1], 3. / 5, decimal=1)
assert_almost_equal(p[2], 2. / 5, decimal=1)
Expand Down Expand Up @@ -171,7 +171,7 @@ def test_uniform_strategy_multioutput():
X = [[0]] * 500
y_pred = clf.predict(X)

for k in xrange(y.shape[1]):
for k in range(y.shape[1]):
p = np.bincount(y_pred[:, k]) / float(len(X))
assert_almost_equal(p[1], 0.5, decimal=1)
assert_almost_equal(p[2], 0.5, decimal=1)
Expand Down Expand Up @@ -388,3 +388,12 @@ def test_constant_strategy_exceptions():
clf = DummyClassifier(strategy="constant", random_state=0,
constant=[2, 0])
assert_raises(ValueError, clf.fit, X, y)


def test_classification_sample_weight():
X = [[0], [0], [1]]
y = [0, 1, 0]
sample_weight = [0.1, 1., 0.1]

clf = DummyClassifier().fit(X, y, sample_weight)
assert_array_almost_equal(clf.class_prior_, [0.2 / 1.2, 1. / 1.2])

0 comments on commit f9773ba

Please sign in to comment.