Skip to content

Commit

Permalink
Merge pull request #466 from gojomo/unicode_err_tolerance
Browse files Browse the repository at this point in the history
{load|intersect}_word2vec_format: allow non-strict unicode error handling
  • Loading branch information
piskvorky committed Sep 24, 2015
2 parents 485d024 + a8a8f21 commit 022bb30
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions gensim/models/word2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -989,7 +989,7 @@ def save_word2vec_format(self, fname, fvocab=None, binary=False):
fout.write(utils.to_utf8("%s %s\n" % (word, ' '.join("%f" % val for val in row))))

@classmethod
def load_word2vec_format(cls, fname, fvocab=None, binary=False, norm_only=True, encoding='utf8'):
def load_word2vec_format(cls, fname, fvocab=None, binary=False, norm_only=True, encoding='utf8', unicode_errors='strict'):
"""
Load the input-hidden weight matrix from the original C word2vec-tool format.
Expand Down Expand Up @@ -1051,12 +1051,12 @@ def add_word(word, weights):
break
if ch != b'\n': # ignore newlines in front of words (some binary files have)
word.append(ch)
word = utils.to_unicode(b''.join(word), encoding=encoding)
word = utils.to_unicode(b''.join(word), encoding=encoding, errors=unicode_errors)
weights = fromstring(fin.read(binary_len), dtype=REAL)
add_word(word, weights)
else:
for line_no, line in enumerate(fin):
parts = utils.to_unicode(line.rstrip(), encoding=encoding).split(" ")
parts = utils.to_unicode(line.rstrip(), encoding=encoding, errors=unicode_errors).split(" ")
if len(parts) != vector_size + 1:
raise ValueError("invalid vector on line %s (is this really the text format?)" % (line_no))
word, weights = parts[0], list(map(REAL, parts[1:]))
Expand All @@ -1073,7 +1073,7 @@ def add_word(word, weights):
result.init_sims(norm_only)
return result

def intersect_word2vec_format(self, fname, binary=False, encoding='utf8'):
def intersect_word2vec_format(self, fname, binary=False, encoding='utf8', unicode_errors='strict'):
"""
Merge the input-hidden weight matrix from the original C word2vec-tool format
given, where it intersects with the current vocabulary. (No words are added to the
Expand Down Expand Up @@ -1101,15 +1101,15 @@ def intersect_word2vec_format(self, fname, binary=False, encoding='utf8'):
break
if ch != b'\n': # ignore newlines in front of words (some binary files have)
word.append(ch)
word = utils.to_unicode(b''.join(word), encoding=encoding)
word = utils.to_unicode(b''.join(word), encoding=encoding, errors=unicode_errors)
weights = fromstring(fin.read(binary_len), dtype=REAL)
if word in self.vocab:
overlap_count += 1
self.syn0[self.vocab[word].index] = weights
self.syn0_lockf[self.vocab[word].index] = 0.0 # lock it
else:
for line_no, line in enumerate(fin):
parts = utils.to_unicode(line.rstrip(), encoding=encoding).split(" ")
parts = utils.to_unicode(line.rstrip(), encoding=encoding, errors=unicode_errors).split(" ")
if len(parts) != vector_size + 1:
raise ValueError("invalid vector on line %s (is this really the text format?)" % (line_no))
word, weights = parts[0], list(map(REAL, parts[1:]))
Expand Down

0 comments on commit 022bb30

Please sign in to comment.