Skip to content

Commit

Permalink
FIX: Let Theil-Sen handle n_samples < n_features case
Browse files Browse the repository at this point in the history
Theil-Sen will fall back to Least Squares like if number of samples is
smaller than the number of features.
  • Loading branch information
FlorianWilhelm committed Mar 9, 2014
1 parent d1d221b commit 61a5195
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 3 deletions.
25 changes: 25 additions & 0 deletions sklearn/linear_model/tests/test_theilsen.py
Expand Up @@ -205,6 +205,15 @@ def test__checksubparams_too_many_subsamples():
TheilSen(n_subsamples=101).fit(X, y)


@raises(AssertionError)
def test__checksubparams_n_subsamples_if_less_samples_than_features():
np.random.seed(0)
n_samples, n_features = 10, 20
X = np.random.randn(n_samples*n_features).reshape(n_samples, n_features)
y = np.random.randn(n_samples)
TheilSen(n_subsamples=9).fit(X, y)


def test_subpopulation():
X, y, w, c = gen_toy_problem_4d()
theilsen = TheilSen(max_subpopulation=1000, random_state=0).fit(X, y)
Expand Down Expand Up @@ -264,3 +273,19 @@ def test_theilsen_parallel_all_but_one_CPUs():
def test_theilsen_parallel_no_CPUs():
X, y, w, c = gen_toy_problem_1d()
TheilSen(n_jobs=0).fit(X, y)


def test_less_samples_than_features():
np.random.seed(0)
n_samples, n_features = 10, 20
X = np.random.randn(n_samples*n_features).reshape(n_samples, n_features)
y = np.random.randn(n_samples)
# Check that Theil-Sen falls back to Least Squares if fit_intercept=False
theilsen = TheilSen(fit_intercept=False).fit(X, y)
lstq = LinearRegression(fit_intercept=False).fit(X, y)
nptest.assert_array_almost_equal(theilsen.coef_, lstq.coef_, 12)
# Check fit_intercept=True case. This will not be equal to the Least
# Squares solution since the intercept is calculated differently.
theilsen = TheilSen(fit_intercept=True).fit(X, y)
y_pred = theilsen.predict(X)
nptest.assert_array_almost_equal(y_pred, y, 12)
10 changes: 7 additions & 3 deletions sklearn/linear_model/theilsen.py
Expand Up @@ -241,12 +241,16 @@ def _check_subparams(self, n_samples, n_features):
n_dim = n_features
n_subsamples = self.n_subsamples
if n_subsamples is not None:
assert n_dim <= n_subsamples <= n_samples
assert n_subsamples <= n_samples
if n_samples >= n_features:
assert n_dim <= n_subsamples
else: # if n_samples < n_features
assert n_subsamples == n_samples
else:
n_subsamples = n_dim
n_subsamples = min(n_dim, n_samples)
if self.max_subpopulation <= 0:
raise ValueError("Subpopulation must be positive.")
n_all = binom(n_samples, n_subsamples)
n_all = max(1, binom(n_samples, n_subsamples))
n_sp = int(min(self.max_subpopulation, n_all))
return n_dim, n_subsamples, n_sp

Expand Down

0 comments on commit 61a5195

Please sign in to comment.