In [None]:
#输入定义
import torch
import torch.nn as nn
import math

max_len=512
d_model=768
batch_size=128
n_head=8

In [None]:
#1.token embedding (batch_size,seq_len,vocab_size)-->(batch_size,seq_len,d_model)
class TokenEmbedding(nn.Embedding):
    def __init__(self, vocab_size, d_model):
        super(TokenEmbedding, self).__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)


In [None]:
#2.Positional embedding (batch_size,seq_len,d_model)-->(batch_size,seq_len,d_model)
class PositionalEncoding(nn.Module):
    def __init__(self,d_model,max_len,device):
        super(PositionalEncoding, self).__init__()
        self.encoding=torch.zeros(max_len,d_model).to(device)
        self.encoding.requires_grad=False

        pos=torch.arange(0,max_len).unsqueeze(1)#(max_len,1)
        _2i=torch.arange(0,d_model,step=2)#(d_model/2,)

        #赋值PE矩阵
        self.encoding[:,0::2]=torch.sin(pos/10000**(_2i/d_model))
        self.encoding[:,1::2]=torch.cos(pos/10000**(_2i/d_model))

    def forward(self,x):
        #x: (batch_size,seq_len,d_model)
        seq_len=x.size(1)
        return x+self.encoding[:seq_len,:]

In [None]:
#3.Multi-head attention
class MultiHeadAttention(nn.Module):
    def __init(self,d_model,n_head):
        super(MultiHeadAttention, self).__init__()
        self.d_model=d_model
        self.n_head=n_head
        self.d_k=d_model//n_head

        self.W_q=nn.Linear(d_model,d_model)
        self.W_k=nn.Linear(d_model,d_model)
        self.W_v=nn.Linear(d_model,d_model)
        self.combine=nn.Linear(d_model,d_model)
        self.softmax=nn.Softmax(dim=-1)

    def forward(self,q,k,v,mask=None):
        #这里输入的q,k,v实际上是计算q,k,v的对应输入，encoder中这三个都是x
        batch_size,seq_len,d_model=q.shape
        n_d=d_model//self.n_head
        q=self.W_q(q)
        k=self.W_k(k)
        v=self.W_v(v)
        #切分并转换-->(batch_size,n_head,seq_len,d_model)
        q=q.view(batch_size,seq_Len,n_head,n_d).permute(0,2,1,3)
        k=k.view(batch_size,seq_Len,n_head,n_d).permute(0,2,1,3)
        v=v.view(batch_size,seq_Len,n_head,n_d).permute(0,2,1,3)
        #计算注意力得分
        scores=torch.matmul(q,k.transpose(-2,-1))/math.sqrt(self.d_k)#(batch_size,n_head,seq_len,seq_len)
        #掩码自注意力
        if mask is not None:
            scores = scores.masked_fill(mask==0, float('-inf'))
        #softmax
        attn=self.softmax(scores)
        context=torch.matmul(attn,v)#(batch_size,n_head,seq_len,d_k)
        #拼接并转换
        context=context.permute(0,2,1,3).contiguous().view(batch_size,seq_len,d_model)#(batch_size,seq_len,d_model)
        output=self.combine(context)#(batch_size,seq_len,d_model)
        return output




In [None]:
#4. Layer Normalization
class LayerNorm(nn.Module):
    def __init__(self,d_model,eps=1e-6):
        super(LayerNorm, self).__init__()
        self.gamma=nn.Parameter(torch.ones(d_model))
        self.beta=nn.Parameter(torch.zeros(d_model))
        self.eps=eps

    def forward(self,x):
        #x: (batch_size,seq_len,d_model)
        mean=x.mean(-1,keepdim=True)
        var=x.var(-1,unbiased=False,keepdim=True)
        out=(x-mean)/torch.sqrt(var+self.eps)
        return self.gamma*out+self.beta

In [None]:
#5.Feed Forward Network
class FeedForward(nn.Module):
    def __init__(self,d_model,d_ff=2048,drpout=0.1):
        super(FeedForward, self).__init__()
        self.linear1=nn.Linear(d_model,d_ff)
        self.linear2=nn.Linear(d_ff,d_model)
        self.relu=nn.ReLU()
        self.dropout=nn.Dropout(drpout)
    def forward(self,x):
        #x: (batch_size,seq_len,d_model)
        out=self.linear1(x)
        out=self.relu(out)
        out=self.dropout(out)
        out=self.linear2(out)
        return out

In [None]:
#6.Transformer embedding layer
class TransformerEmbedding(nn.Module):
    def ___init__(self,max_len,vocab_size,d_model,device):
        super(TransformerEmbedding, self).__init__()
        self.token_embedding=TokenEmbedding(vocab_size,d_model)
        self.position_encoding=PositionalEncoding(d_model,max_len,device)
        self.dropout=nn.Dropout(0.1)
    def forward(self,x):
        token_emb=self.token_embedding(x)
        pos_emb=self.position_encoding(x)
        emb=token_emb+pos_emb
        emb=self.dropout(emb)
        return emb

In [None]:
#7.Encoder Layer (batch_size,seq_len,d_model)-->(batch_size,seq_len,d_model)
class EncoderLayer(nn.Module):
    def __init__(self,d_model,n_head,d_ff=2048,dropout=0.1):
        super(EncoderLayer,self).__init__()
        self.mha=MultiHeadAttention(d_model,n_head)
        self.norm1=LayerNorm(d_model)
        self.dropout1=nn.Dropout(dropout)

        self.ffn=FeedForward(d_model,d_ff,dropout)
        self.norm2=LayerNorm(d_model)
        self.dropout2=nn.Dropout(dropout)
        
        

    def forward(self,x,mask=None):
        _x=x
        attn_out=self.mha(x,x,x,mask)
        attn_out=self.dropout1(attn_out)
        x=self.norm1(_x+attn_out)
        
        _x=x
        ffn_out=self.ffn(x)
        ffn_out=self.dropout2(ffn_out)
        x=self.norm2(_x+ffn_out)
        return x



In [None]:
#8.Decoder Layer (batch_size,seq_len,d_model)-->(batch_size,seq_len,d_model)
class DecoderLayer(nn.Module):
    def __init__(self,d_model,n_head,d_ff=2048,dropout=0.1):
        super(DecoderLayer,self).__init__()
        self.self_mha=MultiHeadAttention(d_model,n_head)
        self.norm1=LayerNorm(d_model)
        self.dropout1=nn.Dropout(dropout)

        self.enc_dec_mha=MultiHeadAttention(d_model,n_head)
        self.norm2=LayerNorm(d_model)
        self.dropout2=nn.Dropout(dropout)

        self.ffn=FeedForward(d_model,d_ff,dropout)
        self.norm3=LayerNorm(d_model)
        self.dropout3=nn.Dropout(dropout)
        
    def forward(self,x,enc_out,t_mask,s_mask):
        _x=x
        self_attn_out=self.self_mha(x,x,x,t_mask)
        self_attn_out=self.dropout1(self_attn_out)
        x=self.norm1(_x+self_attn_out)
        
        if enc_out is not None:
            _x=x
            enc_dec_attn_out=self.enc_dec_mha(x,enc_out,enc_out,s_mask)
            enc_dec_attn_out=self.dropout2(enc_dec_attn_out)
            x=self.norm2(_x+enc_dec_attn_out)

        _x=x
        ffn_out=self.ffn(x)
        ffn_out=self.dropout3(ffn_out)
        x=self.norm3(_x+ffn_out)
        return x


In [None]:
#9. Transformer Encoder
class Encoder(nn.Module):
    def __init__(self,num_layers,d_model,n_head,enc_vocab_size,d_ff=2048,dropout=0.1,device='cpu'):
        super(Encoder, self).__init__()
        self.embedding=TransformerEmbedding(max_len,enc_vocab_size,d_model,device)
        self.layers=nn.ModuleList([EncoderLayer(d_model,n_head,d_ff,dropout) for _ in range(num_layers)])
    def forward(self,x,s_mask):
        x=self.embedding(x)
        for layer in self.layers:
            x=layer(x,s_mask)
        return x
    
#10. Transformer Decoder
class Decoder(nn.Module):
    def __init__(self,num_layers,d_model,n_head,dec_vocab_size,d_ff=2048,dropout=0.1,device='cpu'):
        super(Encoder, self).__init__()
        self.embedding=TransformerEmbedding(max_len,dec_vocab_size,d_model,device)
        self.layers=nn.ModuleList([DecoderLayer(d_model,n_head,d_ff,dropout) for _ in range(num_layers)])
        self.fc=nn.Linear(d_model,dec_vocab_size)
    def forward(self,x,enc,t_mask,s_mask):
        x=self.embedding(x)
        for layer in self.layers:
            x=layer(x,enc,t_mask,s_mask)
        dec= self.fc(x)
        return dec

In [None]:
#11. Transformer Model
class Transformer(nn.Module):
    def __init__(self,src_idx,trg_idx,enc_num_layers,dec_num_layers,d_model,n_head,enc_vocab_size,dec_vocab_size,d_ff=2048,dropout=0.1,device='cpu'):
        super(Transformer, self).__init__()
        self.encoder=Encoder(enc_num_layers,d_model,n_head,enc_vocab_size,d_ff,dropout,device)
        self.decoder=Decoder(dec_num_layers,d_model,n_head,dec_vocab_size,d_ff,dropout,device)
        self.src_idx=src_idx
        self.trg_idx=trg_idx
        self.device=device
    def make_casual_mask(self,q,k):#输出一个目标序列和源序列长度的掩码
        len_q,len_k=q.size(1),k.size(1)
        casual_mask=torch.tril(torch.ones((len_q,len_k),device=self.device)).bool()
        return casual_mask
    
    def make_pad_mask(self, q, k, pad_idx_q, pad_idx_k):
        len_q, len_k = q.size(1), k.size(1)
        # q_mask 的形状: (Batch, 1, len_q, 1)
        q_mask = q.ne(pad_idx_q).unsqueeze(1).unsqueeze(3)
        # k_mask 的形状: (Batch, 1, 1, len_k)
        k_mask = k.ne(pad_idx_k).unsqueeze(1).unsqueeze(2)
        # 广播机制会自动将两个张量扩展到 (Batch, 1, len_q, len_k)
        # 然后执行逐元素的逻辑与操作
        mask = q_mask & k_mask
        return mask

    def forward(self,src,tgt,src_mask,tgt_mask):
        src_mask=self.make_pad_mask(src,src,self.src_idx,self.src_idx)
        trg_mask=self.make_pad_mask(tgt,tgt,self.trg_idx,self.trg_idx)&self.make_casual_mask(tgt,tgt) 
        src_trg_mask=self.make_pad_mask(tgt,src,self.trg_idx,self.src_idx)
        enc_out=self.encoder(src,src_mask)
        dec_out=self.decoder(tgt,enc_out,src_trg_mask,trg_mask)
        return dec_out