# Implement Transformer Model

I want to implement a transformer model by pytorch like the photo below.

<img src="transformer.png" alt="Transformer Image">
<!-- <div style="text-align: center;">
    <img src="transformer.png" alt="Transformer Image" style="width: 20%; height: auto;">
</div> -->

## Encoder

### Rotate Position Embedding (RoPE)

In [1]:
import math
import torch
from torch import nn
import torch.nn.functional as F


batch_size = 64  # Batch size
max_len = 1000  # Maximum length of a sequence
d_model = 512  # Embedding size


class PositionEmbedding(nn.Module):
    def __init__(self, d_model, max_len, dropout):
        super().__init__()

        assert d_model % 2 == 0, "d_model should be devisible by 2." 
        
        position = torch.arange(max_len).reshape(-1, 1)
        inv_freq = torch.exp(- math.log(10000.0) * torch.arange(0, d_model, 2) / d_model)
        pe_sin = torch.sin(position * inv_freq)
        pe_cos = torch.cos(position * inv_freq)
        pe = torch.zeros(max_len, d_model)
        pe[:, ::2] = pe_sin
        pe[:, 1::2] = pe_cos
        pe = pe.unsqueeze(0)
        self.register_buffer("pe", pe)
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x):
        # x shape is [batch_size, seq_len, d_model]

        seq_len = x.size(1)
        x = x + self.pe[:, :seq_len, :]
        return self.dropout(x)

In [2]:
# # test
# m = PositionEmbedding(d_model=d_model, max_len=max_len, dropout=0.5)
# x = torch.randn(10, 100, 512)
# m(x).shape

### Implement a Multi-Attention Model

In [3]:
def attention(query: torch.tensor, key: torch.tensor, value: torch.tensor, mask: torch.tensor, dropout: nn.Module):

    q_d = query.size(-1)
    k_d = key.size(-1)
    assert q_d == k_d, "q_d should equal to k_d"

    scores = torch.matmul(query, key.transpose(-1, -2)) / torch.sqrt(torch.tensor(k_d, dtype=torch.float32))
    if mask is not None:
        scores.masked_fill(mask==0, -1e9)
    p_attn = F.softmax(scores, dim=-1)    
    if dropout is not None:
        p_attn = dropout(p_attn)

    return p_attn @ value, p_attn


class MultiAttention(nn.Module):
    def __init__(self, d_model, h, dropout=0.1):
        super().__init__()
        assert d_model % h == 0, "d_model should be devisible by h"

        self.d_h = d_model // h
        self.h = h  # heads num
        self.linearlist = nn.ModuleList(
            [nn.Linear(d_model, d_model) for _ in range(4)]
        )
        self.p_attn = None
        self.dropout = nn.Dropout(dropout)

    def forward(self, query, key, value, mask=None):
        """
        query: [batch_size, seq_len, d_k]  d_k is the aggregation of the hidden dimensions of all attention heads.
        key: [batch_size, seq_len, d_k]    we make d_k == d_v == d_h
        value: [batch_size, seq_len, d_v]
        """
        batch_size = query.size(0)

        query, key, value = [linear(x) for linear,x in zip(self.linearlist, (query, key, value))]        
        query = query.view(batch_size, -1, self.h, self.d_h).transpose(1,2)
        key = key.view(batch_size, -1, self.h, self.d_h).transpose(1,2)
        value = value.view(batch_size, -1, self.h, self.d_h).transpose(1,2)

        mat_attention, self.p_attn = attention(query, key, value, mask, self.dropout)
        mat_attention = mat_attention.transpose(1, 2).contiguous().view(batch_size, -1, d_model)

        return self.linearlist[-1](mat_attention)

In [4]:
# # test
# d_model, h = 12, 4
# x1 = torch.randn(3, 5, d_model)
# x2 = torch.randn(3, 5, d_model)
# x3 = torch.randn(3, 5, d_model)

# m = MultiAttention(d_model, h)
# r = m(x1, x1, x1)

# r.shape

### Implement a Residual Network

In [5]:
class Residual(nn.Module):
    def __init__(self, d_model:int, dropout:float=0.1):
        super().__init__()
        self.norm = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(p=dropout)
    
    def forward(self, x, subnetwork):
        return x + self.dropout(x + self.norm(subnetwork(x)))

In [6]:
# # test
# size = 4
# m = Residual(size, dropout=0.1)
# x = torch.randn(5, size)
# m(x, nn.Linear(size, size)).shape

#### Implement a Feed-Forward Network

In [7]:
class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout:float=0.1):
        super().__init__()
        self.linear_a = nn.Linear(d_model, d_ff)
        self.linear_b = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(p=dropout)
    
    def forward(self, x):
        return self.linear_b(self.dropout(F.relu(self.linear_a(x))))

In [8]:
# # test
# d_model = 7
# m = FeedForward(d_model, 5, 0.1)
# x = torch.randn(6, d_model)
# x.shape

In [9]:
class EncoderLayer(nn.Module):
    def __init__(
        self, 
        d_model: int, 
        attn: MultiAttention, 
        ffn: FeedForward, 
        res: Residual
    ):
        super().__init__()
        self.attn = attn
        self.ffn = ffn
        self.resduial = nn.ModuleList(
            res for _ in range(2)
        )
        self.d_model = d_model

    def forward(self, x, mask):
        attn = self.resduial[0](x, lambda x: self.attn(x, x, x, mask))
        return self.resduial[1](attn, self.ffn)

In [10]:
# # test
# d_model = 128
# attn = MultiAttention(d_model=d_model, h=4, dropout=0.1)
# ffn = FeedForward(d_model=d_model, d_ff=256)
# res = Residual(d_model=d_model, dropout=0.1)
# m = EncoderLayer(attn=attn, ffn=ffn, res=res)

# x = torch.randn(3, 400, d_model) # [batch_size, seq_len, d_model]

# # m(x).shape

### Nest the EncoderLayer to Encoder

In [11]:
class Encoder(nn.Module):
    def __init__(self, encodelayer: EncoderLayer, N):
        super().__init__()
        self.encodelayers = nn.ModuleList(
            [encodelayer for _ in range(N)]
        )
        self.normlayer = nn.LayerNorm(encodelayer.d_model)

    def forward(self, x, mask):
        for layer in self.encodelayers:
            x = layer(x, mask)

        return self.normlayer(x)

In [12]:
# # test
# d_model = 128
# attn = MultiAttention(d_model=d_model, h=4, dropout=0.1)
# ffn = FeedForward(d_model=d_model, d_ff=256)
# res = Residual(d_model=d_model, dropout=0.1)
# m = Encoder(EncoderLayer(d_model=d_model, attn=attn, ffn=ffn, res=res), N=3)

# x = torch.randn(3, 400, d_model) # [batch_size, seq_len, d_model]
# m(x, mask=None).shape

## Decoder

### Implement a DecoderLayer

In [13]:
class DecoderLayer(nn.Module):
    def __init__(
        self, 
        d_model: int,
        attn: MultiAttention, 
        ffn: FeedForward, 
        res: Residual, 
        dropout=0.1
    ):
        super().__init__()
        self.attnlist = nn.ModuleList(
            [attn for _ in range(2)]
        )
        self.ffn = ffn
        self.reslist = nn.ModuleList(
            [res for _ in range(3)]
        )
        self.dropout = nn.Dropout(p=dropout)
        self.d_model = d_model

    def forward(self, x, memory, src_mask, tgt_mask):
        x = self.reslist[0](x, lambda x: self.attnlist[0](x, x, x, tgt_mask))
        x = self.reslist[1](x, lambda x: self.attnlist[0](x, memory, memory, src_mask))
        x = self.reslist[2](x, self.ffn)
        return x

In [14]:
# # test
# d_model = 128
# attn = MultiAttention(d_model=d_model, h=4, dropout=0.1)
# ffn = FeedForward(d_model=d_model, d_ff=256)
# res = Residual(d_model=d_model, dropout=0.1)
# m = Encoder(EncoderLayer(d_model=d_model, attn=attn, ffn=ffn, res=res), N=3)

# x = torch.randn(3, 400, d_model) # [batch_size, seq_len, d_model]
# m(x, src_mask=None, tgt_mask=None).shape

### Nest the DecoderLayers to Decoder

In [15]:
class Decoder(nn.Module):
    def __init__(self, decoderlayer:DecoderLayer, N):
        super().__init__()
        self.decoderlayers = nn.ModuleList(
            [decoderlayer for _ in range(N)]
        )
        self.norm = nn.LayerNorm(decoderlayer.d_model)

    def forward(self, x, memory, src_mask, tgt_mask):
        for decoderlayer in self.decoderlayers:
            x = decoderlayer(x, memory, src_mask, tgt_mask)
        return self.norm(x)

### GeneratorLayer

Change the dim from d_model to vocab

In [16]:
class Generator(nn.Module):
    def __init__(self, d_model: int, vocab: int):
        super().__init__()
        self.linear = nn.Linear(d_model, vocab)
    
    def forward(self, x):
        return F.log_softmax(self.linear(x), dim=-1)

### Embeddings

In [17]:
class Embeddings(nn.Module):
    def __init__(self, d_model, vocab):
        super().__init__()
        self.embedding = nn.Embedding(num_embeddings=vocab, embedding_dim=d_model)
        self.d_model = d_model
    
    def forward(self, x):
        return self.embedding(x) * math.sqrt(self.d_model)

## Assembling all models together

In [18]:
"""
在实现的时候往往遇到这样一个问题, 在__init__方法中应该传入nn.Module对象还是传入更加具体的参数比如d_model。
经过尝试, 我认为比较好的一个实践是考量后续对模型结构是否会进行调整。对于一些较为稳定且简单的结构, 比如Embedding, Linear可以不显式传入模型,
在__init__中定义, 一般情况下, 我更青睐于在__init__中写清楚会用的组件, 尤其是一些自定义的模型。
但是, Transformer很多时候都没有这种规范.
"""
class EncoderDecoder(nn.Module):
    def __init__(
        self, d_model: int,
        pos_embed: PositionEmbedding,
        encode: Encoder,
        decode: Decoder,
        gen: Generator, 
        embed: Embeddings,
    ):
        super().__init__()
        self.encode = encode
        self.decode = decode
        self.src_embed = embed
        self.tgt_embed = embed
        self.generator = gen
        self.d_model = d_model

        # pre-process
        self.src_process = nn.Sequential(self.src_embed, pos_embed)
        self.tgt_process = nn.Sequential(self.tgt_embed, pos_embed)
    
    def forward(self, src, tgt):
        src = self.src_process(src)
        tgt = self.tgt_process(tgt)
        tmp = self.encode(src, mask=None)
        tmp = self.decode(tgt, tmp, src_mask=None, tgt_mask=None)
        return tmp

In [19]:
# test
d_model = 512
N = 6
src_vocab = 100
tgt_vocab = 100
attn = MultiAttention(d_model, h=8)
ffn = FeedForward(d_model, d_ff=2048, dropout=0.1)
res = Residual(d_model, dropout=0.1)
transformer = EncoderDecoder(
    d_model=d_model,
    pos_embed=PositionEmbedding(d_model, max_len=5000, dropout=0.1),
    encode=Encoder(EncoderLayer(d_model, attn, ffn, res), N),
    decode=Decoder(DecoderLayer(d_model, attn, ffn, res, dropout=0.1), N),
    gen=Generator(d_model, tgt_vocab),
    embed=Embeddings(d_model, tgt_vocab)
)

In [20]:
# test
x = torch.randint(0, 100, (10, 1000))
transformer.forward(x, x).shape

torch.Size([10, 1000, 512])