Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
DOC: add an example of robust fitting
This compares RANSAC and ThielSen
- Loading branch information
1 parent
5189cfb
commit 434c20d
Showing
1 changed file
with
88 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
""" | ||
Demo robust fitting | ||
=================== | ||
Here a sine function is fit with a polynomial of order 3, for values | ||
close to zero. | ||
Robust fitting is demoed in different situations: | ||
- No measurement errors, only modelling errors (fitting a sine with a | ||
polynomial) | ||
- Measurement errors in X | ||
- Measurement errors in y | ||
The median absolute deviation to non corrupt new data is used to judge | ||
the quality of the prediction. | ||
What we can see that: | ||
- RANSAC is good for strong outliers in the y direction | ||
- TheilSen is good for small outliers, both in direction X and y, but has | ||
a break point above which it performs worst than OLS. | ||
""" | ||
|
||
from matplotlib import pyplot as plt | ||
import numpy as np | ||
|
||
from sklearn import linear_model, metrics | ||
from sklearn.preprocessing import PolynomialFeatures | ||
from sklearn.pipeline import make_pipeline | ||
|
||
np.random.seed(42) | ||
|
||
X = np.random.normal(size=400) | ||
y = np.sin(X) | ||
# Make sure that it X is 2D | ||
X = X[:, np.newaxis] | ||
|
||
X_test = np.random.normal(size=200) | ||
y_test = np.sin(X_test) | ||
X_test = X_test[:, np.newaxis] | ||
|
||
y_errors = y.copy() | ||
y_errors[::3] = 3 | ||
|
||
X_errors = X.copy() | ||
X_errors[::3] = 3 | ||
|
||
y_errors_large = y.copy() | ||
y_errors_large[::3] = 10 | ||
|
||
X_errors_large = X.copy() | ||
X_errors_large[::3] = 10 | ||
|
||
estimators = [('OLS', linear_model.LinearRegression()), | ||
('Theil-Sen', linear_model.TheilSen(random_state=42)), | ||
('RANSAC', linear_model.RANSACRegressor(random_state=42)), ] | ||
|
||
x_plot = np.linspace(X.min(), X.max()) | ||
|
||
for title, this_X, this_y in [ | ||
('Modeling errors only', X, y), | ||
('Corrupt X, small deviants', X_errors, y), | ||
('Corrupt y, small deviants', X, y_errors), | ||
('Corrupt X, large deviants', X_errors_large, y), | ||
('Corrupt y, large deviants', X, y_errors_large), | ||
]: | ||
plt.figure() | ||
plt.plot(this_X[:, 0], this_y, 'k+') | ||
|
||
for name, estimator in estimators: | ||
model = make_pipeline(PolynomialFeatures(3), estimator) | ||
model.fit(this_X, this_y) | ||
mse = metrics.mean_squared_error(model.predict(X_test), y_test) | ||
y_plot = model.predict(x_plot[:, np.newaxis]) | ||
plt.plot(x_plot, y_plot, | ||
label='%s: error = %.3f' % (name, mse)) | ||
|
||
plt.legend(loc='best', frameon=False, | ||
title='Error: mean absolute deviation to non corrupt data') | ||
plt.xlim(-4, 10.2) | ||
plt.ylim(-2, 10.2) | ||
plt.title(title) | ||
plt.show() |