In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

/kaggle/input/chinese-couplets/couplet/vocabs
/kaggle/input/chinese-couplets/couplet/test/out.txt
/kaggle/input/chinese-couplets/couplet/test/in.txt
/kaggle/input/chinese-couplets/couplet/test/.in.txt.swp
/kaggle/input/chinese-couplets/couplet/test/.out.txt.swp
/kaggle/input/chinese-couplets/couplet/train/out.txt
/kaggle/input/chinese-couplets/couplet/train/in.txt


1. 使用中文对联数据集训练带有attention的seq2seq模型，利用tensorboard跟踪。https://www.kaggle.com/datasets/jiaminggogogo/chinese-couplets

In [2]:
# 1、处理数据
import pickle
def get_data_list(in_path, out_path):
    with open(in_path) as in_file, open(out_path) as out_file:
        enc_data, dec_data = [], []
        for line in list(zip(in_file, out_file)):
            enc_data.append(line[0].strip().split())
            dec_data.append(['<s>'] + line[1].strip().split() + ['</s>'])
        return enc_data, dec_data

In [3]:
# 训练数据:770491
train_enc_data, train_dec_data = get_data_list('/kaggle/input/chinese-couplets/couplet/train/in.txt', 
                                               '/kaggle/input/chinese-couplets/couplet/train/out.txt')
# 测试数据:4000
test_enc_data, test_dec_data = get_data_list('/kaggle/input/chinese-couplets/couplet/test/in.txt', 
                                               '/kaggle/input/chinese-couplets/couplet/test/out.txt')

In [4]:
# 加载字典
with open('/kaggle/input/chinese-couplets/couplet/vocabs') as f:
    word_list = ['PAD', 'UNK'] + [word.strip() for word in f]
    vocab = {word:i for i, word in enumerate(word_list)}

In [5]:
# 2、定义模型
import torch.nn as nn
import torch

class Encoder(nn.Module):
    def __init__(self, input_size, emb_size, hidden_size, dropout):
        super().__init__()
        self.embedding = nn.Embedding(input_size, emb_size)
        self.rnn = nn.GRU(emb_size, hidden_size, batch_first=True, bidirectional=True)
    
    def forward(self, enc_idxs):
        embedded = self.embedding(enc_idxs)
        # output: [batch_size, seq_len, hidden_size * 2]
        # h_n: [num_layers * 2, batch_size, hidden_size]
        outputs, h_n = self.rnn(embedded)
        # 返回值: [batch_size, hidden_size * 2]
        return outputs, torch.cat((h_n[0], h_n[1]), dim=1)

class Attention(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, enc_outputs, dec_outputs):
        # enc_outputs: [batch_size, enc_seq_len, hidden_size * 2]
        # dec_outputs: [batch_size, dec_seq_len, hidden_size * 2]
        a_t = torch.bmm(enc_outputs, dec_outputs.permute(0, 2, 1)) # [batch_size, enc_seq_len, dec_seq_len]
        a_t = torch.softmax(a_t, dim=1) # [batch_size, enc_seq_len, dec_seq_len]
        c_t = torch.bmm(a_t.permute(0, 2, 1), enc_outputs) # [batch_size, dec_seq_len, hidden_size * 2]
        return c_t
    
class Decoder(nn.Module):
    def __init__(self, input_size, emb_size, hidden_size, dropout):
        super().__init__()
        self.embedding = nn.Embedding(input_size, emb_size)
        self.rnn = nn.GRU(emb_size, hidden_size * 2, batch_first=True)
        self.attention = Attention()
        self.attention_fc = nn.Linear(hidden_size * 4, hidden_size * 2)
        self.act = nn.Tanh()
        self.fc = nn.Linear(hidden_size * 2, input_size)

    def forward(self, dec_idxs, h_0, enc_outputs):
        embedded = self.embedding(dec_idxs)
        # dec_output: [batch_size, seq_len, hidden_size * 2]
        # h_n: [num_layers, batch_size, hidden_size * 2]，返回最后一个时间步的隐藏状态，用于进行推理
        dec_outputs, h_n = self.rnn(embedded, h_0.unsqueeze(0))
        c_t = self.attention(enc_outputs, dec_outputs) # [batch_size, seq_len, hidden_size * 2]
        cat_outputs = torch.cat((c_t, dec_outputs), dim=2) # [batch_size, seq_len, hidden_size * 4]
        outputs = self.attention_fc(cat_outputs) # [batch_size, seq_len, hidden_size * 2]
        outputs = self.act(outputs) # [batch_size, seq_len, hidden_size * 2]
        logits = self.fc(outputs) # [batch_size, seq_len, input_size]
        return logits, h_n
    
class Seq2Seq(nn.Module):
    def __init__(self, enc_input_size, dec_input_size, emb_size, hidden_size, dropout=0.3):
        super().__init__()
        self.encoder = Encoder(enc_input_size, emb_size, hidden_size, dropout)
        self.decoder = Decoder(dec_input_size, emb_size, hidden_size, dropout)

    def forward(self, enc_idxs, dec_idxs):
        enc_outputs, h_0 = self.encoder(enc_idxs)
        outputs, h_n = self.decoder(dec_idxs, h_0, enc_outputs)
        return outputs, h_n

# if __name__ == "__main__":
#     seq2seq = Seq2Seq(200, 300, 70, 128)
#     enc_idxs = torch.randint(0, 200, (3, 10))
#     dec_idxs = torch.randint(0, 300, (3, 12))
#     outputs, h_n = seq2seq(enc_idxs, dec_idxs)
#     print(outputs.shape) # [3, 12, 300]
#     print(h_n.shape) # [1, 3, 256]

In [6]:
# 3、模型训练
import pickle
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
from tqdm import tqdm
# from EncoderDecoderModel import Seq2Seq
from torch.utils.tensorboard import SummaryWriter

2025-04-27 05:20:27.780870: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1745731227.965673      19 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1745731228.018752      19 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [7]:
# 定义格式化函数
def format_batch(enc_vocab, dec_vocab):
    def format_batch_fn(batch):
        enc_ids, dec_ids, target_ids = [], [], []
        for enc_line, dec_line in batch:
            enc_input = [enc_vocab.get(token, enc_vocab['UNK']) for token in enc_line]
            dec_input = [dec_vocab.get(token, dec_vocab['UNK']) for token in dec_line]
            enc_ids.append(torch.tensor(enc_input))
            dec_ids.append(torch.tensor(dec_input[:-1]))
            target_ids.append(torch.tensor(dec_input[1:]))  # 目标是输入序列的偏移
        enc_inputs = pad_sequence(enc_ids, batch_first=True, padding_value=enc_vocab['PAD'])
        dec_inputs = pad_sequence(dec_ids, batch_first=True, padding_value=dec_vocab['PAD'])
        targets = pad_sequence(target_ids, batch_first=True, padding_value=dec_vocab['PAD'])
        return enc_inputs, dec_inputs, targets
    return format_batch_fn

In [8]:
train_dataloader = DataLoader(list(zip(train_enc_data, train_dec_data)), batch_size=256, shuffle=True, collate_fn=format_batch(vocab, vocab))
test_dataloader = DataLoader(list(zip(test_enc_data, test_dec_data)), batch_size=256, shuffle=False, collate_fn=format_batch(vocab, vocab))

In [9]:
writer = SummaryWriter(log_dir='/kaggle/working/runs/cat')
emb_size = 100
hidden_size = 512
epochs = 10
model = Seq2Seq(len(vocab), len(vocab), emb_size, hidden_size)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [10]:
for epoch in range(epochs):
        # 训练模型
    model.train()
    train_bar = tqdm(train_dataloader)
    for i, (enc_inputs, dec_inputs, targets) in enumerate(train_bar):
        enc_inputs, dec_inputs, targets = enc_inputs.to(device), dec_inputs.to(device), targets.to(device)
        out, _ = model(enc_inputs, dec_inputs)
        # targets: [batch_size, seq_len]
        # out: [batch_size, seq_len, vocab_size]
        loss = loss_fn(out.view(-1, len(vocab)), targets.view(-1))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        writer.add_scalar("loss", loss.item(), epoch * len(train_dataloader) + i)
        train_bar.set_description(f"Epoch {epoch + 1}, Loss: {loss.item()}")
    # 测试模型
    model.eval()
    with torch.no_grad():
        total = 0
        correct = 0
        for enc_inputs, dec_inputs, targets in test_dataloader:
            enc_inputs, dec_inputs, targets = enc_inputs.to(device), dec_inputs.to(device), targets.to(device)
            out, _ = model(enc_inputs, dec_inputs)
            pred = torch.argmax(out, dim=-1)
            # 标记非填充位置
            non_padding_mask = (targets != 0)
            correct += (pred[non_padding_mask] == targets[non_padding_mask]).sum().item()
            total += non_padding_mask.sum().item()
        accuracy = 100 * correct / total
        print(f"Epoch {epoch + 1}, Accuracy: {accuracy}, Total: {total}, Correct: {correct}")
        writer.add_scalar("accuracy", accuracy, epoch)
writer.close()
# 保存模型
torch.save(model.state_dict(), "model_cat.pth")

Epoch 1, Loss: 1.318755030632019: 100%|██████████| 3010/3010 [11:42<00:00,  4.29it/s]


Epoch 1, Accuracy: 33.59089158482573, Total: 41544, Correct: 13955


Epoch 2, Loss: 1.5449485778808594: 100%|██████████| 3010/3010 [11:49<00:00,  4.24it/s]


Epoch 2, Accuracy: 35.95224340458309, Total: 41544, Correct: 14936


Epoch 3, Loss: 1.2648987770080566: 100%|██████████| 3010/3010 [11:50<00:00,  4.23it/s]


Epoch 3, Accuracy: 36.91267090313884, Total: 41544, Correct: 15335


Epoch 4, Loss: 1.1247605085372925: 100%|██████████| 3010/3010 [11:53<00:00,  4.22it/s]


Epoch 4, Accuracy: 37.3627960716349, Total: 41544, Correct: 15522


Epoch 5, Loss: 1.2235298156738281: 100%|██████████| 3010/3010 [11:51<00:00,  4.23it/s]


Epoch 5, Accuracy: 37.278548045445795, Total: 41544, Correct: 15487


Epoch 6, Loss: 1.1288634538650513: 100%|██████████| 3010/3010 [11:51<00:00,  4.23it/s]


Epoch 6, Accuracy: 37.280955131908335, Total: 41544, Correct: 15488


Epoch 7, Loss: 1.0872397422790527: 100%|██████████| 3010/3010 [11:51<00:00,  4.23it/s]


Epoch 7, Accuracy: 37.396495282110536, Total: 41544, Correct: 15536


Epoch 8, Loss: 1.058830976486206: 100%|██████████| 3010/3010 [11:53<00:00,  4.22it/s]


Epoch 8, Accuracy: 37.1317157712305, Total: 41544, Correct: 15426


Epoch 9, Loss: 1.142983317375183: 100%|██████████| 3010/3010 [11:51<00:00,  4.23it/s]


Epoch 9, Accuracy: 36.96562680531485, Total: 41544, Correct: 15357


Epoch 10, Loss: 1.1397030353546143: 100%|██████████| 3010/3010 [11:52<00:00,  4.22it/s]


Epoch 10, Accuracy: 36.6527055651839, Total: 41544, Correct: 15227
