In [1]:
import torch
import torch.nn as nn

  from .autonotebook import tqdm as notebook_tqdm


a clean and readable version to understand the notation in transformer

https://github.com/pytorch/pytorch/blob/4bf90558e0cbafbf03fa7e4285367f12658bde54/torch/nn/modules/transformer.py#L296

# 1.1 torch.nn.MultiheadAttention

## a: core usage

~~~
forward(query, key, value, key_padding_mask=None, need_weights=True, attn_mask=None, average_attn_weights=True)
~~~

In [2]:
multihead_attn = torch.nn.MultiheadAttention(embed_dim=5, num_heads=1, batch_first=True)

In [3]:
# batch_size = 1
# sequence_length = 4
# embedding_size = 5

In [4]:
x = torch.randn((1,4,5))
print(x)

tensor([[[-0.2691, -0.0816,  0.9442,  0.9448, -0.1601],
         [ 0.3016, -0.2399,  0.7206, -0.4855, -1.1777],
         [-0.5336,  0.1804,  0.3119,  1.3854, -0.2336],
         [ 1.3802,  0.3300, -0.5571,  0.8665,  1.1283]]])


In [5]:
attn_output, attn_output_weights = multihead_attn(x,x,x)

In [6]:
print(attn_output)
print(attn_output.shape)

tensor([[[ 0.0664,  0.0354, -0.1694,  0.0943, -0.1532],
         [ 0.0416,  0.0203, -0.1636,  0.0986, -0.1422],
         [ 0.0560,  0.0305, -0.1760,  0.0884, -0.1525],
         [ 0.0432,  0.0319, -0.1664,  0.0943, -0.1445]]],
       grad_fn=<TransposeBackward0>)
torch.Size([1, 4, 5])


In [7]:
print(attn_output_weights)
print(attn_output_weights.shape)

tensor([[[0.2333, 0.2215, 0.2412, 0.3040],
         [0.2337, 0.3121, 0.2497, 0.2045],
         [0.2793, 0.2419, 0.2438, 0.2350],
         [0.3042, 0.2723, 0.2123, 0.2112]]], grad_fn=<DivBackward0>)
torch.Size([1, 4, 4])


In [8]:
attn_output_weights.bmm(x) #??

tensor([[[ 0.2949,  0.0717,  0.2857,  0.7105, -0.0115],
         [ 0.1802,  0.0186,  0.4095,  0.5924, -0.2326],
         [ 0.1921,  0.0407,  0.3831,  0.6878, -0.1213],
         [ 0.1785,  0.0179,  0.4319,  0.6323, -0.1806]]],
       grad_fn=<BmmBackward0>)

## b understand attn_mask in multihead_attention

* attn_mask – If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape (L, S)(L,S) or (N\cdot\text{num\_heads}, L, S)(N⋅num_heads,L,S), where NN is the batch size, LL is the target sequence length, and SS is the source sequence length. A 2D mask will be broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch. Binary, byte, and float masks are supported. 

https://github.com/pytorch/pytorch/blob/master/torch/nn/functional.py
~~~
def _scaled_dot_product_attention(
    q: Tensor,
    k: Tensor,
    v: Tensor,
    attn_mask: Optional[Tensor] = None,
    dropout_p: float = 0.0,
) -> Tuple[Tensor, Tensor]:
    r"""
    Computes scaled dot product attention on query, key and value tensors, using
    an optional attention mask if passed, and applying dropout if a probability
    greater than 0.0 is specified.
    Returns a tensor pair containing attended values and attention weights.
    Args:
        q, k, v: query, key and value tensors. See Shape section for shape details.
        attn_mask: optional tensor containing mask values to be added to calculated
            attention. May be 2D or 3D; see Shape section for details.
        dropout_p: dropout probability. If greater than 0.0, dropout is applied.
    Shape:
        - q: :math:`(B, Nt, E)` where B is batch size, Nt is the target sequence length,
            and E is embedding dimension.
        - key: :math:`(B, Ns, E)` where B is batch size, Ns is the source sequence length,
            and E is embedding dimension.
        - value: :math:`(B, Ns, E)` where B is batch size, Ns is the source sequence length,
            and E is embedding dimension.
        - attn_mask: either a 3D tensor of shape :math:`(B, Nt, Ns)` or a 2D tensor of
            shape :math:`(Nt, Ns)`.
        - Output: attention values have shape :math:`(B, Nt, E)`; attention weights
            have shape :math:`(B, Nt, Ns)`
    """
    B, Nt, E = q.shape
    q = q / math.sqrt(E)
    # (B, Nt, E) x (B, E, Ns) -> (B, Nt, Ns)
    if attn_mask is not None:
        attn = torch.baddbmm(attn_mask, q, k.transpose(-2, -1))
    else:
        attn = torch.bmm(q, k.transpose(-2, -1))

    attn = softmax(attn, dim=-1)
    if dropout_p > 0.0:
        attn = dropout(attn, p=dropout_p)
    # (B, Nt, Ns) x (B, Ns, E) -> (B, Nt, E)
    output = torch.bmm(attn, v)
    return output, attn
~~~

In [9]:
def generate_subsequent_mask(tgt_sz, src_sz): 
    mask = (torch.triu(torch.ones(src_sz, tgt_sz)) == 1).transpose(0, 1) 
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) 
    return mask

In [10]:
attn_mask = generate_subsequent_mask(4,4)
print(attn_mask)
attn_mask = attn_mask.bool()
print(attn_mask)
#attn_mask = torch.randint(0,2,[3,3]).bool()
#print(attn_mask)

tensor([[0., -inf, -inf, -inf],
        [0., 0., -inf, -inf],
        [0., 0., 0., -inf],
        [0., 0., 0., 0.]])
tensor([[False,  True,  True,  True],
        [False, False,  True,  True],
        [False, False, False,  True],
        [False, False, False, False]])


In [11]:
attn_output, attn_output_weights = multihead_attn(x,x,x, attn_mask=attn_mask)
print(attn_output)
print(attn_output_weights)

tensor([[[ 0.0339,  0.0954, -0.2361,  0.0237, -0.1749],
         [-0.0650,  0.0281, -0.0852,  0.1471, -0.0764],
         [ 0.0305,  0.0017, -0.2087,  0.0633, -0.1567],
         [ 0.0432,  0.0319, -0.1664,  0.0943, -0.1445]]],
       grad_fn=<TransposeBackward0>)
tensor([[[1.0000, 0.0000, 0.0000, 0.0000],
         [0.4281, 0.5719, 0.0000, 0.0000],
         [0.3651, 0.3162, 0.3187, 0.0000],
         [0.3042, 0.2723, 0.2123, 0.2112]]], grad_fn=<DivBackward0>)


In [12]:
#attn_output_weights.bmm(x)

## c: understand key_padding_mask in multihead_attention

* key_padding_mask – If specified, a mask of shape (N, S)(N,S) indicating which elements within key to ignore for the purpose of attention (i.e. treat as “padding”). 

batch_size = 3，seq_length_ =4，token looks like
~~~
[
    [‘a’,'b','c','<PAD>'],
    [‘a’,'b','c','d'],
    [‘a’,'b','<PAD>','<PAD>']
]

~~~

key_padding_mask.shape = （3,4）

~~~
padding_mask = torch.tensor([
    [False, False, False, True],
    [False, False, False, False],
    [False, False, True, True]
])
print(padding_mask)
~~~

In [13]:
padding_mask = torch.tensor([
    [False, False, True,True]])

In [14]:
padding_mask.shape

torch.Size([1, 4])

In [15]:
attn_output, attn_output_weights = multihead_attn(x,x,x, key_padding_mask=padding_mask)
print(attn_output)
print(attn_output_weights)

tensor([[[-0.0503,  0.0381, -0.1076,  0.1288, -0.0910],
         [-0.0650,  0.0281, -0.0852,  0.1471, -0.0764],
         [-0.0464,  0.0408, -0.1137,  0.1238, -0.0950],
         [-0.0478,  0.0398, -0.1115,  0.1256, -0.0936]]],
       grad_fn=<TransposeBackward0>)
tensor([[[0.5129, 0.4871, 0.0000, 0.0000],
         [0.4281, 0.5719, 0.0000, 0.0000],
         [0.5359, 0.4641, 0.0000, 0.0000],
         [0.5276, 0.4724, 0.0000, 0.0000]]], grad_fn=<DivBackward0>)


# d: mix use of attn_mask and padding mask

In [16]:
attn_output, attn_output_weights = multihead_attn(x,x,x, attn_mask=attn_mask, key_padding_mask=padding_mask)
print(attn_output)
print(attn_output_weights)

tensor([[[ 0.0339,  0.0954, -0.2361,  0.0237, -0.1749],
         [-0.0650,  0.0281, -0.0852,  0.1471, -0.0764],
         [-0.0464,  0.0408, -0.1137,  0.1238, -0.0950],
         [-0.0478,  0.0398, -0.1115,  0.1256, -0.0936]]],
       grad_fn=<TransposeBackward0>)
tensor([[[1.0000, 0.0000, 0.0000, 0.0000],
         [0.4281, 0.5719, 0.0000, 0.0000],
         [0.5359, 0.4641, 0.0000, 0.0000],
         [0.5276, 0.4724, 0.0000, 0.0000]]], grad_fn=<DivBackward0>)


# 1.2 nn.TransformerEncoderLayer

~~~
class TransformerEncoderLayer(Module):
    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", layer_norm_eps=1e-5):
        super(TransformerEncoderLayer, self).__init__()
        self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
        # Implementation of Feedforward model
        self.linear1 = Linear(d_model, dim_feedforward)
        self.dropout = Dropout(dropout)
        self.linear2 = Linear(dim_feedforward, d_model)
        self.activation = _get_activation_fn(activation)
        
    def forward(self, src: Tensor, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None) -> Tensor:
        r"""Pass the input through the encoder layer.

        Args:
            src: the sequence to the encoder layer (required).
            src_mask: the mask for the src sequence (optional).
            src_key_padding_mask: the mask for the src keys per batch (optional).
        """
        src2 = self.self_attn(src, src, src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]
        src = src + self.dropout1(src2)
        src = self.norm1(src)
        src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
        src = src + self.dropout2(src2)
        src = self.norm2(src)
        return src
~~~

Based on the PyTorch implementation source code **src_mask** is what is called **attn_mask** in a MultiheadAttention module and **src_key_padding_mask** is equivalent to **key_padding_mask** in a MultiheadAttention module.

In [17]:
encoder_layer = nn.TransformerEncoderLayer(d_model=5, nhead=1, batch_first=True)

~~~
forward(src, src_mask=None, src_key_padding_mask=None, is_causal=False)
~~~

In [18]:
src = torch.randn(1,4,5)
print(src)

tensor([[[-1.4174,  0.3715,  0.8087,  0.5961, -1.0753],
         [ 0.2875,  0.9670,  0.2216,  1.1212,  0.2438],
         [-0.9028, -1.3876,  0.4233, -1.0591, -0.5709],
         [ 0.3651,  1.9209,  0.3671, -1.0275, -0.5568]]])


In [19]:
out = encoder_layer(src)
print(out)
print(out.shape)

tensor([[[-1.2363,  0.9163,  0.5926,  0.9229, -1.1955],
         [-0.4635,  1.1635, -0.9957,  1.2424, -0.9467],
         [-0.3861, -1.3372,  1.7499, -0.0021, -0.0246],
         [-0.1120,  1.8660, -0.0859, -1.0653, -0.6028]]],
       grad_fn=<NativeLayerNormBackward0>)
torch.Size([1, 4, 5])


In [20]:
out = encoder_layer(src, src_mask=attn_mask, src_key_padding_mask=padding_mask)
print(out)

tensor([[[-1.1862,  0.8848,  0.5392,  0.9972, -1.2349],
         [-0.3258,  1.0354, -1.1770,  1.3124, -0.8450],
         [-0.0124, -1.5519,  1.5700, -0.2541,  0.2485],
         [ 0.1095,  1.8375, -0.3342, -1.1221, -0.4907]]],
       grad_fn=<NativeLayerNormBackward0>)


# 1.3 torch.nn.TransformerEncoder

~~~
class TransformerEncoder(Module):
    """
    Args:
        encoder_layer: an instance of the TransformerEncoderLayer() class (required).
        num_layers: the number of sub-encoder-layers in the encoder (required).
        norm: the layer normalization component (optional).
    """

    def __init__(self, encoder_layer, num_layers, norm=None):
        super(TransformerEncoder, self).__init__()
        self.layers = _get_clones(encoder_layer, num_layers)
        self.num_layers = num_layers
        self.norm = norm

    def forward(self, src: Tensor, mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None) -> Tensor:
        r"""Pass the input through the encoder layers in turn.

        Args:
            src: the sequence to the encoder (required).
            mask: the mask for the src sequence (optional).
            src_key_padding_mask: the mask for the src keys per batch (optional).

        Shape:
            see the docs in Transformer class.
        """
        output = src

        for mod in self.layers:
            output = mod(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask)

        if self.norm is not None:
            output = self.norm(output)

        return output

~~~

In [21]:
transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=2)

~~~
forward(src, mask=None, src_key_padding_mask=None, is_causal=None)
~~~

In [22]:
out = transformer_encoder.forward(src)
print(out)

tensor([[[-0.8970,  1.1202,  0.1338,  1.0171, -1.3741],
         [-0.3626,  1.1703, -1.1551,  1.1995, -0.8520],
         [ 0.2637, -1.7185,  1.3872,  0.1926, -0.1250],
         [-0.1441,  1.9542, -0.4394, -0.5676, -0.8032]]],
       grad_fn=<NativeLayerNormBackward0>)


In [23]:
output = transformer_encoder.forward(src, mask=attn_mask, src_key_padding_mask=padding_mask)
print(output)

tensor([[[-0.5272,  1.0591,  0.0050,  1.0461, -1.5830],
         [ 0.4179,  0.6345, -1.7489,  1.0971, -0.4006],
         [ 0.6252, -1.7875,  1.1223, -0.2571,  0.2970],
         [ 0.7973,  1.4865, -1.0211, -1.0280, -0.2348]]],
       grad_fn=<NativeLayerNormBackward0>)


## 1.4 torch.nn.TransformerDecoderLayer

![](https://pytorch.org/tutorials/_images/seq2seq.png)

~~~
class TransformerDecoderLayer(Module):
    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", layer_norm_eps=1e-5):
        super(TransformerDecoderLayer, self).__init__()
        self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
        self.multihead_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
    
    def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None,
                tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None) -> Tensor:
        r"""Pass the inputs (and mask) through the decoder layer.

        Args:
            tgt: the sequence to the decoder layer (required).
            memory: the sequence from the last layer of the encoder (required).
            tgt_mask: the mask for the tgt sequence (optional).
            memory_mask: the mask for the memory sequence (optional).
            tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
            memory_key_padding_mask: the mask for the memory keys per batch (optional).
        """
        tgt2 = self.self_attn(tgt, tgt, tgt, attn_mask=tgt_mask,
                              key_padding_mask=tgt_key_padding_mask)[0]
        tgt = tgt + self.dropout1(tgt2)
        tgt = self.norm1(tgt)
        tgt2 = self.multihead_attn(tgt, memory, memory, attn_mask=memory_mask,
                                   key_padding_mask=memory_key_padding_mask)[0]
        tgt = tgt + self.dropout2(tgt2)
        tgt = self.norm2(tgt)
        tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
        tgt = tgt + self.dropout3(tgt2)
        tgt = self.norm3(tgt)
        return tgt
~~~

In [24]:
decoder_layer = nn.TransformerDecoderLayer(d_model=5, nhead=1, batch_first=True)

In [25]:
# batch_size = 1
# sequence_length = 4
# embedding_size = 5
memory = torch.rand(1, 4, 5)
print(memory)
print(memory.shape)

tensor([[[0.5490, 0.0093, 0.4613, 0.8061, 0.5491],
         [0.7317, 0.2648, 0.1612, 0.0341, 0.6122],
         [0.2431, 0.7019, 0.5981, 0.7685, 0.7452],
         [0.6860, 0.6298, 0.2847, 0.3592, 0.7690]]])
torch.Size([1, 4, 5])


In [26]:
# batch_size = 1
# sequence_length = 6
# embedding_size = 5
tgt = torch.rand(1, 6, 5)
print(tgt)
print(tgt.shape)

tensor([[[0.7859, 0.2923, 0.8150, 0.3474, 0.6866],
         [0.7447, 0.6915, 0.9891, 0.1155, 0.9180],
         [0.7315, 0.1757, 0.1024, 0.7563, 0.4602],
         [0.3719, 0.6723, 0.1913, 0.7401, 0.6185],
         [0.7635, 0.4613, 0.5866, 0.3904, 0.1410],
         [0.0569, 0.6474, 0.7203, 0.5985, 0.1828]]])
torch.Size([1, 6, 5])


In [27]:
out = decoder_layer.forward(tgt, memory)
print(out)
print(out.shape)

tensor([[[ 0.6976, -1.5439,  0.7686, -0.8375,  0.9152],
         [-0.0777, -0.1555,  0.1359, -1.5240,  1.6213],
         [ 0.0332, -0.9220, -1.1599,  1.6181,  0.4306],
         [-0.6357,  0.4516, -1.6471,  0.8808,  0.9504],
         [ 1.8324,  0.2090, -0.4004, -0.6065, -1.0346],
         [-1.0780,  1.2596,  0.6040,  0.4954, -1.2810]]],
       grad_fn=<NativeLayerNormBackward0>)
torch.Size([1, 6, 5])


### A tgt_mask and tgt_key_padding_mask

the role of tgt_mask and tgt_key_padding_mask is the same as that of in the encoder

In [28]:
tgt_mask = generate_subsequent_mask(6,6)
print(tgt_mask.shape)
print(tgt_mask)

torch.Size([6, 6])
tensor([[0., -inf, -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf, -inf],
        [0., 0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0., 0.]])


assume the last three token is <padding>

In [29]:
tgt_padding_mask = torch.tensor([
    [False, False, False, True,True,True]])
print(tgt_padding_mask.shape)
print(tgt_padding_mask)

torch.Size([1, 6])
tensor([[False, False, False,  True,  True,  True]])


In [30]:
out = decoder_layer.forward(tgt, memory,tgt_mask=tgt_mask, tgt_key_padding_mask=tgt_padding_mask)
print(out)
print(out.shape)

tensor([[[ 0.4119, -1.5630,  0.3293, -0.5742,  1.3961],
         [ 0.0660,  0.0369,  0.0420, -1.6507,  1.5058],
         [ 0.5523, -1.0356, -1.2190,  1.4375,  0.2648],
         [-0.7665, -0.0370, -1.4042,  1.1377,  1.0700],
         [ 1.8948,  0.0544, -0.3442, -0.7892, -0.8157],
         [-1.2370,  0.9065,  0.5855,  0.9385, -1.1935]]],
       grad_fn=<NativeLayerNormBackward0>)
torch.Size([1, 6, 5])


### B **pay attention to tye role of memory_mask and memory_key_padding_mask in the model**

~~~
self.multihead_attn(tgt, memory, memory, attn_mask=memory_mask, key_padding_mask=memory_key_padding_mask)
~~~

in the context of multi-head attention, the role of memory_mask and memory_key_padding_mask is to mask some position in the attention_weight to not involve in calculating the output 

In [31]:
multihead_attn = torch.nn.MultiheadAttention(embed_dim=5, num_heads=1, batch_first=True)

In [32]:
multihead_attn(tgt, memory, memory)

(tensor([[[-0.0215, -0.2030, -0.0060, -0.1130,  0.0062],
          [-0.0240, -0.2012, -0.0025, -0.1108,  0.0066],
          [-0.0192, -0.2047, -0.0097, -0.1152,  0.0058],
          [-0.0222, -0.2025, -0.0055, -0.1125,  0.0062],
          [-0.0211, -0.2035, -0.0079, -0.1137,  0.0058],
          [-0.0225, -0.2023, -0.0057, -0.1123,  0.0060]]],
        grad_fn=<TransposeBackward0>),
 tensor([[[0.2664, 0.2103, 0.2888, 0.2344],
          [0.2749, 0.2026, 0.2936, 0.2289],
          [0.2608, 0.2196, 0.2799, 0.2397],
          [0.2707, 0.2113, 0.2858, 0.2322],
          [0.2708, 0.2182, 0.2752, 0.2358],
          [0.2739, 0.2146, 0.2813, 0.2302]]], grad_fn=<DivBackward0>))

In [33]:
padding_mask = torch.tensor([
    [False, False, True,True]])
print(padding_mask.shape)

torch.Size([1, 4])


In [34]:
memory_mask = generate_subsequent_mask(6,4)
print(memory_mask)
print(memory_mask.shape)

tensor([[0., -inf, -inf, -inf],
        [0., 0., -inf, -inf],
        [0., 0., 0., -inf],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]])
torch.Size([6, 4])


In [35]:
multihead_attn(tgt, memory, memory, attn_mask=memory_mask, key_padding_mask=padding_mask)

(tensor([[[-0.1731, -0.1094,  0.1429,  0.0015,  0.0132],
          [-0.0675, -0.1704,  0.0132, -0.0815, -0.0047],
          [-0.0593, -0.1752,  0.0031, -0.0879, -0.0061],
          [-0.0640, -0.1725,  0.0089, -0.0842, -0.0053],
          [-0.0620, -0.1736,  0.0065, -0.0858, -0.0056],
          [-0.0638, -0.1726,  0.0086, -0.0844, -0.0054]]],
        grad_fn=<TransposeBackward0>),
 tensor([[[1.0000, 0.0000, 0.0000, 0.0000],
          [0.5756, 0.4244, 0.0000, 0.0000],
          [0.5429, 0.4571, 0.0000, 0.0000],
          [0.5617, 0.4383, 0.0000, 0.0000],
          [0.5537, 0.4463, 0.0000, 0.0000],
          [0.5608, 0.4392, 0.0000, 0.0000]]], grad_fn=<DivBackward0>))

the memory_mask is always be set to None in practice

In [36]:
out = decoder_layer.forward(tgt, memory,tgt_mask=tgt_mask, tgt_key_padding_mask=tgt_padding_mask, memory_mask=memory_mask, memory_key_padding_mask=padding_mask)
print(out)
print(out.shape)

tensor([[[ 0.3237, -1.3233,  0.5419, -0.9428,  1.4005],
         [ 0.2189,  0.1617,  0.3480, -1.8708,  1.1423],
         [ 0.6429, -0.9477, -1.4305,  1.1287,  0.6066],
         [-0.5365,  0.1947, -1.6440,  0.9944,  0.9914],
         [ 1.9187, -0.0513, -0.5537, -0.3866, -0.9272],
         [-1.5761,  0.8436,  0.7901,  0.7387, -0.7964]]],
       grad_fn=<NativeLayerNormBackward0>)
torch.Size([1, 6, 5])


## 1.5 torch.nn.TransformerDecoder

~~~
class TransformerDecoder(Module):
    def __init__(self, decoder_layer, num_layers, norm=None):
        super(TransformerDecoder, self).__init__()
        self.layers = _get_clones(decoder_layer, num_layers)
        self.num_layers = num_layers
        self.norm = norm

    def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None,
                memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None,
                memory_key_padding_mask: Optional[Tensor] = None) -> Tensor:
        r"""Pass the inputs (and mask) through the decoder layer in turn.

        Args:
            tgt: the sequence to the decoder (required).
            memory: the sequence from the last layer of the encoder (required).
            tgt_mask: the mask for the tgt sequence (optional).
            memory_mask: the mask for the memory sequence (optional).
            tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
            memory_key_padding_mask: the mask for the memory keys per batch (optional).

        Shape:
            see the docs in Transformer class.
        """
        output = tgt

        for mod in self.layers:
            output = mod(output, memory, tgt_mask=tgt_mask,
                         memory_mask=memory_mask,
                         tgt_key_padding_mask=tgt_key_padding_mask,
                         memory_key_padding_mask=memory_key_padding_mask)

        if self.norm is not None:
            output = self.norm(output)

        return output
~~~

## 1.6 nn.Transformer

~~~
class Transformer(Module):
    def __init__(self, d_model: int = 512, nhead: int = 8, num_encoder_layers: int = 6,
                 num_decoder_layers: int = 6, dim_feedforward: int = 2048, dropout: float = 0.1,
                 activation: str = "relu", custom_encoder: Optional[Any] = None, custom_decoder: Optional[Any] = None,
                 layer_norm_eps: float = 1e-5) -> None:
        super(Transformer, self).__init__()

        if custom_encoder is not None:
            self.encoder = custom_encoder
        else:
            encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, activation, layer_norm_eps)
            encoder_norm = LayerNorm(d_model, eps=layer_norm_eps)
            self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)

        if custom_decoder is not None:
            self.decoder = custom_decoder
        else:
            decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout, activation, layer_norm_eps)
            decoder_norm = LayerNorm(d_model, eps=layer_norm_eps)
            self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm)

        self._reset_parameters()

        self.d_model = d_model
        self.nhead = nhead
        
    def forward(self, src: Tensor, tgt: Tensor, src_mask: Optional[Tensor] = None, tgt_mask: Optional[Tensor] = None,
                memory_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None,
                tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None) -> Tensor:
        r"""Take in and process masked source/target sequences.

        Args:
            src: the sequence to the encoder (required).
            tgt: the sequence to the decoder (required).
            src_mask: the additive mask for the src sequence (optional).
            tgt_mask: the additive mask for the tgt sequence (optional).
            memory_mask: the additive mask for the encoder output (optional).
            src_key_padding_mask: the ByteTensor mask for src keys per batch (optional).
            tgt_key_padding_mask: the ByteTensor mask for tgt keys per batch (optional).
            memory_key_padding_mask: the ByteTensor mask for memory keys per batch (optional).

        Shape:
            - src: :math:`(S, N, E)`.
            - tgt: :math:`(T, N, E)`.
            - src_mask: :math:`(S, S)`.
            - tgt_mask: :math:`(T, T)`.
            - memory_mask: :math:`(T, S)`.
            - src_key_padding_mask: :math:`(N, S)`.
            - tgt_key_padding_mask: :math:`(N, T)`.
            - memory_key_padding_mask: :math:`(N, S)`.

            Note: [src/tgt/memory]_mask ensures that position i is allowed to attend the unmasked
            positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
            while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
            are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
            is provided, it will be added to the attention weight.
            [src/tgt/memory]_key_padding_mask provides specified elements in the key to be ignored by
            the attention. If a ByteTensor is provided, the non-zero positions will be ignored while the zero
            positions will be unchanged. If a BoolTensor is provided, the positions with the
            value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.

            - output: :math:`(T, N, E)`.

            Note: Due to the multi-head attention architecture in the transformer model,
            the output sequence length of a transformer is same as the input sequence
            (i.e. target) length of the decode.

            where S is the source sequence length, T is the target sequence length, N is the
            batch size, E is the feature number

        Examples:
            >>> output = transformer_model(src, tgt, src_mask=src_mask, tgt_mask=tgt_mask)
        """

        if src.size(1) != tgt.size(1):
            raise RuntimeError("the batch number of src and tgt must be equal")

        if src.size(2) != self.d_model or tgt.size(2) != self.d_model:
            raise RuntimeError("the feature number of src and tgt must be equal to d_model")

        memory = self.encoder(src, mask=src_mask, src_key_padding_mask=src_key_padding_mask)
        output = self.decoder(tgt, memory, tgt_mask=tgt_mask, memory_mask=memory_mask,
                              tgt_key_padding_mask=tgt_key_padding_mask,
                              memory_key_padding_mask=memory_key_padding_mask)
        return output

~~~

In [37]:
transformer_model = nn.Transformer(d_model=5, nhead=1, num_encoder_layers=1,batch_first=True)

In [38]:
output = transformer_model.forward(src =src, tgt=tgt)
print(output)
print(output.shape)

tensor([[[-0.2965,  0.7936,  0.8327,  0.4980, -1.8278],
         [-1.2585,  0.3553,  1.5216,  0.3167, -0.9351],
         [-0.3392,  0.7469,  1.0419,  0.3218, -1.7714],
         [-1.0772,  0.3735,  1.2926,  0.6688, -1.2577],
         [-0.7090,  0.6794,  1.1446,  0.4681, -1.5831],
         [-1.3177,  0.5440,  1.1414,  0.7098, -1.0775]]],
       grad_fn=<NativeLayerNormBackward0>)
torch.Size([1, 6, 5])


In [44]:
src_mask = generate_subsequent_mask(4,4)
tgt_mask = generate_subsequent_mask(6,6)
memory_mask = generate_subsequent_mask(6,4)

In [40]:
src_key_padding_mask = torch.tensor([[False, False, True,True]])
print(src_key_padding_mask)
print(src_key_padding_mask.shape)

tensor([[False, False,  True,  True]])
torch.Size([1, 4])


In [41]:
tgt_key_padding_mask = torch.tensor([
    [False, False, False, True,True,True]])
print(tgt_key_padding_mask.shape)
print(tgt_key_padding_mask)

torch.Size([1, 6])
tensor([[False, False, False,  True,  True,  True]])


In [42]:
memory_key_padding_mask = src_key_padding_mask
print(memory_key_padding_mask)
print(memory_key_padding_mask.shape)

tensor([[False, False,  True,  True]])
torch.Size([1, 4])


In [45]:
output = transformer_model.forward(src =src, 
                                   tgt=tgt,
                                   src_mask=src_mask,
                                   tgt_mask=tgt_mask, 
                                   memory_mask=memory_mask,
                                   src_key_padding_mask=src_key_padding_mask,
                                  tgt_key_padding_mask=tgt_key_padding_mask,
                                  memory_key_padding_mask=memory_key_padding_mask)
print(output)
print(output.shape)

tensor([[[-1.2855,  0.4926,  1.5218,  0.1471, -0.8760],
         [-1.0737,  0.8774,  1.4166, -0.2065, -1.0139],
         [-1.1779,  0.2661,  1.6428,  0.1709, -0.9019],
         [-1.1278,  0.4578,  1.5637,  0.1336, -1.0273],
         [-0.6210,  0.1721,  1.7180,  0.0088, -1.2779],
         [-0.2900, -0.3071,  1.9049, -0.2431, -1.0648]]],
       grad_fn=<NativeLayerNormBackward0>)
torch.Size([1, 6, 5])


## appendix: a case study to show how to use mask and padding in translation in NLP

~~~
from torch import Tensor
import torch
import torch.nn as nn
from torch.nn import Transformer
import math
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# helper Module that adds positional encoding to the token embedding to introduce a notion of word order.
class PositionalEncoding(nn.Module):
    def __init__(self,
                 emb_size: int,
                 dropout: float,
                 maxlen: int = 5000):
        super(PositionalEncoding, self).__init__()
        den = torch.exp(- torch.arange(0, emb_size, 2)* math.log(10000) / emb_size)
        pos = torch.arange(0, maxlen).reshape(maxlen, 1)
        pos_embedding = torch.zeros((maxlen, emb_size))
        pos_embedding[:, 0::2] = torch.sin(pos * den)
        pos_embedding[:, 1::2] = torch.cos(pos * den)
        pos_embedding = pos_embedding.unsqueeze(-2)

        self.dropout = nn.Dropout(dropout)
        self.register_buffer('pos_embedding', pos_embedding)

    def forward(self, token_embedding: Tensor):
        return self.dropout(token_embedding + self.pos_embedding[:token_embedding.size(0), :])

# helper Module to convert tensor of input indices into corresponding tensor of token embeddings
class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size: int, emb_size):
        super(TokenEmbedding, self).__init__()
        self.embedding = nn.Embedding(vocab_size, emb_size)
        self.emb_size = emb_size

    def forward(self, tokens: Tensor):
        return self.embedding(tokens.long()) * math.sqrt(self.emb_size)

# Seq2Seq Network
class Seq2SeqTransformer(nn.Module):
    def __init__(self,
                 num_encoder_layers: int,
                 num_decoder_layers: int,
                 emb_size: int,
                 nhead: int,
                 src_vocab_size: int,
                 tgt_vocab_size: int,
                 dim_feedforward: int = 512,
                 dropout: float = 0.1):
        super(Seq2SeqTransformer, self).__init__()
        self.transformer = Transformer(d_model=emb_size,
                                       nhead=nhead,
                                       num_encoder_layers=num_encoder_layers,
                                       num_decoder_layers=num_decoder_layers,
                                       dim_feedforward=dim_feedforward,
                                       dropout=dropout)
        self.generator = nn.Linear(emb_size, tgt_vocab_size)
        self.src_tok_emb = TokenEmbedding(src_vocab_size, emb_size)
        self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, emb_size)
        self.positional_encoding = PositionalEncoding(
            emb_size, dropout=dropout)

    def forward(self,
                src: Tensor,
                trg: Tensor,
                src_mask: Tensor,
                tgt_mask: Tensor,
                src_padding_mask: Tensor,
                tgt_padding_mask: Tensor,
                memory_key_padding_mask: Tensor):
        src_emb = self.positional_encoding(self.src_tok_emb(src))
        tgt_emb = self.positional_encoding(self.tgt_tok_emb(trg))
        outs = self.transformer(src_emb, tgt_emb, src_mask, tgt_mask, None,
                                src_padding_mask, tgt_padding_mask, memory_key_padding_mask)
        return self.generator(outs)

    def encode(self, src: Tensor, src_mask: Tensor):
        return self.transformer.encoder(self.positional_encoding(
                            self.src_tok_emb(src)), src_mask)

    def decode(self, tgt: Tensor, memory: Tensor, tgt_mask: Tensor):
        return self.transformer.decoder(self.positional_encoding(
                          self.tgt_tok_emb(tgt)), memory,
                          tgt_mask)
~~~

how to generate the mask and padding_mask matrix in this context

~~~
def generate_square_subsequent_mask(sz):
    mask = (torch.triu(torch.ones((sz, sz), device=DEVICE)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask


def create_mask(src, tgt):
    src_seq_len = src.shape[0]
    tgt_seq_len = tgt.shape[0]

    tgt_mask = generate_square_subsequent_mask(tgt_seq_len)
    src_mask = torch.zeros((src_seq_len, src_seq_len),device=DEVICE).type(torch.bool)

    src_padding_mask = (src == PAD_IDX).transpose(0, 1)
    tgt_padding_mask = (tgt == PAD_IDX).transpose(0, 1)
    return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask
~~~

1. pay attention to the memory_mask and memory_padding_mask in the train round
2. pay attention to target_input and tgt_output in this context

~~~
def train_epoch(model, optimizer):
    model.train()
    losses = 0
    train_iter = Multi30k(split='train', language_pair=(SRC_LANGUAGE, TGT_LANGUAGE))
    train_dataloader = DataLoader(train_iter, batch_size=BATCH_SIZE, collate_fn=collate_fn)

    for src, tgt in train_dataloader:
        src = src.to(DEVICE)
        tgt = tgt.to(DEVICE)

        tgt_input = tgt[:-1, :]

        src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)

        logits = model(src, tgt_input, src_mask, tgt_mask,src_padding_mask, tgt_padding_mask, src_padding_mask)

        optimizer.zero_grad()

        tgt_out = tgt[1:, :]
        loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
        loss.backward()

        optimizer.step()
        losses += loss.item()

    return losses / len(list(train_dataloader))
~~~

### answer

int the context of translation task
1. the memory_mask = None
2. memory_key_padding_mask = src_padding_mask