Skip to content

Commit

Permalink
Add (slow) tests for loading glove and charngram vectors
Browse files Browse the repository at this point in the history
  • Loading branch information
nelson-liu committed Sep 4, 2017
1 parent 2252f6f commit 6db1c5d
Showing 1 changed file with 51 additions and 1 deletion.
52 changes: 51 additions & 1 deletion test/test_vocab.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from numpy.testing import assert_allclose
from torchtext import vocab

from .common.test_markers import slow

logging.basicConfig(format="%(asctime)s - %(levelname)s "
"- %(name)s - %(message)s",
Expand All @@ -22,7 +23,7 @@ def test_vocab_basic(self):
self.assertEqual(v.itos, ['<unk>', '<pad>', '<bos>',
'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'hello', 'world'])

def test_vocab_download_vectors(self):
def test_vocab_download_fasttext_vectors(self):
c = Counter({'hello': 4, 'world': 3, 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5, 'freq_too_low': 2})
# Build a vocab and get vectors twice to test caching.
for i in range(2):
Expand All @@ -44,3 +45,52 @@ def test_vocab_download_vectors(self):
expected_fasttext_simple_en[word])

assert_allclose(vectors[v.stoi['<unk>']], np.zeros(300))

@slow
def test_vocab_download_glove_vectors(self):
c = Counter({'hello': 4, 'world': 3, 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5, 'freq_too_low': 2})
# Build a vocab and get vectors twice to test caching.
for i in range(2):
v = vocab.Vocab(c, min_freq=3, specials=['<pad>', '<bos>'],
vectors='glove.twitter.27B.200d')

self.assertEqual(v.itos, ['<unk>', '<pad>', '<bos>',
'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'hello', 'world'])
vectors = v.vectors.numpy()

# The first 5 entries in each vector.
expected_twitter = {
'hello': [0.34683, -0.19612, -0.34923, -0.28158, -0.75627],
'world': [0.035771, 0.62946, 0.27443, -0.36455, 0.39189],
}

for word in expected_twitter:
assert_allclose(vectors[v.stoi[word], :5],
expected_twitter[word])

assert_allclose(vectors[v.stoi['<unk>']], np.zeros(200))

@slow
def test_vocab_download_charngram_vectors(self):
c = Counter({'hello': 4, 'world': 3, 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5, 'freq_too_low': 2})
# Build a vocab and get vectors twice to test caching.
for i in range(2):
v = vocab.Vocab(c, min_freq=3, specials=['<pad>', '<bos>'],
vectors='charngram.100d')

self.assertEqual(v.itos, ['<unk>', '<pad>', '<bos>',
'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'hello', 'world'])
vectors = v.vectors.numpy()

# The first 5 entries in each vector.
expected_charngram = {
'hello': [-0.44782442, -0.08937783, -0.34227219,
-0.16233221, -0.39343098],
'world': [-0.29590717, -0.05275926, -0.37334684, 0.27117205, -0.3868292],
}

for word in expected_charngram:
assert_allclose(vectors[v.stoi[word], :5],
expected_charngram[word])

assert_allclose(vectors[v.stoi['<unk>']], np.zeros(100))

0 comments on commit 6db1c5d

Please sign in to comment.