Skip to content

Commit

Permalink
Reintroduced dependencies concat logic
Browse files Browse the repository at this point in the history
  • Loading branch information
w4nderlust committed Mar 22, 2020
1 parent 27cba17 commit 7d74e0b
Showing 1 changed file with 19 additions and 31 deletions.
50 changes: 19 additions & 31 deletions ludwig/features/base_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,55 +212,55 @@ def postprocess_results(
def populate_defaults(input_feature):
pass

# todo tf2: adapt for tf2
def concat_dependencies(self, hidden, final_hidden):
def concat_dependencies(self, hidden, other_features_hidden):
if len(self.dependencies) > 0:
dependencies_hidden = []
for dependency in self.dependencies:
# the dependent feature is ensured to be present in final_hidden
# because we did the topological sort of the features before
dependency_final_hidden = final_hidden[dependency]
dependency_final_hidden = other_features_hidden[dependency]

# todo tf2: test all 4 branches, for now only vector x vector is tested
if len(hidden.shape) > 2:
if len(dependency_final_hidden[0].shape) > 2:
if len(dependency_final_hidden.shape) > 2:
# matrix matrix -> concat
dependencies_hidden.append(dependency_final_hidden[0])
assert hidden.shape[1] == dependency_final_hidden.shape[1]
dependencies_hidden.append(dependency_final_hidden)
else:
# matrix vector -> tile concat
sequence_max_length = tf.shape(hidden)[1]
sequence_max_length = hidden.shape[1]
multipliers = tf.concat(
[[1], tf.expand_dims(sequence_max_length, -1), [1]],
[[1], sequence_max_length[:, tf.newaxis], [1]],
0
)
tiled_representation = tf.tile(
tf.expand_dims(dependency_final_hidden[0], 1),
tf.expand_dims(dependency_final_hidden, 1),
multipliers
)

# todo tf2: modify this with TF2 mask mechanics
sequence_length = sequence_length_3D(hidden)
mask = tf.sequence_mask(
sequence_length,
sequence_max_length
)
tiled_representation = tf.multiply(
tiled_representation,
tf.cast(tf.expand_dims(mask, -1), dtype=tf.float32)
tf.cast(mask[:, tf.newaxis], dtype=tf.float32)
)

dependencies_hidden.append(tiled_representation)

else:
if len(dependency_final_hidden[0].shape) > 2:
if len(dependency_final_hidden.shape) > 2:
# vector matrix -> reduce concat
dependencies_hidden.append(
reduce_sequence(dependency_final_hidden[0],
reduce_sequence(dependency_final_hidden,
self.reduce_dependencies)
)
else:
# vector vector -> concat
dependencies_hidden.append(dependency_final_hidden[0])

# hidden_size += dependency_final_hidden[1]
dependencies_hidden.append(dependency_final_hidden)

try:
hidden = tf.concat([hidden] + dependencies_hidden, -1)
Expand Down Expand Up @@ -348,12 +348,11 @@ def prepare_decoder_inputs(
self.reduce_input
)

# ================ Adding Dependencies ================
# todo tf2 reintroduce this
# feature_hidden = self.concat_dependencies(
# feature_hidden,
# other_output_features
# )
# ================ Concat Dependencies ================
feature_hidden = self.concat_dependencies(
feature_hidden,
other_output_features
)

# ================ Output-wise Fully Connected ================
feature_hidden = self.output_specific_fully_connected(
Expand All @@ -363,15 +362,4 @@ def prepare_decoder_inputs(
)
other_output_features[self.feature_name] = feature_hidden

# ================ Outputs ================
# train_mean_loss, eval_loss, output_tensors = self.build_output(
# feature_hidden,
# feature_hidden_size,
# **kwargs
# )
#
# loss_weight = float(self.loss['weight'])
# weighted_train_mean_loss = train_mean_loss * loss_weight
# weighted_eval_loss = eval_loss * loss_weight

return feature_hidden

0 comments on commit 7d74e0b

Please sign in to comment.