In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

/kaggle/input/testdate/in.txt
/kaggle/input/couplet/vocab.bin
/kaggle/input/couplet/encoder.json
/kaggle/input/couplet/decoder.json


In [2]:
import torch
import torch.nn as nn

# 编码器
class Encoder(nn.Module):
    def __init__(self, input_dim, emb_dim, hidden_dim, dropout):
        super(Encoder, self).__init__()
        # 定义嵌入层
        self.embedding = nn.Embedding(input_dim, emb_dim)
        # 定义GRU层
        self.rnn = nn.GRU(emb_dim, hidden_dim,dropout=dropout, 
                          batch_first=True, bidirectional=True)

    def forward(self, token_seq):
        # token_seq: [batch_size, seq_len]
        # embedded: [batch_size, seq_len, emb_dim]
        embedded = self.embedding(token_seq)

        # outputs: [batch_size, seq_len, hidden_dim * 2]
        # hidden: [2, batch_size, hidden_dim]
        outputs, hidden = self.rnn(embedded)

        # 返回，Encoder最后一个时间步的隐藏状态(拼接)
        # return outputs[:, -1, :]
        # 返回最后一个时间步的隐藏状态(拼接)
        # return torch.cat((hidden[0], hidden[1]), dim=1)
        # 返回最后一个时间步的隐状态（相加）
        return hidden.sum(dim=0)

# 解码器
class Decoder(nn.Module):
    def __init__(self, input_dim, emb_dim, hidden_dim, dropout):
        super(Decoder, self).__init__()
        # 定义嵌入层
        self.embedding = nn.Embedding(input_dim, emb_dim)
        # 定义GRU层
        self.rnn = nn.GRU(emb_dim, hidden_dim , dropout=dropout,
                          batch_first=True)
        # 定义线性层
        self.fc = nn.Linear(hidden_dim , input_dim)  # 解码词典中词汇概率

    def forward(self, token_seq, hidden_state):
        # token_seq: [batch_size, seq_len]
        # embedded: [batch_size, seq_len, emb_dim]
        embedded = self.embedding(token_seq)

        # outputs: [batch_size, seq_len, hidden_dim * 2]
        # hidden: [1, batch_size, hidden_dim * 2]
        outputs, hidden = self.rnn(embedded, hidden_state.unsqueeze(0))

        # logits: [batch_size, seq_len, input_dim]
        logits = self.fc(outputs)
        return logits, hidden

class Seq2Seq(nn.Module):

    def __init__(self,
                 enc_emb_size, 
                 dec_emb_size,
                 emb_dim,
                 hidden_size,
                 dropout=0.5,
                 ):
        
        super().__init__()

        # encoder
        self.encoder = Encoder(enc_emb_size, emb_dim, hidden_size, dropout=dropout)
        # decoder
        self.decoder = Decoder(dec_emb_size, emb_dim, hidden_size, dropout=dropout)


    def forward(self, enc_input, dec_input):
        # encoder last hidden state
        encoder_state = self.encoder(enc_input)
        output,hidden = self.decoder(dec_input, encoder_state)

        return output,hidden

In [3]:
from torch.nn.utils.rnn import pad_sequence
def get_proc(enc_voc, dec_voc):

    # 嵌套函数定义
    # 外部函数变量生命周期会延续到内部函数调用结束 （闭包）

    def batch_proc(data):
        """
        批次数据处理并返回
        """
        enc_ids, dec_ids, labels = [],[],[]
        for enc,dec in data:
            # token -> token index
            enc_idx = [enc_voc[tk] for tk in enc]
            dec_idx = [dec_voc[tk] for tk in dec]

            # encoder_input
            enc_ids.append(torch.tensor(enc_idx))
            # decoder_input
            dec_ids.append(torch.tensor(dec_idx[:-1]))
            # label
            labels.append(torch.tensor(dec_idx[1:]))

        
        # 数据转换张量 [batch, max_token_len]
        # 用批次中最长token序列构建张量
        enc_input = pad_sequence(enc_ids, batch_first=True)
        dec_input = pad_sequence(dec_ids, batch_first=True)
        targets = pad_sequence(labels, batch_first=True)

        # 返回数据都是模型训练和推理的需要
        return enc_input, dec_input, targets

    # 返回回调函数
    return batch_proc   

In [4]:
import pickle
import torch
import json
from torch.utils.data import DataLoader
import torch.optim as optim
import torch.nn as nn
from tqdm import tqdm

device = torch.device('cuda')

# 加载训练数据
with open('/kaggle/input/couplet/vocab.bin','rb') as f:
    evoc,dvoc = pickle.load(f)

with open('/kaggle/input/couplet/encoder.json') as f:
    enc_data = json.load(f)
with open('/kaggle/input/couplet/decoder.json') as f:
    dec_data = json.load(f)

ds = list(zip(enc_data,dec_data))
dl = DataLoader(ds, batch_size=256, shuffle=True, collate_fn=get_proc(evoc, dvoc))

# 构建训练模型
# 模型构建
model = Seq2Seq(
    enc_emb_size=len(evoc),
    dec_emb_size=len(dvoc),
    emb_dim=100,
    hidden_size=120,
    dropout=0.5,
)
model.to(device)

# 优化器、损失
optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

# 训练
for epoch in range(20):
    model.train()
    tpbar = tqdm(dl)
    for enc_input, dec_input, targets in tpbar:
        enc_input = enc_input.to(device)
        dec_input = dec_input.to(device)
        targets = targets.to(device)

        # 前向传播 
        logits, _ = model(enc_input, dec_input)

        # 计算损失
        # CrossEntropyLoss需要将logits和targets展平
        # logits: [batch_size, seq_len, vocab_size]
        # targets: [batch_size, seq_len]
        # 展平为 [batch_size * seq_len, vocab_size] 和 [batch_size * seq_len]
        loss = criterion(logits.view(-1, logits.size(-1)), targets.view(-1))

        # 反向传播
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        tpbar.set_description(f'Epoch {epoch+1}, Loss: {loss.item():.4f}')

torch.save(model.state_dict(), '/kaggle/working/seq2seq_state_add.bin')

Epoch 1, Loss: 1.8525: 100%|██████████| 3010/3010 [02:04<00:00, 24.17it/s]
Epoch 2, Loss: 1.6026: 100%|██████████| 3010/3010 [02:02<00:00, 24.53it/s]
Epoch 3, Loss: 1.6621: 100%|██████████| 3010/3010 [02:03<00:00, 24.45it/s]
Epoch 4, Loss: 1.6112: 100%|██████████| 3010/3010 [02:02<00:00, 24.47it/s]
Epoch 5, Loss: 1.5575: 100%|██████████| 3010/3010 [02:03<00:00, 24.34it/s]
Epoch 6, Loss: 1.3853: 100%|██████████| 3010/3010 [02:02<00:00, 24.52it/s]
Epoch 7, Loss: 1.6961: 100%|██████████| 3010/3010 [02:02<00:00, 24.53it/s]
Epoch 8, Loss: 1.5129: 100%|██████████| 3010/3010 [02:03<00:00, 24.39it/s]
Epoch 9, Loss: 1.4281: 100%|██████████| 3010/3010 [02:02<00:00, 24.53it/s]
Epoch 10, Loss: 1.4399: 100%|██████████| 3010/3010 [02:02<00:00, 24.50it/s]
Epoch 11, Loss: 1.5815: 100%|██████████| 3010/3010 [02:02<00:00, 24.48it/s]
Epoch 12, Loss: 1.4112: 100%|██████████| 3010/3010 [02:03<00:00, 24.41it/s]
Epoch 13, Loss: 1.4614: 100%|██████████| 3010/3010 [02:02<00:00, 24.55it/s]
Epoch 14, Loss: 1.580

In [12]:
import torch
import pickle

# 加载训练好的模型和词典
state_dict = torch.load('/kaggle/working/seq2seq_state_add.bin')
with open('/kaggle/input/couplet/vocab.bin','rb') as f:
    evoc,dvoc = pickle.load(f)

model = Seq2Seq(
    enc_emb_size=len(evoc),
    dec_emb_size=len(dvoc),
    emb_dim=100,
    hidden_size=120,
    dropout=0.5,
)
model.load_state_dict(state_dict)

# 创建解码器反向字典
dvoc_inv = {v:k for k,v in dvoc.items()}

def test(in_file):
    with open(in_file,'r',encoding='utf-8') as f, open('/kaggle/working/test_out_add.txt','w',encoding='utf-8') as w:
        lines = f.read().split('\n')
        for line in lines:
            #空数据排除
            if line == '':
                continue
            # 用户输入
            enc_idx = torch.tensor([[evoc.get(tk, evoc['UNK']) for tk in line.split(' ')[:-1]]])

            # 推理
            # 最大解码长度=输入长度
            max_dec_len = len(line.split(' ')[:-1])

            model.eval()
            with torch.no_grad():
                # 编码器
                hidden_state = model.encoder(enc_idx)

                # 解码器输入 shape [1,1]
                dec_input = torch.tensor([[dvoc['BOS']]])

                # 循环decoder
                dec_tokens = []
                while True:
                    if len(dec_tokens) >= max_dec_len:
                        break
                    # 解码器 
                    # logits: [1,1,dec_voc_size]
                    logits,hidden_state = model.decoder(dec_input, hidden_state)
                    # logits,hidden_state = model.decoder(dec_input, hidden_state, enc_outputs)
                    
                    # 下个token index
                    next_token = torch.argmax(logits, dim=-1)

                    if dvoc_inv[next_token.squeeze().item()] == 'EOS':
                        break
                    # 收集每次token_index 【解码集合】
                    dec_tokens.append(next_token.squeeze().item())
                    # decoder的下一个输入 = token_index
                    dec_input = next_token
                    hidden_state = hidden_state.view(1, -1)

            # 输出解码结果
            w.write(''.join([dvoc_inv[tk] for tk in dec_tokens]))
            w.write('\n')
    return

  state_dict = torch.load('/kaggle/working/seq2seq_state_add.bin')


In [13]:
test('/kaggle/input/testdate/in.txt')