diff --git a/.gitignore b/.gitignore index 77ee74c7e6..78232efd98 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,4 @@ torchtext.egg-info/ */**/*.pyc */**/*~ *~ +.cache diff --git a/build_tools/travis/install.sh b/build_tools/travis/install.sh index 359917658e..c0c92a5284 100644 --- a/build_tools/travis/install.sh +++ b/build_tools/travis/install.sh @@ -54,6 +54,10 @@ pip install -r requirements.txt if [[ "$SKIP_TESTS" != "true" ]]; then # SpaCy English models python -m spacy download en + + # NLTK data needed for Moses tokenizer + python -m nltk.downloader perluniprops nonbreaking_prefixes + # PyTorch conda install --yes pytorch torchvision -c soumith fi diff --git a/test/data/test_field.py b/test/data/test_field.py index f8bd82785d..8a8d9f1b78 100644 --- a/test/data/test_field.py +++ b/test/data/test_field.py @@ -1,6 +1,5 @@ from unittest import TestCase -import six import torchtext.data as data @@ -88,20 +87,3 @@ def test_pad(self): field = data.Field(init_token="", eos_token="", sequential=False, include_lengths=True) assert field.pad(minibatch) == minibatch - - def test_get_tokenizer(self): - # Test the default case with str.split - assert data.get_tokenizer(str.split) == str.split - test_str = "A string, particularly one with slightly complex punctuation." - assert data.get_tokenizer(str.split)(test_str) == str.split(test_str) - - # Test SpaCy option, and verify it properly handles punctuation. - assert data.get_tokenizer("spacy")(six.text_type(test_str)) == [ - "A", "string", ",", "particularly", "one", "with", "slightly", - "complex", "punctuation", "."] - - # Test that errors are raised for invalid input arguments. - with self.assertRaises(ValueError): - data.get_tokenizer(1) - with self.assertRaises(ValueError): - data.get_tokenizer("some other string") diff --git a/test/data/test_utils.py b/test/data/test_utils.py new file mode 100644 index 0000000000..508b1361df --- /dev/null +++ b/test/data/test_utils.py @@ -0,0 +1,33 @@ +from unittest import TestCase + +import six +import torchtext.data as data + + +class TestUtils(TestCase): + def test_get_tokenizer(self): + # Test the default case with str.split + assert data.get_tokenizer(str.split) == str.split + test_str = "A string, particularly one with slightly complex punctuation." + assert data.get_tokenizer(str.split)(test_str) == str.split(test_str) + + # Test SpaCy option, and verify it properly handles punctuation. + assert data.get_tokenizer("spacy")(six.text_type(test_str)) == [ + "A", "string", ",", "particularly", "one", "with", "slightly", + "complex", "punctuation", "."] + + # Test Moses option. Test strings taken from NLTK doctests. + # Note that internally, MosesTokenizer converts to unicode if applicable + moses_tokenizer = data.get_tokenizer("moses") + assert moses_tokenizer(test_str) == [ + "A", "string", ",", "particularly", "one", "with", "slightly", + "complex", "punctuation", "."] + + # Nonbreaking prefixes should tokenize the final period. + assert moses_tokenizer(six.text_type("abc def.")) == ["abc", "def", "."] + + # Test that errors are raised for invalid input arguments. + with self.assertRaises(ValueError): + data.get_tokenizer(1) + with self.assertRaises(ValueError): + data.get_tokenizer("some other string") diff --git a/torchtext/data/example.py b/torchtext/data/example.py index c5faf73638..4aa9e10e34 100644 --- a/torchtext/data/example.py +++ b/torchtext/data/example.py @@ -49,8 +49,8 @@ def fromtree(cls, data, fields, subtrees=False): try: from nltk.tree import Tree except ImportError: - print('''Please install NLTK: - $ pip install nltk''') + print("Please install NLTK. " + "See the docs at http://nltk.org for more information.") raise tree = Tree.fromstring(data) if subtrees: diff --git a/torchtext/data/utils.py b/torchtext/data/utils.py index 6703c57421..9aa36a3078 100644 --- a/torchtext/data/utils.py +++ b/torchtext/data/utils.py @@ -1,7 +1,7 @@ def get_tokenizer(tokenizer): if callable(tokenizer): return tokenizer - if tokenizer == 'spacy': + if tokenizer == "spacy": try: import spacy spacy_en = spacy.load('en') @@ -14,10 +14,24 @@ def get_tokenizer(tokenizer): print("Please install SpaCy and the SpaCy English tokenizer. " "See the docs at https://spacy.io for more information.") raise + elif tokenizer == "moses": + try: + from nltk.tokenize.moses import MosesTokenizer + moses_tokenizer = MosesTokenizer() + return moses_tokenizer.tokenize + except ImportError: + print("Please install NLTK. " + "See the docs at http://nltk.org for more information.") + raise + except LookupError: + print("Please install the necessary NLTK corpora. " + "See the docs at http://nltk.org for more information.") + raise raise ValueError("Requested tokenizer {}, valid choices are a " - "callable that takes a single string as input " - "and \"spacy\" for the SpaCy English " - "tokenizer.".format(tokenizer)) + "callable that takes a single string as input, " + "\"spacy\" for the SpaCy English tokenizer, or " + "\"moses\" for the NLTK port of the Moses tokenization " + "script.".format(tokenizer)) def interleave_keys(a, b):