Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP : added a module for "Label Propagation" #301

Closed
wants to merge 94 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
94 commits
Select commit Hold shift + click to select a range
fe62153
added label propagation class
clayw Jul 18, 2011
c69a904
switch map and sum commands to numpy
clayw Jul 20, 2011
f801626
fixing up tests, adding "unlabeled_identifier"
clayw Jul 23, 2011
b0eeaa2
basic features of multiclass labeling up
clayw Jul 26, 2011
9a4a684
fixing the way labeling works
clayw Jul 28, 2011
436f0d3
checking in minor changes
clayw Jul 29, 2011
8c7cc15
added documentation, reworking tests
clayw Jul 30, 2011
071d14b
fixing up tests
clayw Aug 2, 2011
b1f1dcd
added a lot more to label propagation, explained algorithms and diffe…
clayw Aug 5, 2011
ea5a422
more documentation
clayw Aug 5, 2011
e664434
added beginning of examples
clayw Aug 5, 2011
49e6f5e
added "structure" example
clayw Aug 8, 2011
bcf48f9
tweaked structure plot
clayw Aug 8, 2011
d814dbf
finalized SVM comparison example
clayw Aug 8, 2011
840ef6e
all tests pass
clayw Aug 8, 2011
b1abfdc
removed some stuff from documentation
clayw Aug 8, 2011
2c63ba0
updated pydoc to make behaviour clearer
clayw Aug 8, 2011
a9eef34
passed PEP8, using already implemented kernel functions
clayw Aug 9, 2011
d52985e
making everything more numpy compatible
clayw Aug 9, 2011
aa77f82
graph construction and example more numpy-like
clayw Aug 9, 2011
e53ecc7
fixed other diagonal matrix construction
clayw Aug 9, 2011
877068c
rename misnamed "plot" example
clayw Aug 9, 2011
12fc1dc
example conforms to pep8
clayw Aug 9, 2011
6d4231c
other example conforms to pep8
clayw Aug 9, 2011
ac54d65
made test conform to pep8
clayw Aug 9, 2011
88dee75
predict() method now numpy friendly (100% numpy friendly now)
clayw Aug 10, 2011
fef95fc
more numpy integration
clayw Aug 10, 2011
882fa66
removed function kernel, switched to string for picklability
clayw Aug 12, 2011
803e8db
fixed a bug in the circle example
clayw Aug 12, 2011
8353bc1
moved label propagation examples to lower subfolder
clayw Aug 12, 2011
97b537c
more numpy friendliness
clayw Aug 12, 2011
24c0109
more numpy use,
clayw Aug 12, 2011
e7412c6
fine tuned some documentation
clayw Aug 12, 2011
0c03b91
added a snazzy label propagation versus SVM decision boundary plot
clayw Aug 14, 2011
1505c71
added more explanation to the plot
clayw Aug 14, 2011
208a70e
added semi_supervised directory
clayw Aug 14, 2011
6a99439
removed old, useless code
clayw Aug 14, 2011
9d1da53
removed unused imports
clayw Aug 14, 2011
7a0382c
added more documentation, another doctest for LabelSpreading
clayw Aug 14, 2011
8e06576
minor tweaks to the overall layout of the code
clayw Aug 14, 2011
5071a05
reverted plot_iris accidental commit
clayw Aug 14, 2011
0c2196a
added unlabeled_identifier explanation to docstrings
clayw Aug 14, 2011
1fcb4f8
Merge remote-tracking branch 'upstream/master'
clayw Aug 16, 2011
2cd1c83
fixed indentation problem in documentation rst
clayw Aug 16, 2011
2a7af09
conformance to pep8
clayw Aug 16, 2011
84090fc
fixed bug in tests causing gram matrix construction to not work prope…
clayw Aug 16, 2011
a3dc4e7
added two new examples, including an active learning demo with label …
clayw Aug 17, 2011
18dea8a
heavily downsampled digits examples (runtime a few seconds now) and r…
clayw Aug 18, 2011
2412daf
changed doc to remove long runningtime warning
clayw Aug 18, 2011
a31e639
rennamed active learning example so it won't be run for doc compilation
clayw Aug 22, 2011
38de418
changed subplot titles so the plot is more clear
clayw Aug 22, 2011
04f354f
Prettify structure example
vene Aug 23, 2011
1216546
DOC: minor style changes
vene Aug 23, 2011
fb906b4
DOC: tweaks
vene Aug 23, 2011
1aec073
Removed print in digits classification example
vene Aug 23, 2011
45f218c
DOC: fixed links and made examples build
vene Aug 23, 2011
f0caafe
Merge branch 'clayw-label_prop' of github.com:vene/scikit-learn into …
vene Aug 23, 2011
67d185e
DOC: clarified example titles
vene Aug 23, 2011
d1f38c1
fixed structure example
clayw Aug 23, 2011
c90b2d9
added vene's subplot adjustments
clayw Aug 23, 2011
e2cf62a
Merge branch 'new_lp'
clayw Aug 23, 2011
fa339f6
made convergence check function private
clayw Aug 23, 2011
3d6ee4c
fixed spelling error with variable name (indicies -> indices)
clayw Aug 26, 2011
4fd0a3c
optimized _build_graph with inplace methods, conform to standards wit…
clayw Sep 2, 2011
f0d88a8
one more optimization! avoids cast to numpy matrix and does in place …
clayw Sep 3, 2011
4cb9e70
fixed test cases to conform to api changes & new internal parameters
clayw Sep 3, 2011
1c75500
updated docs!
clayw Sep 4, 2011
00c3d4d
Merge git://github.com/scikit-learn/scikit-learn
clayw Sep 4, 2011
2eba5a4
localized a variable
clayw Sep 5, 2011
3e83638
fixed test suite, changed module to conform to new sklearn naming scheme
clayw Sep 5, 2011
472cf5b
fixed examples for new naming scheme
clayw Sep 5, 2011
732f18a
FIX: compat with numpy version lacking the out argument for dot
ogrisel Sep 5, 2011
a8c77e8
ENH: misc style / docstrings improvements
ogrisel Sep 5, 2011
59879d4
merged ogrisel's docs & optimization, also fixed active learning exam…
clayw Sep 5, 2011
ff4a3f3
more enhancements, variable names and test fixes
ogrisel Sep 13, 2011
1a06a46
changed a bunch of variable names, fixed some test cases
clayw Sep 14, 2011
ec149ed
all code works great, all tests pass, full coverage
clayw Sep 14, 2011
fba4c7c
changed a variable name to conform to scikits code
clayw Sep 14, 2011
431723c
correct variable names and added inline comments for active learning …
clayw Sep 14, 2011
3ee7314
added attributes text to explain named attributes
clayw Sep 14, 2011
efddf2a
STY: mostly style + avoid a zip in favor of an np.argsort
agramfort Sep 15, 2011
3a1e98c
STY : in label_propagation.py
agramfort Sep 15, 2011
082c873
ENH : using numpy broadcasting instead of dot_out
agramfort Sep 15, 2011
ac198aa
Merge branch 'master' of git://github.com/scikit-learn/scikit-learn
clayw Sep 21, 2011
2f7e997
Merge branch 'master' of https://github.com/scikit-learn/scikit-learn
clayw Oct 11, 2011
0ac341a
added support for sparse KNN graphs and tests
clayw Oct 12, 2011
5e7d3c9
finishing up sparse additions (need to complete todo)
clayw Oct 12, 2011
24ac2f5
sparse KNN graphs now work
clayw Oct 12, 2011
0faf870
ENH add label propagation algorithm
clayw Jul 18, 2011
fbc1bac
scikits.learn -> sklearn migration in label propagation
larsmans Jan 6, 2012
4c44755
BUG don't pass estimator params to fit in label propagation
larsmans Jan 6, 2012
68444c0
finalized KNN work, all tests pass properly
clayw Jan 10, 2012
fdfb531
Merge branch 'larsmans-label-propagation'
clayw Jan 10, 2012
af0feb9
Merge branch 'master' of https://github.com/scikit-learn/scikit-learn
clayw Jan 10, 2012
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
14 changes: 14 additions & 0 deletions doc/modules/classes.rst
Expand Up @@ -415,6 +415,20 @@ From text
kernel_approximation.AdditiveChi2Sampler
kernel_approximation.SkewedChi2Sampler

Label propagation
=================

.. automodule:: scikits.learn.label_propagation
:no-members:
:no-inherited-members:

.. currentmodule:: scikits.learn

.. autosummary::
:toctree: generated/
:template: class.rst
label_propagation.LabelPropagation
label_propagation.LabelSpreading

.. _lda_ref:

Expand Down
50 changes: 50 additions & 0 deletions doc/modules/label_propagation.rst
@@ -0,0 +1,50 @@
.. _label_propagation:

===================================================
Label Propagation
===================================================

`sklearn.semisupervised.label_propagation` contains a few variations of semi-supervised
graph inference algorithms. In the semi-supervised classification setting, the
learning algorithm is fed both labeled and unlabeled data. With the addition of
unlabeled data in the training model, the algorithm can better learn the total
structure of the data. These algorithms generally do very well in practice even
when faced with far fewer labeled points than ordinary classification models.

A few features available in this model:
* Can be used for classification and regression tasks
* Kernel methods to project data into alternate dimensional spaces

.. currentmodule:: scikits.learn.label_propagation
.. topic:: Input labels for semi-supervised learning
It is important to assign an identifier to unlabeled points along with the
labeled data when training the model with the `fit` method.

This module provides two label propagation models: :class:`LabelPropagation` and
:class:`LabelSpreading`. Both work by forming a fully connected graph for each
item in the input dataset. They differ only in the definition of the matrix
that represents the graph and the clamp effect on the label distributions.
:class:`LabelPropagation` is far more intuitive than :class:`LabelSpreading`
which is motivatived by deeper mathematics.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In practice why would someone choose LabelSpreading over LabelPropagation? Do they have the same computational complexity / scalability behavior? Do they converge to the same solution? If not what kind of assumption do they make on the data?

Also is this model solving a convex problem with a unique solution or a problem with potentially several global minimum in the objective function so that random initialization or the order of the samples can lead to different solutions?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new documentation should anwer all of these questions. I also added some other reference material. Really, the Label Spreading algorithm should outperform Label Propagation in nearly every case, but Label Propagation is more intuitive and it's easier to understand how it works.

Do you think we should remove the Label Propagation class?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it can serve as a simple base line for semi supervised learning then I think we should keep it. A bit like k-NN for supervised learning: it can be used as a sanity check for more advanced algorithms.


Clamping
========
=======

Clamping allows the algorithm to change the weight of the true ground labeled
data to some degree. The :class:`LabelPropagation` algorithm performs hard
clamping of input labels, which means :math:`\alpha=1`. This clamping factor
can be relaxed, to say :math:`\alpha=0.8`, which means that we will always
retain 80 percent of our original label distribution, but the algorithm gets to
change it's confidence of the distribution within 20 percent.

Examples
========
* :ref:`example_label_propagation_plot_label_propagation_versus_svm_iris.py`
* :ref:`example_label_propagation_structure.py`


References
==========
[1] Yoshua Bengio, Olivier Delalleau, Nicolas Le Roux. In Semi-Supervised
Learning (2006), pp. 193-216
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know if there is a open access resource (e.g. with a downloadable PDF) available online (typically on the papers author) to add as a secondary reference?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here they are:

BTW on the author page of the first paper it is said that Label Propagation is a precursor to this paper:

A graph-based semi-supervised learning algorithm that creates a graph over labeled and unlabeled examples. More similar examples are connected by edges with higher weights. The intuition is for the labels to propagate on the graph to unlabeled data. The solution can be found with simple matrix operations, and has strong connections to spectral graph theory. [ps.gz] [pdf] [Matlab code] [data]

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should really investigate the more recent algorithm, then. Especially
given the fact that spectral graph problem can be implemented
efficiently, and that we are starting to have a fairly good expertise on
these problems.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had the semi supervised learning book on my shelve but had not read it yet. Diving into it right now.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also about scalability, it might be possible to improve it by using a graph laplacian based on a sparse k-nn connectivity matrix instead of a dense n_samples * n_samples heat kernel. But we would need to review the literature first and we can probably investigate this track later (after the merge of this PR).

1 change: 1 addition & 0 deletions doc/unsupervised_learning.rst
Expand Up @@ -14,5 +14,6 @@ Unsupervised learning
modules/covariance
modules/outlier_detection
modules/hmm
modules/label_propagation


@@ -0,0 +1,93 @@
"""
========================================
Label Propagation digits active learning
========================================

Demonstrates an active learning technique to learn handwritten digits
using label propagation.

We start by training a label propagation model with only 10 labeled points,
then we select the top five most uncertain points to label. Next, we train
with 15 labeled points (original 10 + 5 new ones). We repeat this process
four times to have a model trained with 30 labeled examples.

A plot will appear showing the top 5 most uncertain digits for each iteration
of training. These may or may not contain mistakes, but we will train the next
model with their true labels.
"""
print __doc__

import numpy as np
import pylab as pl

from scipy import stats

from sklearn import datasets
from sklearn import label_propagation

from sklearn.metrics import metrics
from sklearn.metrics.metrics import confusion_matrix

digits = datasets.load_digits()
X = digits.data[:330]
y = digits.target[:330]

n_total_samples = len(y)
n_labeled_points = 10

unlabeled_indices = np.arange(n_total_samples)[n_labeled_points:]
f = pl.figure()

for i in range(5):
y_train = np.copy(y)
y_train[unlabeled_indices] = -1

lp_model = label_propagation.LabelSpreading(gamma=0.25, max_iter=5)
lp_model.fit(X, y_train)

predicted_labels = lp_model.transduction_[unlabeled_indices]
true_labels = y[unlabeled_indices]

cm = confusion_matrix(true_labels, predicted_labels,
labels=lp_model.unique_labels_)

print "Label Spreading model: %d labeled & %d unlabeled (%d total)" %\
(n_labeled_points, n_total_samples - n_labeled_points, n_total_samples)

print metrics.classification_report(true_labels, predicted_labels)

print "Confusion matrix"
print cm

# compute the entropies of transduced label distributions
pred_entropies = stats.distributions.entropy(
lp_model.label_distributions_.T)

# select five digit examples that the classifier is most uncertain about
uncertainty_index = uncertainty_index = np.argsort(pred_entropies)[-5:]

# keep track of indicies that we get labels for
delete_indices = np.array([])

f.text(.05, (1 - (i + 1) * .183),
"model %d\n\nfit with\n%d labels" % ((i + 1), i * 5 + 10), size=10)
for index, image_index in enumerate(uncertainty_index):
image = digits.images[image_index]

sub = f.add_subplot(5, 5, index + 1 + (5 * i))
sub.imshow(image, cmap=pl.cm.gray_r)
sub.set_title('predict: %i\ntrue: %i' % (
lp_model.transduction_[image_index], y[image_index]), size=10)
sub.axis('off')

# labeling 5 points, remote from labeled set
delete_index, = np.where(unlabeled_indices == image_index)
delete_indices = np.concatenate((delete_indices, delete_index))

unlabeled_indices = np.delete(unlabeled_indices, delete_indices)
n_labeled_points += 5

f.suptitle("Active learning with Label Propagation.\nRows show 5 most "
"uncertain labels to learn with the next model.")
pl.subplots_adjust(0.12, 0.03, 0.9, 0.8, 0.2, 0.45)
pl.show()
96 changes: 96 additions & 0 deletions examples/semi_supervised/label_propagation_versus_svm_iris.py
@@ -0,0 +1,96 @@
"""
================================================
Label Propagation versus SVM on the Iris dataset
================================================

Performance comparison between Label Propagation in the semi-supervised setting
to SVM in the supervised setting in the iris dataset.

First 9 experiments: SVM (SVM), Label Propagation (LP), Label Spreading (LS)
operate in the "inductive setting". That is, the system is trained with some
percentage of data and then queried against unseen datapoints to infer a label.

The final 10th experiment is in the transductive setting. Using a label
spreading algorithm, the system is trained with approximately 24 percent of the
data labeled and during training, unlabeled points are transductively assigned
values. The test precision, recall, and F1 scores are based on these
transductively assigned labels.
"""
print __doc__

import numpy as np

from sklearn import datasets
from sklearn import svm
from sklearn import label_propagation

from sklearn.metrics.metrics import precision_score
from sklearn.metrics.metrics import recall_score
from sklearn.metrics.metrics import f1_score

rng = np.random.RandomState(0)

iris = datasets.load_iris()

X = iris.data
y = iris.target

# 80% data to keep
hold_80 = rng.rand(len(y)) < 0.8
train, = np.where(hold_80)

# 20% test data
test, = np.where(hold_80 == False)

X_all = X[train]
y_all = y[train]

svc = svm.SVC(kernel='rbf')
svc.fit(X_all, y_all)
print "Limited Label data example"
print "Test name\tprecision\trecall \tf1"
print "SVM 80.0pct\t%0.6f\t%0.6f\t%0.6f" %\
(precision_score(svc.predict(X[test]), y[test]),
recall_score(svc.predict(X[test]), y[test]),
f1_score(svc.predict(X[test]), y[test]))

print "-------"

for num in [0.2, 0.3, 0.4, 1.0]:
lp = label_propagation.LabelPropagation()
hold_new = rng.rand(len(train)) > num
train_new, = np.where(hold_new)
y_dup = np.copy(y_all)
y_dup[train_new] = -1
lp.fit(X_all, y_dup)
print "LP %0.1fpct\t%0.6f\t%0.6f\t%0.6f" % \
(80 * num, precision_score(lp.predict(X[test]), y[test]),
recall_score(lp.predict(X[test]), y[test]),
f1_score(lp.predict(X[test]), y[test]))

# label spreading
for num in [0.2, 0.3, 0.4, 1.0]:
lspread = label_propagation.LabelSpreading()
hold_new = rng.rand(len(train)) > num
train_new, = np.where(hold_new)
y_dup = np.copy(y_all)
y_dup[train_new] = -1
lspread.fit(X_all, y_dup)
print "LS %0.1fpct\t%0.6f\t%0.6f\t%0.6f" % \
(80 * num, precision_score(lspread.predict(X[test]), y[test]),
recall_score(lspread.predict(X[test]), y[test]),
f1_score(lspread.predict(X[test]), y[test]))

print "-------"
lspread = label_propagation.LabelSpreading(alpha=0.8)
y_dup = np.copy(y)
hold_new = rng.rand(len(train)) > 0.3
train_new, = np.where(hold_new)
y_dup = np.copy(y)
y_dup[train_new] = -1
lspread.fit(X, y)
trans_result = np.asarray(lspread.transduction_)
print "LS 20tran\t%0.6f\t%0.6f\t%0.6f" % \
(precision_score(trans_result[test], y[test]),
recall_score(trans_result[test], y[test]),
f1_score(trans_result[test], y[test]))
83 changes: 83 additions & 0 deletions examples/semi_supervised/plot_label_propagation_digits.py
@@ -0,0 +1,83 @@
"""
===================================================
Label Propagation digits: Demonstrating performance
===================================================

This example demonstrates the power of semisupervised learning by
training a Label Spreading model to classify handwritten digits
with sets of very few labels.

The handwritten digit dataset has 1797 total points. The model will
be trained using all points, but only 30 will be labeled. Results
in the form of a confusion matrix and a series of metrics over each
class will be very good.

At the end, the top 10 most uncertain predictions will be shown.
"""
print __doc__

import numpy as np
import pylab as pl

from scipy import stats

from sklearn import datasets
from sklearn import label_propagation

from sklearn.metrics import metrics
from sklearn.metrics.metrics import confusion_matrix

digits = datasets.load_digits()
X = digits.data[:330]
y = digits.target[:330]

n_total_samples = len(y)
n_labeled_points = 30

indices = np.arange(n_total_samples)

unlabeled_set = indices[n_labeled_points:]

# shuffle everything around
y_train = np.copy(y)
y_train[unlabeled_set] = -1

###############################################################################
# Learn with LabelSpreading
lp_model = label_propagation.LabelSpreading(gamma=0.25, max_iter=5)
lp_model.fit(X, y_train)
predicted_labels = lp_model.transduction_[unlabeled_set]
true_labels = y[unlabeled_set]

cm = confusion_matrix(true_labels, predicted_labels,
labels=lp_model.unique_labels_)

print "Label Spreading model: %d labeled & %d unlabeled points (%d total)" % \
(n_labeled_points, n_total_samples - n_labeled_points, n_total_samples)

print metrics.classification_report(true_labels, predicted_labels)

print "Confusion matrix"
print cm

# calculate uncertainty values for each transduced distribution
pred_entropies = stats.distributions.entropy(lp_model.label_distributions_.T)

# pick the top 10 most uncertain labels
uncertainty_index = np.argsort(pred_entropies)[-10:]

###############################################################################
# plot
f = pl.figure(figsize=(7, 5))
for index, image_index in enumerate(uncertainty_index):
image = digits.images[image_index]

sub = f.add_subplot(2, 5, index + 1)
sub.imshow(image, cmap=pl.cm.gray_r)
pl.xticks([])
pl.yticks([])
sub.set_title('predict: %i\ntrue: %i' % (
lp_model.transduction_[image_index], y[image_index]))

f.suptitle('Learning with small amount of labeled data')
pl.show()