diff --git a/examples/linear_model/plot_polynomial_interpolation.py b/examples/linear_model/plot_polynomial_interpolation.py index 267bb144e39f7..623d304e5ecc4 100644 --- a/examples/linear_model/plot_polynomial_interpolation.py +++ b/examples/linear_model/plot_polynomial_interpolation.py @@ -18,18 +18,21 @@ matrix induced by a polynomial kernel. This example shows that you can do non-linear regression with a linear model, -by manually adding non-linear features. Kernel methods extend this idea and can -induce very high (even infinite) dimensional feature spaces. +using a pipeline to add non-linear features. Kernel methods extend this idea +and can induce very high (even infinite) dimensional feature spaces. """ print(__doc__) # Author: Mathieu Blondel +# Jake Vanderplas # License: BSD 3 clause import numpy as np -import pylab as pl +import matplotlib.pyplot as plt from sklearn.linear_model import Ridge +from sklearn.preprocessing import PolynomialFeatures +from sklearn.pipeline import Pipeline def f(x): @@ -46,15 +49,20 @@ def f(x): x = np.sort(x[:20]) y = f(x) -pl.plot(x_plot, f(x_plot), label="ground truth") -pl.scatter(x, y, label="training points") +# create matrix versions of these arrays +X = x[:, np.newaxis] +X_plot = x_plot[:, np.newaxis] + +plt.plot(x_plot, f(x_plot), label="ground truth") +plt.scatter(x, y, label="training points") for degree in [3, 4, 5]: - ridge = Ridge() - ridge.fit(np.vander(x, degree + 1), y) - pl.plot(x_plot, ridge.predict(np.vander(x_plot, degree + 1)), - label="degree %d" % degree) + model = Pipeline([('poly', PolynomialFeatures(degree)), + ('ridge', Ridge())]) + model.fit(X, y) + y_plot = model.predict(X_plot) + plt.plot(x_plot, y_plot, label="degree %d" % degree) -pl.legend(loc='lower left') +plt.legend(loc='lower left') -pl.show() +plt.show() diff --git a/sklearn/preprocessing/__init__.py b/sklearn/preprocessing/__init__.py index 0866b70d99c4e..7b6648172ff8e 100644 --- a/sklearn/preprocessing/__init__.py +++ b/sklearn/preprocessing/__init__.py @@ -15,6 +15,8 @@ from .data import scale from .data import OneHotEncoder +from .data import PolynomialFeatures + from .label import label_binarize from .label import LabelBinarizer from .label import LabelEncoder @@ -34,6 +36,7 @@ 'Scaler', 'StandardScaler', 'add_dummy_feature', + 'PolynomialFeatures', 'binarize', 'normalize', 'scale', diff --git a/sklearn/preprocessing/data.py b/sklearn/preprocessing/data.py index 428d3a7944b4a..8615b1ec9233e 100644 --- a/sklearn/preprocessing/data.py +++ b/sklearn/preprocessing/data.py @@ -6,6 +6,7 @@ import numbers import warnings +import itertools import numpy as np from scipy import sparse @@ -394,6 +395,99 @@ def __init__(self, copy=True, with_mean=True, with_std=True): super(Scaler, self).__init__(copy, with_mean, with_std) +class PolynomialFeatures(BaseEstimator, TransformerMixin): + """Transform to Polynomial Features + + Generate a new feature matrix consisting of all polynomial combinations + of the features with degree less than or equal to the specified degree. + For example, if an input sample is two dimensional and of the form + [a, b], the degree-2 polynomial features are [1, a, b, a^2, ab, b^2]. + + Parameters + ---------- + degree : integer + The degree of the polynomial features. Default = 2. + include_bias : integer + If True (default), then include a bias column, the feature in which + all polynomial powers are zero (i.e. a column of ones - acts as an + intercept term in a linear model). + + Examples + -------- + >>> X = np.arange(6).reshape(3, 2) + >>> X + array([[0, 1], + [2, 3], + [4, 5]]) + >>> poly = PolynomialFeatures(2) + >>> poly.fit_transform(X) + array([[ 1, 0, 1, 0, 0, 1], + [ 1, 2, 3, 4, 6, 9], + [ 1, 4, 5, 16, 20, 25]]) + + Attributes + ---------- + `powers_` : np.ndarray, shape = (Np, n_features) + This is the matrix of powers used to construct the polynomial + features. powers_[i, j] is the exponent of the j^th input + feature in the i^th output feature. + + Notes + ----- + Be aware that the number of features in the output array scales + exponentially in the number of features of the input array, so this + is not suitable for higher-dimensional data. + """ + def __init__(self, degree=2, include_bias=True): + self.degree = degree + self.include_bias = include_bias + + @staticmethod + def _power_matrix(n_features, degree, include_bias): + """Compute the matrix of polynomial powers""" + # Find permutations/combinations which add to degree or less + deg_min = 0 if include_bias else 1 + powers = itertools.product(*(range(degree + 1) + for i in range(n_features))) + powers = np.array([c for c in powers if deg_min <= sum(c) <= degree]) + + # sort so that the order of the powers makes sense + i = np.lexsort(np.vstack([powers.T, powers.sum(1)])) + return powers[i] + + def fit(self, X, y=None): + """ + Compute the polynomial feature combinations + """ + n_samples, n_features = array2d(X).shape + self.powers_ = self._power_matrix(n_features, + self.degree, + self.include_bias) + return self + + def transform(self, X, y=None): + """Transform data to polynomial features + + Parameters + ---------- + X : array with shape [n_samples, n_features] + The data to transform, row by row. + + Returns + ------- + XP : np.ndarray shape [n_samples, NP] + The matrix of features, where NP is the number of polynomial + features generated from the combination of inputs. + """ + X = array2d(X) + n_samples, n_features = X.shape + + if n_features != self.powers_.shape[1]: + raise ValueError("X shape does not match training shape") + + return (X[:, None, :] ** self.powers_).prod(-1) + + def normalize(X, norm='l2', axis=1, copy=True): """Normalize a dataset along any axis diff --git a/sklearn/preprocessing/tests/test_data.py b/sklearn/preprocessing/tests/test_data.py index 5201a8be29093..1f84a4b42dd07 100644 --- a/sklearn/preprocessing/tests/test_data.py +++ b/sklearn/preprocessing/tests/test_data.py @@ -23,6 +23,7 @@ from sklearn.preprocessing.data import scale from sklearn.preprocessing.data import MinMaxScaler from sklearn.preprocessing.data import add_dummy_feature +from sklearn.preprocessing.data import PolynomialFeatures from sklearn import datasets @@ -35,6 +36,32 @@ def toarray(a): return a +def test_polynomial_features(): + """Test Polynomial Features""" + X1 = np.arange(6)[:, np.newaxis] + P1 = np.hstack([np.ones_like(X1), + X1, X1 ** 2, X1 ** 3]) + deg1 = 3 + + X2 = np.arange(6).reshape((3, 2)) + x1 = X2[:, :1] + x2 = X2[:, 1:] + P2 = np.hstack([x1 ** 0 * x2 ** 0, + x1 ** 1 * x2 ** 0, + x1 ** 0 * x2 ** 1, + x1 ** 2 * x2 ** 0, + x1 ** 1 * x2 ** 1, + x1 ** 0 * x2 ** 2]) + deg2 = 2 + + for (deg, X, P) in [(deg1, X1, P1), (deg2, X2, P2)]: + P_test = PolynomialFeatures(deg, include_bias=True).fit_transform(X) + assert_array_almost_equal(P_test, P) + + P_test = PolynomialFeatures(deg, include_bias=False).fit_transform(X) + assert_array_almost_equal(P_test, P[:, 1:]) + + def test_scaler_1d(): """Test scaling of dataset along single axis""" rng = np.random.RandomState(0) diff --git a/sklearn/tests/test_common.py b/sklearn/tests/test_common.py index 42d5e7fc1dddb..2754bb96150bd 100644 --- a/sklearn/tests/test_common.py +++ b/sklearn/tests/test_common.py @@ -49,7 +49,7 @@ 'DictVectorizer', 'LabelBinarizer', 'LabelEncoder', 'TfidfTransformer', 'IsotonicRegression', 'OneHotEncoder', 'RandomTreesEmbedding', 'FeatureHasher', 'DummyClassifier', - 'DummyRegressor', 'TruncatedSVD'] + 'DummyRegressor', 'TruncatedSVD', 'PolynomialFeatures'] def test_all_estimators():