Skip to content

Commit

Permalink
refactor: start of tagger decoder
Browse files Browse the repository at this point in the history
  • Loading branch information
jimthompson5802 committed Apr 14, 2020
1 parent 4e8dca2 commit ecfa703
Showing 1 changed file with 21 additions and 20 deletions.
41 changes: 21 additions & 20 deletions ludwig/models/modules/sequence_decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,14 +158,15 @@ def __init__(
self.regularize = regularize
self.attention = attention

def __call__(
def call(
self,
output_feature,
targets,
hidden,
hidden_size,
regularizer,
is_timeseries=False
**kwargs
# output_feature,
# targets,
# hidden_size,
# regularizer,
# is_timeseries=False
):
logger.debug(' hidden shape: {0}'.format(hidden.shape))
if len(hidden.shape) != 3:
Expand Down Expand Up @@ -215,20 +216,20 @@ def __call__(
)
logger.debug(' logits: {0}'.format(logits))

if is_timeseries:
probabilities_sequence = tf.zeros_like(logits)
predictions_sequence = tf.reshape(logits, [-1, sequence_length])
else:
probabilities_sequence = tf.nn.softmax(
logits,
name='probabilities_{}'.format(output_feature['name'])
)
predictions_sequence = tf.argmax(
logits,
-1,
name='predictions_{}'.format(output_feature['name']),
output_type=tf.int32
)
# if is_timeseries:
# probabilities_sequence = tf.zeros_like(logits)
# predictions_sequence = tf.reshape(logits, [-1, sequence_length])
# else:
# probabilities_sequence = tf.nn.softmax(
# logits,
# name='probabilities_{}'.format(output_feature['name'])
# )
# predictions_sequence = tf.argmax(
# logits,
# -1,
# name='predictions_{}'.format(output_feature['name']),
# output_type=tf.int32
# )

predictions_sequence_length = sequence_length_3D(hidden)

Expand Down

0 comments on commit ecfa703

Please sign in to comment.