Skip to content

Commit

Permalink
DOC Reorganize plot_nca_illustration example into subsections (#14795)
Browse files Browse the repository at this point in the history
  • Loading branch information
m-clare authored and thomasjpfan committed Sep 18, 2019
1 parent e52e9c8 commit 6680bff
Showing 1 changed file with 31 additions and 28 deletions.
59 changes: 31 additions & 28 deletions examples/neighbors/plot_nca_illustration.py
Expand Up @@ -3,10 +3,10 @@
Neighborhood Components Analysis Illustration
=============================================
An example illustrating the goal of learning a distance metric that maximizes
the nearest neighbors classification accuracy. The example is solely for
illustration purposes. Please refer to the :ref:`User Guide <nca>` for
more information.
This example illustrates a learned distance metric that maximizes
the nearest neighbors classification accuracy. It provides a visual
representation of this metric compared to the original point
space. Please refer to the :ref:`User Guide <nca>` for more information.
"""

# License: BSD 3 clause
Expand All @@ -20,23 +20,31 @@

print(__doc__)

random_state = 0
##############################################################################
# Original points
# ---------------
# First we create a data set of 9 samples from 3 classes, and plot the points
# in the original space. For this example, we focus on the classification of
# point no. 3. The thickness of a link between point no. 3 and another point
# is proportional to their distance.

# Create a tiny data set of 9 samples from 3 classes
X, y = make_classification(n_samples=9, n_features=2, n_informative=2,
n_redundant=0, n_classes=3, n_clusters_per_class=1,
class_sep=1.0, random_state=random_state)
class_sep=1.0, random_state=0)

# Plot the points in the original space
plt.figure()
plt.figure(1)
ax = plt.gca()

# Draw the graph nodes
for i in range(X.shape[0]):
ax.text(X[i, 0], X[i, 1], str(i), va='center', ha='center')
ax.scatter(X[i, 0], X[i, 1], s=300, c=cm.Set1(y[[i]]), alpha=0.4)

def p_i(X, i):
ax.set_title("Original points")
ax.axes.get_xaxis().set_visible(False)
ax.axes.get_yaxis().set_visible(False)
ax.axis('equal') # so that boundaries are displayed correctly as circles


def link_thickness_i(X, i):
diff_embedded = X[i] - X
dist_embedded = np.einsum('ij,ij->i', diff_embedded,
diff_embedded)
Expand All @@ -52,34 +60,30 @@ def p_i(X, i):
def relate_point(X, i, ax):
pt_i = X[i]
for j, pt_j in enumerate(X):
thickness = p_i(X, i)
thickness = link_thickness_i(X, i)
if i != j:
line = ([pt_i[0], pt_j[0]], [pt_i[1], pt_j[1]])
ax.plot(*line, c=cm.Set1(y[j]),
linewidth=5*thickness[j])


# we consider only point 3
i = 3

# Plot bonds linked to sample i in the original space
relate_point(X, i, ax)
ax.set_title("Original points")
ax.axes.get_xaxis().set_visible(False)
ax.axes.get_yaxis().set_visible(False)
ax.axis('equal')
plt.show()

# Learn an embedding with NeighborhoodComponentsAnalysis
nca = NeighborhoodComponentsAnalysis(max_iter=30, random_state=random_state)
##############################################################################
# Learning an embedding
# ---------------------
# We use :class:`~sklearn.neighbors.NeighborhoodComponentsAnalysis` to learn an
# embedding and plot the points after the transformation. We then take the
# embedding and find the nearest neighbors.

nca = NeighborhoodComponentsAnalysis(max_iter=30, random_state=0)
nca = nca.fit(X, y)

# Plot the points after transformation with NeighborhoodComponentsAnalysis
plt.figure()
plt.figure(2)
ax2 = plt.gca()

# Get the embedding and find the new nearest neighbors
X_embedded = nca.transform(X)

relate_point(X_embedded, i, ax2)

for i in range(len(X)):
Expand All @@ -88,7 +92,6 @@ def relate_point(X, i, ax):
ax2.scatter(X_embedded[i, 0], X_embedded[i, 1], s=300, c=cm.Set1(y[[i]]),
alpha=0.4)

# Make axes equal so that boundaries are displayed correctly as circles
ax2.set_title("NCA embedding")
ax2.axes.get_xaxis().set_visible(False)
ax2.axes.get_yaxis().set_visible(False)
Expand Down

0 comments on commit 6680bff

Please sign in to comment.