In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torch.utils.tensorboard import SummaryWriter
from torch.nn.utils.rnn import pad_sequence
import json
from tqdm import tqdm

# 配置路径
INPUT_DIR = '/kaggle/input/chinese-couplets/couplet/'
TRAIN_IN_PATH = os.path.join(INPUT_DIR, 'train/in.txt')
TRAIN_OUT_PATH = os.path.join(INPUT_DIR, 'train/out.txt')
TEST_IN_PATH = os.path.join(INPUT_DIR, 'test/in.txt')
TEST_OUT_PATH = os.path.join(INPUT_DIR, 'test/out.txt')
OUTPUT_DIR = '/kaggle/working/'
MODEL_SAVE_PATH = os.path.join(OUTPUT_DIR, 'seq2seq_concat.pth')
LOG_DIR = os.path.join(OUTPUT_DIR, 'runs/concat_attention')

os.makedirs(OUTPUT_DIR, exist_ok=True)

class Vocabulary:
    pad_token = '<PAD>'
    unk_token = '<UNK>'
    bos_token = '<BOS>'
    eos_token = '<EOS>'
    
    def __init__(self, vocab_dict):
        self.vocab = vocab_dict
        self.inv_vocab = {v:k for k,v in vocab_dict.items()}
        self.pad_token_id = self.vocab[self.pad_token]
        self.unk_token_id = self.vocab[self.unk_token]
        self.bos_token_id = self.vocab[self.bos_token]
        self.eos_token_id = self.vocab[self.eos_token]

    def __getitem__(self, token):
        return self.vocab.get(token, self.vocab[self.unk_token])

    @classmethod
    def from_files(cls, file_paths):
        tokens = [cls.pad_token, cls.unk_token, cls.bos_token, cls.eos_token]
        for file_path in file_paths:
            with open(file_path, 'r', encoding='utf-8') as f:
                tokens.extend(line.strip().split())
        unique_tokens = sorted(set(tokens))
        return cls({tk:i for i, tk in enumerate(unique_tokens)})

class CoupletDataset(Dataset):
    def __init__(self, in_path, out_path, vocab):
        self.enc_data = [line.strip() for line in open(in_path, 'r', encoding='utf-8')]
        self.dec_data = [line.strip() for line in open(out_path, 'r', encoding='utf-8')]
        self.vocab = vocab
        assert len(self.enc_data) == len(self.dec_data)

    def __len__(self):
        return len(self.enc_data)

    def __getitem__(self, idx):
        enc = self.enc_data[idx]
        dec = ['<BOS>'] + self.dec_data[idx].split() + ['<EOS>']
        enc_ids = [self.vocab[tk] for tk in enc]
        dec_ids = [self.vocab[tk] for tk in dec]
        return torch.tensor(enc_ids), torch.tensor(dec_ids)

class ConcatAttention(nn.Module):
    def __init__(self, encoder_hidden_dim, decoder_hidden_dim):
        super().__init__()
        self.encoder_hidden_dim = encoder_hidden_dim
        self.decoder_hidden_dim = decoder_hidden_dim
        self.W = nn.Linear(encoder_hidden_dim + decoder_hidden_dim, decoder_hidden_dim)
        self.V = nn.Linear(decoder_hidden_dim, 1)

    def forward(self, hidden, encoder_outputs):
        hidden = hidden.unsqueeze(1).repeat(1, encoder_outputs.size(1), 1)
        combined = torch.cat((hidden, encoder_outputs), dim=2)
        energy = torch.tanh(self.W(combined))
        attention = torch.softmax(self.V(energy).squeeze(2), dim=1)
        context = torch.bmm(attention.unsqueeze(1), encoder_outputs)
        return attention, context

class Encoder(nn.Module):
    def __init__(self, input_dim, emb_dim, hidden_dim):
        super().__init__()
        self.embedding = nn.Embedding(input_dim, emb_dim)
        self.rnn = nn.GRU(
            input_size=emb_dim,
            hidden_size=hidden_dim,
            bidirectional=True,
            batch_first=True
        )

    def forward(self, x):
        embedded = self.embedding(x)
        outputs, hidden = self.rnn(embedded)
        return outputs, hidden

class Decoder(nn.Module):
    def __init__(self, output_dim, emb_dim, encoder_hidden_dim, decoder_hidden_dim):
        super().__init__()
        self.encoder_hidden_dim = encoder_hidden_dim
        self.decoder_hidden_dim = decoder_hidden_dim
        self.embedding = nn.Embedding(output_dim, emb_dim)
        self.rnn = nn.GRU(
            input_size=emb_dim + encoder_hidden_dim * 2,
            hidden_size=decoder_hidden_dim,
            batch_first=True
        )
        self.fc_out = nn.Linear(encoder_hidden_dim * 2 + decoder_hidden_dim + emb_dim, output_dim)
        self.attention = ConcatAttention(encoder_hidden_dim * 2, decoder_hidden_dim)

    def forward(self, x, encoder_outputs, hidden):
        batch_size = x.size(0)
        trg_len = x.size(1)
        
        outputs = torch.zeros(batch_size, trg_len, self.fc_out.out_features).to(x.device)
        input = x[:, 0].unsqueeze(1)
        
        for t in range(1, trg_len):
            embedded = self.embedding(input)
            attn_weights, context = self.attention(hidden[-1], encoder_outputs)
            rnn_input = torch.cat([embedded, context], dim=2)
            output, hidden = self.rnn(rnn_input, hidden)
            prediction = self.fc_out(torch.cat([embedded.squeeze(1), context.squeeze(1), hidden[-1]], dim=1))
            outputs[:, t] = prediction
            input = prediction.argmax(1).unsqueeze(1)
        
        return outputs

class Seq2Seq(nn.Module):
    def __init__(self, enc_vocab_size, dec_vocab_size, emb_dim, encoder_hidden_dim, decoder_hidden_dim):
        super().__init__()
        self.encoder = Encoder(enc_vocab_size, emb_dim, encoder_hidden_dim)
        self.decoder = Decoder(dec_vocab_size, emb_dim, encoder_hidden_dim * 2, decoder_hidden_dim)

    def forward(self, src, trg):
        encoder_outputs, hidden = self.encoder(src)
        return self.decoder(trg, encoder_outputs, hidden)

def custom_collate(batch, vocab):
    enc_seqs, dec_seqs = zip(*batch)
    max_enc_len = max(len(s) for s in enc_seqs)
    max_dec_len = max(len(s) for s in dec_seqs)
    
    padded_enc = [torch.cat([s, torch.full((max_enc_len - len(s),), vocab.pad_token_id)]) 
                  for s in enc_seqs]
    padded_dec = [torch.cat([s, torch.full((max_dec_len - len(s),), vocab.eos_token_id)]) 
                  for s in dec_seqs]
    
    return pad_sequence(padded_enc, batch_first=True), pad_sequence(padded_dec, batch_first=True)

def train_model():
    vocab = Vocabulary.from_files([TRAIN_IN_PATH, TRAIN_OUT_PATH, TEST_IN_PATH, TEST_OUT_PATH])
    train_dataset = CoupletDataset(TRAIN_IN_PATH, TRAIN_OUT_PATH, vocab)
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=lambda b: custom_collate(b, vocab))

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = Seq2Seq(
        enc_vocab_size=len(vocab.vocab),
        dec_vocab_size=len(vocab.vocab),
        emb_dim=256,
        encoder_hidden_dim=256,
        decoder_hidden_dim=512
    ).to(device)

    optimizer = optim.AdamW(model.parameters(), lr=1e-4)
    criterion = nn.CrossEntropyLoss(ignore_index=vocab.pad_token_id)
    writer = SummaryWriter(LOG_DIR)

    best_loss = float('inf')
    for epoch in range(10):
        model.train()
        total_loss = 0
        for batch in tqdm(train_loader, desc=f'Epoch {epoch+1}'):
            src, trg = batch
            src, trg = src.to(device), trg.to(device)
            output = model(src, trg[:, :-1])
            output = output.view(-1, output.size(-1))
            target = trg[:, 1:].view(-1)
            loss = criterion(output, target)
            
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            total_loss += loss.item()
        
        val_loss = evaluate_model(model, test_loader, vocab)
        writer.add_scalars('Loss', {'train': total_loss/len(train_loader), 'val': val_loss}, epoch)
        if val_loss < best_loss:
            best_loss = val_loss
            torch.save(model.state_dict(), MODEL_SAVE_PATH)
    
    writer.close()

def evaluate_model(model, data_loader, vocab):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for batch in data_loader:
            src, trg = batch
            src, trg = src.to(device), trg.to(device)
            output = model(src, trg[:, :-1])
            output = output.view(-1, output.size(-1))
            target = trg[:, 1:].view(-1)
            loss = criterion(output, target)
            total_loss += loss.item()
    return total_loss / len(data_loader)

def generate_couplet(input_text, model, vocab, max_len=50):
    model.eval()
    with torch.no_grad():
        enc_tokens = input_text.split()
        enc_ids = [vocab[tk] for tk in enc_tokens]
        src_tensor = torch.tensor([enc_ids]).to(device)
        encoder_outputs, hidden = model.encoder(src_tensor)
        
        dec_input = torch.tensor([[vocab.bos_token_id]]).to(device)
        decoded_tokens = []
        for _ in range(max_len):
            output, hidden = model.decoder(dec_input, encoder_outputs, hidden)
            topv, topi = output.topk(1)
            next_token = topi.item()
            if next_token == vocab.eos_token_id:
                break
            decoded_tokens.append(next_token)
            dec_input = topi.unsqueeze(0)
        return ' '.join([vocab.inv_vocab[token] for token in decoded_tokens])

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 训练模型
train_model()

# 示例推理
test_input = "天增岁月人增寿"
print(f"Input: {test_input}")
prediction = generate_couplet(test_input, model, vocab)
print(f"Prediction: {prediction}")

# 保存词汇表
with open(os.path.join(OUTPUT_DIR, 'vocab.json'), 'w') as f:
    json.dump(vocab.inv_vocab, f, ensure_ascii=False)