Skip to content

Commit

Permalink
fix parameter specials in Field.build_vocab (#495)
Browse files Browse the repository at this point in the history
* fix parameter specials in Field.build_vocab
  • Loading branch information
speedcell4 authored and mttk committed Feb 1, 2019
1 parent edafd9c commit a6e520e
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 10 deletions.
19 changes: 10 additions & 9 deletions test/data/test_field.py
Expand Up @@ -154,19 +154,20 @@ def test_build_vocab(self):
fields=json_fields) fields=json_fields)


# Test build_vocab default # Test build_vocab default
question_field.build_vocab(tsv_dataset, json_dataset) question_field.build_vocab(tsv_dataset, json_dataset, specials=['<space>'])
assert question_field.vocab.freqs == Counter( assert question_field.vocab.freqs == Counter(
{'When': 4, 'do': 4, 'you': 4, 'use': 4, 'instead': 4, {'When': 4, 'do': 4, 'you': 4, 'use': 4, 'instead': 4,
'of': 4, 'was': 4, 'Lincoln': 4, 'born?': 4, 'シ': 2, 'of': 4, 'was': 4, 'Lincoln': 4, 'born?': 4, 'シ': 2,
'し?': 2, 'Where': 2, 'What': 2, 'is': 2, '2+2': 2, 'し?': 2, 'Where': 2, 'What': 2, 'is': 2, '2+2': 2,
'"&"': 2, '"and"?': 2, 'Which': 2, 'location': 2, '"&"': 2, '"and"?': 2, 'Which': 2, 'location': 2,
'Abraham': 2, '2+2=?': 2}) 'Abraham': 2, '2+2=?': 2})
expected_stoi = {'<unk>': 0, '<pad>': 1, 'Lincoln': 2, 'When': 3, expected_stoi = {'<unk>': 0, '<pad>': 1, '<space>': 2,
'born?': 4, 'do': 5, 'instead': 6, 'of': 7, 'Lincoln': 3, 'When': 4,
'use': 8, 'was': 9, 'you': 10, '"&"': 11, 'born?': 5, 'do': 6, 'instead': 7, 'of': 8,
'"and"?': 12, '2+2': 13, '2+2=?': 14, 'Abraham': 15, 'use': 9, 'was': 10, 'you': 11, '"&"': 12,
'What': 16, 'Where': 17, 'Which': 18, 'is': 19, '"and"?': 13, '2+2': 14, '2+2=?': 15, 'Abraham': 16,
'location': 20, 'し?': 21, 'シ': 22} 'What': 17, 'Where': 18, 'Which': 19, 'is': 20,
'location': 21, 'し?': 22, 'シ': 23}
assert dict(question_field.vocab.stoi) == expected_stoi assert dict(question_field.vocab.stoi) == expected_stoi
# Turn the stoi dictionary into an itos list # Turn the stoi dictionary into an itos list
expected_itos = [x[0] for x in sorted(expected_stoi.items(), expected_itos = [x[0] for x in sorted(expected_stoi.items(),
Expand Down Expand Up @@ -348,9 +349,9 @@ def test_numericalize_stop_words(self):
test_example_data = question_field.pad( test_example_data = question_field.pad(
[question_field.preprocess(x) for x in [question_field.preprocess(x) for x in
[["When", "do", "you", "use", "シ", [["When", "do", "you", "use", "シ",
"instead", "of", "し?"], "instead", "of", "し?"],
["What", "is", "2+2", "<pad>", "<pad>", ["What", "is", "2+2", "<pad>", "<pad>",
"<pad>", "<pad>", "<pad>"], "<pad>", "<pad>", "<pad>"],
["Here", "is", "a", "sentence", "with", ["Here", "is", "a", "sentence", "with",
"some", "oovs", "<pad>"]]] "some", "oovs", "<pad>"]]]
) )
Expand Down
2 changes: 1 addition & 1 deletion torchtext/data/field.py
Expand Up @@ -304,7 +304,7 @@ def build_vocab(self, *args, **kwargs):
counter.update(chain.from_iterable(x)) counter.update(chain.from_iterable(x))
specials = list(OrderedDict.fromkeys( specials = list(OrderedDict.fromkeys(
tok for tok in [self.unk_token, self.pad_token, self.init_token, tok for tok in [self.unk_token, self.pad_token, self.init_token,
self.eos_token] self.eos_token] + kwargs.pop('specials', [])
if tok is not None)) if tok is not None))
self.vocab = self.vocab_cls(counter, specials=specials, **kwargs) self.vocab = self.vocab_cls(counter, specials=specials, **kwargs)


Expand Down

0 comments on commit a6e520e

Please sign in to comment.