# 1. 使用中文对联数据集训练带有attention的seq2seq模型，利用tensorboard跟踪。
# https://www.kaggle.com/datasets/jiaminggogogo/chinese-couplets
# 2. 尝试encoder hidden state不同的返回形式（concat和add）
# 3. 编写并实现seq2seq attention版的推理实现。


In [None]:
# --------------------process--------------------
import os
import json
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence

current_dir = f"/kaggle/"


def read_data(in_file, out_file):
    """读取训练数据返回数据集合"""
    enc_data, dec_data = [], []

    in_ = open(in_file, encoding="utf-8")
    out_ = open(out_file, encoding="utf-8")

    for enc, dec in zip(in_, out_):
        # 分词
        enc_tks = enc.strip()
        dec_tks = dec.strip()
        # 保存
        enc_data.append(enc_tks)
        dec_data.append(dec_tks)

    assert len(enc_data) == len(dec_data), "输入输出数据长度不一致"
    return enc_data, dec_data


def get_proc(vocab):

    # 嵌套函数定义
    # 外部函数变量生命周期会延续到内部函数调用结束 （闭包）

    def batch_proc(data):
        """
        批次数据处理并返回
        """
        enc_ids, dec_ids, labels = [], [], []
        for enc, dec in data:
            # token -> token index 首尾添加起始和结束token
            enc_idx = [vocab["<s>"]] + [vocab[tk] for tk in enc] + [vocab["</s>"]]
            dec_idx = [vocab["<s>"]] + [vocab[tk] for tk in dec] + [vocab["</s>"]]

            # encoder_input
            enc_ids.append(torch.tensor(enc_idx))
            # decoder_input
            dec_ids.append(torch.tensor(dec_idx[:-1]))
            # label
            labels.append(torch.tensor(dec_idx[1:]))

        # 数据转换张量 [batch, max_token_len]
        # 用批次中最长token序列构建张量
        enc_input = pad_sequence(enc_ids, batch_first=True)
        dec_input = pad_sequence(dec_ids, batch_first=True)
        targets = pad_sequence(labels, batch_first=True)

        # 返回数据都是模型训练和推理的需要
        return enc_input, dec_input, targets

    # 返回回调函数
    return batch_proc


class Vocabulary:

    def __init__(self, vocab):
        self.vocab = vocab
        # 添加未知词处理
        self.unk_token = "<unk>"
        self.unk_index = vocab.get(self.unk_token, len(vocab))

    @classmethod
    def from_file(cls, vocab_file):
        with open(vocab_file, encoding="utf-8") as f:
            words = [tk.strip() for tk in f if tk.strip()]

        # 确保包含所有必需的特殊标记
        special_tokens = ["<pad>", "<unk>", "<s>", "</s>"]
        for token in special_tokens:
            if token not in words:
                words.insert(0, token)

        vocab_dict = {word: idx for idx, word in enumerate(words)}
        return cls(vocab_dict)

    def __getitem__(self, token):
        """安全获取token索引，未知词返回<unk>索引"""
        return self.vocab.get(token, self.unk_index)


# ---------测试-------
# 加载词典
vocab_file = f"{current_dir}/input/chinese-couplets/couplet/vocabs"
vocab = Vocabulary.from_file(vocab_file)
# 读取数据
enc_train_file = f"{current_dir}/input/chinese-couplets/couplet/train/in.txt"
dec_train_file = f"{current_dir}/input/chinese-couplets/couplet/train/out.txt"
enc_data, dec_data = read_data(enc_train_file, dec_train_file)
print(f"enc len is :{len(enc_data)}")
print(f"dec len is :{len(dec_data)}")
print(f"词汇数量: {len(vocab.vocab)}")
# 编码+解码
dataset = list(zip(enc_data, dec_data))
dataloader = DataLoader(
    dataset,
    batch_size=2,  # 批量大小
    shuffle=True,  # 是否打乱数据
    collate_fn=get_proc(vocab),
)
# 处理缓存为json
with open(f"{current_dir}/working/encoder.json", "w", encoding="utf-8") as f:
    json.dump(enc_data, f)
with open(f"{current_dir}/working/decoder.json", "w", encoding="utf-8") as f:
    json.dump(dec_data, f)

# # 加载
# with open("encoder.json", "w", encoding="utf-8") as f:
#     for enc in enc_data:
#         str_json = json.dumps(enc)
#         f.write(str_json + "\n")
print("数据加载和处理完成。")

In [None]:
# --------------------EncoderDecoderAttentionModel--------------------

import os
import torch
import torch.nn as nn


# Encoder
class Encoder(nn.Module):

    def __init__(
        self, input_dim, emb_dim, hidden_dim, num_layers=2, state_type="concat"
    ):
        super(Encoder, self).__init__()

        # 定义嵌入层
        self.embedding = nn.Embedding(input_dim, emb_dim)

        # 定义gru层
        self.rnn = nn.GRU(
            emb_dim,
            hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            bidirectional=True,
        )

        # 自定义返回state类型 [concat add ]
        self.state_type = state_type
        self.hidden_dim = hidden_dim

    def forward(self, token_seq):
        # token_seq: [batch_size, seq_len ]
        # embedded : [batch_size, seq_len, emb_dim ]
        embedded = self.embedding(token_seq)

        # outputs: [batch_size, seq_len, hidden_dim * 2]
        # hidden : [2, batch_size, hidden_dim]
        outputs, hidden = self.rnn(embedded)

        if self.state_type == "concat":
            hidden = outputs[:, -1, :]
        elif self.state_dict == "add":
            hidden = torch.sum(hidden, dim=0)
            outputs = outputs[..., : self.hidden_dim] + outputs[..., self.hidden_dim :]
        else:
            raise ValueError("state_type must be 'concat' or 'add'")

        return hidden, outputs


# Decoder Attention 机制
class Attention(nn.Module):

    def __init__(self):
        super().__init__()

    def forward(self, enc_output, dec_output):
        # a_t = ht @ h_s
        a_t = torch.bmm(enc_output, dec_output.permute(0, 2, 1))

        # 1. 计算 结合解码token 和编码 token 关联权重
        a_t = torch.softmax(a_t, dim=-1)

        # 2. 计算 关联权重和编码token的 贡献值
        c_t = torch.bmm(a_t.permute(0, 2, 1), enc_output)

        return c_t


# Decoder
class Decoder(nn.Module):
    def __init__(
        self, input_dim, emb_dim, hidden_dim, dropout=0.5, state_type="concat"
    ):
        super(Decoder, self).__init__()
        if state_type == "concat":
            hidden_dim = hidden_dim * 2  # 双向GRU的hidden_dim需要乘以2

        # 定义嵌入层
        self.embedding = nn.Embedding(input_dim, emb_dim)

        # 定义gru层
        self.rnn = nn.GRU(
            emb_dim,
            hidden_dim,
            batch_first=True,
        )

        # 定义全连接层
        self.fc = nn.Linear(hidden_dim, input_dim)

        # 定义注意力层
        self.attention = Attention()

        # attention结果转换线性层
        self.attention_fc = nn.Linear(hidden_dim * 2, hidden_dim)

        # dropout层
        self.dropout = nn.Dropout(dropout)

    def forward(self, token_seq, hidden_state, enc_output):
        # token_seq: [batch_size, seq_len ]
        # embedded : [batch_size, seq_len, emb_dim ]
        embedded = self.embedding(token_seq)

        # outputs: [batch_size, seq_len, hidden_dim]
        # hidden : [1, batch_size, hidden_dim]
        dec_output, hidden = self.rnn(embedded, hidden_state.unsqueeze(0))

        # 计算注意力
        c_t = self.attention(enc_output, dec_output)
        cat_output = torch.cat((c_t, dec_output), dim=-1)
        # 通过注意力全连接层
        out = torch.tanh(self.attention_fc(cat_output))
        # dropout
        out = self.dropout(out)

        # out : [batch_size, seq_len, hidden_dim*2 ]
        logits = self.fc(out)
        return logits, hidden


# Seq2Seq Attention Model
class Seq2Seq(nn.Module):
    def __init__(
        self,
        enc_emb_size,
        dec_emb_size,
        emb_dim,
        hidden_size,
        dropout=0.5,
        state_type="concat",
    ):
        super().__init__()

        # 定义编码器
        self.encoder = Encoder(
            enc_emb_size, emb_dim, hidden_size, state_type=state_type
        )

        # 定义解码器
        self.decoder = Decoder(
            dec_emb_size, emb_dim, hidden_size, dropout, state_type=state_type
        )

    def forward(self, enc_input, dec_input):
        # 编码器输出
        enc_hidden_state, enc_output = self.encoder(enc_input)

        # 解码器输出
        dec_output, hidden = self.decoder(dec_input, enc_hidden_state, enc_output)

        return dec_output, hidden


# --------------------测试--------------------
input_dim = 200  # 编码器词汇大小
emb_dim = 256  # 嵌入维度
hidden_dim = 256  # 隐藏层大小
dropout = 0.5  # dropout比率
batch_size = 5  # 批量大小
seq_len = 10  # 序列长度

# 测试 encoder
encoder = Encoder(input_dim, emb_dim, hidden_dim, state_type="concat")
token_seq = torch.randint(0, input_dim, (batch_size, seq_len))  # 随机生成token序列
hidden_state = encoder(token_seq)
print("Encoder output shape:", hidden_state[0].shape)  # 输出隐藏状态形状
print("Encoder output shape:", hidden_state[1].shape)  # 输出编码器输出形状

# 测试 decoder
decoder = Decoder(input_dim, emb_dim, hidden_dim, dropout, state_type="concat")
token_seq = torch.randint(0, input_dim, (batch_size, seq_len))  # 随机生成token序列
logits, hidden = decoder(token_seq, hidden_state[0], hidden_state[1])
print("Decoder logits shape:", logits.shape)  # 输出解码器输出形状
print("Decoder hidden state shape:", hidden.shape)  # 输出解码器隐藏状态形状

# 测试 Seq2Seq 模型
seq2seq = Seq2Seq(
    enc_emb_size=input_dim,
    dec_emb_size=input_dim,
    emb_dim=emb_dim,
    hidden_size=hidden_dim,
    dropout=dropout,
)
logits, hidden = seq2seq(
    enc_input=torch.randint(0, input_dim, (batch_size, seq_len)),
    dec_input=torch.randint(0, input_dim, (batch_size, seq_len)),
)
print("Seq2Seq logits shape:", logits.shape)  # 输出Seq2Seq模型输出形状
print("Seq2Seq hidden state shape:", hidden.shape)  # 输出Seq2Seq隐藏状态形状
# 输出Seq2Seq编码器输出形状
print("Seq2Seq encoder output shape:", hidden[0].shape)  # 输出Seq2Seq编码器输出形状

In [None]:
# --------------------seq2seq_train--------------------
import os
import json
import pickle as pkl
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter


# 测试
writer = SummaryWriter(log_dir=os.path.join(current_dir, "logs"))
train_loss_cnt = 0.0
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 1.加载数据
vocab_file = f"{current_dir}/input/chinese-couplets/couplett/vocabs"
vocab = Vocabulary.from_file(vocab_file)

with open(f"{current_dir}/working/encoder.json") as f:
    enc_data = json.load(f)
with open(f"{current_dir}/working/decoder.json") as f:
    dec_data = json.load(f)

ds = list(zip(enc_data, dec_data))
dl = DataLoader(ds, batch_size=256, shuffle=True, collate_fn=get_proc(vocab))
# 构建训练模型
model = Seq2Seq(
    enc_emb_size=len(vocab.vocab),
    dec_emb_size=len(vocab.vocab),
    emb_dim=200,
    hidden_size=250,
    dropout=0.5,
    state_type="concat",
).to(device)
# 优化器和损失函数
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
# 2.训练模型
for epoch in range(20):
    model.train()
    tpbar = tqdm(dl)
    for enc_input, dec_input, targets in tpbar:
        enc_input = enc_input.to(device)
        dec_input = dec_input.to(device)
        targets = targets.to(device)
        # 前向传播
        logits, _ = model(enc_input, dec_input)
        # 计算损失
        loss = criterion(logits.view(-1, logits.size(-1)), targets.view(-1))
        # 反向传播和优化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # 更新进度条
        tpbar.set_description(f"Epoch [{epoch + 1}], Loss: {loss.item():.4f}")
        # TensorBoard记录
        writer.add_scalar("Loss/train_step", loss.item(), train_loss_cnt)
        train_loss_cnt += loss.item()
# 3.保存模型和词典
torch.save(model.state_dict(), f"{current_dir}/working/seq2seq_concat_state.bin")

In [None]:
# --------------------seq2seq_infer--------------------
import torch
import pickle
import random
import os

enc_test_file = f"{current_dir}/input/chinese-couplets/couplet/test/in.txt"
dec_test_file = f"{current_dir}/input/chinese-couplets/couplet/test/out.txt"
enc_data, dec_data = read_data(enc_test_file, dec_test_file)
# 加载训练好的模型和词典
state_dict = torch.load(f"{current_dir}/working/seq2seq_concat_state.bin")
vocab_file = f"{current_dir}/input/chinese-couplets/couplett/vocabs"

vocab = Vocabulary.from_file(vocab_file)
model = Seq2Seq(
    enc_emb_size=len(vocab.vocab),
    dec_emb_size=len(vocab.vocab),
    emb_dim=200,
    hidden_size=250,
    dropout=0.5,
    state_type="concat",
)
model.load_state_dict(state_dict)
# 创建解码器反向字典
dvoc_inv = {v: k for k, v in vocab.vocab.items()}
# 随机选取测试样本
rnd_idx = random.randint(0, len(enc_data) - 1)
enc_input = enc_data[rnd_idx]
dec_output = dec_data[rnd_idx]
enc_idx = torch.tensor([[vocab.vocab[tk] for tk in enc_input]])
print(enc_idx.shape)
# 推理
max_dec_len = len(enc_input)
model.eval()  # 设置模型为评估模式
with torch.no_grad():
    # 编码器输出
    hidden_state, enc_outputs = model.encoder(enc_idx)
    # 初始化解码器输入
    dec_input = torch.tensor([[vocab["<s>"]]])  # <s> token作为开始符
    # 循环decoder
    dec_tokens = []
    while len(dec_tokens) < max_dec_len:
        # 解码器输出
        logits, hidden_state = model.decoder(dec_input, hidden_state, enc_outputs)
        # 下个 token index
        next_token = logits.argmax(dim=-1)  # 获取最大概率的token index
        if dvoc_inv[next_token.squeeze().item()] == "</s>":
            break
        # 更新解码器输入
        dec_tokens.append(next_token.squeeze().item())  # 添加到解码结果中
        # decoder 的下一个输入 = token_index
        dec_input = next_token
        hidden_state = hidden_state.view(1, -1)  # 确保hidden_state形状正确
# 输出解码结果
print(f"上联：", "".join(enc_input))
print("模型预测下联：", "".join([dvoc_inv[tk] for tk in dec_tokens]))
print("真实下联：", "".join(dec_output))