Skip to content

Commit

Permalink
ENH text vectorizers should raise warnings when user params will be u…
Browse files Browse the repository at this point in the history
…nused (#14602)
  • Loading branch information
getgaurav2 authored and jnothman committed Sep 6, 2019
1 parent a6103d3 commit 96ef6b8
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 2 deletions.
7 changes: 7 additions & 0 deletions doc/whats_new/v0.22.rst
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,13 @@ Changelog
:mod:`sklearn.feature_extraction`
.................................

- |Enhancement| A warning will now be raised if a parameter choice means
that another parameter will be unused on calling the fit() method for
:class:`feature_extraction.text.HashingVectorizer`,
:class:`feature_extraction.text.CountVectorizer` and
:class:`feature_extraction.text.TfidfVectorizer`.
:pr:`14602` by :user:`Gaurav Chawla <getgaurav2>`.

- |Fix| Functions created by build_preprocessor and build_analyzer of
:class:`feature_extraction.text.VectorizerMixin` can now be pickled.
:pr:`14430` by :user:`Dillon Niederhut <deniederhut>`.
Expand Down
43 changes: 41 additions & 2 deletions sklearn/feature_extraction/tests/test_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def test_countvectorizer_custom_vocabulary_pipeline():
('tfidf', TfidfTransformer())])
X = pipe.fit_transform(ALL_FOOD_DOCS)
assert (set(pipe.named_steps['count'].vocabulary_) ==
set(what_we_like))
set(what_we_like))
assert X.shape[1] == len(what_we_like)


Expand Down Expand Up @@ -1121,7 +1121,7 @@ def test_tfidf_vectorizer_type(vectorizer_dtype, output_dtype,
expected_warning_cls = warning_cls if warning_expected else None
with pytest.warns(expected_warning_cls,
match=warning_msg_match) as record:
X_idf = vectorizer.fit_transform(X)
X_idf = vectorizer.fit_transform(X)
if expected_warning_cls is None:
relevant_warnings = [w for w in record
if isinstance(w, warning_cls)]
Expand Down Expand Up @@ -1304,3 +1304,42 @@ def analyzer(doc):

with pytest.raises(Exception, match="testing"):
Estimator(analyzer=analyzer, input='file').fit_transform([f])


@pytest.mark.parametrize(
'Vectorizer',
[CountVectorizer, HashingVectorizer, TfidfVectorizer]
)
@pytest.mark.parametrize(
'stop_words, tokenizer, preprocessor, ngram_range, token_pattern,'
'analyzer, unused_name, ovrd_name, ovrd_msg',
[(["you've", "you'll"], None, None, (1, 1), None, 'char',
"'stop_words'", "'analyzer'", "!= 'word'"),
(None, lambda s: s.split(), None, (1, 1), None, 'char',
"'tokenizer'", "'analyzer'", "!= 'word'"),
(None, lambda s: s.split(), None, (1, 1), r'\w+', 'word',
"'token_pattern'", "'tokenizer'", "is not None"),
(None, None, lambda s:s.upper(), (1, 1), r'\w+', lambda s:s.upper(),
"'preprocessor'", "'analyzer'", "is callable"),
(None, None, None, (1, 2), None, lambda s:s.upper(),
"'ngram_range'", "'analyzer'", "is callable"),
(None, None, None, (1, 1), r'\w+', 'char',
"'token_pattern'", "'analyzer'", "!= 'word'")]
)
def test_unused_parameters_warn(Vectorizer, stop_words,
tokenizer, preprocessor,
ngram_range, token_pattern,
analyzer, unused_name, ovrd_name,
ovrd_msg):

train_data = JUNK_FOOD_DOCS
# setting parameter and checking for corresponding warning messages
vect = Vectorizer()
vect.set_params(stop_words=stop_words, tokenizer=tokenizer,
preprocessor=preprocessor, ngram_range=ngram_range,
token_pattern=token_pattern, analyzer=analyzer)
msg = ("The parameter %s will not be used"
" since %s %s" % (unused_name, ovrd_name, ovrd_msg)
)
with pytest.warns(UserWarning, match=msg):
vect.fit(train_data)
37 changes: 37 additions & 0 deletions sklearn/feature_extraction/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,7 @@ def build_analyzer(self):
and n-grams generation.
"""

if callable(self.analyzer):
if self.input in ['file', 'filename']:
self._validate_custom_analyzer()
Expand All @@ -405,6 +406,7 @@ def build_analyzer(self):
preprocessor=preprocess, decoder=self.decode)

elif self.analyzer == 'char_wb':

return partial(_analyze, ngrams=self._char_wb_ngrams,
preprocessor=preprocess, decoder=self.decode)

Expand Down Expand Up @@ -468,6 +470,32 @@ def _validate_params(self):
"lower boundary larger than the upper boundary."
% str(self.ngram_range))

def _warn_for_unused_params(self):

if self.tokenizer is not None and self.token_pattern is not None:
warnings.warn("The parameter 'token_pattern' will not be used"
" since 'tokenizer' is not None'")

if self.preprocessor is not None and callable(self.analyzer):
warnings.warn("The parameter 'preprocessor' will not be used"
" since 'analyzer' is callable'")

if (self.ngram_range != (1, 1) and self.ngram_range is not None
and callable(self.analyzer)):
warnings.warn("The parameter 'ngram_range' will not be used"
" since 'analyzer' is callable'")
if self.analyzer != 'word' or callable(self.analyzer):
if self.stop_words is not None:
warnings.warn("The parameter 'stop_words' will not be used"
" since 'analyzer' != 'word'")
if self.token_pattern is not None and \
self.token_pattern != r"(?u)\b\w\w+\b":
warnings.warn("The parameter 'token_pattern' will not be used"
" since 'analyzer' != 'word'")
if self.tokenizer is not None:
warnings.warn("The parameter 'tokenizer' will not be used"
" since 'analyzer' != 'word'")


class HashingVectorizer(TransformerMixin, VectorizerMixin, BaseEstimator):
"""Convert a collection of text documents to a matrix of token occurrences
Expand Down Expand Up @@ -549,6 +577,7 @@ class HashingVectorizer(TransformerMixin, VectorizerMixin, BaseEstimator):
preprocessor : callable or None (default)
Override the preprocessing (string transformation) stage while
preserving the tokenizing and n-grams generation steps.
Only applies if ``analyzer is not callable``.
tokenizer : callable or None (default)
Override the string tokenization step while preserving the
Expand All @@ -574,6 +603,7 @@ class HashingVectorizer(TransformerMixin, VectorizerMixin, BaseEstimator):
The lower and upper boundary of the range of n-values for different
n-grams to be extracted. All values of n such that min_n <= n <= max_n
will be used.
Only applies if ``analyzer is not callable``.
analyzer : string, {'word', 'char', 'char_wb'} or callable
Whether the feature should be made of word or character n-grams.
Expand Down Expand Up @@ -681,6 +711,7 @@ def fit(self, X, y=None):
"Iterable over raw text documents expected, "
"string object received.")

self._warn_for_unused_params()
self._validate_params()

self._get_hasher().fit(X, y=y)
Expand Down Expand Up @@ -805,6 +836,7 @@ class CountVectorizer(VectorizerMixin, BaseEstimator):
preprocessor : callable or None (default)
Override the preprocessing (string transformation) stage while
preserving the tokenizing and n-grams generation steps.
Only applies if ``analyzer is not callable``.
tokenizer : callable or None (default)
Override the string tokenization step while preserving the
Expand Down Expand Up @@ -834,6 +866,7 @@ class CountVectorizer(VectorizerMixin, BaseEstimator):
The lower and upper boundary of the range of n-values for different
n-grams to be extracted. All values of n such that min_n <= n <= max_n
will be used.
Only applies if ``analyzer is not callable``.
analyzer : string, {'word', 'char', 'char_wb'} or callable
Whether the feature should be made of word or character n-grams.
Expand Down Expand Up @@ -1093,6 +1126,7 @@ def fit(self, raw_documents, y=None):
-------
self
"""
self._warn_for_unused_params()
self.fit_transform(raw_documents)
return self

Expand Down Expand Up @@ -1456,6 +1490,7 @@ class TfidfVectorizer(CountVectorizer):
preprocessor : callable or None (default=None)
Override the preprocessing (string transformation) stage while
preserving the tokenizing and n-grams generation steps.
Only applies if ``analyzer is not callable``.
tokenizer : callable or None (default=None)
Override the string tokenization step while preserving the
Expand Down Expand Up @@ -1500,6 +1535,7 @@ class TfidfVectorizer(CountVectorizer):
The lower and upper boundary of the range of n-values for different
n-grams to be extracted. All values of n such that min_n <= n <= max_n
will be used.
Only applies if ``analyzer is not callable``.
max_df : float in range [0.0, 1.0] or int (default=1.0)
When building the vocabulary ignore terms that have a document
Expand Down Expand Up @@ -1698,6 +1734,7 @@ def fit(self, raw_documents, y=None):
self : TfidfVectorizer
"""
self._check_params()
self._warn_for_unused_params()
X = super().fit_transform(raw_documents)
self._tfidf.fit(X)
return self
Expand Down

0 comments on commit 96ef6b8

Please sign in to comment.