Skip to content

Commit

Permalink
Merge branch 'master' into sc
Browse files Browse the repository at this point in the history
  • Loading branch information
vene committed Jun 26, 2011
2 parents 24c3a68 + 5698a4c commit 9847371
Show file tree
Hide file tree
Showing 38 changed files with 669 additions and 377 deletions.
4 changes: 2 additions & 2 deletions doc/developers/index.rst
Expand Up @@ -99,7 +99,7 @@ request. This will send an email to the commiters, but might also send an
email to the mailing list in order to get more visibility.

It is recommented to check that your contribution complies with the following
rules before submitting a pull request::
rules before submitting a pull request:

* Follow the `coding-guidelines`_ (see below).

Expand Down Expand Up @@ -127,7 +127,7 @@ rules before submitting a pull request::

To build the documentation see `documentation`_ below.

You can also check for common programming errors with the following tools::
You can also check for common programming errors with the following tools:

* Code with a good unittest coverage (at least 80%), check with::

Expand Down
1 change: 1 addition & 0 deletions doc/modules/classes.rst
Expand Up @@ -440,6 +440,7 @@ Cross Validation
cross_val.LeaveOneLabelOut
cross_val.LeavePLabelOut
cross_val.Bootstrap
cross_val.ShuffleSplit


Grid Search
Expand Down
10 changes: 7 additions & 3 deletions doc/modules/decomposition.rst
Expand Up @@ -170,9 +170,13 @@ implemented here is based on [Mrl09]_ .
Independent component analysis (ICA)
====================================

ICA finds components that are maximally independent. It is classically
used to separate mixed signals (a problem know as *blind source
separation*), as in the example below:
Independent component analysis separates a multivariate signal into
additive subcomponents that are maximally independent. It is
implemented in scikit-learn using the :class:`Fast ICA <FastICA>`
algorithm.

It is classically used to separate mixed signals (a problem known as
*blind source separation*), as in the example below:

.. figure:: ../auto_examples/decomposition/images/plot_ica_blind_source_separation_1.png
:target: ../auto_examples/decomposition/plot_ica_blind_source_separation.html
Expand Down
8 changes: 5 additions & 3 deletions doc/modules/linear_model.rst
Expand Up @@ -152,7 +152,8 @@ the coefficients. See :ref:`least_angle_regression` for another implementation.

>>> clf = linear_model.Lasso(alpha = 0.1)
>>> clf.fit ([[0, 0], [1, 1]], [0, 1])
Lasso(alpha=0.1, fit_intercept=True)
Lasso(precompute='auto', alpha=0.1, max_iter=1000, tol=0.0001,
fit_intercept=True)
>>> clf.predict ([[1, 1]])
array([ 0.8])

Expand Down Expand Up @@ -240,9 +241,10 @@ function of the norm of its coefficients.
>>> from scikits.learn import linear_model
>>> clf = linear_model.LassoLARS(alpha=.1)
>>> clf.fit ([[0, 0], [1, 1]], [0, 1])
LassoLARS(alpha=0.1, max_iter=500, verbose=False, fit_intercept=True)
LassoLARS(normalize=True, verbose=False, fit_intercept=True, max_iter=500,
precompute='auto', alpha=0.1)
>>> clf.coef_
array([ 0.30710678, 0. ])
array([ 0.50710678, 0. ])

.. topic:: Examples:

Expand Down
2 changes: 1 addition & 1 deletion doc/modules/neighbors.rst
Expand Up @@ -10,7 +10,7 @@ using a majority vote among the k neighbors.

Despite its simplicity, nearest neighbors has been successful in a
large number of classification problems, including handwritten digits
or satellite image scenes. It is ofter successful in situation where
or satellite image scenes. It is often successful in situation where
the decision boundary is very irregular.

Classification
Expand Down
20 changes: 15 additions & 5 deletions examples/linear_model/plot_lasso_path_crossval.py
Expand Up @@ -3,8 +3,14 @@
Cross validated Lasso path with coordinate descent
==================================================
Compute a 5-fold cross-validated Lasso path with coordinate descent to find the
optimal value of alpha.
Compute a 20-fold cross-validated Lasso path with coordinate descent to
find the optimal value of alpha.
Note how the optimal value of alpha varies for each fold. This
illustrates why nested-cross validation is necessary when trying to
evaluate the performance of a method for which a parameter is chosen by
cross-validation: this choice of parameter may not be optimal for unseen
data.
"""
print __doc__

Expand All @@ -15,6 +21,7 @@
import pylab as pl

from scikits.learn.linear_model import LassoCV
from scikits.learn.cross_val import KFold
from scikits.learn import datasets

diabetes = datasets.load_diabetes()
Expand All @@ -30,7 +37,7 @@
eps = 1e-3 # the smaller it is the longer is the path

print "Computing regularization path using the lasso..."
model = LassoCV(eps=eps).fit(X, y)
model = LassoCV(eps=eps, cv=KFold(len(y), 20)).fit(X, y)

##############################################################################
# Display results
Expand All @@ -51,9 +58,12 @@
pl.axis('tight')

pl.subplot(2, 1, 2)
ymin, ymax = 2600, 3800
pl.plot(m_log_alphas, model.mse_path_)
ymin, ymax = 2300, 3800
pl.plot(m_log_alphas, model.mse_path_, '--')
pl.plot(m_log_alphas, model.mse_path_.mean(axis=-1), 'k',
label='Average accross the folds')
pl.vlines([m_log_alpha], ymin, ymax, linestyle='dashed')
pl.legend(loc='best')

pl.xlabel('-log(lambda)')
pl.ylabel('MSE')
Expand Down
16 changes: 11 additions & 5 deletions scikits/learn/base.py
Expand Up @@ -257,11 +257,10 @@ class TransformerMixin(object):
"""

def fit_transform(self, X, y=None, **fit_params):
"""Fit model to data and subsequently transform the data
"""Fit to data, then transform it
Sometimes, fit and transform can be implemented more efficiently
jointly than separately. In those cases, the estimator will typically
override the method.
Fits transformer to X and y with optional parameters fit_params
and returns a transformed version of X.
Parameters
----------
Expand All @@ -273,7 +272,14 @@ def fit_transform(self, X, y=None, **fit_params):
Returns
-------
self : returns an instance of self.
X_new : numpy array of shape [n_samples, n_features_new]
Transformed array.
Note
-----
This method just calls fit and transform consecutively, i.e., it is not
an optimized implementation of fit_transform, unlike other transformers
such as PCA.
"""
if y is None:
# fit method of arity 1 (unsupervised transformation)
Expand Down
94 changes: 92 additions & 2 deletions scikits/learn/cross_val.py
Expand Up @@ -320,8 +320,6 @@ def __len__(self):
return self.k


##############################################################################

class LeaveOneLabelOut(object):
"""Leave-One-Label_Out cross-validation iterator
Expand Down Expand Up @@ -599,6 +597,98 @@ def __len__(self):
return self.n_bootstraps


class ShuffleSplit(object):
"""Random split cross-validation iterator
Provides train/test indices to split data in train test sets
"""

def __init__(self, n, n_splits=20, test_fraction=0.1,
indices=False, random_state=None):
"""Random split cross validation
Provides train/test indices to split data in .
Note: contrary to other cross-validation strategies, random
splits does not garanty that all folds will be different,
although this is unlikely for sizeable datasets
Parameters
----------
n : int
Total number of elements in the dataset.
n_splits : int (default is 20)
Number of splitting iterations
test_fraction: float (default is 0.1)
should be between 0.0 and 1.0 and represent the proportion of
the dataset to include in the test split.
indices: boolean, optional (default False)
Return train/test split with integer indices or boolean mask.
Integer indices are useful when dealing with sparse matrices
that cannot be indexed by boolean masks.
random_state : int or RandomState
Pseudo number generator state used for random sampling.
Examples
----------
>>> from scikits.learn import cross_val
>>> rs = cross_val.ShuffleSplit(4, n_splits=3, test_fraction=.25, random_state=0)
>>> len(rs)
3
>>> print rs
ShuffleSplit(4, n_splits=3, test_fraction=0.25, indices=False, random_state=0)
>>> for train_index, test_index in rs:
... print "TRAIN:", train_index, "TEST:", test_index
...
TRAIN: [False True True True] TEST: [ True False False False]
TRAIN: [ True True True False] TEST: [False False False True]
TRAIN: [ True False True True] TEST: [False True False False]
"""
self.n = n
self.n_splits = n_splits
self.test_fraction = test_fraction
self.random_state = random_state
self.indices = indices

def __iter__(self):
rng = self.random_state = check_random_state(self.random_state)
n_test = ceil(self.test_fraction * self.n)
for i in range(self.n_splits):
# random partition
permutation = rng.permutation(self.n)
ind_train = permutation[:-n_test]
ind_test = permutation[-n_test:]

if self.indices:
yield ind_train, ind_test
else:
train_mask = np.zeros(self.n, dtype=np.bool)
train_mask[ind_train] = True
test_mask = np.zeros(self.n, dtype=np.bool)
test_mask[ind_test] = True
yield train_mask, test_mask

def __repr__(self):
return ('%s(%d, n_splits=%d, test_fraction=%s, indices=%s, '
'random_state=%d)' % (
self.__class__.__name__,
self.n,
self.n_splits,
str(self.test_fraction),
self.indices,
self.random_state,
))

def __len__(self):
return self.n_splits


##############################################################################

def _cross_val_score(estimator, X, y, score_func, train, test, iid):
"""Inner loop for cross validation"""
if score_func is None:
Expand Down
58 changes: 24 additions & 34 deletions scikits/learn/datasets/_svmlight_format.cpp
Expand Up @@ -25,11 +25,8 @@
#include <Python.h>
#include <numpy/arrayobject.h>

#include <cctype>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <fstream>
#include <sstream>
#include <stdexcept>
#include <string>
#include <vector>
Expand Down Expand Up @@ -197,8 +194,8 @@ static PyObject *to_csr(std::vector<double> &data,

class SyntaxError : public std::runtime_error {
public:
SyntaxError(char const *msg)
: std::runtime_error(std::string(msg) + " in SVMlight/libSVM file")
SyntaxError(std::string const &msg)
: std::runtime_error(msg + " in SVMlight/libSVM file")
{
}
};
Expand All @@ -215,40 +212,30 @@ void parse_line(const std::string& line,
if (line.length() == 0)
throw SyntaxError("empty line");

// Parse label
// FIXME: this should be done using standard C++ IOstream facilities,
// so we don't need to read the lines into strings first and get better
// error handling.
const char *in_string = line.c_str();
double y;
if (line[0] == '#')
return;

if (!std::sscanf(in_string, "%lf", &y))
// FIXME: we shouldn't be parsing line-by-line.
// Also, we might catch more syntax errors with failbit.
std::istringstream in(line);
in.exceptions(std::ios::badbit);

double y;
if (!(in >> y))
throw SyntaxError("non-numeric or missing label");

labels.push_back(y);

const char* position;
position = std::strchr(in_string, ' ') + 1;

indptr.push_back(data.size());

// Parse feature-value pairs.
for ( ;
(position
&& position < in_string + line.length()
&& position[0] != '#');
position = std::strchr(position, ' ')) {

// Consume multiple spaces, if needed.
while (std::isspace(*position))
position++;

// Parse the feature-value pair.
int id = std::atoi(position);
position = std::strchr(position, ':') + 1;
double value = std::atof(position);
indices.push_back(id);
data.push_back(value);
char c;
double x;
unsigned idx;

while (in >> idx >> c >> x) {
if (c != ':')
throw SyntaxError(std::string("expected ':', got '") + c + "'");
indices.push_back(int(idx));
data.push_back(x);
}
}

Expand All @@ -269,6 +256,9 @@ void parse_file(char const *file_path,
file_stream.rdbuf()->pubsetbuf(&buffer[0], buffer_size);
file_stream.open(file_path);

if (!file_stream)
throw std::ios_base::failure("File doesn't exist!");

std::string line;
while (std::getline(file_stream, line))
parse_line(line, data, indices, indptr, labels);
Expand Down
7 changes: 5 additions & 2 deletions scikits/learn/datasets/tests/data/svmlight_classification.txt
@@ -1,3 +1,6 @@
1.0 2:2.5 10:-5.2 15:1.5
2.0 5:1.0 12:-3
# comment
# note: the next line contains a tab
1.0 2:2.5 10:-5.2 15:1.5 # and an inline comment
2.0 5:1.0 12:-3
# another comment
3.0 20:27
4 changes: 4 additions & 0 deletions scikits/learn/datasets/tests/test_svmlight_format.py
Expand Up @@ -67,3 +67,7 @@ def test_load_invalid_file():
@raises(TypeError)
def test_not_a_filename():
load_svmlight_file(1)

@raises(IOError)
def test_invalid_filename():
load_svmlight_file("trou pic nic douille")

0 comments on commit 9847371

Please sign in to comment.