Skip to content

Commit

Permalink
Removed redundant code and modified embedding_modules file (#655)
Browse files Browse the repository at this point in the history
  • Loading branch information
soovam123 committed Mar 17, 2020
1 parent 567bdce commit 5ffe45c
Showing 1 changed file with 73 additions and 72 deletions.
145 changes: 73 additions & 72 deletions ludwig/models/modules/embedding_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,46 @@ def embedding_matrix(
return embeddings, embedding_size


def embedding_matrix_on_device(
vocab,
embedding_size,
regularizer,
representation='dense',
embeddings_trainable=True,
pretrained_embeddings=None,
force_embedding_size=False,
embeddings_on_cpu=False,
initializer=None
):
if embeddings_on_cpu:
with tf.device('/cpu:0'):
embeddings, embedding_size = embedding_matrix(
vocab,
embedding_size,
representation=representation,
embeddings_trainable=embeddings_trainable,
pretrained_embeddings=pretrained_embeddings,
force_embedding_size=force_embedding_size,
initializer=initializer,
regularizer=regularizer
)
else:
embeddings, embedding_size = embedding_matrix(
vocab,
embedding_size,
representation=representation,
embeddings_trainable=embeddings_trainable,
pretrained_embeddings=pretrained_embeddings,
force_embedding_size=force_embedding_size,
initializer=initializer,
regularizer=regularizer
)

logger.debug(' embeddings: {0}'.format(embeddings))

return embeddings, embedding_size


class Embed:
def __init__(
self,
Expand Down Expand Up @@ -118,30 +158,17 @@ def __call__(
if not self.regularize:
regularizer = None

if self.embeddings_on_cpu:
with tf.device('/cpu:0'):
embeddings, embedding_size = embedding_matrix(
self.vocab,
self.embedding_size,
representation=self.representation,
embeddings_trainable=self.embeddings_trainable,
pretrained_embeddings=self.pretrained_embeddings,
force_embedding_size=self.force_embedding_size,
initializer=self.initializer,
regularizer=regularizer
)
else:
embeddings, embedding_size = embedding_matrix(
self.vocab,
self.embedding_size,
representation=self.representation,
embeddings_trainable=self.embeddings_trainable,
pretrained_embeddings=self.pretrained_embeddings,
force_embedding_size=self.force_embedding_size,
initializer=self.initializer,
regularizer=regularizer
)
logger.debug(' embeddings: {0}'.format(embeddings))
embeddings, embedding_size = embedding_matrix_on_device(
self.vocab,
self.embedding_size,
regularizer,
self.representation,
self.embeddings_trainable,
self.pretrained_embeddings,
self.force_embedding_size,
self.embeddings_on_cpu,
self.initializer
)

embedded = tf.nn.embedding_lookup(embeddings, input_ids,
name='embeddings_lookup')
Expand Down Expand Up @@ -190,30 +217,17 @@ def __call__(
if not self.regularize:
regularizer = None

if self.embeddings_on_cpu:
with tf.device('/cpu:0'):
embeddings, embedding_size = embedding_matrix(
self.vocab,
self.embedding_size,
representation=self.representation,
embeddings_trainable=self.embeddings_trainable,
pretrained_embeddings=self.pretrained_embeddings,
force_embedding_size=self.force_embedding_size,
initializer=self.initializer,
regularizer=regularizer
)
else:
embeddings, embedding_size = embedding_matrix(
self.vocab,
self.embedding_size,
representation=self.representation,
embeddings_trainable=self.embeddings_trainable,
pretrained_embeddings=self.pretrained_embeddings,
force_embedding_size=self.force_embedding_size,
initializer=self.initializer,
regularizer=regularizer
)
logger.debug(' embeddings: {0}'.format(embeddings))
embeddings, embedding_size = embedding_matrix_on_device(
self.vocab,
self.embedding_size,
regularizer,
self.representation,
self.embeddings_trainable,
self.pretrained_embeddings,
self.force_embedding_size,
self.embeddings_on_cpu,
self.initializer
)

signed_input = tf.cast(tf.sign(tf.abs(input_ids)), tf.int32)
multiple_hot_indexes = tf.multiply(
Expand Down Expand Up @@ -280,30 +294,17 @@ def __call__(
if not self.regularize:
regularizer = None

if self.embeddings_on_cpu:
with tf.device('/cpu:0'):
embeddings, embedding_size = embedding_matrix(
self.vocab,
self.embedding_size,
representation=self.representation,
embeddings_trainable=self.embeddings_trainable,
pretrained_embeddings=self.pretrained_embeddings,
force_embedding_size=self.force_embedding_size,
initializer=self.initializer,
regularizer=regularizer
)
else:
embeddings, embedding_size = embedding_matrix(
self.vocab,
self.embedding_size,
representation=self.representation,
embeddings_trainable=self.embeddings_trainable,
pretrained_embeddings=self.pretrained_embeddings,
force_embedding_size=self.force_embedding_size,
initializer=self.initializer,
regularizer=regularizer
)
logger.debug(' embeddings: {0}'.format(embeddings))
embeddings, embedding_size = embedding_matrix_on_device(
self.vocab,
self.embedding_size,
regularizer,
self.representation,
self.embeddings_trainable,
self.pretrained_embeddings,
self.force_embedding_size,
self.embeddings_on_cpu,
self.initializer
)

multiple_hot_indexes = tf.multiply(
input_sparse,
Expand Down

0 comments on commit 5ffe45c

Please sign in to comment.