In [40]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import math

In [41]:
class PositionalEncoding(nn.Module):

    def __init__(self, emb_size, dropout, maxlen=5000):
        super().__init__()
        # 行缩放指数值
        den = torch.exp(- torch.arange(0, emb_size, 2) * math.log(10000) / emb_size)
        # 位置编码索引 (5000,1)
        pos = torch.arange(0, maxlen).reshape(maxlen, 1)
        # 编码矩阵 (5000, emb_size)
        pos_embdding = torch.zeros((maxlen, emb_size))
        pos_embdding[:, 0::2] = torch.sin(pos * den)
        pos_embdding[:, 1::2] = torch.cos(pos * den)
        # 添加和batch对应维度 (1, 5000, emb_size)
        pos_embdding = pos_embdding.unsqueeze(-2)
        # dropout
        self.dropout = nn.Dropout(dropout)
        # 注册当前矩阵不参与参数更新
        self.register_buffer('pos_embedding', pos_embdding)

    def forward(self, token_embdding):
        token_len = token_embdding.size(1)  # token长度
        # (1, token_len, emb_size)
        add_emb = self.pos_embedding[:token_len, :] + token_embdding
        return self.dropout(add_emb)

class Seq2SeqTransformer(nn.Module):

    def __init__(self, d_model, nhead, num_enc_layers, num_dec_layers, 
                 dim_forward, dropout, enc_voc_size, dec_voc_size):
        super().__init__()
        # transformer
        self.transformer = nn.Transformer(d_model=d_model,
                                          nhead=nhead,
                                          num_encoder_layers=num_enc_layers,
                                          num_decoder_layers=num_dec_layers,
                                          dim_feedforward=dim_forward,
                                          dropout=dropout,
                                          batch_first=True)
        # encoder input embedding
        self.enc_emb = nn.Embedding(enc_voc_size, d_model)
        # decoder input embedding
        self.dec_emb = nn.Embedding(dec_voc_size, d_model)
        # predict generate linear
        self.predict = nn.Linear(d_model, dec_voc_size)  # token预测基于解码器词典
        # positional encoding
        self.pos_encoding = PositionalEncoding(d_model, dropout)

    def forward(self, enc_inp, dec_inp, tgt_mask, enc_pad_mask, dec_pad_mask):
        # multi head attention之前基于位置编码embedding生成
        enc_emb = self.pos_encoding(self.enc_emb(enc_inp))
        dec_emb = self.pos_encoding(self.dec_emb(dec_inp))
        # 调用transformer计算
        outs = self.transformer(src=enc_emb, tgt=dec_emb, tgt_mask=tgt_mask,
                         src_key_padding_mask=enc_pad_mask, 
                         tgt_key_padding_mask=dec_pad_mask)
        # 推理
        return self.predict(outs)
    
    # 推理环节使用方法
    def encode(self, enc_inp):
        enc_emb = self.pos_encoding(self.enc_emb(enc_inp))
        return self.transformer.encoder(enc_emb)
    
    def decode(self, dec_inp, memory, dec_mask):
        dec_emb = self.pos_encoding(self.dec_emb(dec_inp))
        return self.transformer.decoder(dec_emb, memory, dec_mask)

In [42]:
corpus= "人生得意须尽欢，莫使金樽空对月"
chs = list(corpus)

enc_tokens, dec_tokens = [],[]

for i in range(1,len(chs)):
    enc = chs[:i]
    dec = ['<s>'] + chs[i:] + ['</s>']
    enc_tokens.append(enc)
    dec_tokens.append(dec)

tokens = ['PAD','UNK','<s>','</s>'] + chs
vocab = { tk:i for i, tk in enumerate(tokens)}

In [None]:
from torch.nn.utils.rnn import pad_sequence
def get_proc(evocab, pad_idx=0):
    # 嵌套函数定义
    # 外部函数变量生命周期会延续到内部函数调用结束 （闭包）
    def batch_proc(data):
        """
        批次数据处理并返回
        """
        enc_ids, dec_ids, labels = [],[],[]
        for enc,dec in data:
            # token -> token index
            enc_idx = [vocab[tk] for tk in enc]
            dec_idx = [vocab[tk] for tk in dec]

            # encoder_input
            enc_ids.append(torch.tensor(enc_idx))
            # decoder_input
            dec_ids.append(torch.tensor(dec_idx[:-1]))
            # label
            labels.append(torch.tensor(dec_idx[1:]))

        
        # 数据转换张量 [batch, max_token_len]
        # 用批次中最长token序列构建张量
        enc_input = pad_sequence(enc_ids, batch_first=True)
        dec_input = pad_sequence(dec_ids, batch_first=True)
        targets = pad_sequence(labels, batch_first=True)

        # 生成编码器填充掩码 (batch_size, 1, src_len)
        enc_pad_mask = (enc_input == pad_idx).transpose(0, 1).unsqueeze(1)
        
        # 生成解码器填充掩码 (batch_size, 1, tgt_len)
        dec_pad_mask = (dec_input == pad_idx).transpose(0, 1).unsqueeze(1)
        
        # 生成目标序列的因果掩码 (tgt_len, tgt_len)
        tgt_len = dec_input.size(0)
        tgt_mask = torch.triu(torch.ones(tgt_len, tgt_len), diagonal=1).bool()

        # 返回数据都是模型训练和推理的需要
        return enc_input, dec_input, targets, tgt_mask, enc_pad_mask, dec_pad_mask

    # 返回回调函数
    return batch_proc



In [44]:
import torch
from torch.utils.data import DataLoader
import torch.optim as optim
import torch.nn as nn
from tqdm import tqdm

# 加载训练数据
enc_data = enc_tokens
dec_data = dec_tokens

ds = list(zip(enc_data,dec_data))
dl = DataLoader(ds, batch_size=2, shuffle=True, collate_fn=get_proc(vocab))


In [47]:
# 构建训练模型
# 模型构建
model = Seq2SeqTransformer(
    d_model = 32, 
    nhead = 8, 
    num_enc_layers = 8, 
    num_dec_layers = 8, 
    dim_forward = 2048, 
    dropout = 0.1, 
    enc_voc_size = len(vocab),
    dec_voc_size = len(vocab)
)

# 优化器、损失
optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

# 训练

for epoch in range(20):
    model.train()
    tpbar = tqdm(dl)
    for enc_input, dec_input, targets, tgt_mask, enc_pad_mask, dec_pad_mask in tpbar:

        # 前向传播 
        logits, _ = model(enc_input, dec_input, tgt_mask, enc_pad_mask, dec_pad_mask)#(enc_input, dec_input)

        # 计算损失
        # CrossEntropyLoss需要将logits和targets展平
        # logits: [batch_size, seq_len, vocab_size]
        # targets: [batch_size, seq_len]
        # 展平为 [batch_size * seq_len, vocab_size] 和 [batch_size * seq_len]
        loss = criterion(logits.view(-1, logits.size(-1)), targets.view(-1))

        # 反向传播
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        tpbar.set_description(f'Epoch {epoch+1}, Loss: {loss.item():.4f}')

torch.save(model.state_dict(), './Seq2SeqTransformer.bin')

  0%|          | 0/7 [01:58<?, ?it/s]


RuntimeError: The size of tensor a (3) must match the size of tensor b (2) at non-singleton dimension 0