Skip to content

Commit

Permalink
refactor: working h3 weighted sum encoder
Browse files Browse the repository at this point in the history
  • Loading branch information
jimthompson5802 committed Jun 12, 2020
1 parent faa4d66 commit bea03e5
Showing 1 changed file with 28 additions and 34 deletions.
62 changes: 28 additions & 34 deletions ludwig/models/modules/h3_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from ludwig.models.modules.fully_connected_modules import FCStack
from ludwig.models.modules.recurrent_modules import RecurrentStack
from ludwig.models.modules.reduction_modules import reduce_sum, reduce_sequence
from ludwig.models.modules.initializer_modules import get_initializer

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -246,7 +247,7 @@ def call(
return {'encoder_output': hidden}


class H3WeightedSum:
class H3WeightedSum(Layer):

def __init__(
self,
Expand Down Expand Up @@ -308,6 +309,8 @@ def __init__(
is greater than 0).
:type regularize: Boolean
"""
super(H3WeightedSum, self).__init__()

self.should_softmax = should_softmax

self.h3_embed = H3Embed(
Expand All @@ -316,13 +319,13 @@ def __init__(
dropout=dropout_rate,
initializer=weights_initializer,
regularize=weights_regularizer,
reduce_output=None,
reduce_output='None',
)

self.weights = tf.get_variable(
'weights',
[19, 1],
initializer=weights_initializer
# renamed from self.weights to self.h3_weights
# apparent conflict with `weight` attribute in super class
self.h3_weights = tf.Variable(
get_initializer(weights_initializer)([19, 1])
)

self.fc_stack = FCStack(
Expand All @@ -343,56 +346,47 @@ def __init__(
default_dropout_rate=dropout_rate,
)

def __call__(
def call(
self,
input_vector,
regularizer,
dropout_rate,
is_training=True
inputs,
training=None,
mask=None
):
"""
:param input_vector: The input vector fed into the encoder.
Shape: [batch x 19], type tf.int8
:type input_vector: Tensor
:param regularizer: The regularizer to use for the weights
of the encoder.
:type regularizer:
:param dropout_rate: Tensor (tf.float) of the probability of dropout
:type dropout_rate: Tensor
:param is_training: Tesnor (tf.bool) specifying if in training mode
(important for dropout)
:type is_training: Tensor
"""
:param training: bool specifying if in training mode (important for dropout)
:type training: bool
:param mask: bool specifying masked values
:type mask: bool
"""
# ================ Embeddings ================
embedded_h3, embedding_size = self.h3_embed(
input_vector = inputs
embedded_h3 = self.h3_embed(
input_vector,
regularizer,
dropout_rate,
is_training=is_training
training=training,
mask=mask
)

# ================ Weighted Sum ================
if self.should_softmax:
weights = tf.nn.softmax(self.weights)
weights = tf.nn.softmax(self.h3_weights)
else:
weights = self.weights
weights = self.h3_weights

hidden = reduce_sum(embedded_h3 * weights)
hidden = reduce_sum(embedded_h3['encoder_output'] * weights)

# ================ FC Stack ================
hidden_size = hidden.shape.as_list()[-1]
logger.debug(' flatten hidden: {0}'.format(hidden))

hidden = self.fc_stack(
hidden,
hidden_size,
regularizer=regularizer,
dropout_rate=dropout_rate,
is_training=is_training
training=training,
mask=mask
)
hidden_size = hidden.shape.as_list()[-1]

return hidden, hidden_size
return {'encoder_output': hidden}


class H3RNN:
Expand Down

0 comments on commit bea03e5

Please sign in to comment.