Skip to content

Commit

Permalink
FIX stop words validation in text vectorizers with custom preprocesso…
Browse files Browse the repository at this point in the history
…rs / tokenizers (#12393)
  • Loading branch information
rth authored and jnothman committed Nov 14, 2018
1 parent 155492f commit 2088072
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 5 deletions.
8 changes: 8 additions & 0 deletions doc/whats_new/v0.20.rst
Expand Up @@ -85,6 +85,14 @@ Changelog
where ``max_features`` was sometimes rounded down to zero.
:issue:`12388` by :user:`Connor Tann <Connossor>`.

:mod:`sklearn.feature_extraction`
...........................

- |Fix| Fixed a regression in v0.20.0 where
:func:`feature_extraction.text.CountVectorizer` and other text vectorizers
could error during stop words validation with custom preprocessors
or tokenizers. :issue:`12393` by `Roman Yurchak`_.

:mod:`sklearn.linear_model`
...........................

Expand Down
42 changes: 42 additions & 0 deletions sklearn/feature_extraction/tests/test_text.py
@@ -1,4 +1,5 @@
from __future__ import unicode_literals
import re
import warnings

import pytest
Expand Down Expand Up @@ -1121,6 +1122,14 @@ def test_vectorizers_invalid_ngram_range(vec):
ValueError, message, vec.transform, ["good news everyone"])


def _check_stop_words_consistency(estimator):
stop_words = estimator.get_stop_words()
tokenize = estimator.build_tokenizer()
preprocess = estimator.build_preprocessor()
return estimator._check_stop_words_consistency(stop_words, preprocess,
tokenize)


@fails_if_pypy
def test_vectorizer_stop_words_inconsistent():
if PY2:
Expand All @@ -1135,11 +1144,44 @@ def test_vectorizer_stop_words_inconsistent():
vec.set_params(stop_words=["you've", "you", "you'll", 'AND'])
assert_warns_message(UserWarning, message, vec.fit_transform,
['hello world'])
# reset stop word validation
del vec._stop_words_id
assert _check_stop_words_consistency(vec) is False

# Only one warning per stop list
assert_no_warnings(vec.fit_transform, ['hello world'])
assert _check_stop_words_consistency(vec) is None

# Test caching of inconsistency assessment
vec.set_params(stop_words=["you've", "you", "you'll", 'blah', 'AND'])
assert_warns_message(UserWarning, message, vec.fit_transform,
['hello world'])


@fails_if_pypy
@pytest.mark.parametrize('Estimator',
[CountVectorizer, TfidfVectorizer, HashingVectorizer])
def test_stop_word_validation_custom_preprocessor(Estimator):
data = [{'text': 'some text'}]

vec = Estimator()
assert _check_stop_words_consistency(vec) is True

vec = Estimator(preprocessor=lambda x: x['text'],
stop_words=['and'])
assert _check_stop_words_consistency(vec) == 'error'
# checks are cached
assert _check_stop_words_consistency(vec) is None
vec.fit_transform(data)

class CustomEstimator(Estimator):
def build_preprocessor(self):
return lambda x: x['text']

vec = CustomEstimator(stop_words=['and'])
assert _check_stop_words_consistency(vec) == 'error'

vec = Estimator(tokenizer=lambda doc: re.compile(r'\w{1,}')
.findall(doc),
stop_words=['and'])
assert _check_stop_words_consistency(vec) is True
30 changes: 25 additions & 5 deletions sklearn/feature_extraction/text.py
Expand Up @@ -269,8 +269,22 @@ def get_stop_words(self):
return _check_stop_list(self.stop_words)

def _check_stop_words_consistency(self, stop_words, preprocess, tokenize):
"""Check if stop words are consistent
Returns
-------
is_consistent : True if stop words are consistent with the preprocessor
and tokenizer, False if they are not, None if the check
was previously performed, "error" if it could not be
performed (e.g. because of the use of a custom
preprocessor / tokenizer)
"""
if id(self.stop_words) == getattr(self, '_stop_words_id', None):
# Stop words are were previously validated
return None

# NB: stop_words is validated, unlike self.stop_words
if id(self.stop_words) != getattr(self, '_stop_words_id', None):
try:
inconsistent = set()
for w in stop_words or ():
tokens = list(tokenize(preprocess(w)))
Expand All @@ -280,10 +294,16 @@ def _check_stop_words_consistency(self, stop_words, preprocess, tokenize):
self._stop_words_id = id(self.stop_words)

if inconsistent:
warnings.warn('Your stop_words may be inconsistent with your '
'preprocessing. Tokenizing the stop words '
'generated tokens %r not in stop_words.' %
sorted(inconsistent))
warnings.warn('Your stop_words may be inconsistent with '
'your preprocessing. Tokenizing the stop '
'words generated tokens %r not in '
'stop_words.' % sorted(inconsistent))
return not inconsistent
except Exception:
# Failed to check stop words consistency (e.g. because a custom
# preprocessor or tokenizer was used)
self._stop_words_id = id(self.stop_words)
return 'error'

def build_analyzer(self):
"""Return a callable that handles preprocessing and tokenization"""
Expand Down

0 comments on commit 2088072

Please sign in to comment.