# 导包

In [77]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import math
from torch.utils.data import Dataset,DataLoader

torch.manual_seed(42)

<torch._C.Generator at 0x7f8269d05650>

# positional encoding

In [78]:
class PositionalEncoding(nn.Module):
    def __init__(self,d_model,max_len=5000):
        super(PositionalEncoding,self).__init__()
        pe=torch.zeros(max_len,d_model)
        position=torch.arange(0,max_len,dtype=torch.float).unsqueeze(1)
        div_term=torch.exp(torch.arange(0,d_model,2).float()*(-math.log(10000)/d_model))

        pe[:,0::2]=torch.sin(position*div_term)
        pe[:,1::2]=torch.cos(position*div_term)

        pe=pe.unsqueeze(0)
        self.register_buffer('pe',pe)

    def forward(self,x):
        x=x+self.pe[:,:x.size(1),:]
        return x

# Scaled Dot-Product Attention

In [79]:
class ScaledDotProductAttention(nn.Module):
    def __init__(self,dropout=0.1):
        super(ScaledDotProductAttention,self).__init__()
        self.dropout=nn.Dropout(dropout)

    def forward(self,query,key,value,mask=None):
        d_k=query.size(-1)
        scores=torch.matmul(query,key.transpose(-2,-1))/math.sqrt(d_k)

        if mask is not None:
            scores=scores.masked_fill(mask==0,-1e9)

        attn_weights=torch.softmax(scores,dim=-1)
        attn_weights=self.dropout(attn_weights)
        output=torch.matmul(attn_weights,value)

        return output,attn_weights

# Multi-Head Attention

In [80]:
class MultiHeadAttention(nn.Module):
    def __init__(self,d_model,n_heads,dropout=0.1):
        super(MultiHeadAttention,self).__init__()
        assert d_model%n_heads==0,"d_model必须能被n_heads整除"

        self.d_model=d_model
        self.n_heads=n_heads
        self.d_k=d_model//n_heads

        self.w_q=nn.Linear(d_model,d_model)
        self.w_k=nn.Linear(d_model,d_model)
        self.w_v=nn.Linear(d_model,d_model)
        self.w_o=nn.Linear(d_model,d_model)

        self.attention=ScaledDotProductAttention(dropout)

    def forward(self,query,key,value,mask=None):
        batch_size=query.size(0)

        Q=self.w_q(query).view(batch_size,-1,self.n_heads,self.d_k).transpose(1,2)
        K=self.w_k(key).view(batch_size,-1,self.n_heads,self.d_k).transpose(1,2)
        V=self.w_v(value).view(batch_size,-1,self.n_heads,self.d_k).transpose(1,2)

        attn_output,attn_weights=self.attention(Q,K,V,mask)

        attn_output=attn_output.transpose(1,2).contiguous().view(batch_size,-1,self.d_model)
        output=self.w_o(attn_output)
        return output

# Position-wise Feed-Forward

In [81]:
class PositionwiseForward(nn.Module):
    def __init__(self,d_model,d_ff,dropout=0.1):
        super(PositionwiseForward,self).__init__()
        self.linear1=nn.Linear(d_model,d_ff)
        self.linear2=nn.Linear(d_ff,d_model)
        self.dropout=nn.Dropout(dropout)
        self.relu=nn.ReLU()

    def forward(self,x):
        return self.linear2(self.dropout(self.relu(self.linear1(x))))

# Encoder Layer

In [82]:
class EncoderLayer(nn.Module):
    def __init__(self,d_model,n_heads,d_ff,dropout=0.1):
        super(EncoderLayer,self).__init__()
        self.self_attn=MultiHeadAttention(d_model,n_heads,dropout)
        self.feed_forward=PositionwiseForward(d_model,d_ff,dropout)
        self.norm1=nn.LayerNorm(d_model)
        self.norm2=nn.LayerNorm(d_model)
        self.dropout=nn.Dropout(dropout)

    def forward(self,x,mask=None):
        attn_output=self.self_attn(x,x,x,mask)
        x=self.norm1(x+self.dropout(attn_output))

        ff_output=self.feed_forward(x)
        x=self.norm2(x+self.dropout(ff_output))

        return x


# Decoder Layer

In [83]:
class DecoderLayer(nn.Module):
    def __init__(self,d_model,n_heads,d_ff,dropout=0.1):
        super(DecoderLayer,self).__init__()
        self.self_attn=MultiHeadAttention(d_model,n_heads,dropout)
        self.cross_attn=MultiHeadAttention(d_model,n_heads,dropout)
        self.feed_forward=PositionwiseForward(d_model,d_ff,dropout)
        self.norm1=nn.LayerNorm(d_model)
        self.norm2=nn.LayerNorm(d_model)
        self.norm3=nn.LayerNorm(d_model)
        self.dropout=nn.Dropout(dropout)

    def forward(self,x,enc_output,src_mask,tgt_mask):
        attn_output=self.self_attn(x,x,x,tgt_mask)
        x=self.norm1(x+self.dropout(attn_output))

        attn_output=self.cross_attn(x,enc_output,enc_output,src_mask)
        x=self.norm2(x+self.dropout(attn_output))

        ff_output=self.feed_forward(x)
        x=self.norm3(x+self.dropout(ff_output))

        return x


# 编码器&解码器

In [84]:
class Encoder(nn.Module):
    def __init__(self,vocab_size,d_model,n_layers,n_heads,d_ff,max_len,dropout=0.1):
        super(Encoder,self).__init__()
        self.d_model=d_model
        self.embedding=nn.Embedding(vocab_size,d_model)
        self.pos_encoding=PositionalEncoding(d_model,max_len)
        self.layers=nn.ModuleList(
            [EncoderLayer(d_model,n_heads,d_ff,dropout)for _ in range(n_layers)]
        )
        self.dropout=nn.Dropout(dropout)

    def forward(self,x,mask):
        x=self.embedding(x)*math.sqrt(self.d_model)
        x=self.pos_encoding(x)
        x=self.dropout(x)
        for layer in self.layers:
            x=layer(x,mask)
        return x 
    
class Decoder(nn.Module):
    def __init__(self, vocab_size, d_model, n_layers, n_heads, d_ff, max_len, dropout=0.1):
        super(Decoder, self).__init__()
        self.d_model = d_model
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = PositionalEncoding(d_model, max_len)
        self.layers = nn.ModuleList([
            DecoderLayer(d_model, n_heads, d_ff, dropout) for _ in range(n_layers)
        ])
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, memory, tgt_mask=None, memory_mask=None):
        # x: [batch_size, tgt_seq_len]
        # memory: [batch_size, src_seq_len, d_model]
        x = self.embedding(x) * math.sqrt(self.d_model)  # scale embedding
        x = self.pos_encoding(x)
        x = self.dropout(x)

        for layer in self.layers:
            x = layer(x, memory, tgt_mask, memory_mask)

        return x  # [batch_size, tgt_seq_len, d_model]


# 完整的Transformer模型

In [85]:
class Transformer(nn.Module):
    def __init__(self,src_voacb_size,tgt_vocab_size,d_model,n_layers,n_heads,d_ff,max_len,dropout=0.1):
        super(Transformer,self).__init__()
        self.encoder=Encoder(src_voacb_size,d_model,n_layers,n_heads,d_ff,max_len,dropout)
        self.decoder=Decoder(tgt_vocab_size,d_model,n_layers,n_heads,d_ff,max_len,dropout)
        self.liner=nn.Linear(d_model,tgt_vocab_size)

    def forward(self,src,tgt,src_mask,tgt_mask):
        enc_output=self.encoder(src,src_mask)
        dec_output=self.decoder(tgt,enc_output,src_mask,tgt_mask)
        output=self.liner(dec_output)
        return output

# 掩码的生成

In [86]:
def create_mask(src, tgt, pad_idx):
    # 源序列填充掩码 [batch_size, 1, 1, src_len]
    src_mask = (src != pad_idx).unsqueeze(1).unsqueeze(2)
    
    # 目标序列填充掩码 [batch_size, 1, tgt_len, 1]
    tgt_pad_mask = (tgt != pad_idx).unsqueeze(1).unsqueeze(3)
    
    # 目标序列子序列掩码 [1, tgt_len, tgt_len]
    tgt_len = tgt.size(1)
    tgt_sub_mask = torch.tril(torch.ones((tgt_len, tgt_len), device=tgt.device)).bool()
    tgt_sub_mask = tgt_sub_mask.unsqueeze(0)
    
    # 组合目标序列掩码 [batch_size, 1, tgt_len, tgt_len]
    tgt_mask = tgt_pad_mask & tgt_sub_mask
    
    return src_mask, tgt_mask

# 数据处理与训练流程

In [87]:
def train_transformer(model,dataloder,optimizer,criterion,pad_idx,device,n_epochs):
    model.train()
    for epoch in range(n_epochs):
        total_loss=0
        for batch_idx,(src,tgt) in enumerate(dataloder):
            src,tgt=src.to(device),tgt.to(device)
            src_mask,tgt_mask=create_mask(src,tgt,pad_idx)
            src_mask,tgt_mask=src_mask.to(device),tgt_mask.to(device)

            optimizer.zero_grad()
            output = model(src, tgt[:, :-1], src_mask, tgt_mask)


            loss=criterion(
                output.contiguous().view(-1,output.size(-1)),
                tgt[:,1:].contiguous().view(-1)
            )

            loss.backward()
            optimizer.step()

            total_loss+=loss.item()

    avg_loss=total_loss/len(dataloder)
    print(f'Epoch{epoch+1}/{n_epochs},Average Loss:{avg_loss:.4f}')