diff --git a/gensim/models/base_any2vec.py b/gensim/models/base_any2vec.py index 5cdb54930d..8f7123e1d7 100644 --- a/gensim/models/base_any2vec.py +++ b/gensim/models/base_any2vec.py @@ -43,6 +43,7 @@ from types import GeneratorType from gensim.utils import deprecated import warnings +import itertools try: from queue import Queue @@ -130,6 +131,11 @@ def _check_training_sanity(self, epochs=None, total_examples=None, total_words=N """Check that the training parameters provided make sense. e.g. raise error if `epochs` not provided.""" raise NotImplementedError() + def _check_input_data_sanity(self, data_iterable=None, data_iterables=None): + """Check that only one argument is not None.""" + if not ((data_iterable is not None) ^ (data_iterables is not None)): + raise ValueError("You must provide only one of singlestream or multistream arguments.") + def _worker_loop(self, job_queue, progress_queue): """Train the model, lifting batches of data from the queue. @@ -322,7 +328,7 @@ def _log_epoch_progress(self, progress_queue, job_queue, cur_epoch=0, total_exam self.total_train_time += elapsed return trained_word_count, raw_word_count, job_tally - def _train_epoch(self, data_iterable, cur_epoch=0, total_examples=None, + def _train_epoch(self, data_iterable=None, data_iterables=None, cur_epoch=0, total_examples=None, total_words=None, queue_factor=2, report_delay=1.0): """Train the model for a single epoch. @@ -330,6 +336,8 @@ def _train_epoch(self, data_iterable, cur_epoch=0, total_examples=None, ---------- data_iterable : iterable of list of object The input corpus. This will be split in chunks and these chunks will be pushed to the queue. + data_iterables : iterable of iterables of list of object + The iterable of input streams like `data_iterable`. Use this parameter in multistream mode. cur_epoch : int, optional The current training epoch, needed to compute the training parameters for each job. For example in many implementations the learning rate would be dropping with the number of epochs. @@ -353,6 +361,7 @@ def _train_epoch(self, data_iterable, cur_epoch=0, total_examples=None, * Total word count used in training. """ + self._check_input_data_sanity(data_iterable, data_iterables) job_queue = Queue(maxsize=queue_factor * self.workers) progress_queue = Queue(maxsize=(queue_factor + 1) * self.workers) @@ -363,6 +372,9 @@ def _train_epoch(self, data_iterable, cur_epoch=0, total_examples=None, for _ in xrange(self.workers) ] + # Chain all input streams into one, because multistream training is not supported yet. + if data_iterables is not None: + data_iterable = itertools.chain(*data_iterables) workers.append(threading.Thread( target=self._job_producer, args=(data_iterable, job_queue), @@ -378,7 +390,7 @@ def _train_epoch(self, data_iterable, cur_epoch=0, total_examples=None, return trained_word_count, raw_word_count, job_tally - def train(self, data_iterable, epochs=None, total_examples=None, + def train(self, data_iterable=None, data_iterables=None, epochs=None, total_examples=None, total_words=None, queue_factor=2, report_delay=1.0, callbacks=(), **kwargs): """Train the model for multiple epochs using multiple workers. @@ -433,8 +445,9 @@ def train(self, data_iterable, epochs=None, total_examples=None, callback.on_epoch_begin(self) trained_word_count_epoch, raw_word_count_epoch, job_tally_epoch = self._train_epoch( - data_iterable, cur_epoch=cur_epoch, total_examples=total_examples, total_words=total_words, - queue_factor=queue_factor, report_delay=report_delay) + data_iterable=data_iterable, data_iterables=data_iterables, cur_epoch=cur_epoch, + total_examples=total_examples, total_words=total_words, queue_factor=queue_factor, + report_delay=report_delay) trained_word_count += trained_word_count_epoch raw_word_count += raw_word_count_epoch job_tally += job_tally_epoch @@ -525,9 +538,9 @@ def _do_train_job(self, data_iterable, job_parameters, thread_private_mem): def _set_train_params(self, **kwargs): raise NotImplementedError() - def __init__(self, sentences=None, workers=3, vector_size=100, epochs=5, callbacks=(), batch_words=10000, - trim_rule=None, sg=0, alpha=0.025, window=5, seed=1, hs=0, negative=5, ns_exponent=0.75, cbow_mean=1, - min_alpha=0.0001, compute_loss=False, fast_version=0, **kwargs): + def __init__(self, sentences=None, input_streams=None, workers=3, vector_size=100, epochs=5, callbacks=(), + batch_words=10000, trim_rule=None, sg=0, alpha=0.025, window=5, seed=1, hs=0, negative=5, + ns_exponent=0.75, cbow_mean=1, min_alpha=0.0001, compute_loss=False, fast_version=0, **kwargs): """ Parameters @@ -624,13 +637,20 @@ def __init__(self, sentences=None, workers=3, vector_size=100, epochs=5, callbac self.neg_labels = zeros(self.negative + 1) self.neg_labels[0] = 1. - if sentences is not None: - if isinstance(sentences, GeneratorType): + if sentences is not None or input_streams is not None: + self._check_input_data_sanity(data_iterable=sentences, data_iterables=input_streams) + if input_streams is not None: + if not isinstance(input_streams, (tuple, list)): + raise TypeError("You must pass tuple or list as the input_streams argument.") + if any(isinstance(stream, GeneratorType) for stream in input_streams): + raise TypeError("You can't pass a generator as any of input streams. Try an iterator.") + elif isinstance(sentences, GeneratorType): raise TypeError("You can't pass a generator as the sentences argument. Try an iterator.") - self.build_vocab(sentences, trim_rule=trim_rule) + + self.build_vocab(sentences=sentences, input_streams=input_streams, trim_rule=trim_rule) self.train( - sentences, total_examples=self.corpus_count, epochs=self.epochs, start_alpha=self.alpha, - end_alpha=self.min_alpha, compute_loss=compute_loss) + sentences=sentences, input_streams=input_streams, total_examples=self.corpus_count, epochs=self.epochs, + start_alpha=self.alpha, end_alpha=self.min_alpha, compute_loss=compute_loss) else: if trim_rule is not None: logger.warning( @@ -763,7 +783,8 @@ def __str__(self): self.__class__.__name__, len(self.wv.index2word), self.vector_size, self.alpha ) - def build_vocab(self, sentences, update=False, progress_per=10000, keep_raw_vocab=False, trim_rule=None, **kwargs): + def build_vocab(self, sentences=None, input_streams=None, workers=None, update=False, progress_per=10000, + keep_raw_vocab=False, trim_rule=None, **kwargs): """Build vocabulary from a sequence of sentences (can be a once-only generator stream). Parameters @@ -773,7 +794,13 @@ def build_vocab(self, sentences, update=False, progress_per=10000, keep_raw_voca consider an iterable that streams the sentences directly from disk/network. See :class:`~gensim.models.word2vec.BrownCorpus`, :class:`~gensim.models.word2vec.Text8Corpus` or :class:`~gensim.models.word2vec.LineSentence` module for such examples. - update : bool, optional + input_streams : list or tuple of iterable of iterables + The tuple or list of `sentences`-like arguments. Use it if you have multiple input streams. It is possible + to process streams in parallel, using `workers` parameter. + workers : int + Used if `input_streams` is passed. Determines how many processes to use for vocab building. + Actual number of workers is determined by `min(len(input_streams), workers)`. + update : bool If true, the new words in `sentences` will be added to model's vocab. progress_per : int, optional Indicates how many words to process before showing/updating the progress. @@ -797,8 +824,10 @@ def build_vocab(self, sentences, update=False, progress_per=10000, keep_raw_voca Key word arguments propagated to `self.vocabulary.prepare_vocab` """ + workers = workers or self.workers total_words, corpus_count = self.vocabulary.scan_vocab( - sentences, progress_per=progress_per, trim_rule=trim_rule) + sentences=sentences, input_streams=input_streams, progress_per=progress_per, trim_rule=trim_rule, + workers=workers) self.corpus_count = corpus_count report_values = self.vocabulary.prepare_vocab( self.hs, self.negative, self.wv, update=update, keep_raw_vocab=keep_raw_vocab, @@ -887,7 +916,7 @@ def estimate_memory(self, vocab_size=None, report=None): ) return report - def train(self, sentences, total_examples=None, total_words=None, + def train(self, sentences=None, input_streams=None, total_examples=None, total_words=None, epochs=None, start_alpha=None, end_alpha=None, word_count=0, queue_factor=2, report_delay=1.0, compute_loss=False, callbacks=()): """Train the model. If the hyper-parameters are passed, they override the ones set in the constructor. @@ -933,8 +962,8 @@ def train(self, sentences, total_examples=None, total_words=None, self.compute_loss = compute_loss self.running_training_loss = 0.0 return super(BaseWordEmbeddingsModel, self).train( - sentences, total_examples=total_examples, total_words=total_words, - epochs=epochs, start_alpha=start_alpha, end_alpha=end_alpha, word_count=word_count, + data_iterable=sentences, data_iterables=input_streams, total_examples=total_examples, + total_words=total_words, epochs=epochs, start_alpha=start_alpha, end_alpha=end_alpha, word_count=word_count, queue_factor=queue_factor, report_delay=report_delay, compute_loss=compute_loss, callbacks=callbacks) def _get_job_params(self, cur_epoch): diff --git a/gensim/models/doc2vec.py b/gensim/models/doc2vec.py index d73e6e777a..f5f2b6e66b 100644 --- a/gensim/models/doc2vec.py +++ b/gensim/models/doc2vec.py @@ -56,6 +56,7 @@ import logging import os import warnings +import multiprocessing try: from queue import Queue @@ -64,6 +65,7 @@ from collections import namedtuple, defaultdict from timeit import default_timer +from functools import reduce from numpy import zeros, float32 as REAL, empty, ones, \ memmap as np_memmap, vstack, integer, dtype, sum as np_sum, add as np_add, repeat as np_repeat, concatenate @@ -74,7 +76,7 @@ from gensim.models.word2vec import Word2VecKeyedVectors, Word2VecVocab, Word2VecTrainables, train_cbow_pair,\ train_sg_pair, train_batch_sg from six.moves import xrange -from six import string_types, integer_types, itervalues +from six import string_types, integer_types, itervalues, iteritems from gensim.models.base_any2vec import BaseWordEmbeddingsModel from gensim.models.keyedvectors import Doc2VecKeyedVectors from types import GeneratorType @@ -435,8 +437,9 @@ class Doc2Vec(BaseWordEmbeddingsModel): includes not only the word vectors of each word in the context, but also the paragraph vector. """ - def __init__(self, documents=None, dm_mean=None, dm=1, dbow_words=0, dm_concat=0, dm_tag_count=1, - docvecs=None, docvecs_mapfile=None, comment=None, trim_rule=None, callbacks=(), **kwargs): + def __init__(self, documents=None, input_streams=None, dm_mean=None, dm=1, dbow_words=0, dm_concat=0, + dm_tag_count=1, docvecs=None, docvecs_mapfile=None, comment=None, trim_rule=None, callbacks=(), + **kwargs): """ Parameters @@ -445,6 +448,9 @@ def __init__(self, documents=None, dm_mean=None, dm=1, dbow_words=0, dm_concat=0 Input corpus, can be simply a list of elements, but for larger corpora,consider an iterable that streams the documents directly from disk/network. If you don't supply `documents`, the model is left uninitialized -- use if you plan to initialize it in some other way. + input_streams : list or tuple of iterable of iterables + The tuple or list of `documents`-like arguments. Use it if you have multiple input streams. It is possible + to process streams in parallel, using `workers` parameter. dm : {1,0}, optional Defines the training algorithm. If `dm=1`, 'distributed memory' (PV-DM) is used. Otherwise, `distributed bag of words` (PV-DBOW) is employed. @@ -566,12 +572,22 @@ def __init__(self, documents=None, dm_mean=None, dm=1, dbow_words=0, dm_concat=0 self.docvecs = docvecs or Doc2VecKeyedVectors(self.vector_size, docvecs_mapfile) self.comment = comment - if documents is not None: - if isinstance(documents, GeneratorType): + if documents is not None or input_streams is not None: + self._check_input_data_sanity(data_iterable=documents, data_iterables=input_streams) + if input_streams is not None: + if not isinstance(input_streams, (tuple, list)): + raise TypeError("You must pass tuple or list as the input_streams argument.") + if any(isinstance(stream, GeneratorType) for stream in input_streams): + raise TypeError("You can't pass a generator as any of input streams. Try an iterator.") + if any(isinstance(stream, TaggedLineDocument) for stream in input_streams): + warnings.warn("Using TaggedLineDocument in multistream mode can lead to incorrect results " + "because of tags collision.") + elif isinstance(documents, GeneratorType): raise TypeError("You can't pass a generator as the documents argument. Try an iterator.") - self.build_vocab(documents, trim_rule=trim_rule) + self.build_vocab(documents=documents, input_streams=input_streams, + trim_rule=trim_rule, workers=self.workers) self.train( - documents, total_examples=self.corpus_count, epochs=self.epochs, + documents=documents, input_streams=input_streams, total_examples=self.corpus_count, epochs=self.epochs, start_alpha=self.alpha, end_alpha=self.min_alpha, callbacks=callbacks) @property @@ -661,7 +677,7 @@ def _do_train_job(self, job, alpha, inits): ) return tally, self._raw_word_count(job) - def train(self, documents, total_examples=None, total_words=None, + def train(self, documents=None, input_streams=None, total_examples=None, total_words=None, epochs=None, start_alpha=None, end_alpha=None, word_count=0, queue_factor=2, report_delay=1.0, callbacks=()): """Update the model's neural weights. @@ -683,6 +699,9 @@ def train(self, documents, total_examples=None, total_words=None, Can be simply a list of elements, but for larger corpora,consider an iterable that streams the documents directly from disk/network. If you don't supply `documents`, the model is left uninitialized -- use if you plan to initialize it in some other way. + input_streams : list or tuple of iterable of iterables + The tuple or list of `documents`-like arguments. Use it if you have multiple input streams. It is possible + to process streams in parallel, using `workers` parameter. total_examples : int, optional Count of sentences. total_words : int, optional @@ -712,7 +731,7 @@ def train(self, documents, total_examples=None, total_words=None, """ super(Doc2Vec, self).train( - documents, total_examples=total_examples, total_words=total_words, + sentences=documents, input_streams=input_streams, total_examples=total_examples, total_words=total_words, epochs=epochs, start_alpha=start_alpha, end_alpha=end_alpha, word_count=word_count, queue_factor=queue_factor, report_delay=report_delay, callbacks=callbacks) @@ -988,7 +1007,8 @@ def estimate_memory(self, vocab_size=None, report=None): report['doctag_syn0'] = self.docvecs.count * self.vector_size * dtype(REAL).itemsize return super(Doc2Vec, self).estimate_memory(vocab_size, report=report) - def build_vocab(self, documents, update=False, progress_per=10000, keep_raw_vocab=False, trim_rule=None, **kwargs): + def build_vocab(self, documents=None, input_streams=None, update=False, progress_per=10000, keep_raw_vocab=False, + trim_rule=None, workers=None, **kwargs): """Build vocabulary from a sequence of sentences (can be a once-only generator stream). Parameters @@ -997,6 +1017,9 @@ def build_vocab(self, documents, update=False, progress_per=10000, keep_raw_voca Can be simply a list of :class:`~gensim.models.doc2vec.TaggedDocument` elements, but for larger corpora, consider an iterable that streams the documents directly from disk/network. See :class:`~gensim.models.doc2vec.TaggedBrownCorpus` or :class:`~gensim.models.doc2vec.TaggedLineDocument` + input_streams : list or tuple of iterable of iterables + The tuple or list of `documents`-like arguments. Use it if you have multiple input streams. It is possible + to process streams in parallel, using `workers` parameter. update : bool If true, the new words in `sentences` will be added to model's vocab. progress_per : int @@ -1017,12 +1040,19 @@ def build_vocab(self, documents, update=False, progress_per=10000, keep_raw_voca * `count` (int) - the word's frequency count in the corpus * `min_count` (int) - the minimum count threshold. + workers : int + Used if `input_streams` is passed. Determines how many processes to use for vocab building. + Actual number of workers is determined by `min(len(input_streams), workers)`. + **kwargs Additional key word arguments passed to the internal vocabulary construction. """ + workers = workers or self.workers total_words, corpus_count = self.vocabulary.scan_vocab( - documents, self.docvecs, progress_per=progress_per, trim_rule=trim_rule) + documents=documents, input_streams=input_streams, docvecs=self.docvecs, + progress_per=progress_per, trim_rule=trim_rule, workers=workers + ) self.corpus_count = corpus_count report_values = self.vocabulary.prepare_vocab( self.hs, self.negative, self.wv, update=update, keep_raw_vocab=keep_raw_vocab, trim_rule=trim_rule, @@ -1086,6 +1116,53 @@ def build_vocab_from_freq(self, word_freq, keep_raw_vocab=False, corpus_count=No self.hs, self.negative, self.wv, self.docvecs, update=update) +def _note_doctag(key, document_length, docvecs): + """Note a document tag during initial corpus scan, for structure sizing.""" + if isinstance(key, integer_types + (integer,)): + docvecs.max_rawint = max(docvecs.max_rawint, key) + else: + if key in docvecs.doctags: + docvecs.doctags[key] = docvecs.doctags[key].repeat(document_length) + else: + docvecs.doctags[key] = Doctag(len(docvecs.offset2doctag), document_length, 1) + docvecs.offset2doctag.append(key) + docvecs.count = docvecs.max_rawint + 1 + len(docvecs.offset2doctag) + + +def _scan_vocab_worker(stream, progress_queue, max_vocab_size, trim_rule): + min_reduce = 1 + vocab = defaultdict(int) + doclen2tags = defaultdict(list) + checked_string_types = 0 + document_no = -1 + total_words = 0 + for document_no, document in enumerate(stream): + if not checked_string_types: + if isinstance(document.words, string_types): + log_msg = "Each 'words' should be a list of words (usually unicode strings). " \ + "First 'words' here is instead plain %s." % type(document.words) + progress_queue.put(log_msg) + + checked_string_types += 1 + + document_length = len(document.words) + + for tag in document.tags: + doclen2tags[document_length].append(tag) + + for word in document.words: + vocab[word] += 1 + total_words += len(document.words) + + if max_vocab_size and len(vocab) > max_vocab_size: + utils.prune_vocab(vocab, min_reduce, trim_rule=trim_rule) + min_reduce += 1 + + progress_queue.put((total_words, document_no + 1)) + progress_queue.put(None) + return vocab, doclen2tags + + class Doc2VecVocab(Word2VecVocab): """Vocabulary used by :class:`~gensim.models.doc2vec.Doc2Vec`. @@ -1123,38 +1200,51 @@ def __init__(self, max_vocab_size=None, min_count=5, sample=1e-3, sorted_vocab=T max_vocab_size=max_vocab_size, min_count=min_count, sample=sample, sorted_vocab=sorted_vocab, null_word=null_word, ns_exponent=ns_exponent) - def scan_vocab(self, documents, docvecs, progress_per=10000, trim_rule=None): - """Create the models Vocabulary: A mapping from unique words in the corpus to their frequency count. - - Parameters - ---------- - documents : iterable of :class:`~gensim.models.doc2vec.TaggedDocument` - The tagged documents used to create the vocabulary. Their tags can be either str tokens or ints (faster). - docvecs : list of :class:`~gensim.models.keyedvectors.Doc2VecKeyedVectors` - The vector representations of the documents in our corpus. Each of them has a size == `vector_size`. - progress_per : int - Progress will be logged every `progress_per` documents. - trim_rule : function, optional - Vocabulary trimming rule, specifies whether certain words should remain in the vocabulary, - be trimmed away, or handled using the default (discard if word count < min_count). - Can be None (min_count will be used, look to :func:`~gensim.utils.keep_vocab_item`), - or a callable that accepts parameters (word, count, min_count) and returns either - :attr:`gensim.utils.RULE_DISCARD`, :attr:`gensim.utils.RULE_KEEP` or :attr:`gensim.utils.RULE_DEFAULT`. - The rule, if given, is only used to prune vocabulary during - :meth:`~gensim.models.doc2vec.Doc2Vec.build_vocab` and is not stored as part of the model. + def _scan_vocab_multistream(self, input_streams, docvecs, workers, trim_rule): + manager = multiprocessing.Manager() + progress_queue = manager.Queue() - The input parameters are of the following types: - * `word` (str) - the word we are examining - * `count` (int) - the word's frequency count in the corpus - * `min_count` (int) - the minimum count threshold. + workers = min(workers, len(input_streams)) + logger.info("Scanning vocab in %i processes.", workers) + pool = multiprocessing.Pool(processes=workers) - Returns - ------- - (int, int) - Tuple of (Total words in the corpus, number of documents) + worker_max_vocab_size = self.max_vocab_size // workers if self.max_vocab_size else None + results = [ + pool.apply_async(_scan_vocab_worker, + (stream, progress_queue, worker_max_vocab_size, trim_rule) + ) for stream in input_streams + ] + pool.close() - """ - logger.info("collecting all words and their counts") + unfinished_tasks = len(results) + total_words = 0 + total_documents = 0 + while unfinished_tasks > 0: + report = progress_queue.get() + if report is None: + unfinished_tasks -= 1 + logger.info("scan vocab task finished, processed %i documents and %i words;" + " awaiting finish of %i more tasks", total_documents, total_words, unfinished_tasks) + elif isinstance(report, string_types): + logger.warning(report) + else: + num_words, num_documents = report + total_words += num_words + total_documents += num_documents + + results = [res.get() for res in results] # pairs (vocab, doclen2tags) + self.raw_vocab = reduce(utils.merge_counts, [r[0] for r in results]) + if self.max_vocab_size: + utils.trim_vocab_by_freq(self.raw_vocab, self.max_vocab_size, trim_rule=trim_rule) + + # Update `docvecs` with document tags information. + for (_, doclen2tags) in results: + for document_length, tags in iteritems(doclen2tags): + for tag in tags: + _note_doctag(tag, document_length, docvecs) + return total_words, total_documents + + def _scan_vocab_singlestream(self, documents, docvecs, progress_per, trim_rule): document_no = -1 total_words = 0 min_reduce = 1 @@ -1182,7 +1272,7 @@ def scan_vocab(self, documents, docvecs, progress_per=10000, trim_rule=None): document_length = len(document.words) for tag in document.tags: - self.note_doctag(tag, document_no, document_length, docvecs) + _note_doctag(tag, document_length, docvecs) for word in document.words: vocab[word] += 1 @@ -1192,38 +1282,54 @@ def scan_vocab(self, documents, docvecs, progress_per=10000, trim_rule=None): utils.prune_vocab(vocab, min_reduce, trim_rule=trim_rule) min_reduce += 1 - logger.info( - "collected %i word types and %i unique tags from a corpus of %i examples and %i words", - len(vocab), docvecs.count, document_no + 1, total_words - ) corpus_count = document_no + 1 self.raw_vocab = vocab return total_words, corpus_count - def note_doctag(self, key, document_no, document_length, docvecs): - """Note a document tag during initial corpus scan, for correctly setting the keyedvectors size. + def scan_vocab(self, documents=None, input_streams=None, docvecs=None, progress_per=10000, workers=None, + trim_rule=None): + """Create the models Vocabulary: A mapping from unique words in the corpus to their frequency count. Parameters ---------- - key : {int, str} - The tag to be noted. - document_no : int - The document's index in `docvecs`. Unused. - document_length : int - The document's length in words. + documents : iterable of :class:`~gensim.models.doc2vec.TaggedDocument` + The tagged documents used to create the vocabulary. Their tags can be either str tokens or ints (faster). docvecs : list of :class:`~gensim.models.keyedvectors.Doc2VecKeyedVectors` - Vector representations of the documents in the corpus. Each vector has size == `vector_size` + The vector representations of the documents in our corpus. Each of them has a size == `vector_size`. + progress_per : int + Progress will be logged every `progress_per` documents. + trim_rule : function, optional + Vocabulary trimming rule, specifies whether certain words should remain in the vocabulary, + be trimmed away, or handled using the default (discard if word count < min_count). + Can be None (min_count will be used, look to :func:`~gensim.utils.keep_vocab_item`), + or a callable that accepts parameters (word, count, min_count) and returns either + :attr:`gensim.utils.RULE_DISCARD`, :attr:`gensim.utils.RULE_KEEP` or :attr:`gensim.utils.RULE_DEFAULT`. + The rule, if given, is only used to prune vocabulary during + :meth:`~gensim.models.doc2vec.Doc2Vec.build_vocab` and is not stored as part of the model. + + The input parameters are of the following types: + * `word` (str) - the word we are examining + * `count` (int) - the word's frequency count in the corpus + * `min_count` (int) - the minimum count threshold. + + Returns + ------- + (int, int) + Tuple of (Total words in the corpus, number of documents) """ - if isinstance(key, integer_types + (integer,)): - docvecs.max_rawint = max(docvecs.max_rawint, key) + logger.info("collecting all words and their counts") + if input_streams is None: + total_words, corpus_count = self._scan_vocab_singlestream(documents, docvecs, progress_per, trim_rule) else: - if key in docvecs.doctags: - docvecs.doctags[key] = docvecs.doctags[key].repeat(document_length) - else: - docvecs.doctags[key] = Doctag(len(docvecs.offset2doctag), document_length, 1) - docvecs.offset2doctag.append(key) - docvecs.count = docvecs.max_rawint + 1 + len(docvecs.offset2doctag) + total_words, corpus_count = self._scan_vocab_multistream(input_streams, docvecs, workers, trim_rule) + + logger.info( + "collected %i word types and %i unique tags from a corpus of %i examples and %i words", + len(self.raw_vocab), docvecs.count, corpus_count, total_words + ) + + return total_words, corpus_count def indexed_doctags(self, doctag_tokens, docvecs): """Get the indexes and backing-arrays used in training examples. diff --git a/gensim/models/fasttext.py b/gensim/models/fasttext.py index 20430eb1e8..d2d4a6d8da 100644 --- a/gensim/models/fasttext.py +++ b/gensim/models/fasttext.py @@ -241,7 +241,7 @@ class FastText(BaseWordEmbeddingsModel): for the internal structure of words, besides their concurrence counts. """ - def __init__(self, sentences=None, sg=0, hs=0, size=100, alpha=0.025, window=5, min_count=5, + def __init__(self, sentences=None, input_streams=None, sg=0, hs=0, size=100, alpha=0.025, window=5, min_count=5, max_vocab_size=None, word_ngrams=1, sample=1e-3, seed=1, workers=3, min_alpha=0.0001, negative=5, ns_exponent=0.75, cbow_mean=1, hashfxn=hash, iter=5, null_word=0, min_n=3, max_n=6, sorted_vocab=1, bucket=2000000, trim_rule=None, batch_words=MAX_WORDS_IN_BATCH, callbacks=()): @@ -256,6 +256,9 @@ def __init__(self, sentences=None, sg=0, hs=0, size=100, alpha=0.025, window=5, or :class:`~gensim.models.word2vec.LineSentence` in :mod:`~gensim.models.word2vec` module for such examples. If you don't supply `sentences`, the model is left uninitialized -- use if you plan to initialize it in some other way. + input_streams : list or tuple of iterable of iterables + The tuple or list of `sentences`-like arguments. Use it if you have multiple input streams. It is possible + to process streams in parallel, using `workers` parameter. min_count : int, optional The model ignores all words with total frequency lower than this. size : int, optional @@ -341,11 +344,11 @@ def __init__(self, sentences=None, sg=0, hs=0, size=100, alpha=0.025, window=5, Initialize and train a `FastText` model:: >>> from gensim.models import FastText - >>> sentences = [["cat", "say", "meow"], ["dog", "say", "woof"]] + >>> input_streams = [[["cat", "say", "meow"], ["dog", "say", "woof"]]] >>> - >>> model = FastText(sentences, min_count=1) - >>> say_vector = model['say'] # get vector for a word - >>> of_vector = model['of'] # get vector for an out-of-vocab word + >>> model = FastText(input_streams=input_streams, min_count=1) + >>> say_vector = model['say'] # get vector for word + >>> of_vector = model['of'] # get vector for out-of-vocab word """ self.load = call_on_class_only @@ -364,9 +367,9 @@ def __init__(self, sentences=None, sg=0, hs=0, size=100, alpha=0.025, window=5, self.wv.bucket = self.bucket super(FastText, self).__init__( - sentences=sentences, workers=workers, vector_size=size, epochs=iter, callbacks=callbacks, - batch_words=batch_words, trim_rule=trim_rule, sg=sg, alpha=alpha, window=window, seed=seed, - hs=hs, negative=negative, cbow_mean=cbow_mean, min_alpha=min_alpha, fast_version=FAST_VERSION) + sentences=sentences, input_streams=input_streams, workers=workers, vector_size=size, epochs=iter, + callbacks=callbacks, batch_words=batch_words, trim_rule=trim_rule, sg=sg, alpha=alpha, window=window, + seed=seed, hs=hs, negative=negative, cbow_mean=cbow_mean, min_alpha=min_alpha, fast_version=FAST_VERSION) @property @deprecated("Attribute will be removed in 4.0.0, use wv.min_n instead") @@ -418,7 +421,8 @@ def syn0_ngrams_lockf(self): def num_ngram_vectors(self): return self.wv.num_ngram_vectors - def build_vocab(self, sentences, update=False, progress_per=10000, keep_raw_vocab=False, trim_rule=None, **kwargs): + def build_vocab(self, sentences=None, input_streams=None, update=False, progress_per=10000, keep_raw_vocab=False, + trim_rule=None, workers=None, **kwargs): """Build vocabulary from a sequence of sentences (can be a once-only generator stream). Each sentence must be a list of unicode strings. @@ -429,6 +433,9 @@ def build_vocab(self, sentences, update=False, progress_per=10000, keep_raw_voca consider an iterable that streams the sentences directly from disk/network. See :class:`~gensim.models.word2vec.BrownCorpus`, :class:`~gensim.models.word2vec.Text8Corpus` or :class:`~gensim.models.word2vec.LineSentence` in :mod:`~gensim.models.word2vec` module for such examples. + input_streams : list or tuple of iterable of iterables + The tuple or list of `sentences`-like arguments. Use it if you have multiple input streams. It is possible + to process streams in parallel, using `workers` parameter. update : bool If true, the new words in `sentences` will be added to model's vocab. progress_per : int @@ -449,6 +456,9 @@ def build_vocab(self, sentences, update=False, progress_per=10000, keep_raw_voca * `count` (int) - the word's frequency count in the corpus * `min_count` (int) - the minimum count threshold. + workers : int + Used if `input_streams` is passed. Determines how many processes to use for vocab building. + Actual number of workers is determined by `min(len(input_streams), workers)`. **kwargs Additional key word parameters passed to :meth:`~gensim.models.base_any2vec.BaseWordEmbeddingsModel.build_vocab`. @@ -479,8 +489,8 @@ def build_vocab(self, sentences, update=False, progress_per=10000, keep_raw_voca self.trainables.old_hash2index_len = len(self.wv.hash2index) return super(FastText, self).build_vocab( - sentences, update=update, progress_per=progress_per, - keep_raw_vocab=keep_raw_vocab, trim_rule=trim_rule, **kwargs) + sentences=sentences, input_streams=input_streams, update=update, progress_per=progress_per, + keep_raw_vocab=keep_raw_vocab, trim_rule=trim_rule, workers=workers, **kwargs) def _set_train_params(self, **kwargs): pass @@ -559,7 +569,7 @@ def _do_train_job(self, sentences, alpha, inits): return tally, self._raw_word_count(sentences) - def train(self, sentences, total_examples=None, total_words=None, + def train(self, sentences=None, input_streams=None, total_examples=None, total_words=None, epochs=None, start_alpha=None, end_alpha=None, word_count=0, queue_factor=2, report_delay=1.0, callbacks=(), **kwargs): """Update the model's neural weights from a sequence of sentences (can be a once-only generator stream). @@ -577,11 +587,14 @@ def train(self, sentences, total_examples=None, total_words=None, Parameters ---------- - sentences : iterable of iterables + sentences : {iterable of iterables, list or tuple of iterable of iterables} The `sentences` iterable can be simply a list of lists of tokens, but for larger corpora, consider an iterable that streams the sentences directly from disk/network. See :class:`~gensim.models.word2vec.BrownCorpus`, :class:`~gensim.models.word2vec.Text8Corpus` or :class:`~gensim.models.word2vec.LineSentence` in :mod:`~gensim.models.word2vec` module for such examples. + input_streams : list or tuple of iterable of iterables + The tuple or list of `sentences`-like arguments. Use it if you have multiple input streams. It is possible + to process streams in parallel, using `workers` parameter. total_examples : int Count of sentences. total_words : int @@ -620,7 +633,7 @@ def train(self, sentences, total_examples=None, total_words=None, """ super(FastText, self).train( - sentences, total_examples=total_examples, total_words=total_words, + sentences=sentences, input_streams=input_streams, total_examples=total_examples, total_words=total_words, epochs=epochs, start_alpha=start_alpha, end_alpha=end_alpha, word_count=word_count, queue_factor=queue_factor, report_delay=report_delay, callbacks=callbacks) self.trainables.get_vocab_word_vecs(self.wv) diff --git a/gensim/models/word2vec.py b/gensim/models/word2vec.py index d163784c1c..933e35e0bd 100755 --- a/gensim/models/word2vec.py +++ b/gensim/models/word2vec.py @@ -113,6 +113,7 @@ from copy import deepcopy from collections import defaultdict import threading +import multiprocessing import itertools import warnings @@ -135,6 +136,7 @@ from gensim.utils import deprecated from six import iteritems, itervalues, string_types from six.moves import xrange +from functools import reduce logger = logging.getLogger(__name__) @@ -626,7 +628,8 @@ class Word2Vec(BaseWordEmbeddingsModel): (which means that the size of the hidden layer is equal to the number of features `self.size`). """ - def __init__(self, sentences=None, size=100, alpha=0.025, window=5, min_count=5, + + def __init__(self, sentences=None, input_streams=None, size=100, alpha=0.025, window=5, min_count=5, max_vocab_size=None, sample=1e-3, seed=1, workers=3, min_alpha=0.0001, sg=0, hs=0, negative=5, ns_exponent=0.75, cbow_mean=1, hashfxn=hash, iter=5, null_word=0, trim_rule=None, sorted_vocab=1, batch_words=MAX_WORDS_IN_BATCH, compute_loss=False, callbacks=(), @@ -644,6 +647,9 @@ def __init__(self, sentences=None, size=100, alpha=0.025, window=5, min_count=5, `_. If you don't supply `sentences`, the model is left uninitialized -- use if you plan to initialize it in some other way. + input_streams : list or tuple of iterable of iterables + The tuple or list of `sentences`-like arguments. Use it if you have multiple input streams. It is possible + to process streams in parallel, using `workers` parameter. size : int, optional Dimensionality of the word vectors. window : int, optional @@ -707,7 +713,6 @@ def __init__(self, sentences=None, size=100, alpha=0.025, window=5, min_count=5, * `word` (str) - the word we are examining * `count` (int) - the word's frequency count in the corpus * `min_count` (int) - the minimum count threshold. - sorted_vocab : {0, 1}, optional If 1, sort the vocabulary by descending frequency before assigning word indexes. See :meth:`~gensim.models.word2vec.Word2VecVocab.sort_vocab()`. @@ -726,8 +731,8 @@ def __init__(self, sentences=None, size=100, alpha=0.025, window=5, min_count=5, Initialize and train a :class:`~gensim.models.word2vec.Word2Vec` model >>> from gensim.models import Word2Vec - >>> sentences = [["cat", "say", "meow"], ["dog", "say", "woof"]] - >>> model = Word2Vec(sentences, min_count=1) + >>> input_streams = [[["cat", "say", "meow"], ["dog", "say", "woof"]]] + >>> model = Word2Vec(input_streams=input_streams, min_count=1) """ self.max_final_vocab = max_final_vocab @@ -742,9 +747,9 @@ def __init__(self, sentences=None, size=100, alpha=0.025, window=5, min_count=5, self.trainables = Word2VecTrainables(seed=seed, vector_size=size, hashfxn=hashfxn) super(Word2Vec, self).__init__( - sentences=sentences, workers=workers, vector_size=size, epochs=iter, callbacks=callbacks, - batch_words=batch_words, trim_rule=trim_rule, sg=sg, alpha=alpha, window=window, seed=seed, - hs=hs, negative=negative, cbow_mean=cbow_mean, min_alpha=min_alpha, compute_loss=compute_loss, + sentences=sentences, input_streams=input_streams, workers=workers, vector_size=size, epochs=iter, + callbacks=callbacks, batch_words=batch_words, trim_rule=trim_rule, sg=sg, alpha=alpha, window=window, + seed=seed, hs=hs, negative=negative, cbow_mean=cbow_mean, min_alpha=min_alpha, compute_loss=compute_loss, fast_version=FAST_VERSION) def _do_train_job(self, sentences, alpha, inits): @@ -782,7 +787,7 @@ def _set_train_params(self, **kwargs): self.compute_loss = kwargs['compute_loss'] self.running_training_loss = 0 - def train(self, sentences, total_examples=None, total_words=None, + def train(self, sentences=None, input_streams=None, total_examples=None, total_words=None, epochs=None, start_alpha=None, end_alpha=None, word_count=0, queue_factor=2, report_delay=1.0, compute_loss=False, callbacks=()): """Update the model's neural weights from a sequence of sentences. @@ -810,11 +815,14 @@ def train(self, sentences, total_examples=None, total_words=None, or :class:`~gensim.models.word2vec.LineSentence` in :mod:`~gensim.models.word2vec` module for such examples. See also the `tutorial on data streaming in Python `_. - total_examples : int, optional - Count of sentences. Used to decay the `alpha` learning rate. - total_words : int, optional - Count of raw words in sentences. Used to decay the `alpha` learning rate. - epochs : int, optional + input_streams : list or tuple of iterable of iterables + The tuple or list of `sentences`-like arguments. Use it if you have multiple input streams. It is possible + to process streams in parallel, using `workers` parameter. + total_examples : int + Count of sentences. + total_words : int + Count of raw words in sentences. + epochs : int Number of iterations (epochs) over the corpus. start_alpha : float, optional Initial learning rate. If supplied, replaces the starting `alpha` from the constructor, @@ -842,16 +850,17 @@ def train(self, sentences, total_examples=None, total_words=None, Examples -------- >>> from gensim.models import Word2Vec - >>> sentences = [["cat", "say", "meow"], ["dog", "say", "woof"]] + >>> input_streams = [[["cat", "say", "meow"], ["dog", "say", "woof"]]] >>> >>> model = Word2Vec(min_count=1) - >>> model.build_vocab(sentences) # prepare the model vocabulary - >>> model.train(sentences, total_examples=model.corpus_count, epochs=model.iter) # train word vectors + >>> model.build_vocab(input_streams=input_streams) # prepare the model vocabulary + >>> model.train(input_streams=input_streams, + >>> total_examples=model.corpus_count, epochs=model.iter) # train word vectors (1, 30) """ return super(Word2Vec, self).train( - sentences, total_examples=total_examples, total_words=total_words, + sentences=sentences, input_streams=input_streams, total_examples=total_examples, total_words=total_words, epochs=epochs, start_alpha=start_alpha, end_alpha=end_alpha, word_count=word_count, queue_factor=queue_factor, report_delay=report_delay, compute_loss=compute_loss, callbacks=callbacks) @@ -1446,6 +1455,40 @@ def __iter__(self): i += self.max_sentence_length +def _scan_vocab_worker(stream, progress_queue, max_vocab_size=None, trim_rule=None): + """Do an initial scan of all words appearing in stream. + + Note: This function can not be Word2VecVocab's method because + of multiprocessing synchronization specifics in Python. + """ + min_reduce = 1 + vocab = defaultdict(int) + checked_string_types = 0 + sentence_no = -1 + total_words = 0 + for sentence_no, sentence in enumerate(stream): + if not checked_string_types: + if isinstance(sentence, string_types): + log_msg = "Each 'sentences' item should be a list of words (usually unicode strings). " \ + "First item here is instead plain %s." % type(sentence) + progress_queue.put(log_msg) + + checked_string_types += 1 + + for word in sentence: + vocab[word] += 1 + + if max_vocab_size and len(vocab) > max_vocab_size: + utils.prune_vocab(vocab, min_reduce, trim_rule=trim_rule) + min_reduce += 1 + + total_words += len(sentence) + + progress_queue.put((total_words, sentence_no + 1)) + progress_queue.put(None) + return vocab + + class Word2VecVocab(utils.SaveLoad): """Vocabulary used by :class:`~gensim.models.word2vec.Word2Vec`.""" def __init__( @@ -1461,9 +1504,7 @@ def __init__( self.max_final_vocab = max_final_vocab self.ns_exponent = ns_exponent - def scan_vocab(self, sentences, progress_per=10000, trim_rule=None): - """Do an initial scan of all words appearing in sentences.""" - logger.info("collecting all words and their counts") + def _scan_vocab_singlestream(self, sentences, progress_per, trim_rule): sentence_no = -1 total_words = 0 min_reduce = 1 @@ -1491,12 +1532,60 @@ def scan_vocab(self, sentences, progress_per=10000, trim_rule=None): utils.prune_vocab(vocab, min_reduce, trim_rule=trim_rule) min_reduce += 1 + corpus_count = sentence_no + 1 + self.raw_vocab = vocab + return total_words, corpus_count + + def _scan_vocab_multistream(self, input_streams, workers, trim_rule): + manager = multiprocessing.Manager() + progress_queue = manager.Queue() + + logger.info("Scanning vocab in %i processes.", min(workers, len(input_streams))) + + workers = min(workers, len(input_streams)) + pool = multiprocessing.Pool(processes=workers) + + worker_max_vocab_size = self.max_vocab_size // workers if self.max_vocab_size else None + results = [ + pool.apply_async(_scan_vocab_worker, + (stream, progress_queue, worker_max_vocab_size, trim_rule) + ) for stream in input_streams + ] + pool.close() + + unfinished_tasks = len(results) + total_words = 0 + total_sentences = 0 + while unfinished_tasks > 0: + report = progress_queue.get() + if report is None: + unfinished_tasks -= 1 + logger.info("scan vocab task finished, processed %i sentences and %i words;" + " awaiting finish of %i more tasks", total_sentences, total_words, unfinished_tasks) + elif isinstance(report, string_types): + logger.warning(report) + else: + num_words, num_sentences = report + total_words += num_words + total_sentences += num_sentences + + self.raw_vocab = reduce(utils.merge_counts, [res.get() for res in results]) + if self.max_vocab_size: + utils.trim_vocab_by_freq(self.raw_vocab, self.max_vocab_size, trim_rule=trim_rule) + return total_words, total_sentences + + def scan_vocab(self, sentences=None, input_streams=None, progress_per=10000, workers=None, trim_rule=None): + logger.info("collecting all words and their counts") + if sentences is not None: + total_words, corpus_count = self._scan_vocab_singlestream(sentences, progress_per, trim_rule) + else: + total_words, corpus_count = self._scan_vocab_multistream(input_streams, workers, trim_rule) + logger.info( "collected %i word types from a corpus of %i raw words and %i sentences", - len(vocab), total_words, sentence_no + 1 + len(self.raw_vocab), total_words, corpus_count ) - corpus_count = sentence_no + 1 - self.raw_vocab = vocab + return total_words, corpus_count def sort_vocab(self, wv): diff --git a/gensim/test/test_doc2vec.py b/gensim/test/test_doc2vec.py index 559e166d4f..c2aeb71a81 100644 --- a/gensim/test/test_doc2vec.py +++ b/gensim/test/test_doc2vec.py @@ -9,7 +9,7 @@ """ -from __future__ import with_statement +from __future__ import with_statement, division import logging import unittest @@ -298,6 +298,37 @@ def test_training(self): model2 = doc2vec.Doc2Vec(corpus, size=100, min_count=2, iter=20, workers=1) self.models_equal(model, model2) + def test_multistream_training(self): + """Test doc2vec multistream training.""" + input_streams = [list_corpus[:len(list_corpus) // 2], list_corpus[len(list_corpus) // 2:]] + + model = doc2vec.Doc2Vec(inpsize=100, min_count=2, iter=20, workers=1, seed=42) + model.build_vocab(input_streams=input_streams, workers=1) + self.assertEqual(model.docvecs.doctag_syn0.shape, (300, 100)) + model.train(input_streams=input_streams, total_examples=model.corpus_count, epochs=model.iter) + self.model_sanity(model) + + # build vocab and train in one step; must be the same as above + model2 = doc2vec.Doc2Vec(input_streams=input_streams, size=100, min_count=2, iter=20, workers=1, seed=42) + + # check resulted vectors; note that order of words may be different + for word in model.wv.index2word: + self.assertEqual(model.wv.most_similar(word, topn=5), model2.wv.most_similar(word, topn=5)) + + def test_multistream_build_vocab(self): + # Expected vocab + model = doc2vec.Doc2Vec(min_count=0) + model.build_vocab(list_corpus) + singlestream_vocab = model.vocabulary.raw_vocab + + # Multistream vocab + model2 = doc2vec.Doc2Vec(min_count=0) + input_streams = [list_corpus[:len(list_corpus) // 2], list_corpus[len(list_corpus) // 2:]] + model2.build_vocab(input_streams=input_streams, workers=2) + multistream_vocab = model2.vocabulary.raw_vocab + + self.assertEqual(singlestream_vocab, multistream_vocab) + def test_dbow_hs(self): """Test DBOW doc2vec training.""" model = doc2vec.Doc2Vec(list_corpus, dm=0, hs=1, negative=0, min_count=2, iter=20) @@ -413,7 +444,6 @@ def models_equal(self, model, model2): # check docvecs self.assertEqual(len(model.docvecs.doctags), len(model2.docvecs.doctags)) self.assertEqual(len(model.docvecs.offset2doctag), len(model2.docvecs.offset2doctag)) - self.assertTrue(np.allclose(model.docvecs.doctag_syn0, model2.docvecs.doctag_syn0)) def test_delete_temporary_training_data(self): """Test doc2vec model after delete_temporary_training_data""" diff --git a/gensim/test/test_fasttext.py b/gensim/test/test_fasttext.py index a2ffcfb0fa..545d75b1e9 100644 --- a/gensim/test/test_fasttext.py +++ b/gensim/test/test_fasttext.py @@ -1,5 +1,6 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- +from __future__ import division import logging import unittest @@ -80,6 +81,52 @@ def test_training(self): oov_vec = model['minor'] # oov word self.assertEqual(len(oov_vec), 10) + def test_multistream_training(self): + input_streams = [sentences[:len(sentences) // 2], sentences[len(sentences) // 2:]] + model = FT_gensim(size=5, min_count=1, hs=1, negative=0, seed=42, workers=1) + model.build_vocab(input_streams=input_streams, workers=2) + self.model_sanity(model) + + model.train(input_streams=input_streams, total_examples=model.corpus_count, epochs=model.iter) + sims = model.most_similar('graph', topn=10) + + self.assertEqual(model.wv.syn0.shape, (12, 5)) + self.assertEqual(len(model.wv.vocab), 12) + self.assertEqual(model.wv.syn0_vocab.shape[1], 5) + self.assertEqual(model.wv.syn0_ngrams.shape[1], 5) + self.model_sanity(model) + + # test querying for "most similar" by vector + graph_vector = model.wv.syn0norm[model.wv.vocab['graph'].index] + sims2 = model.most_similar(positive=[graph_vector], topn=11) + sims2 = [(w, sim) for w, sim in sims2 if w != 'graph'] # ignore 'graph' itself + self.assertEqual(sims, sims2) + + # build vocab and train in one step; must be the same as above + model2 = FT_gensim(input_streams=input_streams, size=5, min_count=1, hs=1, negative=0, seed=42, workers=1) + self.models_equal(model, model2) + + # verify oov-word vector retrieval + invocab_vec = model['minors'] # invocab word + self.assertEqual(len(invocab_vec), 5) + + oov_vec = model['minor'] # oov word + self.assertEqual(len(oov_vec), 5) + + def test_multistream_build_vocab(self): + # Expected vocab + model = FT_gensim(size=5, min_count=1, hs=1, negative=0, seed=42) + model.build_vocab(list_corpus) + singlestream_vocab = model.vocabulary.raw_vocab + + # Multistream vocab + model2 = FT_gensim(size=5, min_count=1, hs=1, negative=0, seed=42) + input_streams = [list_corpus[:len(list_corpus) // 2], list_corpus[len(list_corpus) // 2:]] + model2.build_vocab(input_streams=input_streams, workers=2) + multistream_vocab = model2.vocabulary.raw_vocab + + self.assertEqual(singlestream_vocab, multistream_vocab) + def models_equal(self, model, model2): self.assertEqual(len(model.wv.vocab), len(model2.wv.vocab)) self.assertEqual(model.num_ngram_vectors, model2.num_ngram_vectors) diff --git a/gensim/test/test_utils.py b/gensim/test/test_utils.py index 0df0d6efc2..5b265d0f77 100644 --- a/gensim/test/test_utils.py +++ b/gensim/test/test_utils.py @@ -120,6 +120,29 @@ def test_sample_dict(self): self.assertTrue(True) +class TestTrimVocabByFreq(unittest.TestCase): + def test_trim_vocab(self): + d = {"word1": 5, "word2": 1, "word3": 2} + expected_dict = {"word1": 5, "word3": 2} + utils.trim_vocab_by_freq(d, topk=2) + self.assertEqual(d, expected_dict) + + d = {"word1": 5, "word2": 2, "word3": 2, "word4": 1} + expected_dict = {"word1": 5, "word2": 2, "word3": 2} + utils.trim_vocab_by_freq(d, topk=2) + self.assertEqual(d, expected_dict) + + +class TestMergeDicts(unittest.TestCase): + def test_merge_dicts(self): + d1 = {"word1": 5, "word2": 1, "word3": 2} + d2 = {"word1": 2, "word3": 3, "word4": 10} + + res_dict = utils.merge_counts(d1, d2) + expected_dict = {"word1": 7, "word2": 1, "word3": 5, "word4": 10} + self.assertEqual(res_dict, expected_dict) + + class TestWindowing(unittest.TestCase): arr10_5 = np.array([ diff --git a/gensim/test/test_word2vec.py b/gensim/test/test_word2vec.py index 411d200676..c2ee97062d 100644 --- a/gensim/test/test_word2vec.py +++ b/gensim/test/test_word2vec.py @@ -7,7 +7,7 @@ """ Automated tests for checking transformation algorithms (the models package). """ - +from __future__ import division import logging import unittest @@ -166,6 +166,20 @@ def testMaxFinalVocab(self): self.assertEqual(reported_values['num_retained_words'], 4) self.assertEqual(model.vocabulary.effective_min_count, 3) + def testMultiStreamBuildVocab(self): + # Expected vocab + model = word2vec.Word2Vec(min_count=0) + model.build_vocab(sentences) + singlestream_vocab = model.vocabulary.raw_vocab + + # Multistream vocab + model = word2vec.Word2Vec(min_count=0) + input_streams = [sentences[:len(sentences) // 2], sentences[len(sentences) // 2:]] + model.build_vocab(input_streams=input_streams, workers=2) + multistream_vocab = model.vocabulary.raw_vocab + + self.assertEqual(singlestream_vocab, multistream_vocab) + def testOnlineLearning(self): """Test that the algorithm is able to add new words to the vocabulary and to a trained model when using a sorted vocabulary""" @@ -480,6 +494,30 @@ def testTraining(self): model2 = word2vec.Word2Vec(sentences, size=2, min_count=1, hs=1, negative=0) self.models_equal(model, model2) + def testMultistreamTraining(self): + """Test word2vec multistream training.""" + # build vocabulary, don't train yet + input_streams = [sentences[:len(sentences) // 2], sentences[len(sentences) // 2:]] + model = word2vec.Word2Vec(size=2, min_count=1, hs=1, negative=0, workers=1, seed=42) + model.build_vocab(input_streams=input_streams) + + self.assertTrue(model.wv.syn0.shape == (len(model.wv.vocab), 2)) + self.assertTrue(model.syn1.shape == (len(model.wv.vocab), 2)) + + model.train(input_streams=input_streams, total_examples=model.corpus_count, epochs=model.iter) + sims = model.most_similar('graph', topn=10) + + # test querying for "most similar" by vector + graph_vector = model.wv.syn0norm[model.wv.vocab['graph'].index] + sims2 = model.most_similar(positive=[graph_vector], topn=11) + sims2 = [(w, sim) for w, sim in sims2 if w != 'graph'] # ignore 'graph' itself + self.assertEqual(sims, sims2) + + # build vocab and train in one step; must be the same as above + model2 = word2vec.Word2Vec(input_streams=input_streams, size=2, min_count=1, hs=1, negative=0, + workers=1, seed=42) + self.models_equal(model, model2) + def testScoring(self): """Test word2vec scoring.""" model = word2vec.Word2Vec(sentences, size=2, min_count=1, hs=1, negative=0) diff --git a/gensim/utils.py b/gensim/utils.py index 8b8d7ea107..ec02cf4bb2 100644 --- a/gensim/utils.py +++ b/gensim/utils.py @@ -33,12 +33,13 @@ import sys import subprocess import inspect +import heapq import numpy as np import numbers import scipy.sparse -from six import iterkeys, iteritems, u, string_types, unichr +from six import iterkeys, iteritems, itervalues, u, string_types, unichr from six.moves import xrange from smart_open import smart_open @@ -1739,6 +1740,50 @@ def prune_vocab(vocab, min_reduce, trim_rule=None): return result +def trim_vocab_by_freq(vocab, topk, trim_rule=None): + """Retain `topk` most frequent words in `vocab`. + If there are more words with the same frequency as `topk`-th one, they will be kept. + Modifies `vocab` in place, returns nothing. + + Parameters + ---------- + vocab : dict + Input dictionary. + topk : int + Number of words with highest frequencies to keep. + trim_rule : function, optional + Function for trimming entities from vocab, default behaviour is `vocab[w] <= min_count`. + + """ + if topk >= len(vocab): + return + + min_count = heapq.nlargest(topk, itervalues(vocab))[-1] + prune_vocab(vocab, min_count, trim_rule=trim_rule) + + +def merge_counts(dict1, dict2): + """Merge `dict1` of (word, freq1) and `dict2` of (word, freq2) into `dict1` of (word, freq1+freq2). + Parameters + ---------- + dict1 : dict of (str, int) + First dictionary. + dict2 : dict of (str, int) + Second dictionary. + Returns + ------- + result : dict + Merged dictionary with sum of frequencies as values. + """ + for word, freq in iteritems(dict2): + if word in dict1: + dict1[word] += freq + else: + dict1[word] = freq + + return dict1 + + def qsize(queue): """Get the (approximate) queue size where available.