Skip to content

Commit

Permalink
Remove topk support for DocumentInteractionAttention.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 352923862
  • Loading branch information
ramakumar1729 committed Jan 29, 2021
1 parent ce7c2b6 commit 66416aa
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 46 deletions.
24 changes: 4 additions & 20 deletions tensorflow_ranking/python/keras/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,6 @@ def __init__(self,
num_heads: int,
head_size: int,
num_layers: int = 1,
topk: Optional[int] = None,
dropout_rate: float = 0.5,
name: Optional[str] = None,
**kwargs: Dict[Any, Any]):
Expand All @@ -353,23 +352,15 @@ def __init__(self,
num_heads: Number of attention heads (see `MultiHeadAttention` for more
details on this argument).
head_size: Size of each attention head.
num_layers: Number of self-attention layers.
topk: top-k positions to attend over. If None, attends over entire list.
num_layers: Number of cross-document attention layers.
dropout_rate: Dropout probability.
name: Name of the layer.
**kwargs: keyword arguments.
Raises:
ValueError: If topk is not None or not a positive integer.
"""
if topk is not None:
if topk <= 0 or not isinstance(topk, int):
raise ValueError('topk should be either None or a positive integer.')
super().__init__(name=name, **kwargs)
self._num_heads = num_heads
self._head_size = head_size
self._num_layers = num_layers
self._topk = topk
self._dropout_rate = dropout_rate

# This projects input to head_size, so that this layer can be applied
Expand Down Expand Up @@ -420,19 +411,13 @@ def call(self,
mask = tf.ones(shape=(batch_size, list_size), dtype=tf.bool)
input_tensor = self._input_projection(inputs, training=training)

q_mask = tf.cast(mask, dtype=tf.int32)
k_mask = q_mask[:, :self._topk] if self._topk else q_mask
attention_mask = nlp_modeling_layers.SelfAttentionMask()([q_mask, k_mask])
mask = tf.cast(mask, dtype=tf.int32)
attention_mask = nlp_modeling_layers.SelfAttentionMask()([mask, mask])

for attention_layer, dropout_layer, norm_layer in self._attention_layers:
# k_tensor, the keys and values attended over, is truncated when topk is
# specified. Note that the output shape is unchanged, as that is
# determined by query_tensor.
k_tensor = (
input_tensor[:, :self._topk, :] if self._topk else input_tensor)
output = attention_layer(
query=input_tensor,
value=k_tensor,
value=input_tensor,
attention_mask=attention_mask,
training=training)
output = dropout_layer(output, training=training)
Expand All @@ -447,7 +432,6 @@ def get_config(self):
'num_heads': self._num_heads,
'head_size': self._head_size,
'num_layers': self._num_layers,
'topk': self._topk,
'dropout_rate': self._dropout_rate,
})
return config
36 changes: 10 additions & 26 deletions tensorflow_ranking/python/keras/layers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

# Lint as: python3
"""Tests for Keras layers in TF-Ranking."""
from absl.testing import parameterized
import tensorflow as tf

from tensorflow_ranking.python.keras import layers
Expand Down Expand Up @@ -214,8 +213,7 @@ def test_call_raise_error(self):
layers.RestoreList()(flattened_logits=flattened_logits_2d, mask=mask)


class DocumentInteractionAttentionLayerTest(tf.test.TestCase,
parameterized.TestCase):
class DocumentInteractionAttentionLayerTest(tf.test.TestCase):

def setUp(self):
super().setUp()
Expand All @@ -230,32 +228,29 @@ def setUp(self):
self._num_layers = 2
self._dropout_rate = 0.5

def _get_din_layer(self, topk=None):
def _get_din_layer(self):
return layers.DocumentInteractionAttention(
num_heads=self._num_heads,
head_size=self._head_size,
num_layers=self._num_layers,
topk=topk,
dropout_rate=self._dropout_rate)

@parameterized.named_parameters(('topk_none', None), ('topk', 1))
def test_serialization(self, topk):
def test_serialization(self):
# Check save and restore config.
layer = self._get_din_layer(topk)
layer = self._get_din_layer()
serialized = tf.keras.layers.serialize(layer)
loaded = tf.keras.layers.deserialize(serialized)
self.assertAllEqual(loaded.get_config(), layer.get_config())

@parameterized.named_parameters(('topk_none', None), ('topk', 1))
def test_deterministic_inference_behavior(self, topk):
din_layer = self._get_din_layer(topk)
def test_deterministic_inference_behavior(self):
din_layer = self._get_din_layer()
output_1 = din_layer(inputs=self._inputs, training=False, mask=self._mask)
output_2 = din_layer(inputs=self._inputs, training=False, mask=self._mask)
self.assertAllClose(output_1, output_2)

def test_call_topk_none(self):
def test_call(self):
tf.random.set_seed(1)
din_layer = self._get_din_layer(topk=None)
din_layer = self._get_din_layer()
output = din_layer(inputs=self._inputs, training=False, mask=self._mask)
self.assertEqual(output.shape.as_list(), [2, 3, self._head_size])

Expand All @@ -265,19 +260,8 @@ def test_call_topk_none(self):
[-1., 1.]]])
self.assertAllClose(expected_output, output)

def test_call_topk(self):
tf.random.set_seed(1)
din_layer = self._get_din_layer(topk=1)
output = din_layer(inputs=self._inputs, training=False, mask=self._mask)
self.assertEqual(output.shape.as_list(), [2, 3, self._head_size])
expected_output = tf.convert_to_tensor([[[-1., 1.], [-1., 1.], [-1., 1.]],
[[-1., 0.99999994], [-1., 1.],
[-1., 1.]]])
self.assertAllClose(expected_output, output)

@parameterized.named_parameters(('topk_none', None), ('topk', 1))
def test_no_effect_circular_padding(self, topk):
din_layer = self._get_din_layer(topk)
def test_no_effect_circular_padding(self):
din_layer = self._get_din_layer()
output_1 = din_layer(inputs=self._inputs, training=False, mask=self._mask)
circular_padded_inputs = tf.constant(
[[[2., 1.], [2., 0.], [2., 1.]], [[1., 0.], [1., 0.], [1., 0.]]],
Expand Down

0 comments on commit 66416aa

Please sign in to comment.