In [1]:
import tensorflow as tf
import torch.optim as optim
import torch
import torch.nn as nn
import pickle

from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.layers import Input, Embedding, LSTM, Dense, Attention
from tensorflow.keras.models import Model
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'当前设备为',device)

当前设备为 cpu


In [3]:
# 编码器
class Encoder(nn.Module):
    def __init__(self, input_dim, emb_dim, hidden_dim, dropout, hidden_form = 'concat'):
        super(Encoder, self).__init__()
        # 定义嵌入层
        self.embedding = nn.Embedding(input_dim, emb_dim)
        # 定义LSTM层
        self.rnn = nn.LSTM(emb_dim, hidden_dim, dropout=dropout, num_layers=2,
                          batch_first=True, bidirectional=True)
        self.hidden_form = hidden_form

    def forward(self, token_seq):
        # token_seq: [batch_size, seq_len]
        # embedded: [batch_size, seq_len, emb_dim]
        embedded = self.embedding(token_seq)
        # outputs: [batch_size, seq_len, hidden_dim * 2]
        # hidden: [4, batch_size, hidden_dim]
        outputs, (h_n, c_n)  = self.rnn(embedded)
        if (self.hidden_form == 'concat'):
            hidden_concat = torch.cat([h_n[0], h_n[1]], dim=1)
            hidden_concat = hidden_concat.unsqueeze(0).repeat(2, 1, 1)
            return hidden_concat, outputs
        elif (self.hidden_form == 'add'):
            hidden_sum = h_n.sum(dim=0)
            hidden_sum = hidden_sum.unsqueeze(0).repeat(2, 1, 1)
            return hidden_sum, outputs

In [4]:
# Attention
class Attention(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, enc_output, dec_output):
        a_t = torch.bmm(enc_output, dec_output.permute(0, 2, 1))
        # 1.计算 结合解码token和编码token，关联的权重
        a_t = torch.softmax(a_t, dim=1)
        # 2.计算 关联权重和编码token 贡献值
        c_t = torch.bmm(a_t.permute(0, 2, 1), enc_output)
        return c_t

In [5]:
# 解码器
class Decoder(nn.Module):
    def __init__(self, input_dim, emb_dim, hidden_dim, dropout, hidden_form = 'concat'):
        super(Decoder, self).__init__()
        # 定义嵌入层
        self.embedding = nn.Embedding(input_dim, emb_dim)
        # 定义LSTM层
        self.rnn = nn.LSTM(emb_dim, hidden_dim * 2, dropout=dropout,
                          num_layers=2,batch_first=True)
        # 定义线性层
        self.fc = nn.Linear(hidden_dim * 2, input_dim)  # 解码词典中词汇概率
        # attention层
        self.atteniton = Attention()
        # attention结果转换线性层
        self.atteniton_fc = nn.Linear(hidden_dim * 4, hidden_dim * 2)
        self.hidden_form = hidden_form

    def forward(self, token_seq, hidden_state, enc_output):
        # token_seq: [batch_size, seq_len]
        # embedded: [batch_size, seq_len, emb_dim]
        embedded = self.embedding(token_seq)
        # outputs: [batch_size, seq_len, hidden_dim * 2]
        # hidden: [2, batch_size, hidden_dim * 2]
        if (self.hidden_form == 'add'):
            hidden_state = hidden_state.repeat(1, 1, 2)
        dec_output, (h_n, c_n) = self.rnn(embedded, (hidden_state, torch.zeros_like(hidden_state)))
        
        # attention运算
        c_t = self.atteniton(enc_output, dec_output)
        # [attention, dec_output]
        cat_output = torch.cat((c_t, dec_output), dim=-1)
        # 线性运算
        out = torch.tanh(self.atteniton_fc(cat_output))

        # out: [batch_size, seq_len, hidden_dim * 2]
        logits = self.fc(out)
        return logits, h_n

In [10]:
# 加载数据
def load_data(file_in_path,file_out_path):
    """
    读取记录数据并返回数据集合
    """
    encode_data = open_data(file_in_path)
    decode_data = open_data(file_out_path)
    
    # 编码数据与解码数据长度是否一致
    assert len(encode_data) == len(decode_data), '上下联原始数据长度不一致'
    return encode_data, decode_data

def open_data(file_path):
    """
    打开文件录入行信息
    """
    data = []
    with open(file_path,encoding='utf-8') as f1:
            # 读取记录行
            lines = f1.read().split('\n')
            for line in lines:
                if line == ' ':
                    continue
                tokens = line.split()
                data.append(tokens)
    return data

def words_to_vocab(words_list):
        """
        从输入的单词列表中构建一个词汇表(vocabulary)
        """
        no_repeat_tokens = set()
        for word in words_list:
            no_repeat_tokens.update(list(word))  
        tokens = ['PAD','UNK'] + list(no_repeat_tokens)
        vocabs = { tk:i for i, tk in enumerate(tokens)}
        return vocabs

def dump_vocab(path,data_in_vocab,data_out_vocab):
    with open(path,'wb') as f:
        pickle.dump((data_in_vocab, data_out_vocab),f)

In [14]:
def data_process(enc_voc, dec_voc):
    def batch_process(data):
        """
        批次数据处理并返回
        """
        enc_ids, dec_ids, labels = [],[],[]
        for enc,dec in data:
            # token -> token index
            enc_idx = [enc_voc[tk] for tk in enc]
            dec_idx = [dec_voc[tk] for tk in dec]

            enc_ids.append(torch.tensor(enc_idx).long())
            dec_ids.append(torch.tensor(dec_idx[:-1]).long())
            labels.append(torch.tensor(dec_idx[1:]).long())
        # 构建张量
        enc_input = pad_sequence(enc_ids, batch_first=True)
        dec_input = pad_sequence(dec_ids, batch_first=True)
        targets = pad_sequence(labels, batch_first=True)
        return enc_input, dec_input, targets
    return batch_process


In [7]:
class Seq2Seq(nn.Module):
    def __init__(self,
                 enc_emb_size,
                 dec_emb_size,
                 emb_dim,
                 hidden_size,
                 dropout=0.5,
                 hidden_form = 'concat'
                 ):

        super().__init__()
        self.encoder = Encoder(enc_emb_size, emb_dim, hidden_size, dropout=dropout, hidden_form = hidden_form)
        self.decoder = Decoder(dec_emb_size, emb_dim, hidden_size, dropout=dropout, hidden_form = hidden_form)

    def forward(self, enc_input, dec_input):
        encoder_state, outputs = self.encoder(enc_input)
        output,hidden = self.decoder(dec_input, encoder_state, outputs)

        return output,hidden

In [8]:
# 测试
train_data_in_path = './couplet/train/in.txt'
train_data_out_path = './couplet/train/out.txt'

test_data_in_path = './couplet/test/in.txt'
test_data_out_path = './couplet/test/out.txt'

# 路径参数
model_path='./model/couplet.pt',
logs_path = 'D:/logs/'

# 配置文件
hidden_form1 = 'concat'
lr1=1e-3
epochs=10


In [11]:
train_data_in ,train_data_out = load_data(train_data_in_path,train_data_out_path)
test_data_in ,test_data_out = load_data(test_data_in_path,test_data_out_path)

# 构建词汇表
train_data_in_vocab  = words_to_vocab(train_data_in)
train_data_out_vocab = words_to_vocab(train_data_out)
test_data_in_vocab  = words_to_vocab(test_data_in)
test_data_out_vocab = words_to_vocab(test_data_out)

dump_vocab('./couplet/train/vocab.bin',train_data_in_vocab, train_data_out_vocab)
dump_vocab('./couplet/test/vocab.bin',test_data_in_vocab, test_data_out_vocab)

print(f'训练集上联',len(train_data_in))
print(f'训练集下联',len(train_data_out))
print(f'测试集上联',len(test_data_in))
print(f'测试集下联',len(test_data_out))

训练集上联 770492
训练集下联 770492
测试集上联 4001
测试集下联 4001


In [12]:
writer = SummaryWriter(logs_path)

model = Seq2Seq(
        enc_emb_size=len(train_data_in),
        dec_emb_size=len(train_data_out),
        emb_dim=200,
        hidden_size=256,
        dropout=0.4,
        hidden_form = hidden_form1
    )
model.to(device)


# 优化器、损失函数
optimizer = optim.AdamW(model.parameters(), weight_decay=0.01, lr=lr1)
criterion = nn.CrossEntropyLoss()

In [15]:
dataloader = DataLoader(
        list(zip(train_data_in,train_data_out)),
        batch_size=32,
        shuffle=True,
        collate_fn=data_process(train_data_in_vocab, train_data_out_vocab)
)
loss_cnt = 0
for epoch in range(epochs):
    model.train()
    tpbar = tqdm(dataloader)
    for enc_input, dec_input, targets in tpbar:
        enc_input = enc_input.to(device)
        dec_input = dec_input.to(device)
        targets = targets.to(device)

        # 前向传播
        logits, _ = model(enc_input, dec_input)

        # 计算损失
        loss = criterion(logits.view(-1, logits.size(-1)), targets.view(-1))

        # 反向传播
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        tpbar.set_description(f'Epoch {epoch+1}, Loss: {loss.item():.4f}')
        writer.add_scalar('Loss/train', loss.item(), loss_cnt)
        loss_cnt += 1

Epoch 1, Loss: 2.0608:  15%|█▌        | 3673/24078 [7:29:35<41:12:35,  7.27s/it]

: 