In [14]:
import json
import torch
from typing import Dict

# ====================== 1. 加载词表 ======================
def load_vocab(file_path: str) -> Dict[str, int]:
	"""加载词表映射文件"""
	with open(file_path, 'r', encoding='utf-8') as f:
		return json.load(f)

# 加载词表映射文件
token2idx = load_vocab('token2idx.json')
idx2token = {int(k): v for k, v in load_vocab('idx2token.json').items()}
vocab_size = len(token2idx)

# ====================== 2. 加载模型 ======================
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 模型结构定义（必须与训练时一致）
class PoemRNN(torch.nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.embed = torch.nn.Embedding(vocab_size, 256)
        self.lstm = torch.nn.LSTM(256, 256, num_layers=2, batch_first=True)
        self.fc = torch.nn.Linear(256, vocab_size)
    
    def forward(self, x, hidden=None):
        x = self.embed(x)
        output, hidden = self.lstm(x, hidden)
        return self.fc(output), hidden

# 实例化并加载模型权重
model = PoemRNN(vocab_size).to(device)
model.load_state_dict(torch.load('poem_rnn.pth', map_location=device))
model.eval()

print("数据预加载完成")
print(f"- 词表大小: {vocab_size}")
print(f"- 设备: {device}")
print(f"- 模型架构: {model.__class__.__name__}")

数据预加载完成
- 词表大小: 105569
- 设备: cpu
- 模型架构: PoemRNN


### 使用温度和TOP_P

In [15]:
import thulac
import torch
import random
import os
import json
from typing import List, Dict

class PoemGenerator:
    def __init__(self, 
                 token2idx: Dict[str, int], 
                 idx2token: Dict[int, str],
                 model: torch.nn.Module,
                 device: str = "cpu",
                 temperature: float = 1.0,
                 top_p: float = 0.9):
        
        self.token2idx = token2idx
        self.idx2token = idx2token
        self.model = model.to(device)
        self.device = device
        self.thulac = thulac.thulac(seg_only=True)
        
        associations_path: str = "../word_associations.json"  # 指定关联词库路径
        self.associations = self._load_associations(associations_path)
        
        # 温度和top_p参数
        self.temperature = temperature
        self.top_p = top_p

    def _load_associations(self, path: str) -> Dict:
        """从 JSON 文件加载关联词库"""
        if not os.path.exists(path):
            print(f"警告: 未找到关联词库文件 {path}，将使用空字典。")
            return {}
        
        with open(path, "r", encoding="utf-8") as f:
            return json.load(f)

    def segment_with_thulac(self, text: str) -> List[str]:
        """THULAC分词"""
        result = self.thulac.cut(text)
        return [word for word, _ in result]

    def get_associated_words(self, word: str, top_k: int = 5) -> List[str]:
        """获取输入词的关联词列表"""
        if word not in self.associations:
            return []
        # 按权重排序后返回前 top_k 个词
        related_words = sorted(
            self.associations[word].items(),
            key=lambda x: x[1],  # 按权重排序
            reverse=True
        )
        return [w for w, _ in related_words[:top_k]]

    def generate_poem(self, prompt: str, max_len: int = 32) -> str:
        """
        生成流程：
        1. THULAC分词 -> 2. 本地联想 -> 3. 模型生成
        """
        # 1. 分词
        segmented_words = self.segment_with_thulac(prompt)
        print(f"分词结果: {segmented_words}")
        
        # 2. 语义联想
        all_candidates = []
        for word in segmented_words:
            all_candidates.append(word)
            all_candidates.extend(self.get_associated_words(word))
            
        print(f"联想候选词: {all_candidates}")

        # 过滤无效词
        valid_words = list(set(w for w in all_candidates if w in self.token2idx))
        print(f"有效候选词: {valid_words}")
        
        # 3. 模型生成
        if not valid_words:
            valid_words = list(self.token2idx.keys())[:10]  # 回退机制
            
        input_ids = [self.token2idx[valid_words[0]]]
        generated = [valid_words[0]]
        hidden = None
        
        for _ in range(max_len - 1):
            with torch.no_grad():
                inputs = torch.tensor([input_ids[-1:]]).to(self.device)
                output, hidden = self.model(inputs, hidden)
            
            # 动态调整k值
            k = min(50, len(self.token2idx))
            top_probs, top_indices = torch.topk(output[0, -1], k=k)
            
            # 使用温度控制概率分布
            probs = torch.nn.functional.softmax(top_probs / self.temperature, dim=0)
            
            # 使用top_p采样
            sorted_probs, sorted_indices = torch.sort(probs, descending=True)
            cumulative_probs = torch.cumsum(sorted_probs, dim=0)
            sorted_indices_to_keep = sorted_indices[cumulative_probs <= self.top_p]
            
            # 从top_p内的词中选择
            valid_indices = sorted_indices_to_keep.tolist()
            
            # 优先选择候选词
            valid_indices.extend(
                [idx.item() for idx in top_indices if self.idx2token[idx.item()] in valid_words]
            )
            
            # 回退到普通生成
            if not valid_indices:
                valid_indices = top_indices[:min(10, k)].tolist()
            
            next_idx = random.choice(valid_indices)
            next_word = self.idx2token[next_idx]
            generated.append(next_word)
            input_ids.append(next_idx)
            
            # 随机注入候选词
            if random.random() > 0.5 and valid_words:
                injected_word = random.choice(valid_words)
                generated.append(injected_word)
                input_ids.append(self.token2idx[injected_word])
        
        # 格式化为四行
        poem = "".join(generated).replace("<UNK>", "").replace("<EOS>", "")
        
        # 清理特殊标记
        poem = poem.replace("<START>", "").replace("<PAD>", "").strip()
        
        lines = [poem[i:i+5] for i in range(0, min(len(poem), 20), 5)]
        return "\n".join(lines[:4])

# ====================== 使用示例 ======================
if __name__ == "__main__":
    # 初始化生成器
    generator = PoemGenerator(
        token2idx=token2idx,
        idx2token=idx2token,
        model=model,
        device="cuda" if torch.cuda.is_available() else "cpu",
        temperature=0.7,
        top_p=0.9
    )
    
    # 生成古诗
    poem = generator.generate_poem("春眠不觉晓", max_len=40)
    print("生成结果：")
    print(poem)


Model loaded succeed
分词结果: ['春眠', '不', '觉晓']
联想候选词: ['春眠', '不觉', '失晓殊', '可喜', '红日', '晓', '不', '与', '在', '我', '为', '去', '觉晓', '一声', '梦', '鶑', '啼帘', '筛半枕']
有效候选词: ['春眠', '与', '不', '为', '一声', '梦', '晓', '鶑', '去', '不觉', '我', '在', '红日', '可喜']
生成结果：
春眠香岩春
眠东归日为
东归日南宗
我手笔我分
