In [1]:
import torch

In [2]:
batches = 4
text_length = 8
attention_weights = torch.rand((batches, text_length))
print(attention_weights)

tensor([[0.1683, 0.7441, 0.6654, 0.1558, 0.3393, 0.0990, 0.7204, 0.1018],
        [0.4469, 0.1852, 0.8273, 0.4945, 0.9406, 0.5152, 0.1824, 0.6835],
        [0.4573, 0.5556, 0.0690, 0.8547, 0.5191, 0.1388, 0.9449, 0.2114],
        [0.1963, 0.8547, 0.8503, 0.3807, 0.0762, 0.2196, 0.2189, 0.7326]])


In [3]:
torch.max(attention_weights, dim=1)

torch.return_types.max(
values=tensor([0.7441, 0.9406, 0.9449, 0.8547]),
indices=tensor([1, 4, 6, 1]))

In [4]:
max_expanded = torch.max(attention_weights, dim=1).values.unsqueeze(-1).expand((batches, text_length))
print(max_expanded)

tensor([[0.7441, 0.7441, 0.7441, 0.7441, 0.7441, 0.7441, 0.7441, 0.7441],
        [0.9406, 0.9406, 0.9406, 0.9406, 0.9406, 0.9406, 0.9406, 0.9406],
        [0.9449, 0.9449, 0.9449, 0.9449, 0.9449, 0.9449, 0.9449, 0.9449],
        [0.8547, 0.8547, 0.8547, 0.8547, 0.8547, 0.8547, 0.8547, 0.8547]])


In [5]:
discrete_att_w = attention_weights.masked_fill((attention_weights == max_expanded), 1)
discrete_att_w = discrete_att_w.masked_fill((attention_weights != max_expanded), 0)
print(discrete_att_w)

tensor([[0., 1., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 1., 0., 0., 0., 0., 0., 0.]])


# Functionalization

In [6]:
def discretize_att_w(attention_weights):
    '''
    PARAMS
    -----
    attention_weights: Attention weights of one batch.
    - torch.Tensor. Size == (batches, max_encoding_steps).
    
    RETURNS
    -----
    discrete_att_w: Discretized attention weights of one batch.
    - torch.Tensor. Size == (batches, max_encoding_steps).
    '''
    max_expanded = torch.max(attention_weights, dim=1).values.unsqueeze(-1).expand(attention_weights.size())
    discrete_att_w = attention_weights.masked_fill((attention_weights == max_expanded), 1)
    discrete_att_w = discrete_att_w.masked_fill((attention_weights != max_expanded), 0)
    
    return discrete_att_w    