Skip to content

Commit

Permalink
Fix passing of keyword args to Dense layers in create_tower
Browse files Browse the repository at this point in the history
Current behavior: kwargs are passed to tf.keras.Sequential.add, so they
are not passed on to tf.keras.layers.Dense as intended. For example,
when passing `use_bias=False` to create_tower with the kwarg name
`kernel_regularizer`, it throws an exception:

Traceback (most recent call last):
  File "/Users/brussell/development/ranking/tensorflow_ranking/python/keras/layers_test.py", line 33, in test_create_tower_with_kwargs
    tower = layers.create_tower([3, 2, 1], 1, activation='relu', use_bias=False)
  File "/Users/brussell/development/ranking/tensorflow_ranking/python/keras/layers.py", line 70, in create_tower
    model.add(tf.keras.layers.Dense(units=layer_width), **kwargs)
  File "/usr/local/anaconda3/lib/python3.9/site-packages/tensorflow/python/trackable/base.py", line 205, in _method_wrapper
    result = method(self, *args, **kwargs)
  File "/usr/local/anaconda3/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 61, in error_handler
    return fn(*args, **kwargs)
TypeError: add() got an unexpected keyword argument 'use_bias'
test_create_tower_with_kwargs

Fix: This PR fixes the behavior by shifting the closing paren of
tf.keras.layers.Dense to the correct location.
  • Loading branch information
b4russell committed Dec 8, 2022
1 parent dfae631 commit 1400f70
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 3 deletions.
6 changes: 3 additions & 3 deletions tensorflow_ranking/python/keras/layers.py
Expand Up @@ -57,7 +57,7 @@ def create_tower(hidden_layer_dims: List[int],
dropout: When not `None`, the probability we will drop out a given
coordinate.
name: Name of the Keras layer.
**kwargs: Keyword arguments for every `tf.keras.Dense` layers.
**kwargs: Keyword arguments for every `tf.keras.layers.Dense` layer.
Returns:
A `tf.keras.Sequential` object.
Expand All @@ -67,13 +67,13 @@ def create_tower(hidden_layer_dims: List[int],
if input_batch_norm:
model.add(tf.keras.layers.BatchNormalization(momentum=batch_norm_moment))
for layer_width in hidden_layer_dims:
model.add(tf.keras.layers.Dense(units=layer_width), **kwargs)
model.add(tf.keras.layers.Dense(units=layer_width, **kwargs))
if use_batch_norm:
model.add(tf.keras.layers.BatchNormalization(momentum=batch_norm_moment))
model.add(tf.keras.layers.Activation(activation=activation))
if dropout:
model.add(tf.keras.layers.Dropout(rate=dropout))
model.add(tf.keras.layers.Dense(units=output_units), **kwargs)
model.add(tf.keras.layers.Dense(units=output_units, **kwargs))
return model


Expand Down
4 changes: 4 additions & 0 deletions tensorflow_ranking/python/keras/layers_test.py
Expand Up @@ -28,6 +28,10 @@ def test_create_tower(self):
outputs = tower(inputs)
self.assertAllEqual([2, 3, 1], outputs.get_shape().as_list())

def test_create_tower_with_bias_kwarg(self):
tower = layers.create_tower([3, 2], 1, use_bias=False)
tower_layers_bias = [tower.get_layer(name).use_bias for name in ['dense_1', 'dense_2']]
self.assertAllEqual([False, False], tower_layers_bias)

class FlattenListTest(tf.test.TestCase):

Expand Down

0 comments on commit 1400f70

Please sign in to comment.