In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
from torch import nn 
import torch.nn.functional as F
import math

# Embedding

## Token Embedding

In [225]:
class TokenEmbedding(nn.Embedding):
    def __init__(self, vocab_len, d_model):
        # vocab_len 是有多少个词，d_model 是词嵌入的维度
        super(TokenEmbedding, self).__init__(vocab_len, d_model, padding_idx=1)

## Position Embedding
![title](image/positional_encoding.jpg)

In [226]:
class PositionEmbedding(nn.Module):
    def __init__(self, d_model, max_len=10000):
        # max_len 是词的最大长度
        super(PositionEmbedding, self).__init__()
        self.encoding = torch.zeros(max_len, d_model)
        self.encoding.requires_grad = False  # 不进行梯度更新

        pos = torch.arange(0, max_len).unsqueeze_(1).float()
        _2i = torch.arange(0, d_model, 2).float()   # 上标上的2i

        self.encoding[:, 0::2] = torch.sin(pos / 10000 ** (_2i / d_model))
        self.encoding[:, 1::2] = torch.cos(pos / 10000 ** (_2i / d_model))

    def forward(self, x):
        seq_len = x.shape[1]
        return self.encoding[:seq_len, :]

## Total Embedding

In [227]:
# 整合一下
class TransformerEmbedding(nn.Module):
    def __init__(self, vocab_size, d_model, max_len, drop_prob):
        super(TransformerEmbedding, self).__init__()
        self.tok_emb = TokenEmbedding(vocab_size, d_model)
        self.pos_emb = PositionEmbedding(d_model, max_len)
        self.dropout = nn.Dropout(p = drop_prob)
    def forward(self, x):
        tok_emb = self.tok_emb(x)
        pos_emb = self.pos_emb(x)
        # print("token shape:", tok_emb.shape)
        # print("pos shape:", pos_emb.shape)
        return self.dropout(tok_emb + pos_emb)

# Layer Norm
![layer](image/layer_norm.jpg)

In [228]:
class LayerNorm(nn.Module):
    def __init__(self, d_model, eps=1e-10):
        super(LayerNorm, self).__init__()
        self.d_model = d_model
        self.eps = eps
        self.gamma = nn.Parameter(torch.ones(self.d_model)) # nn.Parameter 让参数可学习
        self.beta = nn.Parameter(torch.zeros(self.d_model))

    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        var = x.var(-1, keepdim=True)
        norm = (x - mean) / torch.sqrt(var + self.eps)
        out = self.gamma * norm + self.beta
        return out

# FFN
![layer](image/positionwise_feed_forward.jpg)


In [229]:
class FeedForward(nn.Module):
    def __init__(self, d_model, d_ffn, drop_prob=0.1):
        super(FeedForward, self).__init__()
        self.fc1 = nn.Linear(d_model, d_ffn)
        self.fc2 = nn.Linear(d_ffn, d_model)
        self.dropout = nn.Dropout(drop_prob)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.dropout(x) # dropout 在前是为了减少神经元的连接，不然做完fc2连接完了再dropout就没有意义了
        x = self.fc2(x)
        return x

# Multihead Attention

In [230]:
class MultiheadAttention(nn.Module):
    def __init__(self, d_model, n_head):
        super(MultiheadAttention, self).__init__()
        
        assert d_model % n_head == 0, "d_model must be divisible by n_head"
        
        self.n_head = n_head
        self.d_model = d_model
        self.n_d = d_model // n_head

        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)
        self.w_combine = nn.Linear(d_model, d_model)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, q, k, v, mask=None):
        batch_size, token_size, d_model = q.shape
        
        # Linear projections
        q = self.w_q(q).view(batch_size, -1, self.n_head, self.n_d)
        k = self.w_k(k).view(batch_size, -1, self.n_head, self.n_d)
        v = self.w_v(v).view(batch_size, -1, self.n_head, self.n_d)
        # shape is (batch_size, token_size, d_model)
        # print("q shape", q.shape)
        
        # Transpose for attention dot product
        q = q.transpose(1, 2)  # (batch_size, n_head, seq_len, n_d)
        k = k.transpose(1, 2)  # (batch_size, n_head, seq_len, n_d)
        v = v.transpose(1, 2)  # (batch_size, n_head, seq_len, n_d)
        # (batch_size, n_head, token_size, n_d)
        # print("q shape", q.shape)

        score = q @ k.permute(0, 1, 3, 2) / math.sqrt(self.n_d)
        # (batch_size, n_head, token_size, token_size)
        # print("score shape", score.shape)
        if mask is not None:
            # mask = torch.tril(torch.ones(token_size, token_size, dtype=bool))
            score = score.masked_fill(mask == 0, float('-inf'))
        print("after mask score shape", score.shape, "v shape", v.shape)
        score = self.softmax(score) @ v
        # (batch_size, n_head, token_size, n_d)
        # print("after softmax score shape", score.shape)
        score = score.permute(0, 2, 1, 3).contiguous().view(batch_size, token_size, d_model)
        out = self.w_combine(score)
        return out

In [231]:
# test multiheadattention
x = torch.randn((2, 3, 20))

MH = MultiheadAttention(20, 4)
y = MH(x, x, x)
print(y.shape)

after mask score shape torch.Size([2, 4, 3, 3]) v shape torch.Size([2, 4, 3, 5])
torch.Size([2, 3, 20])


# Encoder Layer
![enc-dec](image/enc_dec.jpg)

In [232]:
class EncodingLayer(nn.Module):
    def __init__(self, d_model, ffn_hidden, n_head, drop_prob=0.1):
        super(EncodingLayer, self).__init__()
        self.attention = MultiheadAttention(d_model, n_head)
        self.norm1 = LayerNorm(d_model)
        self.drop1 = nn.Dropout(drop_prob)

        self.ffn = FeedForward(d_model, ffn_hidden, drop_prob)
        self.norm2 = LayerNorm(d_model)
        self.drop2 = nn.Dropout(drop_prob)

    def forward(self, x, mask=None):
        _x = x  # resnet 提前保存一下
        x = self.attention(x, x, x, mask)
        x = self.drop1(x)
        x = self.norm1(x + _x)

        _x = x
        x = self.ffn(x)
        x = self.drop2(x)
        x = self.norm2(x + _x)
        return x

In [233]:
x = torch.randn((2, 3, 20))
EL = EncodingLayer(20, 10, 4, 0.2)
y = EL(x)
print(y)

after mask score shape torch.Size([2, 4, 3, 3]) v shape torch.Size([2, 4, 3, 5])
tensor([[[-1.4251, -0.0715,  0.2319, -0.9279,  0.7561,  0.5588, -1.2054,
           2.0926,  0.2089, -0.3476,  1.5660, -0.0668, -0.3974,  0.2857,
           0.6410, -1.9899,  0.2132, -0.5885, -0.6619,  1.1278],
         [ 0.7191,  0.4261, -0.2634,  0.8636, -0.3541,  0.2164, -0.9546,
           1.1670,  0.2811, -0.3747,  0.5911, -0.6887,  0.5172,  0.6444,
           0.0257, -3.2548,  0.2132, -0.2142, -0.9057,  1.3450],
         [ 0.4652,  0.5323, -1.5483, -0.7145, -0.9961,  0.9341, -0.8949,
           1.1572, -0.6694,  0.2622,  2.3892,  0.0673, -0.0305,  0.2442,
          -1.7625,  0.1880,  0.0925, -0.0085, -0.8794,  1.1721]],

        [[ 0.4705,  0.6534, -0.4684, -0.3649, -1.1773,  0.7850, -0.7311,
           0.6769, -0.5894,  0.8563, -0.4199,  0.7542, -0.5458,  1.1924,
           1.0694, -2.2038, -0.8464,  0.1349, -1.1128,  1.8669],
         [-2.2884,  0.3297,  0.3089, -0.3971, -0.3399,  0.1480, -1.8170,


# Decoder Layer

In [234]:
class DecoderLayer(nn.Module):
    def __init__(self, d_model, n_head, ffn_hidden, drop_prob=0.1) -> None:
        super(DecoderLayer, self).__init__()
        self.attention = MultiheadAttention(d_model, n_head)
        self.norm1 = LayerNorm(d_model)
        self.drop1 = nn.Dropout(drop_prob)

        self.cross_attention = MultiheadAttention(d_model, n_head)
        self.norm2 = LayerNorm(d_model)
        self.drop2 = nn.Dropout(drop_prob)

        self.ffn = FeedForward(d_model, ffn_hidden, drop_prob)
        self.norm3 = LayerNorm(d_model)
        self.drop3 = nn.Dropout(drop_prob)
    def forward(self, dec, enc, token_mask, dec_mask):
        _dec = dec
        dec = self.attention(dec, dec, dec, token_mask)
        dec = self.drop1(dec)
        dec = self.norm1(dec + _dec)

        if enc is not None:
            _dec = dec
            dec = self.cross_attention(dec, enc, enc, dec_mask)
            dec = self.drop2(dec)
            dec = self.norm2(dec + _dec)

        _dec = dec
        dec = self.ffn(dec)
        dec = self.drop3(dec)
        dec = self.norm3(dec + _dec)
        return dec

In [235]:
x = torch.randn((2, 3, 20))
DL = DecoderLayer(20, 4, 10, 0.2)
y = DL(x, x, None, None)
print(y)

after mask score shape torch.Size([2, 4, 3, 3]) v shape torch.Size([2, 4, 3, 5])
after mask score shape torch.Size([2, 4, 3, 3]) v shape torch.Size([2, 4, 3, 5])
tensor([[[ 0.0332, -0.1682,  0.0918, -2.6525,  0.7405, -0.6335,  0.5861,
          -0.0096, -1.0309,  0.6493, -0.1162,  1.0375,  1.1636,  0.4502,
          -1.9318,  0.7453,  0.3558,  1.1932, -0.7706,  0.2667],
         [ 0.3320, -0.6758, -0.0838, -1.3168,  2.1527,  0.1473, -1.2233,
           0.2695, -2.0551,  0.5421,  1.2470, -1.2797,  0.2074,  0.8121,
          -0.3774, -0.0571,  1.3036, -0.3322,  0.0838,  0.3036],
         [-0.9416,  0.2964, -0.7549,  0.5917, -0.0484, -0.4286,  0.1047,
           0.0933,  0.1726,  0.9945, -0.8170,  1.8015, -0.0493, -1.0429,
          -2.0494,  1.7465, -0.9998, -0.7180,  0.9404,  1.1083]],

        [[ 0.2906,  0.8368, -0.3392,  1.0948,  1.6970, -1.2326,  0.5624,
          -1.5848, -0.7892, -0.7397,  0.1584, -1.1686,  1.5829, -0.5894,
           0.4029,  1.1360, -0.8587,  0.5866,  0.2448, -1

# Transformer
![total](image/The_transformer_encoder_decoder_stack.png)

In [236]:
class Encoder(nn.Module):
    def __init__(
        self, 
        vocab_size,
        d_model, 
        max_len,
        ffn_hidden, 
        n_head, 
        drop_prob,
        n_layers
    ):
        super(Encoder, self).__init__()

        self.embedding = TransformerEmbedding(vocab_size, d_model, max_len, drop_prob)
        self.layers = nn.ModuleList(
            [
                EncodingLayer(d_model, ffn_hidden, n_head, drop_prob)
                for _ in range(n_layers)
            ]
        )
    
    def forward(self, x, enc_mask):
        x = self.embedding(x)
        print("encoder embedding shape", x.shape)
        for layer in self.layers:
            x = layer(x, enc_mask)
        return x

In [237]:
class Decoder(nn.Module):
    def __init__(
        self,
        vocab_size,
        d_model,
        max_len,
        ffn_hidden,
        n_head,
        drop_prob,
        n_layers
    ):
        super(Decoder, self).__init__()

        self.embedding = TransformerEmbedding(
            vocab_size, d_model, max_len, drop_prob
        )

        self.layers = nn.ModuleList(
            [
                DecoderLayer(d_model, n_head, ffn_hidden, drop_prob)
                for _ in range(n_layers)
            ]
        )

        self.fc = nn.Linear(d_model, vocab_size)

    def forward(self, dec, enc, token_mask, dec_mask):
        dec = self.embedding(dec)
        print("decoder embedding shape", dec.shape)
        for layer in self.layers:
            dec = layer(dec, enc, token_mask, dec_mask)

        dec = self.fc(dec)
        
        return dec


## Mask 机制

In [238]:
class Transformer(nn.Module):
    def __init__(
        self, 
        src_pad_idx, 
        tgt_pad_idx,
        enc_voc_size, 
        dec_voc_size,
        max_len,
        d_model,
        n_head,
        ffn_hidden,
        n_layers,
        drop_prob
    ):
        super(Transformer, self).__init__()

        self.encoder = Encoder(
            enc_voc_size, d_model, max_len, ffn_hidden, n_head, drop_prob, n_layers
        )
        self.decoder = Decoder(
            dec_voc_size, d_model, max_len, ffn_hidden, n_head, drop_prob, n_layers
        )

        self.src_pad_idx = src_pad_idx
        self.tgt_pad_idx = tgt_pad_idx

    def pad_mask(self, q, k, pad_idx_q, pad_idx_k):
        '''
        attention padding mask, 用于屏蔽padding的token, 使其不参与attention计算
        pad_idx_q: query的padding token
        pad_idx_k: key的padding token
        '''
        len_q = q.shape[1]  # q shape is (batch_size, q_token_size)
        len_k = k.shape[1]  # k shape is (batch_size, k_token_size)
        # attention 结束的shape 是 batch_size, q_token_size, k_token_size
        
        q = q.ne(pad_idx_q).unsqueeze(1).unsqueeze(3)
        q = q.repeat(1, 1, 1, len_k)
        k = k.ne(pad_idx_k).unsqueeze(1).unsqueeze(2)
        k = k.repeat(1, 1, len_q, 1)

        mask = q & k
        return mask
    
    def causal_mask(self, q, k):
        len_q, len_k = q.shape[1], k.shape[1]
        mask = (torch.tril(torch.ones(len_q, len_k, dtype=bool)))
        return mask
    
    def forward(self, src, tgt):
        src_mask = self.pad_mask(src, src, self.src_pad_idx, self.src_pad_idx)
        tgt_mask_1 = self.pad_mask(tgt, tgt, self.tgt_pad_idx, self.tgt_pad_idx)
        tgt_mask_2 = self.causal_mask(tgt, tgt)
        tgt_mask = tgt_mask_1 * tgt_mask_2
        src_tgt_mask = self.pad_mask(tgt, src, self.tgt_pad_idx, self.src_pad_idx)
        # print("src_mask shape:", src_mask.shape)
        # print("tgt mask_1 shape:", tgt_mask_1.shape)
        # print("tgt mask_2 shape:", tgt_mask_2.shape)
        # print("tgt_mask shape:", tgt_mask.shape)
        # print("src_tgt_mask shape:", src_tgt_mask.shape)
        enc = self.encoder(src, src_mask)
        ouput = self.decoder(tgt, enc, tgt_mask, src_tgt_mask)
        return ouput

In [239]:
enc_voc_size = 6000
dec_voc_size = 8000
src_pad_idx = 1
tgt_pad_idx = 1
tgt_sos_idx = 2
batch_size = 32
max_len = 1024
d_model = 512
n_layers = 3
n_head = 2
ffn_hidden = 1024
drop_prob = 0.1
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = "cpu"

In [240]:
model = Transformer(src_pad_idx=src_pad_idx,
                    tgt_pad_idx=tgt_pad_idx,
                    d_model=d_model,
                    enc_voc_size=enc_voc_size,
                    dec_voc_size=dec_voc_size,
                    max_len=max_len,
                    ffn_hidden=ffn_hidden,
                    n_head=n_head,
                    n_layers=n_layers,
                    drop_prob=drop_prob,
                    )

def initialize_weights(m):
    if hasattr(m, 'weight') and m.weight.dim() > 1:
        nn.init.kaiming_uniform(m.weight.data)
        
model.apply(initialize_weights)
src = torch.load("tensor_src.pt")
tgt = torch.load("tensor_tgt.pt")
print("src shape", src.shape, "tgt shape", tgt.shape)
res = model(src, tgt)
print("src shape", src.shape, "tgt shape", tgt.shape)

  nn.init.kaiming_uniform(m.weight.data)


src shape torch.Size([128, 36]) tgt shape torch.Size([128, 38])
src_mask shape: torch.Size([128, 1, 36, 36])
tgt mask_1 shape: torch.Size([128, 1, 38, 38])
tgt mask_2 shape: torch.Size([38, 38])
tgt_mask shape: torch.Size([128, 1, 38, 38])
src_tgt_mask shape: torch.Size([128, 1, 38, 36])
encoder embedding shape torch.Size([128, 36, 512])
after mask score shape torch.Size([128, 2, 36, 36]) v shape torch.Size([128, 2, 36, 256])
after mask score shape torch.Size([128, 2, 36, 36]) v shape torch.Size([128, 2, 36, 256])
after mask score shape torch.Size([128, 2, 36, 36]) v shape torch.Size([128, 2, 36, 256])
decoder embedding shape torch.Size([128, 38, 512])
after mask score shape torch.Size([128, 2, 38, 38]) v shape torch.Size([128, 2, 38, 256])
after mask score shape torch.Size([128, 2, 38, 36]) v shape torch.Size([128, 2, 36, 256])
after mask score shape torch.Size([128, 2, 38, 38]) v shape torch.Size([128, 2, 38, 256])
after mask score shape torch.Size([128, 2, 38, 36]) v shape torch.Siz

# PyTorch API

In [8]:
transformer = nn.Transformer(
    d_model=d_model,
    nhead=n_head,
    num_encoder_layers=n_layers,
    num_decoder_layers=n_layers,
    dim_feedforward=ffn_hidden,
    dropout=drop_prob
)

def initialize_weights(m):
    if hasattr(m, 'weight') and m.weight.dim() > 1:
        nn.init.kaiming_uniform(m.weight.data)
transformer.apply(initialize_weights)

x = torch.randn((32, 10, 512))
y = torch.randn((32, 10, 512))
res = transformer(x, y)
print(res.shape)

torch.Size([32, 10, 512])


  nn.init.kaiming_uniform(m.weight.data)
