Skip to content

Commit

Permalink
more code relo, to better FT or KV groupings
Browse files Browse the repository at this point in the history
  • Loading branch information
gojomo committed Dec 5, 2019
1 parent 381f5dc commit b912c75
Show file tree
Hide file tree
Showing 6 changed files with 267 additions and 277 deletions.
3 changes: 1 addition & 2 deletions gensim/models/fasttext.py
Expand Up @@ -290,10 +290,9 @@
import gensim.models._fasttext_bin

from gensim.models.word2vec import Word2VecVocab, Word2VecTrainables, train_sg_pair, train_cbow_pair # noqa
from gensim.models.keyedvectors import KeyedVectors, _l2_norm
from gensim.models.keyedvectors import KeyedVectors, _l2_norm, _save_word2vec_format
from gensim.models.base_any2vec import BaseWordEmbeddingsModel
from gensim.models.utils_any2vec import ft_ngram_hashes
from gensim.models.utils_any2vec import _save_word2vec_format
from gensim import utils
from gensim.utils import deprecated, call_on_class_only

Expand Down
218 changes: 203 additions & 15 deletions gensim/models/keyedvectors.py
Expand Up @@ -169,20 +169,16 @@

from numpy import dot, float32 as REAL, memmap as np_memmap, \
double, array, zeros, vstack, sqrt, newaxis, integer, \
ndarray, sum as np_sum, prod, argmax
ndarray, sum as np_sum, prod, argmax, dtype, ascontiguousarray, \
frombuffer
import numpy as np

from gensim import utils, matutils # utility fnc for pickling, common scipy operations etc
from gensim.corpora.dictionary import Dictionary
from six import string_types, integer_types
from six import string_types, integer_types, iteritems
from six.moves import zip, range
from scipy import stats
from gensim.utils import deprecated
from gensim.models.utils_any2vec import (
_save_word2vec_format,
_load_word2vec_format,
ft_ngram_hashes,
)

# For backwards compatibility, see https://github.com/RaRe-Technologies/gensim/issues/2201
#
Expand Down Expand Up @@ -318,7 +314,7 @@ def add(self, entities, weights, replace=False):
def __setitem__(self, entities, weights):
"""Add entities and theirs vectors in a manual way.
If some entity is already in the vocabulary, old vector is replaced with the new one.
This method is alias for :meth:`~gensim.models.keyedvectors.BaseKeyedVectors.add` with `replace=True`.
This method is alias for :meth:`~gensim.models.keyedvectors.KeyedVectors.add` with `replace=True`.
Parameters
----------
Expand Down Expand Up @@ -639,7 +635,7 @@ def most_similar_cosmul(self, positive=None, negative=None, topn=10):
Additional positive or negative examples contribute to the numerator or denominator,
respectively - a potentially sensible but untested extension of the method.
With a single positive example, rankings will be the same as in the default
:meth:`~gensim.models.keyedvectors.WordEmbeddingsKeyedVectors.most_similar`.
:meth:`~gensim.models.keyedvectors.KeyedVectors.most_similar`.
Parameters
----------
Expand Down Expand Up @@ -789,7 +785,7 @@ def distances(self, word_or_vector, other_words=()):

def distance(self, w1, w2):
"""Compute cosine distance between two words.
Calculate 1 - :meth:`~gensim.models.keyedvectors.WordEmbeddingsKeyedVectors.similarity`.
Calculate 1 - :meth:`~gensim.models.keyedvectors.KeyedVectors.similarity`.
Parameters
----------
Expand Down Expand Up @@ -849,7 +845,7 @@ def n_similarity(self, ws1, ws2):
@staticmethod
def _log_evaluate_word_analogies(section):
"""Calculate score by section, helper for
:meth:`~gensim.models.keyedvectors.WordEmbeddingsKeyedVectors.evaluate_word_analogies`.
:meth:`~gensim.models.keyedvectors.KeyedVectors.evaluate_word_analogies`.
Parameters
----------
Expand All @@ -871,7 +867,7 @@ def _log_evaluate_word_analogies(section):
def evaluate_word_analogies(self, analogies, restrict_vocab=300000, case_insensitive=True, dummy4unknown=False):
"""Compute performance of the model on an analogy test set.
This is modern variant of :meth:`~gensim.models.keyedvectors.WordEmbeddingsKeyedVectors.accuracy`, see
This is modern variant of :meth:`~gensim.models.keyedvectors.KeyedVectors.accuracy`, see
`discussion on GitHub #1935 <https://github.com/RaRe-Technologies/gensim/pull/1935>`_.
The accuracy is reported (printed to log and returned as a score) for each section separately,
Expand Down Expand Up @@ -1162,7 +1158,6 @@ def save_word2vec_format(self, fname, fvocab=None, binary=False, total_vec=None)
(in case word vectors are appended with document vectors afterwards).
"""
# from gensim.models.word2vec import save_word2vec_format
_save_word2vec_format(
fname, self.vocab, self.vectors, fvocab=fvocab, binary=binary, total_vec=total_vec)

Expand Down Expand Up @@ -1202,11 +1197,10 @@ def load_word2vec_format(cls, fname, fvocab=None, binary=False, encoding='utf8',
Returns
-------
:class:`~gensim.models.keyedvectors.Word2VecKeyedVectors`
:class:`~gensim.models.keyedvectors.KeyedVectors`
Loaded model.
"""
# from gensim.models.word2vec import load_word2vec_format
return _load_word2vec_format(
cls, fname, fvocab=fvocab, binary=binary, encoding=encoding, unicode_errors=unicode_errors,
limit=limit, datatype=datatype)
Expand Down Expand Up @@ -1675,3 +1669,197 @@ def __str__(self):



def _save_word2vec_format(fname, vocab, vectors, fvocab=None, binary=False, total_vec=None):
"""Store the input-hidden weight matrix in the same format used by the original
C word2vec-tool, for compatibility.
Parameters
----------
fname : str
The file path used to save the vectors in.
vocab : dict
The vocabulary of words.
vectors : numpy.array
The vectors to be stored.
fvocab : str, optional
File path used to save the vocabulary.
binary : bool, optional
If True, the data wil be saved in binary word2vec format, else it will be saved in plain text.
total_vec : int, optional
Explicitly specify total number of vectors
(in case word vectors are appended with document vectors afterwards).
"""
if not (vocab or vectors):
raise RuntimeError("no input")
if total_vec is None:
total_vec = len(vocab)
vector_size = vectors.shape[1]
if fvocab is not None:
logger.info("storing vocabulary in %s", fvocab)
with utils.open(fvocab, 'wb') as vout:
for word, vocab_ in sorted(iteritems(vocab), key=lambda item: -item[1].count):
vout.write(utils.to_utf8("%s %s\n" % (word, vocab_.count)))
logger.info("storing %sx%s projection weights into %s", total_vec, vector_size, fname)
assert (len(vocab), vector_size) == vectors.shape
with utils.open(fname, 'wb') as fout:
fout.write(utils.to_utf8("%s %s\n" % (total_vec, vector_size)))
# store in sorted order: most frequent words at the top
for word, vocab_ in sorted(iteritems(vocab), key=lambda item: -item[1].count):
row = vectors[vocab_.index]
if binary:
row = row.astype(REAL)
fout.write(utils.to_utf8(word) + b" " + row.tostring())
else:
fout.write(utils.to_utf8("%s %s\n" % (word, ' '.join(repr(val) for val in row))))


# Functions for internal use by _load_word2vec_format function


def _add_word_to_result(result, counts, word, weights, vocab_size):

word_id = len(result.vocab)
if word in result.vocab:
logger.warning("duplicate word '%s' in word2vec file, ignoring all but first", word)
return
if counts is None:
# most common scenario: no vocab file given. just make up some bogus counts, in descending order
word_count = vocab_size - word_id
elif word in counts:
# use count from the vocab file
word_count = counts[word]
else:
logger.warning("vocabulary file is incomplete: '%s' is missing", word)
word_count = None

result.vocab[word] = Vocab(index=word_id, count=word_count)
result.vectors[word_id] = weights
result.index2word.append(word)


def _add_bytes_to_result(result, counts, chunk, vocab_size, vector_size, datatype, unicode_errors):
start = 0
processed_words = 0
bytes_per_vector = vector_size * dtype(REAL).itemsize
max_words = vocab_size - len(result.vocab)
for _ in range(max_words):
i_space = chunk.find(b' ', start)
i_vector = i_space + 1

if i_space == -1 or (len(chunk) - i_vector) < bytes_per_vector:
break

word = chunk[start:i_space].decode("utf-8", errors=unicode_errors)
# Some binary files are reported to have obsolete new line in the beginning of word, remove it
word = word.lstrip('\n')
vector = frombuffer(chunk, offset=i_vector, count=vector_size, dtype=REAL).astype(datatype)
_add_word_to_result(result, counts, word, vector, vocab_size)
start = i_vector + bytes_per_vector
processed_words += 1

return processed_words, chunk[start:]


def _word2vec_read_binary(fin, result, counts, vocab_size, vector_size, datatype, unicode_errors, binary_chunk_size):
chunk = b''
tot_processed_words = 0

while tot_processed_words < vocab_size:
new_chunk = fin.read(binary_chunk_size)
chunk += new_chunk
processed_words, chunk = _add_bytes_to_result(
result, counts, chunk, vocab_size, vector_size, datatype, unicode_errors)
tot_processed_words += processed_words
if len(new_chunk) < binary_chunk_size:
break
if tot_processed_words != vocab_size:
raise EOFError("unexpected end of input; is count incorrect or file otherwise damaged?")


def _word2vec_read_text(fin, result, counts, vocab_size, vector_size, datatype, unicode_errors, encoding):
for line_no in range(vocab_size):
line = fin.readline()
if line == b'':
raise EOFError("unexpected end of input; is count incorrect or file otherwise damaged?")
parts = utils.to_unicode(line.rstrip(), encoding=encoding, errors=unicode_errors).split(" ")
if len(parts) != vector_size + 1:
raise ValueError("invalid vector on line %s (is this really the text format?)" % line_no)
word, weights = parts[0], [datatype(x) for x in parts[1:]]
_add_word_to_result(result, counts, word, weights, vocab_size)


def _load_word2vec_format(cls, fname, fvocab=None, binary=False, encoding='utf8', unicode_errors='strict',
limit=None, datatype=REAL, binary_chunk_size=100 * 1024):
"""Load the input-hidden weight matrix from the original C word2vec-tool format.
Note that the information stored in the file is incomplete (the binary tree is missing),
so while you can query for word similarity etc., you cannot continue training
with a model loaded this way.
Parameters
----------
fname : str
The file path to the saved word2vec-format file.
fvocab : str, optional
File path to the vocabulary.Word counts are read from `fvocab` filename, if set
(this is the file generated by `-save-vocab` flag of the original C tool).
binary : bool, optional
If True, indicates whether the data is in binary word2vec format.
encoding : str, optional
If you trained the C model using non-utf8 encoding for words, specify that encoding in `encoding`.
unicode_errors : str, optional
default 'strict', is a string suitable to be passed as the `errors`
argument to the unicode() (Python 2.x) or str() (Python 3.x) function. If your source
file may include word tokens truncated in the middle of a multibyte unicode character
(as is common from the original word2vec.c tool), 'ignore' or 'replace' may help.
limit : int, optional
Sets a maximum number of word-vectors to read from the file. The default,
None, means read all.
datatype : type, optional
(Experimental) Can coerce dimensions to a non-default float type (such as `np.float16`) to save memory.
Such types may result in much slower bulk operations or incompatibility with optimized routines.)
binary_chunk_size : int, optional
Read input file in chunks of this many bytes for performance reasons.
Returns
-------
object
Returns the loaded model as an instance of :class:`cls`.
"""

counts = None
if fvocab is not None:
logger.info("loading word counts from %s", fvocab)
counts = {}
with utils.open(fvocab, 'rb') as fin:
for line in fin:
word, count = utils.to_unicode(line, errors=unicode_errors).strip().split()
counts[word] = int(count)

logger.info("loading projection weights from %s", fname)
with utils.open(fname, 'rb') as fin:
header = utils.to_unicode(fin.readline(), encoding=encoding)
vocab_size, vector_size = (int(x) for x in header.split()) # throws for invalid file format
if limit:
vocab_size = min(vocab_size, limit)
result = cls(vector_size)
result.vector_size = vector_size
result.vectors = zeros((vocab_size, vector_size), dtype=datatype)

if binary:
_word2vec_read_binary(fin, result, counts,
vocab_size, vector_size, datatype, unicode_errors, binary_chunk_size)
else:
_word2vec_read_text(fin, result, counts, vocab_size, vector_size, datatype, unicode_errors, encoding)
if result.vectors.shape[0] != len(result.vocab):
logger.info(
"duplicate words detected, shrinking matrix size from %i to %i",
result.vectors.shape[0], len(result.vocab)
)
result.vectors = ascontiguousarray(result.vectors[: len(result.vocab)])
assert (len(result.vocab), vector_size) == result.vectors.shape

logger.info("loaded %s matrix from %s", result.vectors.shape, fname)
return result
5 changes: 2 additions & 3 deletions gensim/models/poincare.py
Expand Up @@ -56,8 +56,7 @@
from six.moves import zip, range

from gensim import utils, matutils
from gensim.models.keyedvectors import Vocab, BaseKeyedVectors
from gensim.models.utils_any2vec import _save_word2vec_format, _load_word2vec_format
from gensim.models.keyedvectors import Vocab, KeyedVectors, _save_word2vec_format, _load_word2vec_format
from numpy import float32 as REAL

try:
Expand Down Expand Up @@ -860,7 +859,7 @@ def compute_loss(self):
self._loss_computed = True


class PoincareKeyedVectors(BaseKeyedVectors):
class PoincareKeyedVectors(KeyedVectors):
"""Vectors and vocab for the :class:`~gensim.models.poincare.PoincareModel` training class.
Used to perform operations on the vectors such as vector lookup, distance calculations etc.
Expand Down

0 comments on commit b912c75

Please sign in to comment.