Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Static compile of word2vec #233

Merged
merged 5 commits into from
Sep 13, 2014
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion MANIFEST.in
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,5 @@ include COPYING
include COPYING.LESSER
include ez_setup.py
include gensim/models/voidptr.h
include gensim/models/word2vec_inner.c
include gensim/models/word2vec_inner.pyx
include gensim_addons/models/word2vec_inner.pyx
231 changes: 112 additions & 119 deletions gensim/models/word2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,9 @@
except ImportError:
from Queue import Queue

from numpy import exp, dot, zeros, outer, random, dtype, get_include, float32 as REAL,\
uint32, seterr, array, uint8, vstack, argsort, fromstring, sqrt, newaxis, ndarray, empty, sum as np_sum,\
prod
from numpy import exp, dot, zeros, outer, random, dtype, float32 as REAL,\
uint32, seterr, array, uint8, vstack, argsort, fromstring, sqrt, newaxis,\
ndarray, empty, sum as np_sum, prod

logger = logging.getLogger("gensim.models.word2vec")

Expand All @@ -84,124 +84,117 @@


try:
raise ImportError # ignore for now
from gensim_addons.models.word2vec_inner import train_sentence_sg, train_sentence_cbow, FAST_VERSION
from gensim.models.word2vec_inner import train_sentence_sg, train_sentence_cbow, FAST_VERSION
except ImportError:
try:
# try to compile and use the faster cython version
import pyximport
models_dir = os.path.dirname(__file__) or os.getcwd()
pyximport.install(setup_args={"include_dirs": [models_dir, get_include()]})
from gensim.models.word2vec_inner import train_sentence_sg, train_sentence_cbow, FAST_VERSION
except:
# failed... fall back to plain numpy (20-80x slower training than the above)
FAST_VERSION = -1

def train_sentence_sg(model, sentence, alpha, work=None):
"""
Update skip-gram model by training on a single sentence.

The sentence is a list of Vocab objects (or None, where the corresponding
word is not in the vocabulary. Called internally from `Word2Vec.train()`.

This is the non-optimized, Python version. If you have cython installed, gensim
will use the optimized version from word2vec_inner instead.

"""
if model.negative:
# precompute negative labels
labels = zeros(model.negative + 1)
labels[0] = 1.0

for pos, word in enumerate(sentence):
if word is None:
continue # OOV word in the input sentence => skip
reduced_window = random.randint(model.window) # `b` in the original word2vec code

# now go over all words from the (reduced) window, predicting each one in turn
start = max(0, pos - model.window + reduced_window)
for pos2, word2 in enumerate(sentence[start : pos + model.window + 1 - reduced_window], start):
# don't train on OOV words and on the `word` itself
if word2 and not (pos2 == pos):
l1 = model.syn0[word2.index]
neu1e = zeros(l1.shape)

if model.hs:
# work on the entire tree at once, to push as much work into numpy's C routines as possible (performance)
l2a = deepcopy(model.syn1[word.point]) # 2d matrix, codelen x layer1_size
fa = 1.0 / (1.0 + exp(-dot(l1, l2a.T))) # propagate hidden -> output
ga = (1 - word.code - fa) * alpha # vector of error gradients multiplied by the learning rate
model.syn1[word.point] += outer(ga, l1) # learn hidden -> output
neu1e += dot(ga, l2a) # save error

if model.negative:
# use this word (label = 1) + `negative` other random words not from this sentence (label = 0)
word_indices = [word.index]
while len(word_indices) < model.negative + 1:
w = model.table[random.randint(model.table.shape[0])]
if w != word.index:
word_indices.append(w)
l2b = model.syn1neg[word_indices] # 2d matrix, k+1 x layer1_size
fb = 1. / (1. + exp(-dot(l1, l2b.T))) # propagate hidden -> output
gb = (labels - fb) * alpha # vector of error gradients multiplied by the learning rate
model.syn1neg[word_indices] += outer(gb, l1) # learn hidden -> output
neu1e += dot(gb, l2b) # save error

model.syn0[word2.index] += neu1e # learn input -> hidden

return len([word for word in sentence if word is not None])

def train_sentence_cbow(model, sentence, alpha, work=None, neu1=None):
"""
Update CBOW model by training on a single sentence.

The sentence is a list of Vocab objects (or None, where the corresponding
word is not in the vocabulary. Called internally from `Word2Vec.train()`.

This is the non-optimized, Python version. If you have cython installed, gensim
will use the optimized version from word2vec_inner instead.

"""
# failed... fall back to plain numpy (20-80x slower training than the above)
FAST_VERSION = -1

def train_sentence_sg(model, sentence, alpha, work=None):
"""
Update skip-gram model by training on a single sentence.

The sentence is a list of Vocab objects (or None, where the corresponding
word is not in the vocabulary. Called internally from `Word2Vec.train()`.

This is the non-optimized, Python version. If you have cython installed, gensim
will use the optimized version from word2vec_inner instead.

"""
if model.negative:
# precompute negative labels
labels = zeros(model.negative + 1)
labels[0] = 1.0

for pos, word in enumerate(sentence):
if word is None:
continue # OOV word in the input sentence => skip
reduced_window = random.randint(model.window) # `b` in the original word2vec code

# now go over all words from the (reduced) window, predicting each one in turn
start = max(0, pos - model.window + reduced_window)
for pos2, word2 in enumerate(sentence[start : pos + model.window + 1 - reduced_window], start):
# don't train on OOV words and on the `word` itself
if word2 and not (pos2 == pos):
l1 = model.syn0[word2.index]
neu1e = zeros(l1.shape)

if model.hs:
# work on the entire tree at once, to push as much work into numpy's C routines as possible (performance)
l2a = deepcopy(model.syn1[word.point]) # 2d matrix, codelen x layer1_size
fa = 1.0 / (1.0 + exp(-dot(l1, l2a.T))) # propagate hidden -> output
ga = (1 - word.code - fa) * alpha # vector of error gradients multiplied by the learning rate
model.syn1[word.point] += outer(ga, l1) # learn hidden -> output
neu1e += dot(ga, l2a) # save error

if model.negative:
# use this word (label = 1) + `negative` other random words not from this sentence (label = 0)
word_indices = [word.index]
while len(word_indices) < model.negative + 1:
w = model.table[random.randint(model.table.shape[0])]
if w != word.index:
word_indices.append(w)
l2b = model.syn1neg[word_indices] # 2d matrix, k+1 x layer1_size
fb = 1. / (1. + exp(-dot(l1, l2b.T))) # propagate hidden -> output
gb = (labels - fb) * alpha # vector of error gradients multiplied by the learning rate
model.syn1neg[word_indices] += outer(gb, l1) # learn hidden -> output
neu1e += dot(gb, l2b) # save error

model.syn0[word2.index] += neu1e # learn input -> hidden

return len([word for word in sentence if word is not None])

def train_sentence_cbow(model, sentence, alpha, work=None, neu1=None):
"""
Update CBOW model by training on a single sentence.

The sentence is a list of Vocab objects (or None, where the corresponding
word is not in the vocabulary. Called internally from `Word2Vec.train()`.

This is the non-optimized, Python version. If you have cython installed, gensim
will use the optimized version from word2vec_inner instead.

"""
if model.negative:
# precompute negative labels
labels = zeros(model.negative + 1)
labels[0] = 1.

for pos, word in enumerate(sentence):
if word is None:
continue # OOV word in the input sentence => skip
reduced_window = random.randint(model.window) # `b` in the original word2vec code
start = max(0, pos - model.window + reduced_window)
window_pos = enumerate(sentence[start : pos + model.window + 1 - reduced_window], start)
word2_indices = [word2.index for pos2, word2 in window_pos if (word2 is not None and pos2 != pos)]
l1 = np_sum(model.syn0[word2_indices], axis=0) # 1 x layer1_size
if word2_indices and model.cbow_mean:
l1 /= len(word2_indices)
neu1e = zeros(l1.shape)

if model.hs:
l2a = model.syn1[word.point] # 2d matrix, codelen x layer1_size
fa = 1. / (1. + exp(-dot(l1, l2a.T))) # propagate hidden -> output
ga = (1. - word.code - fa) * alpha # vector of error gradients multiplied by the learning rate
model.syn1[word.point] += outer(ga, l1) # learn hidden -> output
neu1e += dot(ga, l2a) # save error

if model.negative:
# precompute negative labels
labels = zeros(model.negative + 1)
labels[0] = 1.

for pos, word in enumerate(sentence):
if word is None:
continue # OOV word in the input sentence => skip
reduced_window = random.randint(model.window) # `b` in the original word2vec code
start = max(0, pos - model.window + reduced_window)
window_pos = enumerate(sentence[start : pos + model.window + 1 - reduced_window], start)
word2_indices = [word2.index for pos2, word2 in window_pos if (word2 is not None and pos2 != pos)]
l1 = np_sum(model.syn0[word2_indices], axis=0) # 1 x layer1_size
if word2_indices and model.cbow_mean:
l1 /= len(word2_indices)
neu1e = zeros(l1.shape)

if model.hs:
l2a = model.syn1[word.point] # 2d matrix, codelen x layer1_size
fa = 1. / (1. + exp(-dot(l1, l2a.T))) # propagate hidden -> output
ga = (1. - word.code - fa) * alpha # vector of error gradients multiplied by the learning rate
model.syn1[word.point] += outer(ga, l1) # learn hidden -> output
neu1e += dot(ga, l2a) # save error

if model.negative:
# use this word (label = 1) + `negative` other random words not from this sentence (label = 0)
word_indices = [word.index]
while len(word_indices) < model.negative + 1:
w = model.table[random.randint(model.table.shape[0])]
if w != word.index:
word_indices.append(w)
l2b = model.syn1neg[word_indices] # 2d matrix, k+1 x layer1_size
fb = 1. / (1. + exp(-dot(l1, l2b.T))) # propagate hidden -> output
gb = (labels - fb) * alpha # vector of error gradients multiplied by the learning rate
model.syn1neg[word_indices] += outer(gb, l1) # learn hidden -> output
neu1e += dot(gb, l2b) # save error

model.syn0[word2_indices] += neu1e # learn input -> hidden, here for all words in the window separately

return len([word for word in sentence if word is not None])
# use this word (label = 1) + `negative` other random words not from this sentence (label = 0)
word_indices = [word.index]
while len(word_indices) < model.negative + 1:
w = model.table[random.randint(model.table.shape[0])]
if w != word.index:
word_indices.append(w)
l2b = model.syn1neg[word_indices] # 2d matrix, k+1 x layer1_size
fb = 1. / (1. + exp(-dot(l1, l2b.T))) # propagate hidden -> output
gb = (labels - fb) * alpha # vector of error gradients multiplied by the learning rate
model.syn1neg[word_indices] += outer(gb, l1) # learn hidden -> output
neu1e += dot(gb, l2b) # save error

model.syn0[word2_indices] += neu1e # learn input -> hidden, here for all words in the window separately

return len([word for word in sentence if word is not None])
# end of plain numpy implementation


class Vocab(object):
Expand Down
Loading