Skip to content

Commit

Permalink
Fixed issue when using add_vector with FastTextKeyedVectors (#3389)
Browse files Browse the repository at this point in the history
Since Gensim 4.0, 'key' in FastTextKeyedVectors always returns True by design.
The proper way to check if a key already exists is with 'key' in FastTextKeyedVectors.key_to_index.

Co-authored-by: dcarron <dcarron@idiap.ch>
  • Loading branch information
globba and dcarron committed Dec 12, 2022
1 parent ca8e4e8 commit 50a9e6b
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 1 deletion.
2 changes: 1 addition & 1 deletion gensim/models/keyedvectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,7 +592,7 @@ def add_vectors(self, keys, weights, extras=None, replace=False):

in_vocab_mask = np.zeros(len(keys), dtype=bool)
for idx, key in enumerate(keys):
if key in self:
if key in self.key_to_index:
in_vocab_mask[idx] = True

# add new entities to the vocab
Expand Down
22 changes: 22 additions & 0 deletions gensim/test/test_fasttext.py
Original file line number Diff line number Diff line change
Expand Up @@ -1782,6 +1782,28 @@ def test_identity(self):
self.assertTrue(np.all(np.array([6, 7, 8]) == n[2]))


class FastTextKeyedVectorsTest(unittest.TestCase):
def test_add_vector(self):
wv = FastTextKeyedVectors(vector_size=2, min_n=3, max_n=6, bucket=2000000)
wv.add_vector("test_key", np.array([0, 0]))

self.assertEqual(wv.key_to_index["test_key"], 0)
self.assertEqual(wv.index_to_key[0], "test_key")
self.assertTrue(np.all(wv.vectors[0] == np.array([0, 0])))

def test_add_vectors(self):
wv = FastTextKeyedVectors(vector_size=2, min_n=3, max_n=6, bucket=2000000)
wv.add_vectors(["test_key1", "test_key2"], np.array([[0, 0], [1, 1]]))

self.assertEqual(wv.key_to_index["test_key1"], 0)
self.assertEqual(wv.index_to_key[0], "test_key1")
self.assertTrue(np.all(wv.vectors[0] == np.array([0, 0])))

self.assertEqual(wv.key_to_index["test_key2"], 1)
self.assertEqual(wv.index_to_key[1], "test_key2")
self.assertTrue(np.all(wv.vectors[1] == np.array([1, 1])))


if __name__ == '__main__':
logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.DEBUG)
unittest.main()

0 comments on commit 50a9e6b

Please sign in to comment.