Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DOC Reorganize plot_nca_illustration example into subsections #14795

Merged
merged 4 commits into from
Sep 18, 2019
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
59 changes: 31 additions & 28 deletions examples/neighbors/plot_nca_illustration.py
Expand Up @@ -3,10 +3,10 @@
Neighborhood Components Analysis Illustration
=============================================

An example illustrating the goal of learning a distance metric that maximizes
the nearest neighbors classification accuracy. The example is solely for
illustration purposes. Please refer to the :ref:`User Guide <nca>` for
more information.
This example illustrates a learned distance metric that maximizes
the nearest neighbors classification accuracy. It provides a visual
representation of this metric compared to the original point
space. Please refer to the :ref:`User Guide <nca>` for more information.
"""

# License: BSD 3 clause
Expand All @@ -20,23 +20,31 @@

print(__doc__)

random_state = 0
##############################################################################
NicolasHug marked this conversation as resolved.
Show resolved Hide resolved
# Original points
# ---------------
# First we create a data set of 9 samples from 3 classes, and plot the points
# in the original space. For this example, we focus on the classification of
# point no. 3. The thickness of a link between point no. 3 and another point
# is proportional to their distance.

# Create a tiny data set of 9 samples from 3 classes
X, y = make_classification(n_samples=9, n_features=2, n_informative=2,
n_redundant=0, n_classes=3, n_clusters_per_class=1,
class_sep=1.0, random_state=random_state)
class_sep=1.0, random_state=0)

# Plot the points in the original space
plt.figure()
plt.figure(1)
ax = plt.gca()

# Draw the graph nodes
for i in range(X.shape[0]):
ax.text(X[i, 0], X[i, 1], str(i), va='center', ha='center')
ax.scatter(X[i, 0], X[i, 1], s=300, c=cm.Set1(y[[i]]), alpha=0.4)

def p_i(X, i):
ax.set_title("Original points")
ax.axes.get_xaxis().set_visible(False)
ax.axes.get_yaxis().set_visible(False)
ax.axis('equal') # so that boundaries are displayed correctly as circles


def link_thickness_i(X, i):
diff_embedded = X[i] - X
dist_embedded = np.einsum('ij,ij->i', diff_embedded,
diff_embedded)
Expand All @@ -52,34 +60,30 @@ def p_i(X, i):
def relate_point(X, i, ax):
pt_i = X[i]
for j, pt_j in enumerate(X):
thickness = p_i(X, i)
thickness = link_thickness_i(X, i)
if i != j:
line = ([pt_i[0], pt_j[0]], [pt_i[1], pt_j[1]])
ax.plot(*line, c=cm.Set1(y[j]),
linewidth=5*thickness[j])


# we consider only point 3
i = 3

# Plot bonds linked to sample i in the original space
relate_point(X, i, ax)
ax.set_title("Original points")
ax.axes.get_xaxis().set_visible(False)
ax.axes.get_yaxis().set_visible(False)
ax.axis('equal')
plt.show()

# Learn an embedding with NeighborhoodComponentsAnalysis
nca = NeighborhoodComponentsAnalysis(max_iter=30, random_state=random_state)
##############################################################################
# Learning an embedding
# ---------------------
# We use :class:`~sklearn.neighbors.NeighborhoodComponentsAnalysis` to learn an
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a title e.g. "learning the embedding"?

# embedding and plot the points after the transformation. We then take the
# embedding and find the nearest neighbors.

nca = NeighborhoodComponentsAnalysis(max_iter=30, random_state=0)
nca = nca.fit(X, y)

# Plot the points after transformation with NeighborhoodComponentsAnalysis
plt.figure()
plt.figure(2)
ax2 = plt.gca()

# Get the embedding and find the new nearest neighbors
X_embedded = nca.transform(X)

relate_point(X_embedded, i, ax2)

for i in range(len(X)):
Expand All @@ -88,7 +92,6 @@ def relate_point(X, i, ax):
ax2.scatter(X_embedded[i, 0], X_embedded[i, 1], s=300, c=cm.Set1(y[[i]]),
alpha=0.4)

# Make axes equal so that boundaries are displayed correctly as circles
ax2.set_title("NCA embedding")
ax2.axes.get_xaxis().set_visible(False)
ax2.axes.get_yaxis().set_visible(False)
Expand Down