In [14]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

/kaggle/input/chinese-couplets/couplet/vocabs
/kaggle/input/chinese-couplets/couplet/test/out.txt
/kaggle/input/chinese-couplets/couplet/test/in.txt
/kaggle/input/chinese-couplets/couplet/test/.in.txt.swp
/kaggle/input/chinese-couplets/couplet/test/.out.txt.swp
/kaggle/input/chinese-couplets/couplet/train/out.txt
/kaggle/input/chinese-couplets/couplet/train/in.txt


## 数据处理

In [15]:
import torch
import torch.nn as nn

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [16]:
def read_data(enc_data_file, dec_data_file):
    """
    读取训练数据返回数据集合
    """
    enc_data,dec_data = [],[]

    with open(enc_data_file) as f1:
        # 读取记录行
        lines = f1.read().split('\n')
        for line in lines:
            if line == ' ':
                continue
            enc_tks = line.split()
            enc_data.append(enc_tks)

    with open(dec_data_file) as f2:
        # 读取记录行
        lines = f2.read().split('\n')
        for line in lines:
            if line == ' ':
                continue
            dec_tks = line.split()
            dec_tks = ['BOS'] + dec_tks + ['EOS']
            dec_data.append(dec_tks)

    # 断言
    assert len(enc_data) == len(dec_data), '编码数据与解码数据长度不一致！'

    return enc_data, dec_data

# 测试
enc_data_file1 = '/kaggle/input/chinese-couplets/couplet/train/in.txt'
dec_data_file1 = '/kaggle/input/chinese-couplets/couplet/train/out.txt'

enc_data_file2 = '/kaggle/input/chinese-couplets/couplet/test/in.txt'
dec_data_file2 = '/kaggle/input/chinese-couplets/couplet/test/out.txt'

enc_data1, dec_data1 = read_data(enc_data_file1, dec_data_file1)
print(len(enc_data1))
print(len(dec_data1))
print(enc_data1[:5])
print(dec_data1[:5])
print()

enc_data2, dec_data2 = read_data(enc_data_file2, dec_data_file2)
print(len(enc_data2))
print(len(dec_data2))
print(enc_data2[:5])
print(dec_data2[:5])

770492
770492
[['晚', '风', '摇', '树', '树', '还', '挺'], ['愿', '景', '天', '成', '无', '墨', '迹'], ['丹', '枫', '江', '冷', '人', '初', '去'], ['忽', '忽', '几', '晨', '昏', '，', '离', '别', '间', '之', '，', '疾', '病', '间', '之', '，', '不', '及', '终', '年', '同', '静', '好'], ['闲', '来', '野', '钓', '人', '稀', '处']]
[['BOS', '晨', '露', '润', '花', '花', '更', '红', 'EOS'], ['BOS', '万', '方', '乐', '奏', '有', '于', '阗', 'EOS'], ['BOS', '绿', '柳', '堤', '新', '燕', '复', '来', 'EOS'], ['BOS', '茕', '茕', '小', '儿', '女', '，', '孱', '羸', '若', '此', '，', '娇', '憨', '若', '此', '，', '更', '烦', '二', '老', '费', '精', '神', 'EOS'], ['BOS', '兴', '起', '高', '歌', '酒', '醉', '中', 'EOS']]

4001
4001
[['腾', '飞', '上', '铁', '，', '锐', '意', '改', '革', '谋', '发', '展', '，', '勇', '当', '千', '里', '马'], ['风', '弦', '未', '拨', '心', '先', '乱'], ['花', '梦', '粘', '于', '春', '袖', '口'], ['晋', '世', '文', '章', '昌', '二', '陆'], ['一', '句', '相', '思', '吟', '岁', '月']]
[['BOS', '和', '谐', '南', '供', '，', '安', '全', '送', '电', '保', '畅', '通', '，', '争', '做', '领', '头', '羊', 'EOS'], ['BOS', '夜', '幕', '已', '沉

In [17]:
def words_to_vocab(words_list):
        no_repeat_tokens = set()

        for word in words_list:
            no_repeat_tokens.update(list(word))  

        tokens = ['PAD','UNK'] + list(no_repeat_tokens)

        vocab = { tk:i for i, tk in enumerate(tokens)}

        return vocab

# 测试
import random

enc_vocab1 = words_to_vocab(enc_data1)
dec_vocab1 = words_to_vocab(dec_data1)

enc_vocab2 = words_to_vocab(enc_data2)
dec_vocab2 = words_to_vocab(dec_data2)

enc_keys1 = random.sample(list(enc_vocab1.keys()), 5)
random_enc_elements1 = {key: enc_vocab1[key] for key in enc_keys1}
print(f'random_enc_elements1:\n{random_enc_elements1}\n')

dec_keys1 = random.sample(list(dec_vocab1.keys()), 5)
random_dec_elements1 = {key: dec_vocab1[key] for key in dec_keys1}
print(f'random_dec_elements1:\n{random_dec_elements1}\n')

enc_keys2 = random.sample(list(enc_vocab2.keys()), 5)
random_enc_elements2 = {key: enc_vocab2[key] for key in enc_keys2}
print(f'random_enc_elements2:\n{random_enc_elements2}\n')

dec_keys2 = random.sample(list(dec_vocab2.keys()), 5)
random_dec_elements2 = {key: dec_vocab2[key] for key in dec_keys2}
print(f'random_dec_elements2:\n{random_dec_elements2}\n')

import pickle
with open('vocab1.bin','wb') as f:
    pickle.dump((enc_vocab1, dec_vocab1),f)
with open('vocab2.bin','wb') as f:
    pickle.dump((enc_vocab2, dec_vocab2),f)

random_enc_elements1:
{'楝': 6521, '缭': 2234, '寅': 3760, '墅': 2527, '赉': 3902}

random_dec_elements1:
{'缈': 6551, '裏': 1733, '炆': 7502, '辘': 2013, '孮': 7035}

random_enc_elements2:
{'幼': 1423, '伙': 2760, '避': 1556, '燎': 2198, '笙': 1992}

random_dec_elements2:
{'窥': 719, '伤': 1392, '始': 1476, '绽': 1972, '诊': 1489}



## transformer模型构建

In [18]:
# 位置编码矩阵
class PositionalEncoding(nn.Module):

    def __init__(self, emb_size, dropout=0.1, maxlen=5000):
        super().__init__()
        # 行缩放指数值
        den = torch.exp(- torch.arange(0, emb_size, 2) * math.log(10000) / emb_size)
        # 位置编码索引 (5000,1)
        pos = torch.arange(0, maxlen).reshape(maxlen, 1)
        # 编码矩阵 (5000, emb_size)
        pos_embdding = torch.zeros((maxlen, emb_size))
        pos_embdding[:, 0::2] = torch.sin(pos * den)
        pos_embdding[:, 1::2] = torch.cos(pos * den)
        # 添加和batch对应维度 (1, 5000, emb_size)
        pos_embdding = pos_embdding.unsqueeze(0)
        # dropout
        self.dropout = nn.Dropout(dropout)
        # 注册当前矩阵不参与参数更新
        self.register_buffer('pos_embedding', pos_embdding)

    def forward(self, token_embdding):
        token_len = token_embdding.size(1)  # token长度
        # (1, token_len, emb_size)
        add_emb = self.pos_embedding[:, :token_len, :] + token_embdding
        return self.dropout(add_emb)

In [19]:
class Seq2SeqTransformer(nn.Module):

    def __init__(self, d_model, nhead, num_enc_layers, num_dec_layers,
                 dim_forward, dropout, enc_voc_size, dec_voc_size):
        super().__init__()
        # transformer
        self.transformer = nn.Transformer(d_model=d_model,
                                          nhead=nhead,
                                          num_encoder_layers=num_enc_layers,
                                          num_decoder_layers=num_dec_layers,
                                          dim_feedforward=dim_forward,
                                          dropout=dropout,
                                          batch_first=True)
        # encoder input embedding
        self.enc_emb = nn.Embedding(enc_voc_size, d_model)
        # decoder input embedding
        self.dec_emb = nn.Embedding(dec_voc_size, d_model)
        # predict generate linear
        self.predict = nn.Linear(d_model, dec_voc_size)  # token预测基于解码器词典
        # positional encoding
        self.pos_encoding = PositionalEncoding(d_model, dropout)

    def forward(self, enc_inp, dec_inp, tgt_mask, enc_pad_mask, dec_pad_mask):
        # multi head attention之前基于位置编码embedding生成
        enc_emb = self.pos_encoding(self.enc_emb(enc_inp))
        dec_emb = self.pos_encoding(self.dec_emb(dec_inp))
        # 调用transformer计算
        outs = self.transformer(src=enc_emb, tgt=dec_emb, tgt_mask=tgt_mask,
                         src_key_padding_mask=enc_pad_mask,
                         tgt_key_padding_mask=dec_pad_mask)
        # 推理
        return self.predict(outs)

    # 推理环节使用方法
    def encode(self, enc_inp):
        enc_emb = self.pos_encoding(self.enc_emb(enc_inp))
        return self.transformer.encoder(enc_emb)

    def decode(self, dec_inp, memory, dec_mask):
        dec_emb = self.pos_encoding(self.dec_emb(dec_inp))
        return self.transformer.decoder(dec_emb, memory, dec_mask)

In [20]:
def get_proc(enc_voc, dec_voc):

    # 嵌套函数定义
    # 外部函数变量生命周期会延续到内部函数调用结束 （闭包）

    def batch_proc(data):
        """
        批次数据处理并返回
        """
        enc_ids, dec_ids, labels = [],[],[]
        for enc,dec in data:
            # token -> token index
            enc_idx = [enc_voc[tk] for tk in enc]
            dec_idx = [dec_voc[tk] for tk in dec]

            # encoder_input
            enc_ids.append(torch.tensor(enc_idx))
            # decoder_input
            dec_ids.append(torch.tensor(dec_idx[:-1]))
            # label
            labels.append(torch.tensor(dec_idx[1:]))

        # 数据转换张量 [batch, max_token_len]
        # 用批次中最长token序列构建张量
        enc_input = pad_sequence(enc_ids, batch_first=True)
        dec_input = pad_sequence(dec_ids, batch_first=True)
        targets = pad_sequence(labels, batch_first=True)

        # 返回数据都是模型训练和推理的需要
        return enc_input, dec_input, targets

    # 返回回调函数
    return batch_proc

## 模型训练

In [29]:
from torch.nn.utils.rnn import pad_sequence
 
# 超参数配置
D_MODEL = 256    # 嵌入维度
NHEAD = 4        # 注意力头数
NUM_ENCODER_LAYERS = 3
NUM_DECODER_LAYERS = 3
DIM_FEEDFORWARD = D_MODEL * 4
DROPOUT = 0.3
# BATCH_SIZE = 32
BATCH_SIZE = 512
MAX_LEN = 50     # 最大对联长度
EPOCHS = 10

In [22]:
# 创建数据加载器（需自行实现Dataset类）
class CoupletDataset(torch.utils.data.Dataset):
    def __init__(self, enc_data, dec_data):
        self.data = list(zip(enc_data, dec_data))
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        enc, dec = self.data[idx]
        return enc, dec

In [30]:
# 数据加载
# enc_data1 = enc_data1[:1024]
# dec_data1 = dec_data1[:1024]
dataset = CoupletDataset(enc_data1, dec_data1)
dataloader = torch.utils.data.DataLoader(
    dataset, 
    batch_size=BATCH_SIZE,
    collate_fn=get_proc(enc_vocab1, dec_vocab1),
    shuffle=True
)

In [32]:
import math
# 初始化模型
model = Seq2SeqTransformer(
    d_model=D_MODEL,
    nhead=NHEAD,
    num_enc_layers=NUM_ENCODER_LAYERS,
    num_dec_layers=NUM_DECODER_LAYERS,
    dim_forward=DIM_FEEDFORWARD,
    dropout=DROPOUT,
    enc_voc_size=len(enc_vocab1),
    dec_voc_size=len(dec_vocab1)
).to(device)
 
# 优化器
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss(ignore_index=0)  # 忽略填充位

In [33]:
# 生成掩码
def create_mask(size, device):
    mask = torch.triu(torch.ones(size, size) * float('-inf'), diagonal=1)
    return mask.to(device)

In [38]:
from tqdm import tqdm
# 训练循环
for epoch in range(EPOCHS):
    model.train()
    tpbar = tqdm(dataloader)
    
    for enc_input, dec_input, targets in tpbar:
        # 数据移至设备
        enc_input = enc_input.to(device)
        dec_input = dec_input.to(device)
        targets = targets.to(device)
        
        # 创建掩码
        tgt_mask = create_mask(dec_input.size(1), device)
        enc_pad_mask = (enc_input == enc_vocab1['PAD'])
        dec_pad_mask = (dec_input == dec_vocab1['PAD'])
        
        # 前向传播
        outputs = model(
            enc_inp=enc_input,
            dec_inp=dec_input,
            tgt_mask=tgt_mask,
            enc_pad_mask=enc_pad_mask,
            dec_pad_mask=dec_pad_mask
        )
        
        # 计算损失（targets需要是(batch_size, seq_len)）
        loss = criterion(
            outputs.view(-1, len(dec_vocab1)),
            targets.view(-1)
        )
        
        # 反向传播
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        tpbar.set_description(f'Epoch {epoch+1}/{EPOCHS} | Loss: {loss.item():.4f}')


Epoch 1/10 | Loss: 4.7934: 100%|██████████| 1505/1505 [04:47<00:00,  5.24it/s]
Epoch 2/10 | Loss: 4.6006: 100%|██████████| 1505/1505 [04:48<00:00,  5.21it/s]
Epoch 3/10 | Loss: 4.3728: 100%|██████████| 1505/1505 [04:49<00:00,  5.20it/s]
Epoch 4/10 | Loss: 4.2023: 100%|██████████| 1505/1505 [04:47<00:00,  5.23it/s]
Epoch 5/10 | Loss: 4.1045: 100%|██████████| 1505/1505 [04:49<00:00,  5.21it/s]
Epoch 6/10 | Loss: 3.9753: 100%|██████████| 1505/1505 [04:48<00:00,  5.22it/s]
Epoch 7/10 | Loss: 4.0236: 100%|██████████| 1505/1505 [04:47<00:00,  5.24it/s]
Epoch 8/10 | Loss: 3.8615: 100%|██████████| 1505/1505 [04:47<00:00,  5.23it/s]
Epoch 9/10 | Loss: 3.9718: 100%|██████████| 1505/1505 [04:49<00:00,  5.21it/s]
Epoch 10/10 | Loss: 3.8202: 100%|██████████| 1505/1505 [04:48<00:00,  5.21it/s]


In [62]:
# 随机选取测试样本
rnd_idx = random.randint(0, len(enc_data2))
enc_input = enc_data2[rnd_idx]
dec_output = dec_data2[rnd_idx]

enc_idx = torch.tensor([[enc_vocab2[tk] for tk in enc_input]])

print(f'enc_idx: {enc_idx}')
print(f'enc_idx.shape: {enc_idx.shape}')

# 最大解码长度
max_dec_len = len(dec_output)
print(f'max_dec_len: {max_dec_len}')

print('enc_input：',''.join(enc_input))
print("dec_output：", ''.join(dec_output))

enc_idx: tensor([[2514, 2024, 1140, 2155, 2862]])
enc_idx.shape: torch.Size([1, 5])
max_dec_len: 7
enc_input： 岸影几家柳
dec_output： BOS莺声百啭歌EOS


In [69]:
def generate_couplet(model, enc_input, enc_vocab, dec_vocab, max_len=50):
    model.eval()

    # 处理输入张量（避免重复构造）
    if not isinstance(enc_input, torch.Tensor):
        # 若输入是列表/数组，直接构造到目标设备
        enc_input = torch.tensor(enc_input, dtype=torch.long, device=device)
    else:
        # 若已经是张量，转移到设备并确保类型正确
        enc_input = enc_input.to(dtype=torch.long, device=device)
    
    # 添加batch维度（若输入是1D）
    if enc_input.dim() == 1:
        enc_input = enc_input.unsqueeze(0)  # (1, seq_len)
    
    # 编码
    memory = model.encode(enc_input)
    
    # 初始化解码输入（使用传入的dec_vocab参数）
    bos_token = dec_vocab['BOS']
    dec_input = torch.tensor([[bos_token]], dtype=torch.long, device=device)  # (1,1)
    generated = []
    
    for _ in range(max_len):
        # 动态生成解码掩码
        current_seq_len = dec_input.size(1)
        dec_mask = create_mask(current_seq_len, device)
        
        # 解码（参数名dec_inp需与类定义一致）
        output = model.decode(
            dec_inp=dec_input,
            memory=memory,
            dec_mask=dec_mask
        )
        
        # 预测下一个token
        next_token = output.argmax(dim=-1)[:, -1:]  # (1,1)
        
        # 检查EOS（使用传入的dec_vocab参数）
        if next_token.item() == dec_vocab['EOS']:
            break
        
        generated.append(next_token.item())
        dec_input = torch.cat([dec_input, next_token], dim=-1)
    
    return generated

# 使用示例（确保传入正确的词汇表）
generated_dec_idx = generate_couplet(
    model,
    enc_idx,          # 确保enc_idx是张量或列表
    enc_vocab2,       # 编码器词汇表
    dec_vocab2        # 解码器词汇表（包含BOS/EOS）
)

# 转换索引到token
generated_dec = [dec_vocab2.get(idx, 'UNK') for idx in generated_dec_idx]
print("dec_eval：", ''.join(generated_dec))

dec_eval： UNKUNKUNKUNKUNKUNKUNKUNKUNKUNKUNKUNKUNKUNKUNKUNKUNKUNKUNKUNKUNKUNKUNKUNKUNKUNKUNKUNKUNKUNKUNKUNKUNKUNKUNKUNKUNKUNKUNKUNKUNKUNKUNKUNKUNKUNKUNKUNKUNKUNK
