In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import math
import random
import numpy as np

torch.manual_seed(12046)
np.random.seed(12046)
random.seed(12046)

In [2]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(dropout)

        # initiate pe
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) # (max_len, 1)

        for i in range(0, d_model, 2):
            pe[:, i] = torch.sin(position * (10000 ** (i / d_model)))
            if i + 1 < d_model:
                pe[:, i + 1] = torch.cos(position * (10000 ** (i / d_model)))
        
        pe = pe.unsqueeze(0) # (1, max_len, d_model)
        self.register_buffer('pe', pe)

    def forward(self, x):
        # x: (batch_size, seq_len, d_model)
        x = x + self.pe[:, :x.size(1)] # boardcasting
        return self.dropout(x)

In [3]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        self.wq = nn.Linear(d_model, d_model)
        self.wk = nn.Linear(d_model, d_model)
        self.wv = nn.Linear(d_model, d_model)

        self.wo = nn.Linear(d_model, d_model)

    def scaled_dot_product_attention(self, q, k, v, mask=None):
        # q, k, v: (batch_size, num_heads, seq_len, d_k)
        # 高效支持多头并行计算（num_heads 维度）；
        # 直接满足注意力分数的矩阵乘法需求（seq_len 和 d_k 的位置）；
        # 与后续的 softmax 归一化、V 的加权求和等操作无缝衔接。
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
        
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)

        attn_weights = F.softmax(scores, dim=-1)
        outputs = torch.matmul(attn_weights, v)

        return outputs, attn_weights

    def split_heads(self, x, batch_size):
        # x: (batch_size, seq_len, d_model)
        # d_model -> num_heads * d_k
        x = x.view(batch_size, -1, self.num_heads, self.d_k)
        return x.transpose(1, 2) # (batch_size, num_heads, seq_len, d_k)

    def forward(self, q, k, v, mask=None):
        q = self.wq(q)
        k = self.wk(k)
        v = self.wv(v)

        batch_size = q.size(0)

        q = self.split_heads(q, batch_size)
        k = self.split_heads(k, batch_size)
        v = self.split_heads(v, batch_size)

        # attention
        outputs, attn_weights = self.scaled_dot_product_attention(q, k, v, mask) # (batch_size, num_heads, seq_len, d_k)

        # combine heads
        outputs = outputs.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model) # (batch_size, seq_len, d_model)

        outputs = self.wo(outputs)
        return outputs, attn_weights


In [4]:
class PositionWiseFeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.fc2(x)
        return x

In [5]:
class Encoder(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.ffn = PositionWiseFeedForward(d_model, d_ff, dropout)

        # LayerNorm
        self.layer_norm1 = nn.LayerNorm(d_model)
        self.layer_norm2 = nn.LayerNorm(d_model)

        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        # Self-Attention
        attn_output, _ = self.self_attn(x, x, x, mask)
        x = x + self.dropout1(attn_output) # ResNet
        x = self.layer_norm1(x) 

        # Feed Forward Network
        ffn_output = self.ffn(x)
        x = x + self.dropout2(ffn_output) # ResNet
        x = self.layer_norm2(x)

        return x

In [6]:
class Decoder(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.cross_attn = MultiHeadAttention(d_model, num_heads)
        self.ffn = PositionWiseFeedForward(d_model, d_ff, dropout)

        # LayerNorm
        self.layer_norm1 = nn.LayerNorm(d_model)
        self.layer_norm2 = nn.LayerNorm(d_model)
        self.layer_norm3 = nn.LayerNorm(d_model)

        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.dropout3 = nn.Dropout(dropout)

    def forward(self, x, enc_output, src_mask=None, tgt_mask=None):
        # Self_Attention
        attn_output, _ = self.self_attn(x, x, x, src_mask)
        x = x + self.dropout1(attn_output) # ResNet
        x = self.layer_norm1(x)

        # Cross_Attention
        cross_attn_output, _ = self.cross_attn(x, enc_output, enc_output, tgt_mask)
        x = x + self.dropout2(cross_attn_output) # ResNet
        x = self.layer_norm2(x)

        # Feed Forward Network
        ffn_output = self.ffn(x)
        x = x + self.dropout3(ffn_output) # ResNet
        x = self.layer_norm3(x)

        return x

In [8]:
class Transformer(nn.Module):
    def __init__(self, src_size, tgt_size, d_model=512, num_heads=8, d_ff=2048, num_layers=6, max_len=5000, dropout=0.1):
        super().__init__()
        # embedding layers
        self.src_emb = nn.Embedding(src_size, d_model)
        self.tgt_emb = nn.Embedding(tgt_size, d_model)

        # positional encoding
        self.pos_emb = PositionalEncoding(d_model, max_len, dropout)

        # encoder and decoder layers
        self.encoder_layers = nn.ModuleList([
            Encoder(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)
        ])
        self.decoder_layers = nn.ModuleList([
            Decoder(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)
        ])

        self.fc_out = nn.Linear(d_model, tgt_size)

        self.d_model = d_model

    def generate_mask(self, src, tgt, pad_seq=0):
        # src, tgt: (batch_size, src_len), (batch_size, tgt_len)
        src_mask = (src != pad_seq).unsqueeze(1).unsqueeze(2) # (batch_size, 1, 1, src_len)
        tgt_mask = (tgt != pad_seq).unsqueeze(1).unsqueeze(2) # (batch_size, 1, 1, tgt_len)

        # future mask
        tgt_len = tgt.size(1)
        future_mask = (1 - torch.triu(torch.ones(1, tgt_len, tgt_len), diagonal=1)).bool()
        tgt_mask = tgt_mask & future_mask

        return src_mask, tgt_mask
    
    def forward(self, src, tgt):
        src_mask, tgt_mask = self.generate_mask(src, tgt)

        # embedding and positional encoding
        src_emb = self.src_emb(src) * math.sqrt(self.d_model)
        src_emb = self.pos_emb(src_emb)
        tgt_emb = self.tgt_emb(tgt) * math.sqrt(self.d_model)
        tgt_emb = self.pos_emb(tgt_emb)

        src_out = src_emb
        for enc_layer in self.encoder_layers:
            src_out = enc_layer(src_out, src_mask)

        tgt_out = tgt_emb
        for dec_layer in self.decoder_layers:
            tgt_out = dec_layer(tgt_out, src_out, src_mask, tgt_mask)

        output = self.fc_out(tgt_out)
        return output