-
-
Notifications
You must be signed in to change notification settings - Fork 4.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
support both old and new fastText model #1319
Changes from 8 commits
9f9dd24
d7725ca
de39ab0
f0c3e25
5f5ace6
1509512
d7e5403
58a66c2
9c9d3ec
8ffb220
06ac316
3deb394
b038fdb
4f6aa4d
5c09bdf
5cdf4e6
55a2d37
aeb05c1
092ef86
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -140,6 +140,7 @@ class FastText(Word2Vec): | |
|
||
def initialize_word_vectors(self): | ||
self.wv = FastTextKeyedVectors() | ||
self.new_format = False | ||
|
||
@classmethod | ||
def train(cls, ft_path, corpus_file, output_file=None, model='cbow', size=100, alpha=0.025, window=5, min_count=5, | ||
|
@@ -256,7 +257,14 @@ def load_binary_data(self, model_binary_file): | |
self.load_vectors(f) | ||
|
||
def load_model_params(self, file_handle): | ||
(dim, ws, epoch, minCount, neg, _, loss, model, bucket, minn, maxn, _, t) = self.struct_unpack(file_handle, '@12i1d') | ||
magic, v= self.struct_unpack(file_handle, '@2i') | ||
if magic == 793712314: # newer format | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it'd be good to store this value in a global variable. |
||
self.new_format = True | ||
dim, ws, epoch, minCount, neg, _, loss, model, bucket, minn, maxn, _, t = self.struct_unpack(file_handle, '@12i1d') | ||
else: # older format | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Better to set |
||
dim = magic | ||
ws = v | ||
epoch, minCount, neg, _, loss, model, bucket, minn, maxn, _, t = self.struct_unpack(file_handle, '@10i1d') | ||
# Parameters stored by [Args::save](https://github.com/facebookresearch/fastText/blob/master/src/args.cc) | ||
self.size = dim | ||
self.window = ws | ||
|
@@ -270,26 +278,34 @@ def load_model_params(self, file_handle): | |
self.wv.max_n = maxn | ||
self.sample = t | ||
|
||
def load_dict(self, file_handle): | ||
(vocab_size, nwords, _) = self.struct_unpack(file_handle, '@3i') | ||
def load_dict(self, file_handle, encoding='utf8'): | ||
vocab_size, nwords, _ = self.struct_unpack(file_handle, '@3i') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd prefer keeping the changes related to the issue with the french wiki in a separate PR. We don't want those changes to block this PR from being merged. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. separated the PRs. Thanks :) |
||
# Vocab stored by [Dictionary::save](https://github.com/facebookresearch/fastText/blob/master/src/dictionary.cc) | ||
assert len(self.wv.vocab) == nwords, 'mismatch between vocab sizes' | ||
assert len(self.wv.vocab) == vocab_size, 'mismatch between vocab sizes' | ||
ntokens, = self.struct_unpack(file_handle, '@q') | ||
if len(self.wv.vocab) != vocab_size: | ||
logger.warnings("If you are loading any model other than pretrained vector wiki.fr, ") | ||
logger.warnings("Please report to gensim or fastText.") | ||
ntokens= self.struct_unpack(file_handle, '@1q') | ||
if self.new_format: | ||
pruneidx_size = self.struct_unpack(file_handle, '@q') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
void Dictionary::load(std::istream& in) { words_.clear(); std::fill(word2int_.begin(), word2int_.end(), -1); in.read((char*) &size_, sizeof(int32_t)); in.read((char*) &nwords_, sizeof(int32_t)); in.read((char*) &nlabels_, sizeof(int32_t)); in.read((char*) &ntokens_, sizeof(int64_t)); in.read((char*) &pruneidx_size_, sizeof(int64_t)); for (int32_t i = 0; i < size_; i++) { char c; entry e; while ((c = in.get()) != 0) { e.word.push_back(c); } in.read((char*) &e.count, sizeof(int64_t)); in.read((char*) &e.type, sizeof(entry_type)); words_.push_back(e); word2int_[find(e.word)] = i; } pruneidx_.clear(); for (int32_t i = 0; i < pruneidx_size_; i++) { int32_t first; int32_t second; in.read((char*) &first, sizeof(int32_t)); in.read((char*) &second, sizeof(int32_t)); pruneidx_[first] = second; } initTableDiscard(); initNgrams(); } I'm not sure if it's present in the models we're loading or not though? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. for j in range(pruneidx_size):
_,_ = self.struct_unpack(file_handle,'@2i') Presence (or absence) of these two lines doesn't affect the code. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't see those two lines in the code. Do you mean There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's -1 actually. Added these two lines in the new commit. I think we should not use assert statement here as some models might have non-negative values, so adding these two lines should be sufficient. |
||
for i in range(nwords): | ||
word_bytes = b'' | ||
char_byte = file_handle.read(1) | ||
# Read vocab word | ||
while char_byte != b'\x00': | ||
word_bytes += char_byte | ||
char_byte = file_handle.read(1) | ||
word = word_bytes.decode('utf8') | ||
count, _ = self.struct_unpack(file_handle, '@ib') | ||
_ = self.struct_unpack(file_handle, '@i') | ||
assert self.wv.vocab[word].index == i, 'mismatch between gensim word index and fastText word index' | ||
self.wv.vocab[word].count = count | ||
word = word_bytes.decode(encoding) | ||
count, _ = self.struct_unpack(file_handle, '@qb') | ||
if word in self.wv.vocab: | ||
# skip loading info about words in bin file which are not present in vec file | ||
# handling mismatch in vocab_size in vec and bin files (ref: wiki.fr) | ||
assert self.wv.vocab[word].index == i, 'mismatch between gensim word index and fastText word index' | ||
self.wv.vocab[word].count = count | ||
|
||
def load_vectors(self, file_handle): | ||
if self.new_format: | ||
_ = self.struct_unpack(file_handle,'@?') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you please add a comment clarifying what this is for? |
||
num_vectors, dim = self.struct_unpack(file_handle, '@2q') | ||
# Vectors stored by [Matrix::save](https://github.com/facebookresearch/fastText/blob/master/src/matrix.cc) | ||
assert self.size == dim, 'mismatch between model sizes' | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code style - keep 1 space before and after
=
Also,
version
would be preferable tov
.