Skip to content

Commit

Permalink
Save nesting_field.vocab.freqs (#403)
Browse files Browse the repository at this point in the history
* Copy vocab of nesting field to nested field

* Add tests for freqs attribute in NestedField
  • Loading branch information
nzw0301 authored and mttk committed Sep 24, 2018
1 parent b18248e commit 1644970
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 0 deletions.
6 changes: 6 additions & 0 deletions test/data/test_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,9 @@ def test_build_vocab_from_dataset(self):
for c in expected:
assert c in CHARS.vocab.stoi

expected_freqs = Counter({"a": 6, "b": 6, "c": 1})
assert CHARS.vocab.freqs == CHARS.nesting_field.vocab.freqs == expected_freqs

def test_build_vocab_from_iterable(self):
nesting_field = data.Field(unk_token="<cunk>", pad_token="<cpad>")
CHARS = data.NestedField(nesting_field)
Expand All @@ -527,6 +530,9 @@ def test_build_vocab_from_iterable(self):
for c in expected:
assert c in CHARS.vocab.stoi

expected_freqs = Counter({"a": 6, "b": 12, "c": 4})
assert CHARS.vocab.freqs == CHARS.nesting_field.vocab.freqs == expected_freqs

def test_pad(self):
nesting_field = data.Field(tokenize=list, unk_token="<cunk>", pad_token="<cpad>",
init_token="<w>", eos_token="</w>")
Expand Down
1 change: 1 addition & 0 deletions torchtext/data/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,6 +644,7 @@ def build_vocab(self, *args, **kwargs):
self.nesting_field.build_vocab(*flattened, **kwargs)
super(NestedField, self).build_vocab()
self.vocab.extend(self.nesting_field.vocab)
self.vocab.freqs = self.nesting_field.vocab.freqs.copy()
if old_vectors is not None:
self.vocab.load_vectors(old_vectors,
unk_init=old_unk_init, cache=old_vectors_cache)
Expand Down

0 comments on commit 1644970

Please sign in to comment.