From 58d560b545e6df4cfc5fd3879f8647ba3a7a0e3b Mon Sep 17 00:00:00 2001 From: Dmitry Persiyanov Date: Tue, 20 Mar 2018 10:06:13 +0300 Subject: [PATCH] Add `gensim.models.BaseKeyedVectors.add_entity` method for fill `KeyedVectors` in manual way. Fix #1942 (#1957) * Introduce BaseKeyedVectors.add(...) method * make default count=1 * add test on add_word method * address @menshikh-iv comments * fix test_keyedvectors after removing add_word alias * add __setitem__, add bulk entities processing + some tests on new functionality * addressing @menshikh-iv comments on docstrings * addressing @gojomo comments * adrressing nitpicks * make self.vectors = np.zeros((0, vector_size)) by default * fix pep8 --- gensim/models/keyedvectors.py | 66 +++++++++++++++++++++++++++-- gensim/test/test_keyedvectors.py | 72 ++++++++++++++++++++++++++++++++ 2 files changed, 135 insertions(+), 3 deletions(-) diff --git a/gensim/models/keyedvectors.py b/gensim/models/keyedvectors.py index 47632a9650..dd2950dc4e 100644 --- a/gensim/models/keyedvectors.py +++ b/gensim/models/keyedvectors.py @@ -72,8 +72,8 @@ except ImportError: PYEMD_EXT = False -from numpy import dot, zeros, float32 as REAL, empty, memmap as np_memmap, \ - double, array, vstack, sqrt, newaxis, integer, \ +from numpy import dot, float32 as REAL, empty, memmap as np_memmap, \ + double, array, zeros, vstack, sqrt, newaxis, integer, \ ndarray, sum as np_sum, prod, argmax, divide as np_divide import numpy as np from gensim import utils, matutils # utility fnc for pickling, common scipy operations etc @@ -109,7 +109,7 @@ def __str__(self): class BaseKeyedVectors(utils.SaveLoad): def __init__(self, vector_size): - self.vectors = [] + self.vectors = zeros((0, vector_size)) self.vocab = {} self.vector_size = vector_size self.index2entity = [] @@ -154,6 +154,65 @@ def get_vector(self, entity): else: raise KeyError("'%s' not in vocabulary" % entity) + def add(self, entities, weights, replace=False): + """Add entities and theirs vectors in a manual way. + If some entity is already in the vocabulary, old vector is keeped unless `replace` flag is True. + + Parameters + ---------- + entities : list of str + Entities specified by string tags. + weights: {list of numpy.ndarray, numpy.ndarray} + List of 1D np.array vectors or 2D np.array of vectors. + replace: bool, optional + Flag indicating whether to replace vectors for entities which are already in the vocabulary, + if True - replace vectors, otherwise - keep old vectors. + + """ + if isinstance(entities, string_types): + entities = [entities] + weights = np.array(weights).reshape(1, -1) + elif isinstance(weights, list): + weights = np.array(weights) + + in_vocab_mask = np.zeros(len(entities), dtype=np.bool) + for idx, entity in enumerate(entities): + if entity in self.vocab: + in_vocab_mask[idx] = True + + # add new entities to the vocab + for idx in np.nonzero(~in_vocab_mask)[0]: + entity = entities[idx] + self.vocab[entity] = Vocab(index=len(self.vocab), count=1) + self.index2entity.append(entity) + + # add vectors for new entities + self.vectors = vstack((self.vectors, weights[~in_vocab_mask])) + + # change vectors for in_vocab entities if `replace` flag is specified + if replace: + in_vocab_idxs = [self.vocab[entities[idx]].index for idx in np.nonzero(in_vocab_mask)[0]] + self.vectors[in_vocab_idxs] = weights[in_vocab_mask] + + def __setitem__(self, entities, weights): + """Add entities and theirs vectors in a manual way. + If some entity is already in the vocabulary, old vector is replaced with the new one. + This method is alias for `add` with `replace=True`. + + Parameters + ---------- + entities : {str, list of str} + Entities specified by string tags. + weights: {list of numpy.ndarray, numpy.ndarray} + List of 1D np.array vectors or 2D np.array of vectors. + + """ + if not isinstance(entities, list): + entities = [entities] + weights = weights.reshape(1, -1) + + self.add(entities, weights, replace=True) + def __getitem__(self, entities): """ Accept a single entity (string tag) or list of entities as input. @@ -163,6 +222,7 @@ def __getitem__(self, entities): If a list, return designated tags' vector representations as a 2D numpy array: #tags x #vector_size. + """ if isinstance(entities, string_types): # allow calls like trained_model['office'], as a shorthand for trained_model[['office']] diff --git a/gensim/test/test_keyedvectors.py b/gensim/test/test_keyedvectors.py index 69579d5f58..26eb443cc1 100644 --- a/gensim/test/test_keyedvectors.py +++ b/gensim/test/test_keyedvectors.py @@ -184,6 +184,78 @@ def test_wv_property(self): """Test that the deprecated `wv` property returns `self`. To be removed in v4.0.0.""" self.assertTrue(self.vectors is self.vectors.wv) + def test_add_single(self): + """Test that adding entity in a manual way works correctly.""" + entities = ['___some_entity{}_not_present_in_keyed_vectors___'.format(i) for i in range(5)] + vectors = [np.random.randn(self.vectors.vector_size) for _ in range(5)] + + # Test `add` on already filled kv. + for ent, vector in zip(entities, vectors): + self.vectors.add(ent, vector) + + for ent, vector in zip(entities, vectors): + self.assertTrue(np.allclose(self.vectors[ent], vector)) + + # Test `add` on empty kv. + kv = EuclideanKeyedVectors(self.vectors.vector_size) + for ent, vector in zip(entities, vectors): + kv.add(ent, vector) + + for ent, vector in zip(entities, vectors): + self.assertTrue(np.allclose(kv[ent], vector)) + + def test_add_multiple(self): + """Test that adding a bulk of entities in a manual way works correctly.""" + entities = ['___some_entity{}_not_present_in_keyed_vectors___'.format(i) for i in range(5)] + vectors = [np.random.randn(self.vectors.vector_size) for _ in range(5)] + + # Test `add` on already filled kv. + vocab_size = len(self.vectors.vocab) + self.vectors.add(entities, vectors, replace=False) + self.assertEqual(vocab_size + len(entities), len(self.vectors.vocab)) + + for ent, vector in zip(entities, vectors): + self.assertTrue(np.allclose(self.vectors[ent], vector)) + + # Test `add` on empty kv. + kv = EuclideanKeyedVectors(self.vectors.vector_size) + kv[entities] = vectors + self.assertEqual(len(kv.vocab), len(entities)) + + for ent, vector in zip(entities, vectors): + self.assertTrue(np.allclose(kv[ent], vector)) + + def test_set_item(self): + """Test that __setitem__ works correctly.""" + vocab_size = len(self.vectors.vocab) + + # Add new entity. + entity = '___some_new_entity___' + vector = np.random.randn(self.vectors.vector_size) + self.vectors[entity] = vector + + self.assertEqual(len(self.vectors.vocab), vocab_size + 1) + self.assertTrue(np.allclose(self.vectors[entity], vector)) + + # Replace vector for entity in vocab. + vocab_size = len(self.vectors.vocab) + vector = np.random.randn(self.vectors.vector_size) + self.vectors['war'] = vector + + self.assertEqual(len(self.vectors.vocab), vocab_size) + self.assertTrue(np.allclose(self.vectors['war'], vector)) + + # __setitem__ on several entities. + vocab_size = len(self.vectors.vocab) + entities = ['war', '___some_new_entity1___', '___some_new_entity2___', 'terrorism', 'conflict'] + vectors = [np.random.randn(self.vectors.vector_size) for _ in range(len(entities))] + + self.vectors[entities] = vectors + + self.assertEqual(len(self.vectors.vocab), vocab_size + 2) + for ent, vector in zip(entities, vectors): + self.assertTrue(np.allclose(self.vectors[ent], vector)) + if __name__ == '__main__': logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.DEBUG)