In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

# import packages

In [16]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
import json
import sys

# 添加自定义模块路径（根据实际存放位置修改路径）
sys.path.append('/kaggle/input/encoderdecoderattenmodel/pytorch/default/1')
from EncoderDecoderAttenModel import Seq2Seq

# 数据预处理

In [17]:
from itertools import islice
import torch
from torch.utils.data import Dataset

class CoupletDataset(Dataset):
    def __init__(self, enc_file, dec_file, max_len=50, max_samples=50000):"
        # 高效读取前max_samples行数据
        self.enc_data = self._read_lines(enc_file, max_samples)
        self.dec_data = self._read_lines(dec_file, max_samples)
        
        # 数据一致性校验
        self._validate_data()
        
        # 构建词汇表
        self.char2idx, self.idx2char = self._build_vocab()
        self.max_len = max_len

    def _read_lines(self, file_path, max_lines):
        """ 内存优化的数据读取方法 """
        lines = []
        with open(file_path, 'r', encoding='utf-8') as f:
            for line in islice(f, max_lines):  # 逐行读取避免内存压力
                line = line.strip()
                if line:  # 过滤空行
                    lines.append(line)
        return lines

    def _validate_data(self):
        """ 数据验证 """
        if len(self.enc_data) != len(self.dec_data):
            raise ValueError(
                f"数据不匹配！上联数：{len(self.enc_data)}, 下联数：{len(self.dec_data)}")
        
        min_length = 3  # 最小有效对联长度（如：上联为2字可能有错误）
        error_samples = []
        for i, (enc, dec) in enumerate(zip(self.enc_data, self.dec_data)):
            if len(enc) < min_length or len(dec) < min_length:
                error_samples.append(i)
        if error_samples:
            print(f"警告：发现{len(error_samples)}条异常数据（行号：{error_samples[:5]}...）")

    def _build_vocab(self):
        """ 优化后的词汇表构建 """
        char_counter = set()
        for enc, dec in zip(self.enc_data, self.dec_data):
            char_counter.update(enc)
            char_counter.update(dec)
        
        # 创建词典（添加特殊标记）
        char2idx = {
            '<PAD>': 0, 
            '<SOS>': 1,   # Start of Sequence
            '<EOS>': 2,   # End of Sequence
            '<UNK>': 3    # Unknown token（添加）
        }
        # 按字符频率排序（可选优化）
        for idx, char in enumerate(char_counter, start=4):
            char2idx[char] = idx
            
        idx2char = {v:k for k, v in char2idx.items()}
        return char2idx, idx2char

    def __len__(self):
        return len(self.enc_data)

    def __getitem__(self, idx):
        """ 添加处理优化和异常防御 """
        # 获取原始文本
        enc_text = self.enc_data[idx][:self.max_len]  # 预截断
        dec_text = self.dec_data[idx][:self.max_len-1]  # 保留EOS位置
        
        # 编码器输入处理
        enc_indices = [
            self.char2idx.get(c, self.char2idx['<UNK>']) 
            for c in enc_text
        ]
        # 填充处理
        enc_padded = enc_indices + [self.char2idx['<PAD>']] * (self.max_len - len(enc_indices))
        
        # 解码器输入处理（含EOS标记）
        dec_indices = (
            [self.char2idx['<SOS>']] +
            [self.char2idx.get(c, self.char2idx['<UNK>']) for c in dec_text] +
            [self.char2idx['<EOS>']]
        )
        # 截断或填充
        dec_padded = dec_indices[:self.max_len] + [self.char2idx['<PAD>']] * (self.max_len - len(dec_indices))
        
        return (
            torch.LongTensor(enc_padded),
            torch.LongTensor(dec_padded)
        )

    def analyze(self):
        """ 数据集分析工具 """
        stats = {
            'total_pairs': len(self),
            'vocab_size': len(self.char2idx),
            'max_length': self.max_len,
            'enc_lengths': [len(s) for s in self.enc_data],
            'dec_lengths': [len(s) for s in self.dec_data]
        }
        print(f"数据集分析：")
        print(f"- 样本总数：{stats['total_pairs']}")
        print(f"- 词汇表大小：{stats['vocab_size']}")
        print(f"- 上联平均长度：{sum(stats['enc_lengths'])/len(self):.1f}")
        print(f"- 下联平均长度：{sum(stats['dec_lengths'])/len(self):.1f}")
        print(f"- 最大允许长度：{self.max_len}")


In [20]:
# 检查并设置设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# 数据集路径（Kaggle输入目录）
dataset = CoupletDataset(
    enc_file='/kaggle/input/chinese-couplets/couplet/train/in.txt',
    dec_file='/kaggle/input/chinese-couplets/couplet/train/out.txt'
)

dataloader = DataLoader(dataset, batch_size=256, shuffle=True)

Using device: cuda
警告：发现147条异常数据（行号：[167, 634, 688, 1248, 1603]...）


# 模型训练

In [21]:
# 初始化模型
model = Seq2Seq(
    enc_emb_size=len(dataset.char2idx),
    dec_emb_size=len(dataset.char2idx),
    emb_dim=256,
    hidden_size=512,
    dropout=0.3
).to(device)

# 训练参数设置
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss(ignore_index=0)
writer = SummaryWriter('/kaggle/working/logs')  # tensorboard日志目录

for epoch in range(20):
    total_loss = 0
    total_correct = 0
    total_tokens = 0
    
    for batch_idx, (enc_inputs, dec_inputs) in enumerate(dataloader):
        enc_inputs = enc_inputs.to(device)
        dec_inputs = dec_inputs.to(device)
        
        # 前向传播
        outputs, _ = model(enc_inputs, dec_inputs[:, :-1])
        
        # 计算损失
        loss = criterion(
            outputs.view(-1, outputs.size(-1)),
            dec_inputs[:, 1:].contiguous().view(-1)
        )
        
        # 计算准确率
        preds = outputs.argmax(dim=-1)
        targets = dec_inputs[:, 1:].contiguous().view(-1)
        mask = targets != 0  # 忽略padding部分
        correct = (preds.view(-1)[mask] == targets[mask]).sum().item()
        
        total_correct += correct
        total_tokens += mask.sum().item()
        
        # 记录训练指标
        writer.add_scalar('Loss/train_batch', loss.item(), epoch*len(dataloader)+batch_idx)
        writer.add_scalar('Accuracy/train_batch', correct/mask.sum().item(), epoch*len(dataloader)+batch_idx)
        
        # 反向传播
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
        optimizer.step()
        
        total_loss += loss.item()
        if batch_idx % 100 == 0:
            print(f'Epoch:{epoch+1} | Batch:{batch_idx} | Loss:{loss.item():.4f}')
    
    # 记录epoch指标
    epoch_loss = total_loss / len(dataloader)
    epoch_acc = total_correct / total_tokens if total_tokens >0 else 0
    writer.add_scalar('Loss/train_epoch', epoch_loss, epoch)
    writer.add_scalar('Accuracy/train_epoch', epoch_acc, epoch)
    
    print(f'Epoch {epoch+1} Complete | Loss: {epoch_loss:.4f} | Acc: {epoch_acc:.4f}')

# 保存模型和词汇表
torch.save(model.state_dict(), '/kaggle/working/couplet_model.pth')
with open('/kaggle/working/vocab.json', 'w') as f:
    json.dump(dataset.char2idx, f)

writer.close()
print("Training complete and models saved!")

Epoch:1 | Batch:0 | Loss:8.6732
Epoch:1 | Batch:100 | Loss:3.5198
Epoch 1 Complete | Loss: 3.7209 | Acc: 0.4910
Epoch:2 | Batch:0 | Loss:3.4952
Epoch:2 | Batch:100 | Loss:3.3289
Epoch 2 Complete | Loss: 3.3310 | Acc: 0.5297
Epoch:3 | Batch:0 | Loss:3.2104
Epoch:3 | Batch:100 | Loss:3.0578
Epoch 3 Complete | Loss: 3.0800 | Acc: 0.5471
Epoch:4 | Batch:0 | Loss:2.8649
Epoch:4 | Batch:100 | Loss:2.7508
Epoch 4 Complete | Loss: 2.7592 | Acc: 0.5726
Epoch:5 | Batch:0 | Loss:2.4187
Epoch:5 | Batch:100 | Loss:2.4248
Epoch 5 Complete | Loss: 2.4270 | Acc: 0.6007
Epoch:6 | Batch:0 | Loss:2.1532
Epoch:6 | Batch:100 | Loss:2.2045
Epoch 6 Complete | Loss: 2.1974 | Acc: 0.6193
Epoch:7 | Batch:0 | Loss:1.9283
Epoch:7 | Batch:100 | Loss:1.9666
Epoch 7 Complete | Loss: 1.9994 | Acc: 0.6370
Epoch:8 | Batch:0 | Loss:1.7416
Epoch:8 | Batch:100 | Loss:1.8291
Epoch 8 Complete | Loss: 1.7946 | Acc: 0.6588
Epoch:9 | Batch:0 | Loss:1.4696
Epoch:9 | Batch:100 | Loss:1.5674
Epoch 9 Complete | Loss: 1.5757 | Acc:

# 推理实现

In [32]:
class CoupletInfer:
    def __init__(self, model_path, vocab_path, max_len=50):
        # 加载词汇表
        with open(vocab_path) as f:
            self.char2idx = json.load(f)
        self.idx2char = {v:k for k,v in self.char2idx.items()}
        
        # 加载模型
        self.model = torch.load(model_path)
        self.model.eval()
        self.max_len = max_len

    def decode(self, text):
        # 文本 -> 编码
        enc_ids = [self.char2idx.get(c, 3) for c in text[:self.max_len]]
        enc_tensor = torch.LongTensor(enc_ids).unsqueeze(0)
        
        # 模型推理
        enc_out, hidden = self.model.encoder(enc_tensor)
        dec_ids = [1]  # SOS=1
        
        for _ in range(self.max_len):
            dec_tensor = torch.LongTensor([dec_ids[-1]])
            dec_out, hidden = self.model.decoder(dec_tensor, hidden, enc_out)
            next_id = dec_out.argmax().item()
            if next_id == 2: break  # EOS=2
            dec_ids.append(next_id)
        
        # 结果转换
        return ''.join([self.idx2char[i] for i in dec_ids[1:]])


In [33]:
infer = CoupletInfer(
    model_path="/kaggle/working/couplet_model.pth",
    vocab_path="/kaggle/working/vocab.json"
)

print(infer.decode("春风送暖"))  # 输出：大地回春

  self.model = torch.load(model_path)


AttributeError: 'collections.OrderedDict' object has no attribute 'eval'