Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #758 from ufal/w2cv
Word2vec module
- Loading branch information
Showing
8 changed files
with
138 additions
and
16 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .word2vec import Word2Vec, get_word2vec_initializer, word2vec_vocabulary |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
"""Word2vec plug-in module. | ||
This module provides functionality needed to work with word2vec files. | ||
""" | ||
|
||
from typing import Callable, List | ||
|
||
import numpy as np | ||
from typeguard import check_argument_types | ||
|
||
from neuralmonkey.vocabulary import ( | ||
Vocabulary, is_special_token, SPECIAL_TOKENS) | ||
|
||
|
||
class Word2Vec: | ||
|
||
def __init__(self, path: str, encoding: str = "utf-8") -> None: | ||
"""Load the word2vec file.""" | ||
check_argument_types() | ||
|
||
# Create the vocabulary object, load the words and vectors from the | ||
# file | ||
|
||
self.vocab = Vocabulary() | ||
embedding_vectors = [] # type: List[np.ndarray] | ||
|
||
with open(path, encoding=encoding) as f_data: | ||
|
||
header = next(f_data) | ||
emb_size = int(header.split()[1]) | ||
|
||
# Add zero embeddings for padding, start, and end token | ||
embedding_vectors.append(np.zeros(emb_size)) | ||
embedding_vectors.append(np.zeros(emb_size)) | ||
embedding_vectors.append(np.zeros(emb_size)) | ||
# Add placeholder for embedding of the unknown symbol | ||
embedding_vectors.append(None) | ||
|
||
for line in f_data: | ||
fields = line.split() | ||
word = fields[0] | ||
vector = np.fromiter((float(x) for x in fields[1:]), | ||
dtype=np.float) | ||
|
||
assert vector.shape[0] == emb_size | ||
|
||
# Embedding of unknown token should be at index 3 to match the | ||
# vocabulary implementation | ||
if is_special_token(word): | ||
embedding_vectors[SPECIAL_TOKENS.index(word)] = vector | ||
else: | ||
self.vocab.add_word(word) | ||
embedding_vectors.append(vector) | ||
|
||
assert embedding_vectors[3] is not None | ||
assert emb_size is not None | ||
|
||
self.embedding_matrix = np.stack(embedding_vectors) | ||
|
||
@property | ||
def vocabulary(self) -> Vocabulary: | ||
"""Get a vocabulary object generated from this word2vec instance.""" | ||
return self.vocab | ||
|
||
@property | ||
def embeddings(self) -> np.ndarray: | ||
"""Get the embedding matrix.""" | ||
return self.embedding_matrix | ||
|
||
|
||
def get_word2vec_initializer(w2v: Word2Vec) -> Callable: | ||
"""Create a word2vec initializer. | ||
A higher-order function that can be called from configuration. | ||
""" | ||
check_argument_types() | ||
|
||
def init(shape: List[int], **kwargs) -> np.ndarray: | ||
if shape != list(w2v.embeddings.shape): | ||
raise ValueError( | ||
"Shapes of model and word2vec embeddings do not match. " | ||
"Word2Vec shape: {}, Should have been: {}" | ||
.format(w2v.embeddings.shape, shape)) | ||
return w2v.embeddings | ||
|
||
return init | ||
|
||
|
||
def word2vec_vocabulary(w2v: Word2Vec) -> Vocabulary: | ||
"""Return the vocabulary from a word2vec object. | ||
This is a helper method used from configuration. | ||
""" | ||
check_argument_types() | ||
return w2v.vocabulary |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
19 5 | ||
</s> 0.001334 0.001473 -0.001277 -0.001093 0.000456 | ||
the 0.076659 -0.080726 -0.019973 0.107462 -0.001254 | ||
, 0.244330 -0.123618 -0.140095 0.069506 -0.157485 | ||
. 0.155178 -0.112344 -0.096781 0.072533 -0.158435 | ||
<unk> 0.250020 -0.214136 -0.157372 0.202529 -0.043517 | ||
to 0.239173 0.068217 -0.120347 0.012056 -0.177319 | ||
of 0.213521 -0.094603 -0.038220 0.144919 0.074206 | ||
and 0.275814 -0.116529 -0.074056 0.017852 -0.082605 | ||
a 0.185463 -0.024054 -0.041432 0.047992 -0.136972 | ||
in 0.113042 -0.098896 -0.031943 0.067912 -0.062604 | ||
" 0.021184 -0.133253 -0.212236 0.150041 -0.139301 | ||
@-@ 0.195248 -0.109222 -0.044442 0.014353 -0.233359 | ||
's 0.165106 -0.007308 0.050040 0.076553 -0.062446 | ||
that 0.110352 -0.056008 -0.089858 0.127519 -0.024820 | ||
for 0.248112 -0.001108 -0.045288 0.056190 0.031153 | ||
is 0.124510 -0.045628 0.037280 0.154971 -0.127219 | ||
on 0.090296 -0.071128 -0.041744 0.118817 -0.137508 | ||
it 0.129336 0.001850 -0.059334 0.132171 -0.151346 | ||
was 0.148002 0.052893 0.127481 -0.001830 -0.091121 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters