In [1]:
import torch
import torch.nn as nn

# 1.1 torch.nn.MultiheadAttention

## a: core usage

~~~
forward(query, key, value, key_padding_mask=None, need_weights=True, attn_mask=None, average_attn_weights=True)
~~~

In [2]:
multihead_attn = torch.nn.MultiheadAttention(embed_dim=4, num_heads=1, batch_first=True)

In [3]:
x = torch.randn((1,3,4))
print(x)

tensor([[[ 0.4028, -0.8381, -0.9777, -0.3953],
         [-0.3942,  0.4389,  1.2422, -0.9700],
         [-1.1927,  1.3185, -1.0513, -1.5547]]])


In [4]:
attn_output, attn_output_weights = multihead_attn(x,x,x)

In [5]:
print(attn_output)

tensor([[[ 0.0881,  0.2321, -0.0496, -0.1493],
         [ 0.1222,  0.2689, -0.1273, -0.2522],
         [ 0.2165,  0.3881, -0.3285, -0.5227]]], grad_fn=<TransposeBackward0>)


In [6]:
print(attn_output_weights)

tensor([[[0.3571, 0.3111, 0.3318],
         [0.2507, 0.3507, 0.3986],
         [0.1129, 0.2516, 0.6355]]], grad_fn=<DivBackward0>)


In [7]:
attn_output_weights.matmul(x)

tensor([[[-0.3745,  0.2747, -0.3115, -0.9587],
         [-0.5127,  0.4693, -0.2285, -1.0590],
         [-0.8116,  0.8536, -0.4660, -1.2767]]], grad_fn=<UnsafeViewBackward0>)

## b understand attn_mask in multihead_attention

* attn_mask – If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape (L, S)(L,S) or (N\cdot\text{num\_heads}, L, S)(N⋅num_heads,L,S), where NN is the batch size, LL is the target sequence length, and SS is the source sequence length. A 2D mask will be broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch. Binary, byte, and float masks are supported. 

https://github.com/pytorch/pytorch/blob/master/torch/nn/functional.py
~~~
def _scaled_dot_product_attention(
    q: Tensor,
    k: Tensor,
    v: Tensor,
    attn_mask: Optional[Tensor] = None,
    dropout_p: float = 0.0,
) -> Tuple[Tensor, Tensor]:
    r"""
    Computes scaled dot product attention on query, key and value tensors, using
    an optional attention mask if passed, and applying dropout if a probability
    greater than 0.0 is specified.
    Returns a tensor pair containing attended values and attention weights.
    Args:
        q, k, v: query, key and value tensors. See Shape section for shape details.
        attn_mask: optional tensor containing mask values to be added to calculated
            attention. May be 2D or 3D; see Shape section for details.
        dropout_p: dropout probability. If greater than 0.0, dropout is applied.
    Shape:
        - q: :math:`(B, Nt, E)` where B is batch size, Nt is the target sequence length,
            and E is embedding dimension.
        - key: :math:`(B, Ns, E)` where B is batch size, Ns is the source sequence length,
            and E is embedding dimension.
        - value: :math:`(B, Ns, E)` where B is batch size, Ns is the source sequence length,
            and E is embedding dimension.
        - attn_mask: either a 3D tensor of shape :math:`(B, Nt, Ns)` or a 2D tensor of
            shape :math:`(Nt, Ns)`.
        - Output: attention values have shape :math:`(B, Nt, E)`; attention weights
            have shape :math:`(B, Nt, Ns)`
    """
    B, Nt, E = q.shape
    q = q / math.sqrt(E)
    # (B, Nt, E) x (B, E, Ns) -> (B, Nt, Ns)
    if attn_mask is not None:
        attn = torch.baddbmm(attn_mask, q, k.transpose(-2, -1))
    else:
        attn = torch.bmm(q, k.transpose(-2, -1))

    attn = softmax(attn, dim=-1)
    if dropout_p > 0.0:
        attn = dropout(attn, p=dropout_p)
    # (B, Nt, Ns) x (B, Ns, E) -> (B, Nt, E)
    output = torch.bmm(attn, v)
    return output, attn
~~~

In [7]:
attn_mask = nn.Transformer.generate_square_subsequent_mask(3)
print(attn_mask)
attn_mask = attn_mask.bool()
print(attn_mask)
#attn_mask = torch.randint(0,2,[3,3]).bool()
#print(attn_mask)

tensor([[0., -inf, -inf],
        [0., 0., -inf],
        [0., 0., 0.]])
tensor([[False,  True,  True],
        [False, False,  True],
        [False, False, False]])


In [8]:
attn_output, attn_output_weights = multihead_attn(x,x,x, attn_mask=attn_mask)
print(attn_output)
print(attn_output_weights)

tensor([[[ 0.0901,  0.0645,  0.0768,  0.0714],
         [ 0.1172,  0.4016,  0.0096,  0.0223],
         [-0.0686,  0.2491, -0.1330, -0.1129]]], grad_fn=<TransposeBackward0>)
tensor([[[1.0000, 0.0000, 0.0000],
         [0.4010, 0.5990, 0.0000],
         [0.1628, 0.1985, 0.6386]]], grad_fn=<DivBackward0>)


## c: understand key_padding_mask in multihead_attention

* key_padding_mask – If specified, a mask of shape (N, S)(N,S) indicating which elements within key to ignore for the purpose of attention (i.e. treat as “padding”). 

batch_size = 3，seq_length_ =4，token looks like
~~~
[
    [‘a’,'b','<PAD>'],
    [‘a’,'b','c'],
    [‘a’,'<PAD>','<PAD>']
]

~~~

key_padding_mask.shape = （3,4）

~~~
padding_mask = torch.tensor([
    [False, False, True],
    [False, False, False],
    [False, True, True]
])
print(padding_mask)
~~~

In [9]:
padding_mask = torch.tensor([
    [False, False, True]])

In [10]:
padding_mask.shape

torch.Size([1, 3])

In [11]:
attn_output, attn_output_weights = multihead_attn(x,x,x, key_padding_mask=padding_mask)
print(attn_output)
print(attn_output_weights)

tensor([[[0.1128, 0.3470, 0.0205, 0.0303],
         [0.1172, 0.4016, 0.0096, 0.0223],
         [0.1149, 0.3737, 0.0152, 0.0264]]], grad_fn=<TransposeBackward0>)
tensor([[[0.4980, 0.5020, 0.0000],
         [0.4010, 0.5990, 0.0000],
         [0.4506, 0.5494, 0.0000]]], grad_fn=<DivBackward0>)


# d: mix use of attn_mask and padding mask

In [12]:
attn_output, attn_output_weights = multihead_attn(x,x,x, attn_mask=attn_mask, key_padding_mask=padding_mask)
print(attn_output)
print(attn_output_weights)

tensor([[[0.0901, 0.0645, 0.0768, 0.0714],
         [0.1172, 0.4016, 0.0096, 0.0223],
         [0.1149, 0.3737, 0.0152, 0.0264]]], grad_fn=<TransposeBackward0>)
tensor([[[1.0000, 0.0000, 0.0000],
         [0.4010, 0.5990, 0.0000],
         [0.4506, 0.5494, 0.0000]]], grad_fn=<DivBackward0>)


# 1.2 nn.TransformerEncoderLayer

In [13]:
encoder_layer = nn.TransformerEncoderLayer(d_model=4, nhead=1, batch_first=True)

In [14]:
src = torch.randn(1,3,4)
print(src)

tensor([[[ 0.8674, -1.5524,  0.4267, -0.1954],
         [-2.1481, -0.5348, -1.4551, -0.3789],
         [-1.9665,  0.4252, -0.1385, -0.3649]]])


In [15]:
out = encoder_layer(src)
print(out)

tensor([[[ 0.9521, -1.6815,  0.3756,  0.3538],
         [-0.9838,  0.6514, -0.9635,  1.2959],
         [-1.5548,  1.0577, -0.1639,  0.6610]]],
       grad_fn=<NativeLayerNormBackward0>)


In [16]:
print(out.size())

torch.Size([1, 3, 4])


In [17]:
attn_mask = torch.randint(0,2,[3,3]).bool()
print(attn_mask)

tensor([[ True,  True, False],
        [False,  True, False],
        [False, False, False]])


In [18]:
out = encoder_layer(src, src_mask=attn_mask)
print(out)

tensor([[[ 0.7793, -1.7167,  0.5236,  0.4138],
         [-0.9969,  1.1466, -0.9914,  0.8417],
         [-1.5465,  1.1288, -0.1426,  0.5603]]],
       grad_fn=<NativeLayerNormBackward0>)


# 1.3 torch.nn.TransformerEncoder

In [19]:
transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=2)

In [20]:
out = transformer_encoder(src)
print(out)

tensor([[[ 1.0504, -1.6393,  0.4283,  0.1606],
         [-0.7235,  0.7281, -1.2160,  1.2115],
         [-1.2552,  1.0471, -0.7041,  0.9122]]],
       grad_fn=<NativeLayerNormBackward0>)
