From a6e520e99a4075bc99ec1df52e4ebf7c1ee01cd9 Mon Sep 17 00:00:00 2001 From: Izen Date: Fri, 1 Feb 2019 21:28:23 +0900 Subject: [PATCH] fix parameter specials in Field.build_vocab (#495) * fix parameter specials in Field.build_vocab --- test/data/test_field.py | 19 ++++++++++--------- torchtext/data/field.py | 2 +- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/test/data/test_field.py b/test/data/test_field.py index 29c7be5274..1b734b6d9b 100644 --- a/test/data/test_field.py +++ b/test/data/test_field.py @@ -154,19 +154,20 @@ def test_build_vocab(self): fields=json_fields) # Test build_vocab default - question_field.build_vocab(tsv_dataset, json_dataset) + question_field.build_vocab(tsv_dataset, json_dataset, specials=['']) assert question_field.vocab.freqs == Counter( {'When': 4, 'do': 4, 'you': 4, 'use': 4, 'instead': 4, 'of': 4, 'was': 4, 'Lincoln': 4, 'born?': 4, 'シ': 2, 'し?': 2, 'Where': 2, 'What': 2, 'is': 2, '2+2': 2, '"&"': 2, '"and"?': 2, 'Which': 2, 'location': 2, 'Abraham': 2, '2+2=?': 2}) - expected_stoi = {'': 0, '': 1, 'Lincoln': 2, 'When': 3, - 'born?': 4, 'do': 5, 'instead': 6, 'of': 7, - 'use': 8, 'was': 9, 'you': 10, '"&"': 11, - '"and"?': 12, '2+2': 13, '2+2=?': 14, 'Abraham': 15, - 'What': 16, 'Where': 17, 'Which': 18, 'is': 19, - 'location': 20, 'し?': 21, 'シ': 22} + expected_stoi = {'': 0, '': 1, '': 2, + 'Lincoln': 3, 'When': 4, + 'born?': 5, 'do': 6, 'instead': 7, 'of': 8, + 'use': 9, 'was': 10, 'you': 11, '"&"': 12, + '"and"?': 13, '2+2': 14, '2+2=?': 15, 'Abraham': 16, + 'What': 17, 'Where': 18, 'Which': 19, 'is': 20, + 'location': 21, 'し?': 22, 'シ': 23} assert dict(question_field.vocab.stoi) == expected_stoi # Turn the stoi dictionary into an itos list expected_itos = [x[0] for x in sorted(expected_stoi.items(), @@ -348,9 +349,9 @@ def test_numericalize_stop_words(self): test_example_data = question_field.pad( [question_field.preprocess(x) for x in [["When", "do", "you", "use", "シ", - "instead", "of", "し?"], + "instead", "of", "し?"], ["What", "is", "2+2", "", "", - "", "", ""], + "", "", ""], ["Here", "is", "a", "sentence", "with", "some", "oovs", ""]]] ) diff --git a/torchtext/data/field.py b/torchtext/data/field.py index 0311bfcf88..090821b315 100644 --- a/torchtext/data/field.py +++ b/torchtext/data/field.py @@ -304,7 +304,7 @@ def build_vocab(self, *args, **kwargs): counter.update(chain.from_iterable(x)) specials = list(OrderedDict.fromkeys( tok for tok in [self.unk_token, self.pad_token, self.init_token, - self.eos_token] + self.eos_token] + kwargs.pop('specials', []) if tok is not None)) self.vocab = self.vocab_cls(counter, specials=specials, **kwargs)