Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG+2] TruncatedSVD: Calculate explained variance. #3067

Conversation

mdbecker
Copy link
Contributor

@mdbecker mdbecker commented Apr 14, 2014

Fixes #3047

We tested this similar to in #2663 and determined that it makes sense to calculate explained variance as part of the fit method but then we merged the _fit method with the fit_transform method to avoid doing some duplicate work. This change will cause a minor performance regression in the case where fit is called by itself separate from transform (i.e. when calling on different inputs) which I believe is not the normal use case for this estimator.

)



Copy link
Member

@ogrisel ogrisel Apr 15, 2014

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cosmetics: please remove multiple blank lines at the end of a file.

Copy link
Contributor Author

@mdbecker mdbecker Apr 15, 2014

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

@mdbecker
Copy link
Contributor Author

@mdbecker mdbecker commented Apr 15, 2014

@ogrisel I need to run lint checks before this is ready. However can you please give me an idea if this looks okay?

for svd_10, svd_20 in svds_10_v_20:
assert_almost_equal(
svd_10.explained_variance_ratio_[0],
svd_20.explained_variance_ratio_[0],
Copy link
Member

@ogrisel ogrisel Apr 15, 2014

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can compare the first 10:

assert_almost_equal(
            svd_10.explained_variance_ratio_,
            svd_20.explained_variance_ratio_[:10],
)

Copy link
Contributor Author

@mdbecker mdbecker Apr 15, 2014

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

@coveralls
Copy link

@coveralls coveralls commented Apr 15, 2014

Coverage Status

Coverage remained the same when pulling f7c8bb9 on mdbecker:truncated_svd_calculate_explained_variance into b88cec5 on scikit-learn:master.

@ogrisel
Copy link
Member

@ogrisel ogrisel commented Apr 15, 2014

Looks good to me.

@larsmans now the fit will always do a transform (to be able to compute the explained variance) but I don't think it's a problem in practice.

@mdbecker mdbecker changed the title [WIP] TruncatedSVD: Calculate explained variance. [MRG] TruncatedSVD: Calculate explained variance. Apr 15, 2014
@ogrisel ogrisel changed the title [MRG] TruncatedSVD: Calculate explained variance. [MRG+1] TruncatedSVD: Calculate explained variance. Apr 15, 2014
@coveralls
Copy link

@coveralls coveralls commented Apr 15, 2014

Coverage Status

Coverage remained the same when pulling 3784c68 on mdbecker:truncated_svd_calculate_explained_variance into b88cec5 on scikit-learn:master.

@mdbecker
Copy link
Contributor Author

@mdbecker mdbecker commented Apr 15, 2014

@ogrisel I forgot to update the docstrings. Should I do that as part of this PR?

@ogrisel
Copy link
Member

@ogrisel ogrisel commented Apr 15, 2014

Sure yes.

@mdbecker
Copy link
Contributor Author

@mdbecker mdbecker commented Apr 16, 2014

@larsmans @ogrisel I think that should do.

X = as_float_array(X, copy=False)
random_state = check_random_state(self.random_state)

# If sparse and not csr or csc, convert to csr
if sp.issparse(X) and not (
X.getformat() == 'csr' or X.getformat() == 'csc'):
Copy link
Member

@larsmans larsmans Apr 16, 2014

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

X.getformat() not in ["csr", "csc"]

@coveralls
Copy link

@coveralls coveralls commented Apr 16, 2014

Coverage Status

Changes Unknown when pulling 5c43aba on mdbecker:truncated_svd_calculate_explained_variance into * on scikit-learn:master*.

@mdbecker
Copy link
Contributor Author

@mdbecker mdbecker commented Apr 16, 2014

@larsmans Fixed. Let me know if you find anything else. Thanks!

@mdbecker
Copy link
Contributor Author

@mdbecker mdbecker commented Apr 16, 2014

@ogrisel @larsmans Repushed. I made one more minor change to the explanation of explained_variance_ratio_ that I think will be helpful to newbies. Let me know if you don't like it.

@@ -61,6 +63,29 @@ class TruncatedSVD(BaseEstimator, TransformerMixin):
----------
`components_` : array, shape (n_components, n_features)
`explained_variance_` : array, [n_components]
The variance of the training samples transformed by a projection to \
Copy link
Member

@ogrisel ogrisel Apr 16, 2014

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't need to put \ at the end of docstring lines. Please remove them.

Copy link
Contributor Author

@mdbecker mdbecker Apr 16, 2014

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

@coveralls
Copy link

@coveralls coveralls commented Apr 16, 2014

Coverage Status

Coverage remained the same when pulling 248acb9 on mdbecker:truncated_svd_calculate_explained_variance into 6f6de86 on scikit-learn:master.

@ogrisel
Copy link
Member

@ogrisel ogrisel commented Apr 16, 2014

For some reason, github does not show your repo as potential target for a pull request from mine... Here is an updated example to demonstrate the usage of explained_variance_ratio_ in a LSA application:

diff --git a/examples/document_clustering.py b/examples/document_clustering.py
index 1480ba8..4785d73 100644
--- a/examples/document_clustering.py
+++ b/examples/document_clustering.py
@@ -159,11 +159,14 @@ if opts.n_components:
     # Vectorizer results are normalized, which makes KMeans behave as
     # spherical k-means for better results. Since LSA/SVD results are
     # not normalized, we have to redo the normalization.
-    lsa = make_pipeline(TruncatedSVD(opts.n_components),
-                        Normalizer(copy=False))
+    svd = TruncatedSVD(opts.n_components)
+    lsa = make_pipeline(svd, Normalizer(copy=False))
     X = lsa.fit_transform(X)
-
     print("done in %fs" % (time() - t0))
+
+    explained_variance = svd.explained_variance_ratio_.sum()
+    print("Explained variance of the SVD step: {}%".format(
+        int(explained_variance * 100)))
     print()


@@ -197,7 +200,7 @@ if not (opts.n_components or opts.use_hashing):
     print("Top terms per cluster:")
     order_centroids = km.cluster_centers_.argsort()[:, ::-1]
     terms = vectorizer.get_feature_names()
-    for i in xrange(true_k):
+    for i in range(true_k):
         print("Cluster %d:" % i, end='')
         for ind in order_centroids[i, :10]:
             print(' %s' % terms[ind], end='')

Please include it in your PR.

@ogrisel ogrisel changed the title [MRG+1] TruncatedSVD: Calculate explained variance. [WIP] TruncatedSVD: Calculate explained variance. Apr 17, 2014
>>> svd.fit(X) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
TruncatedSVD(algorithm='randomized', n_components=5, n_iter=5,
random_state=42, tol=0.0)
>>> print(svd.explained_variance_ratio_) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
Copy link
Member

@ogrisel ogrisel Apr 17, 2014

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you don't need normalized whitespace anymore here.

@mdbecker mdbecker closed this Apr 17, 2014
@mdbecker mdbecker reopened this Apr 17, 2014
@coveralls
Copy link

@coveralls coveralls commented Apr 17, 2014

Coverage Status

Coverage remained the same when pulling eee44de9c17bc5cc052c528b361c2995eba4c99d on mdbecker:truncated_svd_calculate_explained_variance into b70a481 on scikit-learn:master.

@mdbecker mdbecker changed the title [WIP] TruncatedSVD: Calculate explained variance. [MRG] TruncatedSVD: Calculate explained variance. Apr 17, 2014
@ogrisel ogrisel changed the title [MRG] TruncatedSVD: Calculate explained variance. [MRG+1] TruncatedSVD: Calculate explained variance. Apr 17, 2014
@ogrisel
Copy link
Member

@ogrisel ogrisel commented Apr 17, 2014

This looks ready for merge to me. @larsmans any other comment?

@larsmans
Copy link
Member

@larsmans larsmans commented Apr 17, 2014

LGTM, feel free to merge!

@larsmans larsmans changed the title [MRG+1] TruncatedSVD: Calculate explained variance. [MRG+2] TruncatedSVD: Calculate explained variance. Apr 17, 2014
ogrisel added a commit that referenced this issue Apr 17, 2014
…ned_variance

[MRG+2] TruncatedSVD: Calculate explained variance.
@ogrisel ogrisel merged commit 47080ec into scikit-learn:master Apr 17, 2014
1 check passed
@ogrisel
Copy link
Member

@ogrisel ogrisel commented Apr 17, 2014

Great, thanks @mdbecker right on time for the official ending of the sprints!

@mdbecker
Copy link
Contributor Author

@mdbecker mdbecker commented Apr 17, 2014

😄 Thanks for all your help @ogrisel & @larsmans. I had a great time!

mdbecker added a commit to mdbecker/scikit-learn that referenced this issue Apr 17, 2014
ogrisel added a commit that referenced this issue Apr 18, 2014
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants