Skip to content
Permalink
Browse files

fix parameter specials in Field.build_vocab (#495)

* fix parameter specials in Field.build_vocab
  • Loading branch information...
speedcell4 authored and mttk committed Feb 1, 2019
1 parent edafd9c commit a6e520e99a4075bc99ec1df52e4ebf7c1ee01cd9
Showing with 11 additions and 10 deletions.
  1. +10 −9 test/data/test_field.py
  2. +1 −1 torchtext/data/field.py
@@ -154,19 +154,20 @@ def test_build_vocab(self):
fields=json_fields)

# Test build_vocab default
question_field.build_vocab(tsv_dataset, json_dataset)
question_field.build_vocab(tsv_dataset, json_dataset, specials=['<space>'])
assert question_field.vocab.freqs == Counter(
{'When': 4, 'do': 4, 'you': 4, 'use': 4, 'instead': 4,
'of': 4, 'was': 4, 'Lincoln': 4, 'born?': 4, '': 2,
'し?': 2, 'Where': 2, 'What': 2, 'is': 2, '2+2': 2,
'"&"': 2, '"and"?': 2, 'Which': 2, 'location': 2,
'Abraham': 2, '2+2=?': 2})
expected_stoi = {'<unk>': 0, '<pad>': 1, 'Lincoln': 2, 'When': 3,
'born?': 4, 'do': 5, 'instead': 6, 'of': 7,
'use': 8, 'was': 9, 'you': 10, '"&"': 11,
'"and"?': 12, '2+2': 13, '2+2=?': 14, 'Abraham': 15,
'What': 16, 'Where': 17, 'Which': 18, 'is': 19,
'location': 20, 'し?': 21, '': 22}
expected_stoi = {'<unk>': 0, '<pad>': 1, '<space>': 2,
'Lincoln': 3, 'When': 4,
'born?': 5, 'do': 6, 'instead': 7, 'of': 8,
'use': 9, 'was': 10, 'you': 11, '"&"': 12,
'"and"?': 13, '2+2': 14, '2+2=?': 15, 'Abraham': 16,
'What': 17, 'Where': 18, 'Which': 19, 'is': 20,
'location': 21, 'し?': 22, '': 23}
assert dict(question_field.vocab.stoi) == expected_stoi
# Turn the stoi dictionary into an itos list
expected_itos = [x[0] for x in sorted(expected_stoi.items(),
@@ -348,9 +349,9 @@ def test_numericalize_stop_words(self):
test_example_data = question_field.pad(
[question_field.preprocess(x) for x in
[["When", "do", "you", "use", "",
"instead", "of", "し?"],
"instead", "of", "し?"],
["What", "is", "2+2", "<pad>", "<pad>",
"<pad>", "<pad>", "<pad>"],
"<pad>", "<pad>", "<pad>"],
["Here", "is", "a", "sentence", "with",
"some", "oovs", "<pad>"]]]
)
@@ -304,7 +304,7 @@ def build_vocab(self, *args, **kwargs):
counter.update(chain.from_iterable(x))
specials = list(OrderedDict.fromkeys(
tok for tok in [self.unk_token, self.pad_token, self.init_token,
self.eos_token]
self.eos_token] + kwargs.pop('specials', [])
if tok is not None))
self.vocab = self.vocab_cls(counter, specials=specials, **kwargs)

0 comments on commit a6e520e

Please sign in to comment.
You can’t perform that action at this time.