Skip to content

Commit

Permalink
Merge pull request #758 from ufal/w2cv
Browse files Browse the repository at this point in the history
Word2vec module
  • Loading branch information
jindrahelcl committed Sep 19, 2018
2 parents 98d9db9 + 62b45e9 commit 5e22640
Show file tree
Hide file tree
Showing 8 changed files with 138 additions and 16 deletions.
2 changes: 1 addition & 1 deletion .pylintrc
Expand Up @@ -330,7 +330,7 @@ max-args=8

# Argument names that match this expression will be ignored. Default to name
# with leading underscore
ignored-argument-names=_.*
ignored-argument-names=_.*|args|kwargs

# Maximum number of locals for function / method body
max-locals=15
Expand Down
1 change: 1 addition & 0 deletions neuralmonkey/tf_manager.py
Expand Up @@ -88,6 +88,7 @@ def __init__(self,
init_op = tf.global_variables_initializer()
for sess in self.sessions:
sess.run(init_op)

self.saver = tf.train.Saver(max_to_keep=None,
var_list=[g for g in tf.global_variables()
if "reward_" not in g.name])
Expand Down
1 change: 1 addition & 0 deletions neuralmonkey/util/__init__.py
@@ -0,0 +1 @@
from .word2vec import Word2Vec, get_word2vec_initializer, word2vec_vocabulary
95 changes: 95 additions & 0 deletions neuralmonkey/util/word2vec.py
@@ -0,0 +1,95 @@
"""Word2vec plug-in module.
This module provides functionality needed to work with word2vec files.
"""

from typing import Callable, List

import numpy as np
from typeguard import check_argument_types

from neuralmonkey.vocabulary import (
Vocabulary, is_special_token, SPECIAL_TOKENS)


class Word2Vec:

def __init__(self, path: str, encoding: str = "utf-8") -> None:
"""Load the word2vec file."""
check_argument_types()

# Create the vocabulary object, load the words and vectors from the
# file

self.vocab = Vocabulary()
embedding_vectors = [] # type: List[np.ndarray]

with open(path, encoding=encoding) as f_data:

header = next(f_data)
emb_size = int(header.split()[1])

# Add zero embeddings for padding, start, and end token
embedding_vectors.append(np.zeros(emb_size))
embedding_vectors.append(np.zeros(emb_size))
embedding_vectors.append(np.zeros(emb_size))
# Add placeholder for embedding of the unknown symbol
embedding_vectors.append(None)

for line in f_data:
fields = line.split()
word = fields[0]
vector = np.fromiter((float(x) for x in fields[1:]),
dtype=np.float)

assert vector.shape[0] == emb_size

# Embedding of unknown token should be at index 3 to match the
# vocabulary implementation
if is_special_token(word):
embedding_vectors[SPECIAL_TOKENS.index(word)] = vector
else:
self.vocab.add_word(word)
embedding_vectors.append(vector)

assert embedding_vectors[3] is not None
assert emb_size is not None

self.embedding_matrix = np.stack(embedding_vectors)

@property
def vocabulary(self) -> Vocabulary:
"""Get a vocabulary object generated from this word2vec instance."""
return self.vocab

@property
def embeddings(self) -> np.ndarray:
"""Get the embedding matrix."""
return self.embedding_matrix


def get_word2vec_initializer(w2v: Word2Vec) -> Callable:
"""Create a word2vec initializer.
A higher-order function that can be called from configuration.
"""
check_argument_types()

def init(shape: List[int], **kwargs) -> np.ndarray:
if shape != list(w2v.embeddings.shape):
raise ValueError(
"Shapes of model and word2vec embeddings do not match. "
"Word2Vec shape: {}, Should have been: {}"
.format(w2v.embeddings.shape, shape))
return w2v.embeddings

return init


def word2vec_vocabulary(w2v: Word2Vec) -> Vocabulary:
"""Return the vocabulary from a word2vec object.
This is a helper method used from configuration.
"""
check_argument_types()
return w2v.vocabulary
14 changes: 7 additions & 7 deletions neuralmonkey/vocabulary.py
Expand Up @@ -25,15 +25,15 @@
END_TOKEN = "</s>"
UNK_TOKEN = "<unk>"

_SPECIAL_TOKENS = [PAD_TOKEN, START_TOKEN, END_TOKEN, UNK_TOKEN]
SPECIAL_TOKENS = [PAD_TOKEN, START_TOKEN, END_TOKEN, UNK_TOKEN]

PAD_TOKEN_INDEX = 0
START_TOKEN_INDEX = 1
END_TOKEN_INDEX = 2
UNK_TOKEN_INDEX = 3


def _is_special_token(word: str) -> bool:
def is_special_token(word: str) -> bool:
"""Check whether word is a special token (such as <pad> or <s>).
Arguments:
Expand All @@ -42,7 +42,7 @@ def _is_special_token(word: str) -> bool:
Returns:
True if the word is special, False otherwise.
"""
return word in _SPECIAL_TOKENS
return word in SPECIAL_TOKENS


# pylint: disable=unused-argument
Expand Down Expand Up @@ -295,7 +295,7 @@ def __init__(self, tokenized_text: List[str] = None,
self.word_to_index = {} # type: Dict[str, int]
self.index_to_word = [] # type: List[str]
self.word_count = {} # type: Dict[str, int]
self.alphabet = {tok for tok in _SPECIAL_TOKENS}
self.alphabet = {tok for tok in SPECIAL_TOKENS}

# flag if the word count are in use
self.correct_counts = False
Expand Down Expand Up @@ -340,7 +340,7 @@ def add_word(self, word: str, occurences: int = 1) -> None:
self.word_to_index[word] = len(self.index_to_word)
self.index_to_word.append(word)
self.word_count[word] = 0
if word not in _SPECIAL_TOKENS:
if not is_special_token(word):
self.add_characters(word)
self.word_count[word] += occurences

Expand Down Expand Up @@ -424,7 +424,7 @@ def truncate(self, size: int) -> None:
for word in words_by_freq:
if len(words_to_delete) == to_delete:
break
if not _is_special_token(word):
if not is_special_token(word):
words_to_delete.append(word)

# sort by index ... bigger indices needs to be removed first
Expand Down Expand Up @@ -452,7 +452,7 @@ def truncate_by_min_freq(self, min_freq: int) -> None:
# ignoring special tokens
infreq_word_count = sum([1 for w in self.word_count
if self.word_count[w] < min_freq
and not _is_special_token(w)])
and not is_special_token(w)])
log("Removing {} infrequent (<{}) words from vocabulary".format(
infreq_word_count, min_freq))
new_size = len(self) - infreq_word_count
Expand Down
2 changes: 1 addition & 1 deletion tests/bpe.ini
Expand Up @@ -54,7 +54,7 @@ class=vocabulary.from_dataset
datasets=[<train_data>]
series_ids=["source_bpe", "target_bpe"]
max_size=209
save_file="tests/output/bpe_vocabulary.pickle"
save_file="tests/outputs/bpe_vocabulary.pickle"
overwrite=True

[encoder_input]
Expand Down
20 changes: 20 additions & 0 deletions tests/data/sample.w2v
@@ -0,0 +1,20 @@
19 5
</s> 0.001334 0.001473 -0.001277 -0.001093 0.000456
the 0.076659 -0.080726 -0.019973 0.107462 -0.001254
, 0.244330 -0.123618 -0.140095 0.069506 -0.157485
. 0.155178 -0.112344 -0.096781 0.072533 -0.158435
<unk> 0.250020 -0.214136 -0.157372 0.202529 -0.043517
to 0.239173 0.068217 -0.120347 0.012056 -0.177319
of 0.213521 -0.094603 -0.038220 0.144919 0.074206
and 0.275814 -0.116529 -0.074056 0.017852 -0.082605
a 0.185463 -0.024054 -0.041432 0.047992 -0.136972
in 0.113042 -0.098896 -0.031943 0.067912 -0.062604
&quot; 0.021184 -0.133253 -0.212236 0.150041 -0.139301
@-@ 0.195248 -0.109222 -0.044442 0.014353 -0.233359
&apos;s 0.165106 -0.007308 0.050040 0.076553 -0.062446
that 0.110352 -0.056008 -0.089858 0.127519 -0.024820
for 0.248112 -0.001108 -0.045288 0.056190 0.031153
is 0.124510 -0.045628 0.037280 0.154971 -0.127219
on 0.090296 -0.071128 -0.041744 0.118817 -0.137508
it 0.129336 0.001850 -0.059334 0.132171 -0.151346
was 0.148002 0.052893 0.127481 -0.001830 -0.091121
19 changes: 12 additions & 7 deletions tests/language-model.ini
Expand Up @@ -28,27 +28,32 @@ name="perplexity"

[train_data]
class=dataset.load_dataset_from_files
s_target="tests/data/train.tc.de"
s_target="tests/data/train.tc.en"
lazy=True

[val_data]
class=dataset.load_dataset_from_files
s_target="tests/data/val.tc.de"
s_target="tests/data/val.tc.en"

[decoder_vocabulary]
class=vocabulary.from_wordlist
path="tests/outputs/vocab/decoder_vocab.tsv"
[word2vec]
class=util.Word2Vec
path="tests/data/sample.w2v"

[w2v_init]
class=util.get_word2vec_initializer
w2v=<word2vec>

[decoder]
class=decoders.decoder.Decoder
name="decoder"
encoders=[]
rnn_size=8
embedding_size=9
embedding_size=5
dropout_keep_prob=0.5
data_id="target"
max_output_len=10
vocabulary=<decoder_vocabulary>
vocabulary=<word2vec.vocabulary>
initializers=[("word_embeddings", <w2v_init>)]

[trainer]
class=trainers.cross_entropy_trainer.CrossEntropyTrainer
Expand Down

0 comments on commit 5e22640

Please sign in to comment.