In [13]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

/kaggle/input/chinese-couplets/couplet/vocabs
/kaggle/input/chinese-couplets/couplet/test/out.txt
/kaggle/input/chinese-couplets/couplet/test/in.txt
/kaggle/input/chinese-couplets/couplet/test/.in.txt.swp
/kaggle/input/chinese-couplets/couplet/test/.out.txt.swp
/kaggle/input/chinese-couplets/couplet/train/out.txt
/kaggle/input/chinese-couplets/couplet/train/in.txt


In [14]:
import tensorboard
print("TensorBoard 版本:", tensorboard.__version__)

TensorBoard 版本: 2.18.0


#### 带有attention的seq2seq模型

In [15]:
import torch
import torch.nn as nn

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [16]:
# 编码器
class Encoder(nn.Module):
    def __init__(self, input_dim, emb_dim, hidden_dim, dropout, hidden_form = 'concat'):
        super(Encoder, self).__init__()
        # 定义嵌入层
        self.embedding = nn.Embedding(input_dim, emb_dim)
        # 定义LSTM层
        self.rnn = nn.LSTM(emb_dim, hidden_dim, dropout=dropout, num_layers=2,
                          batch_first=True, bidirectional=True)
        self.hidden_form = hidden_form

    def forward(self, token_seq):
        # token_seq: [batch_size, seq_len]
        # embedded: [batch_size, seq_len, emb_dim]
        embedded = self.embedding(token_seq)
        # outputs: [batch_size, seq_len, hidden_dim * 2]
        # hidden: [4, batch_size, hidden_dim]
        outputs, (h_n, c_n)  = self.rnn(embedded)
        if (self.hidden_form == 'concat'):
            hidden_concat = torch.cat([h_n[0], h_n[1]], dim=1)
            hidden_concat = hidden_concat.unsqueeze(0).repeat(2, 1, 1)
            return hidden_concat, outputs
        elif (self.hidden_form == 'add'):
            hidden_sum = h_n.sum(dim=0)
            hidden_sum = hidden_sum.unsqueeze(0).repeat(2, 1, 1)
            return hidden_sum, outputs

# 测试
hidden_form1 = 'concat'
# hidden_form1 = 'add'

input_dim1 = 200
emb_dim1 = 100
hidden_dim1 = 256
dropout1 = 0.5
batch_size1 = 4
seq_len1 = 10

encoder1 = Encoder(input_dim1, emb_dim1, hidden_dim1, dropout1, hidden_form = hidden_form1)
token_seq1 = torch.randint(0, input_dim1, (batch_size1, seq_len1))
print(f'token_seq1.shape : {token_seq1.shape}')
hidden_state1, outputs1 = encoder1(token_seq1)
print(f'hidden_state1.shape : {hidden_state1.shape}')
print(f'outputs1.shape : {outputs1.shape}')

token_seq1.shape : torch.Size([4, 10])
hidden_state1.shape : torch.Size([2, 4, 512])
outputs1.shape : torch.Size([4, 10, 512])


In [17]:
# attention
class Attention(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, enc_output, dec_output):
        # a_t = h_t @ h_s
        a_t = torch.bmm(enc_output, dec_output.permute(0, 2, 1))
        # 1.计算 结合解码token和编码token，关联的权重
        a_t = torch.softmax(a_t, dim=1)
        # 2.计算 关联权重和编码token 贡献值
        c_t = torch.bmm(a_t.permute(0, 2, 1), enc_output)
        return c_t

In [18]:
# 解码器
class Decoder(nn.Module):
    def __init__(self, input_dim, emb_dim, hidden_dim, dropout, hidden_form = 'concat'):
        super(Decoder, self).__init__()
        # 定义嵌入层
        self.embedding = nn.Embedding(input_dim, emb_dim)
        # 定义LSTM层
        self.rnn = nn.LSTM(emb_dim, hidden_dim * 2, dropout=dropout,
                          num_layers=2,batch_first=True)
        # 定义线性层
        self.fc = nn.Linear(hidden_dim * 2, input_dim)  # 解码词典中词汇概率
        # attention层
        self.atteniton = Attention()
        # attention结果转换线性层
        self.atteniton_fc = nn.Linear(hidden_dim * 4, hidden_dim * 2)
        self.hidden_form = hidden_form

    def forward(self, token_seq, hidden_state, enc_output):
        # token_seq: [batch_size, seq_len]
        # embedded: [batch_size, seq_len, emb_dim]
        embedded = self.embedding(token_seq)
        # outputs: [batch_size, seq_len, hidden_dim * 2]
        # hidden: [2, batch_size, hidden_dim * 2]
        if (self.hidden_form == 'add'):
            hidden_state = hidden_state.repeat(1, 1, 2)
        dec_output, (h_n, c_n) = self.rnn(embedded, (hidden_state, torch.zeros_like(hidden_state)))
        
        # attention运算
        c_t = self.atteniton(enc_output, dec_output)
        # [attention, dec_output]
        cat_output = torch.cat((c_t, dec_output), dim=-1)
        # 线性运算
        out = torch.tanh(self.atteniton_fc(cat_output))

        # out: [batch_size, seq_len, hidden_dim * 2]
        logits = self.fc(out)
        return logits, h_n

# 测试
input_dim2 = 200
emb_dim2 = 100
hidden_dim2 = 256
dropout2 = 0.5
batch_size2 = 4
seq_len2 = 10

decoder2 = Decoder(input_dim2, emb_dim2, hidden_dim2, dropout2, hidden_form = hidden_form1)
token_seq2 = torch.randint(0, input_dim2, (batch_size2, seq_len2))
print(f'token_seq2.shape : {token_seq2.shape}')
logits2, hidden_state2= decoder2(token_seq2, hidden_state1, outputs1)
print(f'logits2.shape : {logits2.shape}')
print(f'hidden_state2.shape : {hidden_state2.shape}')

token_seq2.shape : torch.Size([4, 10])
logits2.shape : torch.Size([4, 10, 200])
hidden_state2.shape : torch.Size([2, 4, 512])


In [19]:
class Seq2Seq(nn.Module):

    def __init__(self,
                 enc_emb_size,
                 dec_emb_size,
                 emb_dim,
                 hidden_size,
                 dropout=0.5,
                 hidden_form = 'concat'
                 ):

        super().__init__()
        # encoder
        self.encoder = Encoder(enc_emb_size, emb_dim, hidden_size, dropout=dropout, hidden_form = hidden_form)
        # decoder
        self.decoder = Decoder(dec_emb_size, emb_dim, hidden_size, dropout=dropout, hidden_form = hidden_form)

    def forward(self, enc_input, dec_input):
        # encoder last hidden state
        encoder_state, outputs = self.encoder(enc_input)
        output,hidden = self.decoder(dec_input, encoder_state, outputs)

        return output,hidden

# 测试
seq2seq = Seq2Seq(
        enc_emb_size=input_dim2,
        dec_emb_size=input_dim2,
        emb_dim=emb_dim2,
        hidden_size=hidden_dim2,
        dropout=dropout2,
        hidden_form = hidden_form1
)

logits,_ = seq2seq(
    enc_input=torch.randint(0, input_dim1, (batch_size1, seq_len1)),
    dec_input=torch.randint(0, input_dim2, (batch_size2, seq_len2))
)
print(logits.shape)

torch.Size([4, 10, 200])


#### 数据处理

In [20]:
def read_data(enc_data_file, dec_data_file):
    """
    读取训练数据返回数据集合
    """
    enc_data,dec_data = [],[]

    with open(enc_data_file) as f1:
        # 读取记录行
        lines = f1.read().split('\n')
        for line in lines:
            if line == ' ':
                continue
            enc_tks = line.split()
            enc_data.append(enc_tks)

    with open(dec_data_file) as f2:
        # 读取记录行
        lines = f2.read().split('\n')
        for line in lines:
            if line == ' ':
                continue
            dec_tks = line.split()
            dec_tks = ['BOS'] + dec_tks + ['EOS']
            dec_data.append(dec_tks)

    # 断言
    assert len(enc_data) == len(dec_data), '编码数据与解码数据长度不一致！'

    return enc_data, dec_data

# 测试
enc_data_file1 = '/kaggle/input/chinese-couplets/couplet/train/in.txt'
dec_data_file1 = '/kaggle/input/chinese-couplets/couplet/train/out.txt'

enc_data_file2 = '/kaggle/input/chinese-couplets/couplet/test/in.txt'
dec_data_file2 = '/kaggle/input/chinese-couplets/couplet/test/out.txt'

enc_data1, dec_data1 = read_data(enc_data_file1, dec_data_file1)
print(len(enc_data1))
print(len(dec_data1))
print(enc_data1[:5])
print(dec_data1[:5])
print()

enc_data2, dec_data2 = read_data(enc_data_file2, dec_data_file2)
print(len(enc_data2))
print(len(dec_data2))
print(enc_data2[:5])
print(dec_data2[:5])

770492
770492
[['晚', '风', '摇', '树', '树', '还', '挺'], ['愿', '景', '天', '成', '无', '墨', '迹'], ['丹', '枫', '江', '冷', '人', '初', '去'], ['忽', '忽', '几', '晨', '昏', '，', '离', '别', '间', '之', '，', '疾', '病', '间', '之', '，', '不', '及', '终', '年', '同', '静', '好'], ['闲', '来', '野', '钓', '人', '稀', '处']]
[['BOS', '晨', '露', '润', '花', '花', '更', '红', 'EOS'], ['BOS', '万', '方', '乐', '奏', '有', '于', '阗', 'EOS'], ['BOS', '绿', '柳', '堤', '新', '燕', '复', '来', 'EOS'], ['BOS', '茕', '茕', '小', '儿', '女', '，', '孱', '羸', '若', '此', '，', '娇', '憨', '若', '此', '，', '更', '烦', '二', '老', '费', '精', '神', 'EOS'], ['BOS', '兴', '起', '高', '歌', '酒', '醉', '中', 'EOS']]

4001
4001
[['腾', '飞', '上', '铁', '，', '锐', '意', '改', '革', '谋', '发', '展', '，', '勇', '当', '千', '里', '马'], ['风', '弦', '未', '拨', '心', '先', '乱'], ['花', '梦', '粘', '于', '春', '袖', '口'], ['晋', '世', '文', '章', '昌', '二', '陆'], ['一', '句', '相', '思', '吟', '岁', '月']]
[['BOS', '和', '谐', '南', '供', '，', '安', '全', '送', '电', '保', '畅', '通', '，', '争', '做', '领', '头', '羊', 'EOS'], ['BOS', '夜', '幕', '已', '沉

In [21]:
def words_to_vocab(words_list):
        no_repeat_tokens = set()

        for word in words_list:
            no_repeat_tokens.update(list(word))  

        tokens = ['PAD','UNK'] + list(no_repeat_tokens)

        vocab = { tk:i for i, tk in enumerate(tokens)}

        return vocab

# 测试
import random

enc_vocab1 = words_to_vocab(enc_data1)
dec_vocab1 = words_to_vocab(dec_data1)

enc_vocab2 = words_to_vocab(enc_data2)
dec_vocab2 = words_to_vocab(dec_data2)

enc_keys1 = random.sample(list(enc_vocab1.keys()), 5)
random_enc_elements1 = {key: enc_vocab1[key] for key in enc_keys1}
print(f'random_enc_elements1:\n{random_enc_elements1}\n')

dec_keys1 = random.sample(list(dec_vocab1.keys()), 5)
random_dec_elements1 = {key: dec_vocab1[key] for key in dec_keys1}
print(f'random_dec_elements1:\n{random_dec_elements1}\n')

enc_keys2 = random.sample(list(enc_vocab2.keys()), 5)
random_enc_elements2 = {key: enc_vocab2[key] for key in enc_keys2}
print(f'random_enc_elements2:\n{random_enc_elements2}\n')

dec_keys2 = random.sample(list(dec_vocab2.keys()), 5)
random_dec_elements2 = {key: dec_vocab2[key] for key in dec_keys2}
print(f'random_dec_elements2:\n{random_dec_elements2}\n')

import pickle
with open('vocab1.bin','wb') as f:
    pickle.dump((enc_vocab1, dec_vocab1),f)
with open('vocab2.bin','wb') as f:
    pickle.dump((enc_vocab2, dec_vocab2),f)

random_enc_elements1:
{'崕': 731, '概': 1919, '軽': 6464, '跟': 6885, '窿': 7102}

random_dec_elements1:
{'臻': 3176, '陵': 3200, '胎': 3144, '吼': 2850, '妟': 6682}

random_enc_elements2:
{'德': 1346, '耘': 2981, '毫': 995, '软': 497, '渰': 2891}

random_dec_elements2:
{'藉': 2878, '辞': 923, '忑': 1867, '葩': 2152, '念': 2227}



In [22]:
def get_proc(enc_voc, dec_voc):

    # 嵌套函数定义
    # 外部函数变量生命周期会延续到内部函数调用结束 （闭包）

    def batch_proc(data):
        """
        批次数据处理并返回
        """
        enc_ids, dec_ids, labels = [],[],[]
        for enc,dec in data:
            # token -> token index
            enc_idx = [enc_voc[tk] for tk in enc]
            dec_idx = [dec_voc[tk] for tk in dec]

            # encoder_input
            enc_ids.append(torch.tensor(enc_idx))
            # decoder_input
            dec_ids.append(torch.tensor(dec_idx[:-1]))
            # label
            labels.append(torch.tensor(dec_idx[1:]))

        # 数据转换张量 [batch, max_token_len]
        # 用批次中最长token序列构建张量
        enc_input = pad_sequence(enc_ids, batch_first=True)
        dec_input = pad_sequence(dec_ids, batch_first=True)
        targets = pad_sequence(labels, batch_first=True)

        # 返回数据都是模型训练和推理的需要
        return enc_input, dec_input, targets

    # 返回回调函数
    return batch_proc

# 测试
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence

enc_input1, dec_input1, targets1 = [],[],[]

dataset1 = list(zip(enc_data1, dec_data1))

dataloader1 = DataLoader(
        dataset1,
        batch_size=2,
        shuffle=True,
        collate_fn=get_proc(enc_vocab1, dec_vocab1)
)

for enc_input1, dec_input1, targets1 in dataloader1:
    print(f'enc_input1.shape : {enc_input1.shape}')
    print(f'dec_input1.shape : {dec_input1.shape}')
    print(f'targets1.shape :   {targets1.shape}')
    break

enc_input1.shape : torch.Size([2, 19])
dec_input1.shape : torch.Size([2, 20])
targets1.shape :   torch.Size([2, 20])


#### 模型训练

In [24]:
import torch
import json
from torch.utils.data import DataLoader
import torch.optim as optim
import torch.nn as nn
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter()
train_loss_cnt = 0

dataloader2 = DataLoader(
        dataset1,
        batch_size=512,
        shuffle=True,
        collate_fn=get_proc(enc_vocab1, dec_vocab1)
)

model = Seq2Seq(
        enc_emb_size=len(enc_vocab1),
        dec_emb_size=len(dec_vocab1),
        emb_dim=200,
        hidden_size=256,
        dropout=0.4,
        hidden_form = hidden_form1
    )
model.to(device)

# 优化器、损失
optimizer = optim.AdamW(model.parameters(), weight_decay=0.01, lr=1e-3)
# optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

# 训练
for epoch in range(10):
    model.train()
    tpbar = tqdm(dataloader2)
    for enc_input, dec_input, targets in tpbar:
        enc_input = enc_input.to(device)
        dec_input = dec_input.to(device)
        targets = targets.to(device)

        # 前向传播
        logits, _ = model(enc_input, dec_input)

        # 计算损失
        # CrossEntropyLoss需要将logits和targets展平
        # logits: [batch_size, seq_len, vocab_size]
        # targets: [batch_size, seq_len]
        # 展平为 [batch_size * seq_len, vocab_size] 和 [batch_size * seq_len]
        loss = criterion(logits.view(-1, logits.size(-1)), targets.view(-1))

        # 反向传播
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        tpbar.set_description(f'Epoch {epoch+1}, Loss: {loss.item():.4f}')
        writer.add_scalar('Loss/train', loss.item(), train_loss_cnt)
        train_loss_cnt += 1

torch.save(model.state_dict(), 'seq2seq_state.bin')

Epoch 1, Loss: 1.6375: 100%|██████████| 1505/1505 [04:50<00:00,  5.19it/s]
Epoch 2, Loss: 1.5420: 100%|██████████| 1505/1505 [04:49<00:00,  5.20it/s]
Epoch 3, Loss: 1.2387: 100%|██████████| 1505/1505 [04:51<00:00,  5.17it/s]
Epoch 4, Loss: 1.2604: 100%|██████████| 1505/1505 [04:52<00:00,  5.15it/s]
Epoch 5, Loss: 1.2163: 100%|██████████| 1505/1505 [04:51<00:00,  5.17it/s]
Epoch 6, Loss: 1.2318: 100%|██████████| 1505/1505 [04:50<00:00,  5.18it/s]
Epoch 7, Loss: 1.1362: 100%|██████████| 1505/1505 [04:51<00:00,  5.16it/s]
Epoch 8, Loss: 1.2142: 100%|██████████| 1505/1505 [04:50<00:00,  5.18it/s]
Epoch 9, Loss: 1.1876: 100%|██████████| 1505/1505 [04:50<00:00,  5.18it/s]
Epoch 10, Loss: 1.0951: 100%|██████████| 1505/1505 [04:51<00:00,  5.16it/s]


#### seq2seq attention版的推理实现

In [25]:
 # 创建解码器反向字典
dvoc_inv = {v:k for k,v in dec_vocab1.items()}

# 随机选取测试样本
rnd_idx = random.randint(0, len(enc_data2))
enc_input = enc_data2[rnd_idx]
dec_output = dec_data2[rnd_idx]

enc_idx = torch.tensor([[enc_vocab1[tk] for tk in enc_input]])

print(f'enc_idx: {enc_idx}')
print(f'enc_idx.shape: {enc_idx.shape}')

# 最大解码长度
max_dec_len = len(dec_output)
print(f'max_dec_len: {max_dec_len}')

print('enc_input：',''.join(enc_input))
print("dec_output：", ''.join(dec_output))

enc_idx: tensor([[ 302,   22,  550, 1501, 5016, 3517,  978]])
enc_idx.shape: torch.Size([1, 7])
max_dec_len: 9
enc_input： 咬文嚼字常添瘦
dec_output： BOS忍气吞声每犯愁EOS


In [39]:
# 推理
enc_idx = enc_idx.to(device)
model.eval()
with torch.no_grad():
    # 编码器输出
    hidden_state, enc_outputs = model.encoder(enc_idx)

    # 解码器输入 shape [1,1]
    dec_input = torch.tensor([[dec_vocab1['BOS']]])
    dec_input = dec_input.to(device)

    # 循环decoder
    dec_tokens = []
    while True:
        if len(dec_tokens) >= max_dec_len:
            break
        logits, hidden_state = model.decoder(dec_input, hidden_state, enc_outputs)
        next_token = torch.argmax(logits, dim=-1)
        if dvoc_inv[next_token.squeeze().item()] == 'EOS':
            break
        dec_tokens.append(next_token.squeeze().item())
        dec_input = next_token

print("dec_eval：", ''.join([dvoc_inv[tk] for tk in dec_tokens]))

dec_eval： 把酒临风不减肥梅自
