In [2]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

/kaggle/input/homework8-1/__results__.html
/kaggle/input/homework8-1/model_cat.pth
/kaggle/input/homework8-1/__notebook__.ipynb
/kaggle/input/homework8-1/__output__.json
/kaggle/input/homework8-1/custom.css
/kaggle/input/chinese-couplets/couplet/vocabs
/kaggle/input/chinese-couplets/couplet/test/out.txt
/kaggle/input/chinese-couplets/couplet/test/in.txt
/kaggle/input/chinese-couplets/couplet/test/.in.txt.swp
/kaggle/input/chinese-couplets/couplet/test/.out.txt.swp
/kaggle/input/chinese-couplets/couplet/train/out.txt
/kaggle/input/chinese-couplets/couplet/train/in.txt


#### 3. 编写并实现seq2seq attention版的推理实现。

In [4]:
# 定义模型
import torch.nn as nn
import torch

class Encoder(nn.Module):
    def __init__(self, input_size, emb_size, hidden_size, dropout):
        super().__init__()
        self.embedding = nn.Embedding(input_size, emb_size)
        self.rnn = nn.GRU(emb_size, hidden_size, batch_first=True, bidirectional=True)
    
    def forward(self, enc_idxs):
        embedded = self.embedding(enc_idxs)
        # output: [batch_size, seq_len, hidden_size * 2]
        # h_n: [num_layers * 2, batch_size, hidden_size]
        outputs, h_n = self.rnn(embedded)
        # 返回值: [batch_size, hidden_size * 2]
        return outputs, torch.cat((h_n[0], h_n[1]), dim=1)

class Attention(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, enc_outputs, dec_outputs):
        # enc_outputs: [batch_size, enc_seq_len, hidden_size * 2]
        # dec_outputs: [batch_size, dec_seq_len, hidden_size * 2]
        a_t = torch.bmm(enc_outputs, dec_outputs.permute(0, 2, 1)) # [batch_size, enc_seq_len, dec_seq_len]
        a_t = torch.softmax(a_t, dim=1) # [batch_size, enc_seq_len, dec_seq_len]
        c_t = torch.bmm(a_t.permute(0, 2, 1), enc_outputs) # [batch_size, dec_seq_len, hidden_size * 2]
        return c_t
    
class Decoder(nn.Module):
    def __init__(self, input_size, emb_size, hidden_size, dropout):
        super().__init__()
        self.embedding = nn.Embedding(input_size, emb_size)
        self.rnn = nn.GRU(emb_size, hidden_size * 2, batch_first=True)
        self.attention = Attention()
        self.attention_fc = nn.Linear(hidden_size * 4, hidden_size * 2)
        self.act = nn.Tanh()
        self.fc = nn.Linear(hidden_size * 2, input_size)

    def forward(self, dec_idxs, h_0, enc_outputs):
        embedded = self.embedding(dec_idxs)
        # dec_output: [batch_size, seq_len, hidden_size * 2]
        # h_n: [num_layers, batch_size, hidden_size * 2]，返回最后一个时间步的隐藏状态，用于进行推理
        dec_outputs, h_n = self.rnn(embedded, h_0.unsqueeze(0))
        c_t = self.attention(enc_outputs, dec_outputs) # [batch_size, seq_len, hidden_size * 2]
        cat_outputs = torch.cat((c_t, dec_outputs), dim=2) # [batch_size, seq_len, hidden_size * 4]
        outputs = self.attention_fc(cat_outputs) # [batch_size, seq_len, hidden_size * 2]
        outputs = self.act(outputs) # [batch_size, seq_len, hidden_size * 2]
        logits = self.fc(outputs) # [batch_size, seq_len, input_size]
        return logits, h_n
    
class Seq2Seq(nn.Module):
    def __init__(self, enc_input_size, dec_input_size, emb_size, hidden_size, dropout=0.3):
        super().__init__()
        self.encoder = Encoder(enc_input_size, emb_size, hidden_size, dropout)
        self.decoder = Decoder(dec_input_size, emb_size, hidden_size, dropout)

    def forward(self, enc_idxs, dec_idxs):
        enc_outputs, h_0 = self.encoder(enc_idxs)
        outputs, h_n = self.decoder(dec_idxs, h_0, enc_outputs)
        return outputs, h_n

In [5]:
import pickle
import torch

def infer(model, enc_vocab, dec_vocab, dec_vocab_reverse, input_sentence):
    # 将输入句子转换为索引序列
    input_idxs = [enc_vocab.get(word, enc_vocab['UNK']) for word in list(input_sentence)]
    model.eval()
    # 推理的长度
    word_len = len(list(input_sentence))
    word_idx= []
    with torch.no_grad():
        # 获取encoder的输出
        enc_outputs, hidden_state = model.encoder(torch.tensor(input_idxs).unsqueeze(0))
        # 初始化decoder的输入：BOS，批次为1, 序列长度为1
        dec_inputs = torch.tensor([[dec_vocab['<s>']]])
        while True:
            # 推理第一个字时：logics: [1, 1, vocab_size]
            logics, h_n = model.decoder(dec_inputs, hidden_state, enc_outputs)
            # 获取最后一次时间步的输出
            next_word_idx = torch.argmax(logics,dim=-1).squeeze().item()
            word_idx.append(next_word_idx)
            # 检查是否达到结束条件
            if next_word_idx == dec_vocab['</s>'] or len(word_idx) >= word_len:
                break
            # 将当前的输出作为下一个时间步的输入
            dec_inputs = torch.tensor([[next_word_idx]])
            # h_n：是rnn的最后一个时间步的隐藏状态，作为下一个时间步的隐藏状态输入
            hidden_state = h_n.view(1,-1)
        return "".join([dec_vocab_reverse.get(idx) for idx in word_idx])

In [8]:
# 加载词典
with open('/kaggle/input/chinese-couplets/couplet/vocabs') as f:
    word_list = ['PAD', 'UNK'] + [word.strip() for word in f]
    vocab = {word:i for i, word in enumerate(word_list)}
# 加载模型
emb_size = 120
hidden_size = 512
model = Seq2Seq(len(vocab), len(vocab), emb_size, hidden_size)
model.load_state_dict(torch.load("/kaggle/input/homework8-1/model_cat.pth", weights_only=True))
vocab_reverse = {v: k for k, v in vocab.items()}

In [9]:
# 进行推理
input_sentence = "彩屏如画，望秀美崤函，花团锦簇"
output_sentence = infer(model, vocab, vocab, vocab_reverse, input_sentence)
print(input_sentence)
print(output_sentence)

彩屏如画，望秀美崤函，花团锦簇
雅韵生辉，赏心香诗句，锦簇花团
