Skip to content

Commit

Permalink
Change signature of tfr.keras.layers.* with mask to list_mask.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 353292810
  • Loading branch information
ramakumar1729 committed Jan 29, 2021
1 parent 66416aa commit 8d37f9f
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 79 deletions.
13 changes: 4 additions & 9 deletions tensorflow_ranking/examples/keras/antique_kpl_din.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,16 +291,14 @@ def _embedding(x):
# Document interaction attention layer.
if FLAGS.use_document_interaction:
concat_tensor = tfr.keras.layers.ConcatFeatures()(
context_features=preprocessed_context_features,
example_features=preprocessed_example_features,
mask=mask)
preprocessed_context_features, preprocessed_example_features, mask)
din_layer = tfr.keras.layers.DocumentInteractionAttention(
num_heads=FLAGS.num_attention_heads,
head_size=FLAGS.head_size,
num_layers=FLAGS.num_attention_layers,
dropout_rate=FLAGS.dropout_rate)
preprocessed_example_features["document_interaction_embedding"] = din_layer(
inputs=concat_tensor, mask=mask)
inputs=concat_tensor, list_mask=mask)

return preprocessed_context_features, preprocessed_example_features

Expand All @@ -313,9 +311,7 @@ def create_ranking_model() -> tf.keras.Model:

(flattened_context_features,
flattened_example_features) = tfr.keras.layers.FlattenList()(
context_features=context_features,
example_features=example_features,
mask=mask)
context_features, example_features, mask)

# Concatenate flattened context and example features along `list_size` dim.
context_input = [
Expand Down Expand Up @@ -345,8 +341,7 @@ def create_ranking_model() -> tf.keras.Model:
dnn.add(tf.keras.layers.Dropout(rate=FLAGS.dropout_rate))
dnn.add(tf.keras.layers.Dense(units=1))

logits = tfr.keras.layers.RestoreList()(
flattened_logits=dnn(input_layer), mask=mask)
logits = tfr.keras.layers.RestoreList()(dnn(input_layer), mask)

return tf.keras.Model(
inputs=dict(
Expand Down
9 changes: 4 additions & 5 deletions tensorflow_ranking/examples/tf_ranking_tfrecord.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,17 +228,16 @@ def _transform_fn(features, mode):
# Document interaction attention layer.
if FLAGS.use_document_interaction:
training = (mode == tf.estimator.ModeKeys.TRAIN)
concat_tensor = tfr.keras.layers.ConcatFeatures()(
context_features=context_features,
example_features=example_features,
mask=mask)
concat_tensor = tfr.keras.layers.ConcatFeatures()(context_features,
example_features,
mask)
din_layer = tfr.keras.layers.DocumentInteractionAttention(
num_heads=FLAGS.num_attention_heads,
head_size=FLAGS.head_size,
num_layers=FLAGS.num_attention_layers,
dropout_rate=FLAGS.dropout_rate)
example_features["document_interaction_embedding"] = din_layer(
inputs=concat_tensor, training=training, mask=mask)
inputs=concat_tensor, training=training, list_mask=mask)

return context_features, example_features

Expand Down
65 changes: 34 additions & 31 deletions tensorflow_ranking/python/keras/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,17 +70,17 @@ def __init__(self,

def call(
self, context_features: Dict[str, tf.Tensor],
example_features: Dict[str, tf.Tensor],
mask: tf.Tensor) -> Tuple[Dict[str, tf.Tensor], Dict[str, tf.Tensor]]:
example_features: Dict[str, tf.Tensor], list_mask: tf.Tensor
) -> Tuple[Dict[str, tf.Tensor], Dict[str, tf.Tensor]]:
"""Call FlattenList layer to flatten context_features and example_features.
Args:
context_features: A map of context features to 2D tensors of shape
[batch_size, feature_dim].
example_features: A map of example features to 3D tensors of shape
[batch_size, list_size, feature_dim].
mask: A Tensor of shape [batch_size, list_size] to mask out the invalid
examples.
list_mask: A Tensor of shape [batch_size, list_size] to mask out the
invalid examples.
Returns:
A tuple of (flattened_context_features, flattened_example_fatures) where
Expand All @@ -93,8 +93,8 @@ def call(
"""
if not example_features:
raise ValueError('Need a valid example feature.')
batch_size = tf.shape(mask)[0]
list_size = tf.shape(mask)[1]
batch_size = tf.shape(list_mask)[0]
list_size = tf.shape(list_mask)[1]
# Expand context features to be of [batch_size, list_size, ...].
flattened_context_features = {}
for name, tensor in context_features.items():
Expand All @@ -105,7 +105,7 @@ def call(

nd_indices = None
if self._circular_padding:
nd_indices, _ = utils.padded_nd_indices(is_valid=mask)
nd_indices, _ = utils.padded_nd_indices(is_valid=list_mask)

flattened_example_features = {}
for name, tensor in example_features.items():
Expand Down Expand Up @@ -171,15 +171,16 @@ def __init__(self,
super().__init__(name=name, **kwargs)
self._by_scatter = by_scatter

def call(self, flattened_logits: tf.Tensor, mask: tf.Tensor) -> tf.Tensor:
def call(self, flattened_logits: tf.Tensor,
list_mask: tf.Tensor) -> tf.Tensor:
"""Restores listwise shape of flattened_logits.
Args:
flattened_logits: A `Tensor` of predicted logits for each pair of query
and documents, 1D tensor of shape [batch_size * list_size] or 2D tensor
of shape [batch_size * list_size, 1].
mask: A boolean `Tensor` of shape [batch_size, list_size] to mask out the
invalid examples.
list_mask: A boolean `Tensor` of shape [batch_size, list_size] to mask out
the invalid examples.
Returns:
A `Tensor` of shape [batch_size, list_size].
Expand All @@ -189,19 +190,20 @@ def call(self, flattened_logits: tf.Tensor, mask: tf.Tensor) -> tf.Tensor:
2D with shape [batch_size * list_size, 1].
"""
try:
logits = tf.reshape(flattened_logits, shape=tf.shape(mask))
logits = tf.reshape(flattened_logits, shape=tf.shape(list_mask))
except:
raise ValueError('`flattened_logits` needs to be either '
'1D of [batch_size * list_size] or '
'2D of [batch_size * list_size, 1].')
if self._by_scatter:
nd_indices, _ = utils.padded_nd_indices(is_valid=mask)
counts = tf.scatter_nd(nd_indices, tf.ones_like(logits), tf.shape(mask))
logits = tf.scatter_nd(nd_indices, logits, tf.shape(mask))
nd_indices, _ = utils.padded_nd_indices(is_valid=list_mask)
counts = tf.scatter_nd(nd_indices, tf.ones_like(logits),
tf.shape(list_mask))
logits = tf.scatter_nd(nd_indices, logits, tf.shape(list_mask))
return tf.where(
tf.math.greater(counts, 0.), logits / counts, tf.math.log(_EPSILON))
else:
return tf.where(mask, logits, tf.math.log(_EPSILON))
return tf.where(list_mask, logits, tf.math.log(_EPSILON))

def get_config(self):
config = super().get_config()
Expand Down Expand Up @@ -261,16 +263,16 @@ def call(
self,
context_features: Dict[str, tf.Tensor],
example_features: Dict[str, tf.Tensor],
mask: [tf.Tensor],
list_mask: [tf.Tensor],
) -> tf.Tensor:
"""Call method for ConcatFeatures layer.
Args:
context_features: A dict of `Tensor`s with shape [batch_size, ...].
example_features: A dict of `Tensor`s with shape [batch_size, list_size,
...].
mask: A boolean tensor of shape [batch_size, list_size], which is True for
a valid example and False for invalid one.
list_mask: A boolean tensor of shape [batch_size, list_size], which is
True for a valid example and False for invalid one.
Returns:
A `Tensor` of shape [batch_size, list_size, ...].
Expand All @@ -279,7 +281,7 @@ def call(
flattened_example_features) = self._flatten_list(
context_features=context_features,
example_features=example_features,
mask=mask)
list_mask=list_mask)
# Concatenate flattened context and example features along `list_size` dim.
context_input = [
tf.keras.layers.Flatten()(flattened_context_features[name])
Expand All @@ -292,8 +294,8 @@ def call(
flattened_concat_features = tf.concat(context_input + example_input, 1)

# Reshape to 3D.
batch_size = tf.shape(mask)[0]
list_size = tf.shape(mask)[1]
batch_size = tf.shape(list_mask)[0]
list_size = tf.shape(list_mask)[1]
return utils.reshape_first_ndims(flattened_concat_features, 1,
[batch_size, list_size])

Expand Down Expand Up @@ -329,13 +331,13 @@ class DocumentInteractionAttention(tf.keras.layers.Layer):
```python
# Batch size = 2, list_size = 3.
inputs = [[[1., 1.], [1., 0.], [1., 1.]], [[0., 0.], [0., 0.], [0., 0.]]]
mask = [[True, True, False], [True, False, False]]
list_mask = [[True, True, False], [True, False, False]]
dia_layer = DocumentInteractionAttention(
num_heads=1, head_size=64, num_layers=1, topk=1)
dia_output = dia_layer(
inputs=inputs,
training=False,
mask=mask)
list_mask=list_mask)
```
"""

Expand Down Expand Up @@ -392,27 +394,28 @@ def __init__(self,
def call(self,
inputs: tf.Tensor,
training: bool = True,
mask: Optional[tf.Tensor] = None) -> tf.Tensor:
list_mask: Optional[tf.Tensor] = None) -> tf.Tensor:
"""Calls the document interaction layer to apply cross-document attention.
Args:
inputs: A tensor of shape [batch_size, list_size, feature_dims].
training: Whether in training or inference mode.
mask: A boolean tensor of shape [batch_size, list_size], which is True for
a valid example and False for invalid one. If this is `None`, then all
examples are treated as valid.
list_mask: A boolean tensor of shape [batch_size, list_size], which is
True for a valid example and False for invalid one. If this is `None`,
then all examples are treated as valid.
Returns:
A tensor of shape [batch_size, list_size, head_size].
"""
batch_size = tf.shape(inputs)[0]
list_size = tf.shape(inputs)[1]
if mask is None:
mask = tf.ones(shape=(batch_size, list_size), dtype=tf.bool)
if list_mask is None:
list_mask = tf.ones(shape=(batch_size, list_size), dtype=tf.bool)
input_tensor = self._input_projection(inputs, training=training)

mask = tf.cast(mask, dtype=tf.int32)
attention_mask = nlp_modeling_layers.SelfAttentionMask()([mask, mask])
list_mask = tf.cast(list_mask, dtype=tf.int32)
attention_mask = nlp_modeling_layers.SelfAttentionMask()(
[list_mask, list_mask])

for attention_layer, dropout_layer, norm_layer in self._attention_layers:
output = attention_layer(
Expand Down

0 comments on commit 8d37f9f

Please sign in to comment.