In [4]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

/kaggle/input/couplet-test-data/out.txt
/kaggle/input/couplet-test-data/in.txt


In [1]:
import torch
import torch.nn as nn
import math
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
from collections import OrderedDict
from torch.utils.data import Dataset
from tqdm import tqdm

# 位置编码
class PositionalEncoding(nn.Module):
    def __init__(self, emb_size, dropout, maxlen=5000):
        super().__init__()
        den = torch.exp(- torch.arange(0, emb_size, 2) * math.log(10000) / emb_size)
        pos = torch.arange(0, maxlen).reshape(maxlen, 1)
        pos_embedding = torch.zeros(maxlen, emb_size)
        pos_embedding[:, 0::2] = torch.sin(pos * den)
        pos_embedding[:, 1::2] = torch.cos(pos * den)
        # 添加批次维度
        pos_embedding = pos_embedding.unsqueeze(-2)
        self.dropout = nn.Dropout(dropout)
        # positional encoding注册为不需要作为模型参数的缓冲中
        self.register_buffer('pos_embedding', pos_embedding)
    def forward(self, token_embedding):
        return self.dropout(token_embedding + self.pos_embedding[:token_embedding.size(0), :])

# 将输入索引张量转换为token embedding张量
class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size, emb_size):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, emb_size)
        self.emb_size = emb_size
    def forward(self, tokens):
        return self.embedding(tokens.long()) * math.sqrt(self.emb_size)
    
class Seq2SeqTransformer(nn.Module):
    def __init__(self, num_encoder_layers, num_decoder_layers, emb_size, nhead, src_vocab_size, tgt_vocab_size, 
                 dim_feedforward=512, dropout=0.1):
        super().__init__()
        self.transformer = nn.Transformer(d_model=emb_size,
                                          nhead=nhead,
                                          num_encoder_layers=num_encoder_layers,
                                          num_decoder_layers=num_decoder_layers,
                                          dim_feedforward=dim_feedforward,
                                          dropout=dropout,
                                          batch_first=True)
        self.generator = nn.Linear(emb_size, tgt_vocab_size)
        self.src_token_emb = TokenEmbedding(src_vocab_size, emb_size)
        self.tgt_token_emb = TokenEmbedding(tgt_vocab_size, emb_size)
        self.positional_encoding = PositionalEncoding(emb_size, dropout)
    def forward(self, src, tgt, src_mask, tgt_mask, src_key_padding_mask, tgt_key_padding_mask, memory_key_padding_mask):
        src_emb = self.positional_encoding(self.src_token_emb(src))
        tgt_emb = self.positional_encoding(self.tgt_token_emb(tgt))
        outs = self.transformer(src_emb, tgt_emb, src_mask, tgt_mask, None, 
                                src_key_padding_mask, tgt_key_padding_mask, memory_key_padding_mask)
        return self.generator(outs)
    def encode(self, src, src_mask):
        return self.transformer.encoder(self.positional_encoding(self.src_token_emb(src)), src_mask)
    def decode(self, tgt, memory, tgt_mask):
        return self.transformer.decoder(self.positional_encoding(self.tgt_token_emb(tgt)), memory, tgt_mask)

# 自定义数据集
class Seq2SeqDataset(Dataset):
    def __init__(self, encode_datas, decode_datas):
        super().__init__()
        self.encode_datas = encode_datas
        self.decode_datas = decode_datas
        self.encode_vocab = self.build_vocab(encode_datas, fill_mask = ['PAD', 'EOS', 'UNK'])
        self.decode_vocab = self.build_vocab(decode_datas, fill_mask = ['PAD', 'BOS', 'EOS', 'UNK'])
    def __getitem__(self, index):
        enc = list(self.encode_datas[index]) + ['EOS']
        dec = ['BOS'] + list(self.decode_datas[index]) + ['EOS']
        e = [self.encode_vocab.get(tk, self.encode_vocab['UNK']) for tk in enc]
        d = [self.decode_vocab.get(tk, self.decode_vocab['UNK']) for tk in dec]
        return e,d
    def __len__(self):
        return len(self.encode_datas)
    # 构建词汇表
    def build_vocab(self, datas, fill_mask):
        vocab = OrderedDict({msk: idx for idx, msk in enumerate(fill_mask)})
        for item in datas:
            for token in list(item):
                vocab[token] = vocab.get(token, len(vocab))
        return vocab

def read_couplet(path):
    datas = []
    with open(path, encoding='utf-8') as f:
        lines = f.readlines()
        for l in lines:
            datas.append(l.strip().split())
    return datas

def build_dataloader(dataset, batch_size, shuffle = False):
    def collate_batch(batch):
        encode_list, decode_list = [], []
        for encode, decode in batch:
            encode_list.append(torch.tensor(encode, dtype=torch.int64))
            decode_list.append(torch.tensor(decode, dtype=torch.int64))
        encode_list = pad_sequence(encode_list, batch_first=True, padding_value=0)
        decode_list = pad_sequence(decode_list, batch_first=True, padding_value=0)
        return encode_list, decode_list
    return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_batch)
def generate_square_subsequent_mask(sz):
    mask = (torch.triu(torch.ones((sz, sz))) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask
def create_mask(src, tgt):
    src_seq_len = src.shape[1]
    tgt_seq_len = tgt.shape[1]
    tgt_mask = generate_square_subsequent_mask(tgt_seq_len)
    src_mask = torch.zeros((src_seq_len, src_seq_len)).type(torch.bool)
    src_padding_mask = src == 0
    tgt_padding_mask = tgt == 0
    return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask
# 训练模型
def train_epoch(epoch, train_dl, model, loss_fn, optimizer):
    model.train()
    losses = 0
    train_bar = tqdm(train_dl)
    for src, tgt in train_bar:
        src = src.to(DEVICE)
        tgt = tgt.to(DEVICE)
        tgt_input = tgt[:, :-1]
        src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)
        src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = src_mask.to(DEVICE), tgt_mask.to(DEVICE), src_padding_mask.to(DEVICE), tgt_padding_mask.to(DEVICE)
        logits = model(src, tgt_input, src_mask, tgt_mask, src_padding_mask, tgt_padding_mask, src_padding_mask)
        optimizer.zero_grad()
        tgt_output = tgt[:, 1:]
        loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_output.reshape(-1))
        loss.backward()
        optimizer.step()
        losses += loss.item()
        train_bar.set_description(f"Epoch {epoch + 1}, Loss: {loss.item()}")
    avg_loss = losses/len(train_dl)
    print(f'epoch={epoch + 1}, avg_loss={avg_loss}')
    return losses/len(train_dl)

In [2]:
encode_datas = read_couplet('/kaggle/input/couplet-test-data/in.txt')
decode_datas = read_couplet('/kaggle/input/couplet-test-data/out.txt')
dataset = Seq2SeqDataset(encode_datas, decode_datas)
# 初始化超参数
LR = 0.0001
SRC_VOCAB_SIZE = len(dataset.encode_vocab)
TGT_VOCAB_SIZE = len(dataset.decode_vocab)
EMB_SIZE = 512
NHEAD = 8
FFN_HID_DIM = 2048
BATCH_SIZE = 64
NUM_ENCODER_LAYERS = 3
NUM_DECODER_LAYERS = 3
NUM_EPOCHS = 60
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 模型初始化
transformer = Seq2SeqTransformer(NUM_ENCODER_LAYERS, NUM_DECODER_LAYERS, EMB_SIZE, 
                                 NHEAD, SRC_VOCAB_SIZE, TGT_VOCAB_SIZE, FFN_HID_DIM)
transformer = transformer.to(DEVICE)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(transformer.parameters(), lr=LR)
# 训练模型
train_dl = build_dataloader(dataset, BATCH_SIZE)
for epoch in range(NUM_EPOCHS):
    train_epoch(epoch, train_dl, transformer, loss_fn, optimizer)
torch.save(transformer.state_dict(), 'model_new.pth')

Epoch 1, Loss: 2.8482983112335205: 100%|██████████| 63/63 [00:06<00:00, 10.09it/s]


epoch=1, avg_loss=3.1825558533744207


Epoch 2, Loss: 2.74883770942688: 100%|██████████| 63/63 [00:05<00:00, 11.34it/s]  


epoch=2, avg_loss=2.6722925466204446


Epoch 3, Loss: 2.684967041015625: 100%|██████████| 63/63 [00:05<00:00, 11.17it/s] 


epoch=3, avg_loss=2.6173970207335455


Epoch 4, Loss: 2.6365303993225098: 100%|██████████| 63/63 [00:05<00:00, 11.05it/s]


epoch=4, avg_loss=2.5777441547030495


Epoch 5, Loss: 2.602055311203003: 100%|██████████| 63/63 [00:05<00:00, 10.96it/s] 


epoch=5, avg_loss=2.5468029294695174


Epoch 6, Loss: 2.5356009006500244: 100%|██████████| 63/63 [00:05<00:00, 10.88it/s]


epoch=6, avg_loss=2.512404835413373


Epoch 7, Loss: 2.4306538105010986: 100%|██████████| 63/63 [00:05<00:00, 10.75it/s]


epoch=7, avg_loss=2.4639988815973677


Epoch 8, Loss: 2.3530731201171875: 100%|██████████| 63/63 [00:05<00:00, 10.62it/s]


epoch=8, avg_loss=2.412877457482474


Epoch 9, Loss: 2.242774724960327: 100%|██████████| 63/63 [00:06<00:00, 10.47it/s] 


epoch=9, avg_loss=2.3617190709189764


Epoch 10, Loss: 2.1533074378967285: 100%|██████████| 63/63 [00:06<00:00, 10.31it/s]


epoch=10, avg_loss=2.315885753858657


Epoch 11, Loss: 2.074043035507202: 100%|██████████| 63/63 [00:06<00:00, 10.13it/s] 


epoch=11, avg_loss=2.258560712375338


Epoch 12, Loss: 1.9774694442749023: 100%|██████████| 63/63 [00:06<00:00, 10.06it/s]


epoch=12, avg_loss=2.2029240017845515


Epoch 13, Loss: 1.8835729360580444: 100%|██████████| 63/63 [00:06<00:00,  9.91it/s]


epoch=13, avg_loss=2.14374479982588


Epoch 14, Loss: 1.7768335342407227: 100%|██████████| 63/63 [00:06<00:00,  9.71it/s]


epoch=14, avg_loss=2.092683220666552


Epoch 15, Loss: 1.7136436700820923: 100%|██████████| 63/63 [00:06<00:00,  9.47it/s]


epoch=15, avg_loss=2.0452617387922984


Epoch 16, Loss: 1.6420565843582153: 100%|██████████| 63/63 [00:06<00:00,  9.30it/s]


epoch=16, avg_loss=2.002429349081857


Epoch 17, Loss: 1.5596915483474731: 100%|██████████| 63/63 [00:06<00:00,  9.09it/s]


epoch=17, avg_loss=1.9445921144788227


Epoch 18, Loss: 1.4509209394454956: 100%|██████████| 63/63 [00:07<00:00,  8.94it/s]


epoch=18, avg_loss=1.8783104608929346


Epoch 19, Loss: 1.358810544013977: 100%|██████████| 63/63 [00:07<00:00,  8.84it/s] 


epoch=19, avg_loss=1.821823307446071


Epoch 20, Loss: 1.2800512313842773: 100%|██████████| 63/63 [00:07<00:00,  9.00it/s]


epoch=20, avg_loss=1.7703121824870034


Epoch 21, Loss: 1.1821765899658203: 100%|██████████| 63/63 [00:06<00:00,  9.17it/s]


epoch=21, avg_loss=1.7197475509038047


Epoch 22, Loss: 1.1165297031402588: 100%|██████████| 63/63 [00:06<00:00,  9.35it/s]


epoch=22, avg_loss=1.6713562238784063


Epoch 23, Loss: 1.0431289672851562: 100%|██████████| 63/63 [00:06<00:00,  9.43it/s]


epoch=23, avg_loss=1.6218054237819852


Epoch 24, Loss: 0.9686458110809326: 100%|██████████| 63/63 [00:06<00:00,  9.47it/s]


epoch=24, avg_loss=1.5664033303185114


Epoch 25, Loss: 0.8697195053100586: 100%|██████████| 63/63 [00:06<00:00,  9.48it/s]


epoch=25, avg_loss=1.5096670502708072


Epoch 26, Loss: 0.832181990146637: 100%|██████████| 63/63 [00:06<00:00,  9.45it/s] 


epoch=26, avg_loss=1.4528756264656308


Epoch 27, Loss: 0.7546163201332092: 100%|██████████| 63/63 [00:06<00:00,  9.41it/s]


epoch=27, avg_loss=1.3982609027907962


Epoch 28, Loss: 0.6864866614341736: 100%|██████████| 63/63 [00:06<00:00,  9.34it/s]


epoch=28, avg_loss=1.3500013587966797


Epoch 29, Loss: 0.6386058330535889: 100%|██████████| 63/63 [00:06<00:00,  9.29it/s]


epoch=29, avg_loss=1.300939505062406


Epoch 30, Loss: 0.5676818490028381: 100%|██████████| 63/63 [00:06<00:00,  9.27it/s]


epoch=30, avg_loss=1.2572637030056544


Epoch 31, Loss: 0.5394834280014038: 100%|██████████| 63/63 [00:06<00:00,  9.22it/s]


epoch=31, avg_loss=1.2087750624096583


Epoch 32, Loss: 0.5003144145011902: 100%|██████████| 63/63 [00:06<00:00,  9.22it/s]


epoch=32, avg_loss=1.1575375123629494


Epoch 33, Loss: 0.43488609790802: 100%|██████████| 63/63 [00:06<00:00,  9.27it/s]  


epoch=33, avg_loss=1.1044778937385196


Epoch 34, Loss: 0.40678152441978455: 100%|██████████| 63/63 [00:06<00:00,  9.27it/s]


epoch=34, avg_loss=1.0411704300888


Epoch 35, Loss: 0.37820684909820557: 100%|██████████| 63/63 [00:06<00:00,  9.31it/s]


epoch=35, avg_loss=0.9972410921066527


Epoch 36, Loss: 0.3285517394542694: 100%|██████████| 63/63 [00:06<00:00,  9.34it/s]


epoch=36, avg_loss=0.9385514841193244


Epoch 37, Loss: 0.2885262668132782: 100%|██████████| 63/63 [00:06<00:00,  9.33it/s]


epoch=37, avg_loss=0.8842848960369353


Epoch 38, Loss: 0.26354658603668213: 100%|██████████| 63/63 [00:06<00:00,  9.29it/s]


epoch=38, avg_loss=0.8313889550784278


Epoch 39, Loss: 0.24906182289123535: 100%|██████████| 63/63 [00:06<00:00,  9.35it/s]


epoch=39, avg_loss=0.7905436148719182


Epoch 40, Loss: 0.22239446640014648: 100%|██████████| 63/63 [00:06<00:00,  9.33it/s]


epoch=40, avg_loss=0.7387698141355363


Epoch 41, Loss: 0.20700079202651978: 100%|██████████| 63/63 [00:06<00:00,  9.37it/s]


epoch=41, avg_loss=0.6988023141073803


Epoch 42, Loss: 0.19877149164676666: 100%|██████████| 63/63 [00:06<00:00,  9.38it/s]


epoch=42, avg_loss=0.6596921500232484


Epoch 43, Loss: 0.1709120124578476: 100%|██████████| 63/63 [00:06<00:00,  9.37it/s]


epoch=43, avg_loss=0.6134581066786297


Epoch 44, Loss: 0.15418429672718048: 100%|██████████| 63/63 [00:06<00:00,  9.39it/s]


epoch=44, avg_loss=0.5807452686722316


Epoch 45, Loss: 0.1533908247947693: 100%|██████████| 63/63 [00:06<00:00,  9.38it/s] 


epoch=45, avg_loss=0.5405890061741784


Epoch 46, Loss: 0.15430134534835815: 100%|██████████| 63/63 [00:06<00:00,  9.38it/s]


epoch=46, avg_loss=0.510881804757648


Epoch 47, Loss: 0.12811340391635895: 100%|██████████| 63/63 [00:06<00:00,  9.39it/s]


epoch=47, avg_loss=0.47812619781683363


Epoch 48, Loss: 0.12954308092594147: 100%|██████████| 63/63 [00:06<00:00,  9.39it/s]


epoch=48, avg_loss=0.44675822745239924


Epoch 49, Loss: 0.11758356541395187: 100%|██████████| 63/63 [00:06<00:00,  9.40it/s]


epoch=49, avg_loss=0.41497572345866096


Epoch 50, Loss: 0.10317942500114441: 100%|██████████| 63/63 [00:06<00:00,  9.39it/s]


epoch=50, avg_loss=0.3851585473333086


Epoch 51, Loss: 0.11148757487535477: 100%|██████████| 63/63 [00:06<00:00,  9.37it/s]


epoch=51, avg_loss=0.3624423752937998


Epoch 52, Loss: 0.09324676543474197: 100%|██████████| 63/63 [00:06<00:00,  9.38it/s]


epoch=52, avg_loss=0.33284033231792


Epoch 53, Loss: 0.09522426873445511: 100%|██████████| 63/63 [00:06<00:00,  9.41it/s]


epoch=53, avg_loss=0.30486582089511177


Epoch 54, Loss: 0.08078301697969437: 100%|██████████| 63/63 [00:06<00:00,  9.40it/s]


epoch=54, avg_loss=0.2804211563770733


Epoch 55, Loss: 0.07565359026193619: 100%|██████████| 63/63 [00:06<00:00,  9.40it/s]


epoch=55, avg_loss=0.26027760465466787


Epoch 56, Loss: 0.07522259652614594: 100%|██████████| 63/63 [00:06<00:00,  9.40it/s]


epoch=56, avg_loss=0.24225775779239714


Epoch 57, Loss: 0.059968963265419006: 100%|██████████| 63/63 [00:06<00:00,  9.41it/s]


epoch=57, avg_loss=0.2243859630728525


Epoch 58, Loss: 0.0752275139093399: 100%|██████████| 63/63 [00:06<00:00,  9.43it/s] 


epoch=58, avg_loss=0.20570194650264012


Epoch 59, Loss: 0.06385817378759384: 100%|██████████| 63/63 [00:06<00:00,  9.40it/s]


epoch=59, avg_loss=0.19258426780265475


Epoch 60, Loss: 0.05539719760417938: 100%|██████████| 63/63 [00:06<00:00,  9.39it/s]


epoch=60, avg_loss=0.18152227997779846


In [17]:
# 推理
def greedy_decode(model, src, src_mask, max_len, start_symbol):
    src, src_mask = src.to(DEVICE), src_mask.to(DEVICE)
    memory = model.encode(src, src_mask)
    # decode的第一个token：BOS
    ys = torch.ones(1,1).fill_(start_symbol).type(torch.long).to(DEVICE)
    for i in range(max_len):
        memory = memory.to(DEVICE)
        tgt_mask = (generate_square_subsequent_mask(ys.shape[1])).to(DEVICE)
        out = model.decode(ys, memory, tgt_mask)
        prob = model.generator(out[:, -1])
        _, next_word = torch.max(prob, dim=1)
        next_word = next_word.item()
        ys = torch.cat([ys, torch.ones(1,1).fill_(next_word).to(DEVICE)], dim=1)
        if next_word == dataset.decode_vocab['EOS']:
            break
    return ys
def translate(model, src_sentence):
    model.eval()
    src = torch.tensor([dataset.encode_vocab[tk] for tk in list(src_sentence) + ['EOS']]).reshape(1, -1)
    num_tokens = src.shape[1]
    src_mask = torch.zeros(num_tokens, num_tokens).type(torch.bool)
    tgt_tokens = greedy_decode(model, src, src_mask, num_tokens, dataset.decode_vocab['BOS']).reshape(-1)
    return ''.join([decode_vocab_rev[tk] for tk in list(tgt_tokens.cpu().numpy())][1:-1])

In [18]:
decode_vocab_rev = {v:k for k, v in dataset.decode_vocab.items()}
# transformer = Seq2SeqTransformer(NUM_ENCODER_LAYERS, NUM_DECODER_LAYERS, EMB_SIZE, 
#                                      NHEAD, SRC_VOCAB_SIZE, TGT_VOCAB_SIZE, FFN_HID_DIM)
# transformer = transformer.to(DEVICE)
# transformer.load_state_dict(torch.load("kaggle_output/model_new.pth"))
sen_enc = '陆才吟岁，心定方知时日快'
sen_dec = translate(transformer, sen_enc)
print(sen_enc)
print(sen_dec)

陆才吟岁，心定方知时日快
果真情人常是，应当达民族
