Skip to content

Commit

Permalink
DOC improved example plot in plot_lda_qda.py (#12942)
Browse files Browse the repository at this point in the history
  • Loading branch information
zyxue authored and jnothman committed Feb 19, 2019
1 parent cff7af7 commit 86f958e
Showing 1 changed file with 17 additions and 19 deletions.
36 changes: 17 additions & 19 deletions examples/classification/plot_lda_qda.py
Expand Up @@ -72,19 +72,15 @@ def plot_data(lda, X, y, y_pred, fig_index):
X0_tp, X0_fp = X0[tp0], X0[~tp0]
X1_tp, X1_fp = X1[tp1], X1[~tp1]

alpha = 0.5

# class 0: dots
plt.plot(X0_tp[:, 0], X0_tp[:, 1], 'o', alpha=alpha,
color='red', markeredgecolor='k')
plt.plot(X0_fp[:, 0], X0_fp[:, 1], '*', alpha=alpha,
color='#990000', markeredgecolor='k') # dark red
plt.scatter(X0_tp[:, 0], X0_tp[:, 1], marker='.', color='red')
plt.scatter(X0_fp[:, 0], X0_fp[:, 1], marker='x',
s=20, color='#990000') # dark red

# class 1: dots
plt.plot(X1_tp[:, 0], X1_tp[:, 1], 'o', alpha=alpha,
color='blue', markeredgecolor='k')
plt.plot(X1_fp[:, 0], X1_fp[:, 1], '*', alpha=alpha,
color='#000099', markeredgecolor='k') # dark blue
plt.scatter(X1_tp[:, 0], X1_tp[:, 1], marker='.', color='blue')
plt.scatter(X1_fp[:, 0], X1_fp[:, 1], marker='x',
s=20, color='#000099') # dark blue

# class 0 and 1 : areas
nx, ny = 200, 100
Expand All @@ -95,14 +91,14 @@ def plot_data(lda, X, y, y_pred, fig_index):
Z = lda.predict_proba(np.c_[xx.ravel(), yy.ravel()])
Z = Z[:, 1].reshape(xx.shape)
plt.pcolormesh(xx, yy, Z, cmap='red_blue_classes',
norm=colors.Normalize(0., 1.))
plt.contour(xx, yy, Z, [0.5], linewidths=2., colors='k')
norm=colors.Normalize(0., 1.), zorder=0)
plt.contour(xx, yy, Z, [0.5], linewidths=2., colors='white')

# means
plt.plot(lda.means_[0][0], lda.means_[0][1],
'o', color='black', markersize=10, markeredgecolor='k')
'*', color='yellow', markersize=15, markeredgecolor='grey')
plt.plot(lda.means_[1][0], lda.means_[1][1],
'o', color='black', markersize=10, markeredgecolor='k')
'*', color='yellow', markersize=15, markeredgecolor='grey')

return splot

Expand All @@ -115,10 +111,9 @@ def plot_ellipse(splot, mean, cov, color):
# filled Gaussian at 2 standard deviation
ell = mpl.patches.Ellipse(mean, 2 * v[0] ** 0.5, 2 * v[1] ** 0.5,
180 + angle, facecolor=color,
edgecolor='yellow',
linewidth=2, zorder=2)
edgecolor='black', linewidth=2)
ell.set_clip_box(splot.bbox)
ell.set_alpha(0.5)
ell.set_alpha(0.2)
splot.add_artist(ell)
splot.set_xticks(())
splot.set_yticks(())
Expand All @@ -133,6 +128,8 @@ def plot_qda_cov(qda, splot):
plot_ellipse(splot, qda.means_[0], qda.covariance_[0], 'red')
plot_ellipse(splot, qda.means_[1], qda.covariance_[1], 'blue')


plt.figure(figsize=(10, 8), facecolor='white')
for i, (X, y) in enumerate([dataset_fixed_cov(), dataset_cov()]):
# Linear Discriminant Analysis
lda = LinearDiscriminantAnalysis(solver="svd", store_covariance=True)
Expand All @@ -147,6 +144,7 @@ def plot_qda_cov(qda, splot):
splot = plot_data(qda, X, y, y_pred, fig_index=2 * i + 2)
plot_qda_cov(qda, splot)
plt.axis('tight')
plt.suptitle('Linear Discriminant Analysis vs Quadratic Discriminant'
'Analysis')
plt.suptitle('Linear Discriminant Analysis vs Quadratic Discriminant Analysis',
y=1.02, fontsize=15)
plt.tight_layout()
plt.show()

0 comments on commit 86f958e

Please sign in to comment.