Skip to content

Commit

Permalink
Removed superflous wrapping of decoder_input_state into list + minor …
Browse files Browse the repository at this point in the history
…renaming
  • Loading branch information
w4nderlust committed Jun 8, 2020
1 parent fd71433 commit e913caa
Showing 1 changed file with 34 additions and 37 deletions.
71 changes: 34 additions & 37 deletions ludwig/models/modules/sequence_decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def __init__(

def _logits_training(self, inputs, target, training=None):
input = inputs['hidden'] # shape [batch_size, seq_size, state_size]
encoder_end_state = self._prepare_decoder_input_state(inputs)
encoder_end_state = self.prepare_encoder_output_state(inputs)

logits = self.decoder_teacher_forcing(
input,
Expand All @@ -133,29 +133,29 @@ def _logits_training(self, inputs, target, training=None):
)
return logits # shape = [b, s, c]

def _prepare_decoder_input_state(self, inputs):
def prepare_encoder_output_state(self, inputs):

if 'encoder_output_state' in inputs:
decoder_input_state = inputs['encoder_output_state']
encoder_output_state = inputs['encoder_output_state']
else:
hidden = inputs['hidden']
if len(hidden.shape) == 3: # encoder_output is a sequence
# reduce_sequence returns a [b, h]
decoder_input_state = reduce_sequence(
encoder_output_state = reduce_sequence(
hidden,
self.reduce_input if self.reduce_input else 'sum'
)
elif len(hidden.shape) == 2:
# this returns a [b, h]
decoder_input_state = hidden
encoder_output_state = hidden
else:
raise ValueError("Only works for 1d or 2d encoder_output")

# now we have to deal with the fact that the state needs to be a list
# in case of lstm or a tensor otherwise
if (self.cell_type == 'lstm' and
isinstance(decoder_input_state, list)):
if len(decoder_input_state) == 2:
isinstance(encoder_output_state, list)):
if len(encoder_output_state) == 2:
# this maybe a unidirectionsl lstm or a bidirectional gru / rnn
# there is no way to tell
# If it is a unidirectional lstm, pass will work fine
Expand All @@ -164,15 +164,15 @@ def _prepare_decoder_input_state(self, inputs):
# which is weird and may lead to poor performance
# todo try to find a way to distinguish among these two cases
pass
elif len(decoder_input_state) == 4:
elif len(encoder_output_state) == 4:
# the encoder was a bidirectional lstm
# a good strategy is to average the 2 h and the 2 c vectors
decoder_input_state = [
encoder_output_state = [
average(
[decoder_input_state[0], decoder_input_state[2]]
[encoder_output_state[0], encoder_output_state[2]]
),
average(
[decoder_input_state[1], decoder_input_state[3]]
[encoder_output_state[1], encoder_output_state[3]]
)
]
else:
Expand All @@ -183,25 +183,25 @@ def _prepare_decoder_input_state(self, inputs):
# "encoder_output_state has length different than 2 or 4. "
# "Please doublecheck your encoder"
# )
average_state = average(decoder_input_state)
decoder_input_state = [average_state, average_state]
average_state = average(encoder_output_state)
encoder_output_state = [average_state, average_state]

elif (self.cell_type == 'lstm' and
not isinstance(decoder_input_state, list)):
decoder_input_state = [decoder_input_state, decoder_input_state]
not isinstance(encoder_output_state, list)):
encoder_output_state = [encoder_output_state, encoder_output_state]

elif (self.cell_type != 'lstm' and
isinstance(decoder_input_state, list)):
isinstance(encoder_output_state, list)):
# here we have a couple options,
# either reuse part of the input encoder state,
# or just use its output
if len(decoder_input_state) == 2:
if len(encoder_output_state) == 2:
# using h and ignoring c
decoder_input_state = decoder_input_state[0]
elif len(decoder_input_state) == 4:
encoder_output_state = encoder_output_state[0]
elif len(encoder_output_state) == 4:
# using average of hs and ignoring cs
decoder_input_state + average(
[decoder_input_state[0], decoder_input_state[2]]
encoder_output_state + average(
[encoder_output_state[0], encoder_output_state[2]]
)
else:
# no idea how lists of length different than 2 or 4
Expand All @@ -211,37 +211,33 @@ def _prepare_decoder_input_state(self, inputs):
# "encoder_output_state has length different than 2 or 4. "
# "Please doublecheck your encoder"
# )
decoder_input_state = average(decoder_input_state)
encoder_output_state = average(encoder_output_state)
# this returns a [b, h]
# decoder_input_state = reduce_sequence(eo, self.reduce_input)

elif (self.cell_type != 'lstm' and
not isinstance(decoder_input_state, list)):
not isinstance(encoder_output_state, list)):
# do nothing, we are good
pass

# at this point decoder_input_state is either a [b,h]
# or a list([b,h], [b,h]) if the decoder cell is an lstm
# but h may not be the same as the decoder state size,
# so we may need to project
if isinstance(decoder_input_state, list):
for i in range(len(decoder_input_state)):
if (decoder_input_state[i].shape[1] !=
if isinstance(encoder_output_state, list):
for i in range(len(encoder_output_state)):
if (encoder_output_state[i].shape[1] !=
self.state_size):
decoder_input_state[i] = self.project(
decoder_input_state[i]
encoder_output_state[i] = self.project(
encoder_output_state[i]
)
else:
if decoder_input_state.shape[1] != self.state_size:
decoder_input_state = self.project(
decoder_input_state
if encoder_output_state.shape[1] != self.state_size:
encoder_output_state = self.project(
encoder_output_state
)

# make sure we are passing back the state tensors in a list
if not isinstance(decoder_input_state, list):
decoder_input_state = [decoder_input_state]

return decoder_input_state
return encoder_output_state

def build_decoder_initial_state(self, batch_size, encoder_state, dtype):
decoder_initial_state = self.decoder_rnncell.get_initial_state(
Expand Down Expand Up @@ -569,7 +565,7 @@ def call(self, inputs, training=None, mask=None):
# form dependent on cell_type
# lstm: list([batch_size, state_size], [batch_size, state_size])
# rnn, gru: [batch_size, state_size]
encoder_output_state = self._prepare_decoder_input_state(
encoder_output_state = self.prepare_encoder_output_state(
inputs[LOGITS]
)

Expand Down Expand Up @@ -688,6 +684,7 @@ def _predictions_eval(
output_type=tf.int64
)

# todo tf2: deal with spurious 0s in predictions
generated_sequence_lengths = sequence_length_2D(predictions)
last_predictions = tf.gather_nd(
predictions,
Expand Down

0 comments on commit e913caa

Please sign in to comment.