In [11]:
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import os
import numpy as np
from typing import List, Dict

# ====================== 1. 加载词表 ======================
def load_vocab(file_path: str) -> Dict[str, int]:
    """加载词表映射文件"""
    with open(file_path, 'r', encoding='utf-8') as f:
        return json.load(f)

# ====================== 2. 定义 Transformer 模型 ======================
class PositionalEncoding(nn.Module):
    """位置编码，为模型提供序列位置信息"""
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)
        
    def forward(self, x):
        # x: [batch_size, seq_len, embed_dim]
        return x + self.pe[:x.size(1), :]

class PoemTransformer(nn.Module):
    def __init__(self, vocab_size, embed_size=256, num_heads=8, num_layers=6, ff_dim=512, dropout=0.1):
        super(PoemTransformer, self).__init__()
        
        # 词向量层
        self.embed = nn.Embedding(vocab_size, embed_size, padding_idx=0)  # 临时设为0，主程序中修正
        
        # 位置编码
        self.positional_encoding = nn.Parameter(torch.rand(1, 5000, embed_size))
        
        # 完整Transformer结构
        self.transformer = nn.Transformer(
            d_model=embed_size,
            nhead=num_heads,
            num_encoder_layers=num_layers,
            num_decoder_layers=num_layers,
            dim_feedforward=ff_dim,
            dropout=dropout,
            batch_first=True
        )
        
        # 输出层
        self.fc_out = nn.Linear(embed_size, vocab_size)
        self.dropout = nn.Dropout(dropout)
        
    def generate_square_subsequent_mask(self, sz):
        """生成自回归掩码"""
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask
        
    def forward(self, src, tgt):
        # 嵌入和位置编码
        src = self.dropout(self.embed(src) + self.positional_encoding[:, :src.size(1), :])
        tgt = self.dropout(self.embed(tgt) + self.positional_encoding[:, :tgt.size(1), :])
        
        # 生成掩码
        tgt_mask = self.generate_square_subsequent_mask(tgt.size(1)).to(tgt.device)
        src_mask = self.generate_square_subsequent_mask(src.size(1)).to(src.device)
        
        # 通过Transformer
        output = self.transformer(src, tgt, src_mask=src_mask, tgt_mask=tgt_mask)
        
        return self.fc_out(output), None

# ====================== 3. 古诗生成器（含叠词优化） ======================
class PoemGenerator:
    def __init__(self,
                 token2idx: Dict[str, int],
                 idx2token: Dict[int, str],
                 model: torch.nn.Module,
                 device: str = "cpu"):

        self.token2idx = token2idx
        self.idx2token = idx2token
        self.model = model.to(device)
        self.device = device
        self.associations = self._load_associations("word_associations.json")
        
        # 初始化分词器
        try:
            import thulac
            self.thulac = thulac.thulac(seg_only=True)
        except ImportError:
            print("警告: 未安装THULAC分词器，将使用简单分词方法")
            self.thulac = None

    def _load_associations(self, path: str) -> Dict:
        """加载关联词库"""
        if os.path.exists(path):
            with open(path, "r", encoding="utf-8") as f:
                return json.load(f)
        print(f"警告: 未找到关联词库文件 {path}，将使用空字典。")
        return {}

    def segment_with_thulac(self, text: str) -> List[str]:
        """分词方法"""
        if self.thulac:
            result = self.thulac.cut(text)
            return [word for word, _ in result]
        return list(text)  # 回退到单字分词

    def get_associated_words(self, word: str, top_k: int = 5) -> List[str]:
        """获取输入词的关联词列表"""
        if word not in self.associations:
            return []
        related_words = sorted(
            self.associations[word].items(),
            key=lambda x: x[1],
            reverse=True
        )
        return [w for w, _ in related_words[:top_k]]

    def generate_poem(self, prompt: str, max_len: int = 32, temperature=0.6, max_repeat=2) -> str:
        """
        生成古诗，含叠词抑制
        max_repeat: 允许的最大连续相同字符数（默认2）
        """
        # 1. 分词与联想
        segmented_words = self.segment_with_thulac(prompt)
        print(f"分词结果: {segmented_words}")
        
        # 收集候选词
        all_candidates = []
        for word in segmented_words:
            all_candidates.append(word)
            all_candidates.extend(self.get_associated_words(word, top_k=5))
        
        # 过滤有效词，补充主题相关常用字
        valid_words = list(set(w for w in all_candidates if w in self.token2idx))
        if not valid_words:
            valid_words = [w for w in list(self.token2idx.keys())[:10] if w not in ['<PAD>', '<START>']]
        
        # 根据主题补充候选词（减少无关字符）
        theme_related = {
            '春天': ['春', '风', '花', '柳', '燕', '啼', '暖', '芽'],
            '秋天': ['秋', '霜', '叶', '月', '雁', '寒', '枫', '露'],
            '夏天': ['夏', '荷', '蝉', '雨', '荫', '热', '蛙', '莲'],
            '冬天': ['冬', '雪', '寒', '梅', '冰', '霜', '风', '松']
        }.get(prompt, ['日', '云', '山', '水', '天', '地', '人', '心'])
        
        valid_words.extend([w for w in theme_related if w in self.token2idx and w not in valid_words])
        valid_words = list(set(valid_words))  # 去重
        print(f"有效候选词: {valid_words}")

        # 2. 初始化输入
        input_ids = [self.token2idx[valid_words[0]]] if valid_words else [self.token2idx['<START>']]
        generated = [self.idx2token[idx] for idx in input_ids]

        # 3. 生成诗句（核心逻辑）
        self.model.eval()
        with torch.no_grad():
            src = torch.tensor([input_ids]).to(self.device)
            tgt = src.clone()
            
            for _ in range(max_len - len(input_ids)):
                # 模型预测
                output, _ = self.model(src, tgt)
                next_probs = output[:, -1, :]  # 最后一个位置的预测
                
                # 处理重复字符：惩罚连续重复的字符
                last_char = generated[-1] if generated else ''
                repeat_count = 1
                # 统计当前连续重复次数
                for c in reversed(generated[:-1]):
                    if c == last_char:
                        repeat_count += 1
                    else:
                        break
                
                # 如果超过最大允许重复次数，降低重复字符的概率
                if repeat_count >= max_repeat and last_char in self.token2idx:
                    repeat_idx = self.token2idx[last_char]
                    next_probs[0, repeat_idx] *= 0.1  # 惩罚重复字符
                    next_probs = F.softmax(next_probs, dim=-1)  # 重新归一化
                
                # 温度采样
                probs = F.softmax(next_probs / temperature, dim=-1)
                next_idx = torch.multinomial(probs, 1).item()
                next_word = self.idx2token.get(next_idx, '')
                
                # 过滤无效字符
                invalid_tokens = ['<END>', '<UNK>', '<PAD>', '<START>']
                if next_word in invalid_tokens:
                    continue
                
                # 二次检查：避免新增重复
                if generated and next_word == generated[-1] and repeat_count >= max_repeat:
                    # 从候选词中选一个不同的字符
                    alternatives = [w for w in valid_words if w != next_word and w not in invalid_tokens]
                    if alternatives:
                        next_word = random.choice(alternatives)
                        next_idx = self.token2idx[next_word]
                
                # 添加到结果
                generated.append(next_word)
                input_ids.append(next_idx)
                tgt = torch.tensor([input_ids]).to(self.device)
                
                # 满足长度则停止
                if len(generated) >= 20:  # 4句x5字=20
                    break

        # 4. 格式化输出
        poem = "".join(generated)
        # 清理无效标记
        for token in invalid_tokens:
            poem = poem.replace(token, "")
        
        # 按5字一句分割
        lines = []
        for i in range(0, min(len(poem), 20), 5):
            line = poem[i:i+5]
            if len(line) == 5:
                lines.append(line)
        
        # 不足4句则补充
        if len(lines) < 4 and valid_words:
            for _ in range(4 - len(lines)):
                filler = "".join(random.sample(valid_words, min(5, len(valid_words))))
                lines.append(filler[:5])
        
        return "\n".join(lines[:4])

# ====================== 4. 模型加载与主程序 ======================
def load_model(vocab_size, model_path=None, device='cpu'):
    """加载模型并处理权重兼容问题"""
    model = PoemTransformer(vocab_size).to(device)
    if model_path and os.path.exists(model_path):
        try:
            model.load_state_dict(torch.load(model_path, map_location=device), strict=False)
            print(f"成功加载预训练模型: {model_path}")
        except Exception as e:
            print(f"加载模型失败: {e}")
            print("将使用随机初始化权重")
    else:
        print("警告: 未找到预训练模型，将使用随机初始化权重")
    return model

if __name__ == "__main__":
    # 设置设备
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # 加载词表
    print("正在加载词表...")
    token2idx = load_vocab('token2idx.json')
    idx2token = {int(k): v for k, v in load_vocab('idx2token.json').items()}
    vocab_size = len(token2idx)
    print(f"词表加载完成，大小: {vocab_size}")
    
    # 加载模型并修正padding_idx
    print("正在加载模型...")
    model = load_model(vocab_size, 'poem_transformer.pth', device)
    model.embed.padding_idx = token2idx.get('<PAD>', 0)  # 动态修正padding_idx
    print("模型加载完成")
    
    # 初始化生成器
    print("正在初始化生成器...")
    generator = PoemGenerator(
        token2idx=token2idx,
        idx2token=idx2token,
        model=model,
        device=device
    )
    print("生成器初始化完成")
    
    # 生成古诗（可修改prompt为"春天"/"夏天"/"冬天"）
    print("\n=== 开始生成古诗 ===")
    poem = generator.generate_poem("春天", max_len=40, max_repeat=2)  # max_repeat控制连续重复
    print("生成结果：")
    print(poem)

正在加载词表...
词表加载完成，大小: 105569
正在加载模型...
成功加载预训练模型: poem_transformer.pth
模型加载完成
正在初始化生成器...
警告: 未找到关联词库文件 word_associations.json，将使用空字典。
Model loaded succeed
生成器初始化完成

=== 开始生成古诗 ===
分词结果: ['春天']
有效候选词: ['风', '春', '柳', '燕', '暖', '花', '芽', '春天', '啼']
生成结果：
风辞聘风潮
花燕暖芽啼
风花春暖燕
风春春天花


In [8]:
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import numpy as np
from typing import List, Dict, Tuple
import jieba
import re
from collections import defaultdict

class EnhancedPoemGenerator:
    def __init__(self, token2idx: Dict[str, int], idx2token: Dict[int, str], device='cpu'):
        self.token2idx = token2idx
        self.idx2token = idx2token
        self.device = device
        
        # 增强的语言资源
        self.rhyme_groups = self._init_enhanced_rhyme_groups()
        self.pingze_map = self._init_enhanced_pingze_map()
        self.thesaurus = self._init_enhanced_thesaurus()
        self.theme_vectors = self._init_theme_vectors()
        
        # 格律模板
        self.meter_templates = {
            '五言绝句': {
                'patterns': [
                    ['仄', '仄', '平', '平', '仄'],
                    ['平', '平', '仄', '仄', '平'],
                    ['平', '平', '平', '仄', '仄'],
                    ['仄', '仄', '仄', '平', '平']
                ],
                'rhyme_pos': [1, 3]  # 第二、四句押韵
            },
            '七言绝句': {
                'patterns': [
                    ['平', '平', '仄', '仄', '平', '平', '仄'],
                    ['仄', '仄', '平', '平', '仄', '仄', '平'],
                    ['仄', '仄', '平', '平', '平', '仄', '仄'],
                    ['平', '平', '仄', '仄', '仄', '平', '平']
                ],
                'rhyme_pos': [1, 3]
            }
        }

    def _init_enhanced_rhyme_groups(self) -> Dict[str, List[str]]:
        """初始化押韵分组"""
        return {
            'a': ['花', '家', '华', '霞', '涯', '沙', '茶', '麻', '纱'],
            'o': ['歌', '多', '河', '波', '罗', '梭', '柯', '戈', '磨'],
            'e': ['车', '斜', '嗟', '些', '爷', '椰', '茄', '靴', '蝎'],
            'i': ['枝', '时', '丝', '迟', '诗', '知', '痴', '池', '脂'],
            'u': ['无', '图', '湖', '孤', '壶', '途', '酥', '糊', '乌'],
            'ong': ['空', '红', '风', '中', '同', '东', '通', '工', '蓬'],
            'an': ['山', '间', '闲', '还', '颜', '关', '湾', '环', '班'],
            'ang': ['长', '香', '光', '阳', '堂', '芳', '昌', '央', '刚'],
            'ai': ['来', '开', '台', '才', '哉', '埃', '该', '孩', '灾'],
            'ei': ['飞', '归', '辉', '威', '非', '妃', '肥', '扉', '菲']
        }

    def _init_enhanced_pingze_map(self) -> Dict[str, str]:
        """初始化平仄映射"""
        pingze = defaultdict(lambda: '仄')
        
        # 第一声（阴平）
        pingze.update({
            '春':'平', '风':'平', '秋':'平', '天':'平', '空':'平',
            '山':'平', '花':'平', '飞':'平', '开':'平', '江':'平',
            '烟':'平', '芳':'平', '清':'平', '新':'平', '声':'平'
        })
        
        # 第二声（阳平）
        pingze.update({
            '年':'平', '来':'平', '时':'平', '人':'平', '明':'平',
            '长':'平', '流':'平', '晴':'平', '头':'平', '寒':'平',
            '林':'平', '门':'平', '前':'平', '行':'平', '情':'平'
        })
        
        # 特殊入声字处理
        pingze.update({
            '月':'仄', '日':'仄', '雪':'仄', '白':'仄', '竹':'仄',
            '石':'仄', '玉':'仄', '色':'仄', '绿':'仄', '落':'仄'
        })
        
        return pingze

    def _init_enhanced_thesaurus(self) -> Dict[str, List[str]]:
        """初始化同类词库"""
        return {
            '季节': ['春', '夏', '秋', '冬', '晨', '夕', '朝', '夜'],
            '天气': ['风', '雨', '雪', '霜', '露', '雾', '雷', '电'],
            '植物': ['花', '草', '树', '柳', '梅', '兰', '竹', '菊'],
            '动物': ['鸟', '燕', '莺', '雁', '鹤', '鹊', '鸥', '鸠'],
            '山水': ['山', '水', '江', '河', '湖', '海', '溪', '峰'],
            '建筑': ['楼', '台', '亭', '阁', '桥', '寺', '塔', '轩'],
            '情感': ['愁', '思', '忆', '念', '忧', '喜', '悲', '欢']
        }

    def _init_theme_vectors(self) -> Dict[str, List[float]]:
        """初始化主题词向量"""
        return {
            '春': [0.9, 0.1, 0.3], '夏': [0.8, 0.2, 0.4],
            '秋': [0.7, 0.3, 0.5], '冬': [0.6, 0.4, 0.6],
            '风': [0.1, 0.9, 0.2], '花': [0.3, 0.8, 0.1],
            '月': [0.2, 0.3, 0.8], '日': [0.8, 0.3, 0.2],
            '山': [0.7, 0.1, 0.4], '水': [0.3, 0.6, 0.1],
            '情': [0.1, 0.4, 0.9], '志': [0.4, 0.1, 0.9]
        }

    def generate_poem(self, model, theme_word: str, poem_type: str = '五言绝句', 
                     max_len: int = 32, strict: bool = True, temperature: float = 1.0) -> str:
        """
        增强的主题约束古诗生成（优化为API接口）
        :param model: 训练好的模型
        :param theme_word: 主题词
        :param poem_type: 诗体类型
        :param max_len: 最大长度
        :param strict: 是否严格模式
        :param temperature: 生成温度
        :return: 生成的诗（格式化字符串）
        """
        model.eval()
        
        # 分词处理
        try:
            words = list(jieba.cut(theme_word))
        except Exception:
            words = list(theme_word)
        
        # 扩展主题相关词汇
        theme_words = []
        for word in words:
            if word in self.token2idx:
                theme_words.extend(self._expand_theme(word))
            else:
                theme_words.append(self._get_synonym_for_unknown_word(word))

        theme_set = set(theme_words)
        
        # 初始化生成
        start_token = self.token2idx.get(theme_word, self.token2idx['<START>'])
        input_idx = torch.tensor([start_token]).unsqueeze(0).to(self.device)
        generated = [self.idx2token.get(start_token, theme_word)]
        
        # 约束生成过程
        for step in range(max_len - 1):
            with torch.no_grad():
                output, _ = model(input_idx, None)
            
            # 应用温度调节
            logits = output[0, -1] / temperature
            
            # 主题约束采样
            top_k = min(50, len(self.token2idx))
            top_indices = torch.topk(logits, top_k).indices.tolist()
            
            # 优先选择主题相关词
            valid_indices = [
                i for i in top_indices 
                if self.idx2token[i] in theme_set or 
                self._is_theme_related(self.idx2token[i], theme_words)
            ]
            
            # 回退机制
            if not valid_indices:
                valid_indices = top_indices[:10]
            
            # 采样下一个token
            probs = torch.softmax(logits[valid_indices], -1)
            next_idx = torch.multinomial(probs, 1).item()
            next_token_idx = valid_indices[next_idx]
            
            # 结束条件
            if next_token_idx == self.token2idx.get('<END>', -1):
                break
                
            generated.append(self.idx2token[next_token_idx])
            input_idx = torch.tensor([next_token_idx]).unsqueeze(0).to(self.device)
        
        # 规范化处理并返回格式化结果
        return self.normalize_poem(''.join(generated[1:]), poem_type, strict, theme_word)

    def normalize_poem(self, raw_text: str, poem_type: str, strict: bool, theme_word: str) -> str:
        """规范化生成的诗句，返回格式化文本"""
        lines = self._basic_formatting(raw_text, poem_type)
        
        if theme_word:
            lines = self._enhance_theme(lines, theme_word)
        
        if strict:
            lines = self._strict_normalization(lines, poem_type)
        
        title = self._generate_title(lines, theme_word)
        
        # 格式化输出（添加空格和换行）
        formatted_lines = []
        for line in lines:
            line = line.replace('，', '').replace('。', '')
            spaced_line = ' '.join(list(line))
            formatted_lines.append(spaced_line)
        
        return f"{title}\n\n" + "\n".join(formatted_lines)

    # 辅助方法（保持原有逻辑不变）
    def _basic_formatting(self, text: str, poem_type: str) -> List[str]:
        char_per_line = 5 if '五言' in poem_type else 7
        cleaned = re.sub(r'[^\u4e00-\u9fa5]', '', text)
        
        lines = []
        current_line = []
        
        for char in cleaned:
            if len(current_line) < char_per_line:
                current_line.append(char)
            else:
                punctuation = '，' if len(lines) % 2 == 0 else '。'
                lines.append(''.join(current_line) + punctuation)
                current_line = [char]
        
        if current_line:
            lines.append(''.join(current_line) + ('。' if len(lines) % 2 == 1 else '。'))
        
        while len(lines) < 4:
            placeholder = random.choice(['望', '观', '见', '看']) + '〇' * (char_per_line - 1)
            lines.append(placeholder + '。')
        
        return lines[:4]

    def _enhance_theme(self, lines: List[str], theme_word: str) -> List[str]:
        if not any(theme_word in line for line in lines):
            replace_pos = random.choice([0, len(lines)-1])
            line = lines[replace_pos]
            if len(line) >= 2:
                pos = random.randint(0, len(line)-2)
                lines[replace_pos] = line[:pos] + theme_word + line[pos+1:]
        
        season_words = [w for w in self.thesaurus['季节'] if any(w in line for line in lines)]
        if season_words:
            primary_season = season_words[0]
            for i, line in enumerate(lines):
                for season in self.thesaurus['季节']:
                    if season != primary_season and season in line:
                        lines[i] = line.replace(season, primary_season)
        
        return lines

    def _strict_normalization(self, lines: List[str], poem_type: str) -> List[str]:
        template = self.meter_templates.get(poem_type, {})
        patterns = template.get('patterns', [])
        rhyme_pos = template.get('rhyme_pos', [])
        
        if len(lines) >= 2 and rhyme_pos:
            rhyme_group = None
            for pos in rhyme_pos:
                if pos < len(lines):
                    line = lines[pos]
                    rhyme_char = line[-2] if line[-1] == '。' else line[-1]
                    rhyme_group = self._find_rhyme_group(rhyme_char) or rhyme_group
			
            if rhyme_group:
                for pos in rhyme_pos:
                    if pos < len(lines):
                        line = lines[pos]
                        last_char = line[-2] if line[-1] == '。' else line[-1]
                        if self._find_rhyme_group(last_char) != rhyme_group:
                            candidates = self._get_rhyme_candidates(rhyme_group)
                            if candidates:
                                new_char = random.choice(candidates)
                                lines[pos] = line[:-1] + new_char + line[-1]
		
        if len(lines) == len(patterns):
            for i in range(len(lines)):
                line = lines[i]
                pattern = patterns[i]
                
                new_line = []
                for j in range(min(len(line)-1, len(pattern))):
                    char = line[j]
                    expected = pattern[j]
                    
                    if self.pingze_map[char] != expected:
                        synonyms = self._find_synonyms(char)
                        for syn in synonyms:
                            if self.pingze_map[syn] == expected:
                                char = syn
                                break
                    
                    new_line.append(char)
                
                new_line.append(line[-1])
                lines[i] = ''.join(new_line)
		
        common_chars = set(self.pingze_map.keys()).union(set(self.thesaurus.keys()))
        for i, line in enumerate(lines):
            new_line = []
            for char in line:
                if char not in common_chars and char not in ['，', '。', '〇']:
                    synonyms = self._find_synonyms(char)
                    char = synonyms[0] if synonyms else '〇'
                new_line.append(char)
            lines[i] = ''.join(new_line)
        
        return lines

    def _generate_title(self, lines: List[str], theme_word: str) -> str:
        if theme_word:
            suffixes = ['吟', '颂', '赋', '词', '曲', '谣', '诗', '歌', '叹', '篇', '章']
            return theme_word + random.choice(suffixes)
        
        first_line = lines[0].strip('，。')
        if len(first_line) >= 2:
            suffixes = ['即景', '有感', '杂咏', '偶成', '抒怀', '寄情', '漫兴', '遣怀']
            return first_line[:2] + random.choice(suffixes)
        return first_line[0] + '吟'

    def _find_rhyme_group(self, char: str) -> str:
        for group, chars in self.rhyme_groups.items():
            if char in chars:
                return group
        return ''

    def _get_rhyme_candidates(self, group: str) -> List[str]:
        return self.rhyme_groups.get(group, [])

    def _find_synonyms(self, char: str) -> List[str]:
        for category, words in self.thesaurus.items():
            if char in words:
                return words
        return [char]

    def _expand_theme(self, word: str) -> List[str]:
        """扩展主题相关词汇"""
        expanded = [word]
        for category, words in self.thesaurus.items():
            if word in words:
                expanded.extend(words)
                break
        return expanded

    def _is_theme_related(self, word: str, theme_words: List[str]) -> bool:
        """判断词语是否与主题相关"""
        if word in theme_words:
            return True
        
        # 检查是否在同一主题类别中
        for category, words in self.thesaurus.items():
            if word in words and any(t in words for t in theme_words):
                return True
        
        return False

    def _get_synonym_for_unknown_word(self, word: str) -> str:
        """为未知词获取同义词"""
        for category, words in self.thesaurus.items():
            if word in words:
                return random.choice(words)
        return word
    



In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import numpy as np
import re
import json
from collections import defaultdict
from typing import List, Dict, Tuple

# ====================== Transformer模型 ======================
class PoemTransformer(nn.Module):
    def __init__(self, vocab_size, embed_size=128, num_heads=4, num_layers=2, ff_dim=256, dropout=0.1):
        super().__init__()
        # 词嵌入层，padding_idx使用词表中的<PAD>索引（与前面代码保持一致）
        self.embed = nn.Embedding(vocab_size, embed_size, padding_idx=2)  # 前面代码中<PAD>固定为2
        self.pos_encoder = nn.Parameter(torch.randn(1, 100, embed_size))
        
        self.transformer = nn.Transformer(
            d_model=embed_size,
            nhead=num_heads,
            num_encoder_layers=num_layers,
            num_decoder_layers=num_layers,
            dim_feedforward=ff_dim,
            dropout=dropout,
            batch_first=True
        )
        
        self.fc_out = nn.Linear(embed_size, vocab_size)
        
    def generate_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        return mask.float().masked_fill(mask == 0, float('-inf'))
        
    def forward(self, src, tgt):
        src = self.embed(src) + self.pos_encoder[:, :src.size(1)]
        tgt = self.embed(tgt) + self.pos_encoder[:, :tgt.size(1)]
        
        tgt_mask = self.generate_mask(tgt.size(1)).to(tgt.device)
        output = self.transformer(src, tgt, tgt_mask=tgt_mask)
        return self.fc_out(output), None

# ====================== 古诗生成器 ======================
class PoemGenerator:
    def __init__(self, token2idx, idx2token, rhyme_dict=None, theme_words=None, device='cpu'):
        # 直接使用前面代码的词表映射（核心：保持与token2idx.json/idx2token.json一致）
        self.token2idx = token2idx  # 来自token2idx.json
        self.idx2token = idx2token  # 来自idx2token.json
        self.device = device
        
        # 押韵字典和主题词库（可选，不影响词表统一性）
        self.rhyme_dict = rhyme_dict or self._build_default_rhyme_dict()
        self.theme_words = theme_words or self._build_default_theme_words()
        
        # 通用词库严格从词表中提取（确保只用词表中的字）
        self.common_words = [char for char in token2idx 
                            if char not in ['<START>', '<END>', '<PAD>']]  # 排除特殊标记

    def _build_default_rhyme_dict(self) -> Dict[str, str]:
        rhyme_dict = defaultdict(str)
        categories = {
            'a': ['花', '家', '霞', '沙', '茶', '麻', '涯', '瓜', '华'],
            'o': ['歌', '多', '河', '波', '罗', '梭', '柯', '戈', '磨'],
            'i': ['衣', '期', '池', '知', '时', '丝', '诗', '棋', '词'],
            'u': ['无', '图', '湖', '孤', '壶', '途', '苏', '书', '珠'],
            'an': ['山', '间', '闲', '还', '颜', '关', '湾', '环', '班'],
            'ang': ['长', '香', '光', '阳', '堂', '芳', '昌', '央', '刚'],
            'eng': ['风', '声', '灯', '僧', '升', '生', '星', '耕', '更'],
            'ong': ['东', '同', '风', '中', '空', '红', '通', '工', '蓬']
        }
        for rhyme, chars in categories.items():
            for char in chars:
                rhyme_dict[char] = rhyme
        return rhyme_dict

    def _build_default_theme_words(self) -> Dict[str, List[str]]:
        return {
            '春': ['春', '花', '柳', '燕', '莺', '风', '雨', '暖', '芽', '绿'],
            '夏': ['夏', '荷', '蝉', '蛙', '阳', '热', '雨', '莲', '塘', '炎'],
            '秋': ['秋', '月', '菊', '雁', '霜', '叶', '凉', '风', '黄', '收'],
            '冬': ['冬', '雪', '梅', '冰', '寒', '风', '凌', '松', '白', '冷'],
            '月': ['月', '夜', '明', '光', '影', '秋', '江', '湖', '镜', '望'],
            '山': ['山', '峰', '岭', '石', '云', '雾', '高', '青', '翠', '险'],
            '水': ['水', '河', '湖', '海', '波', '浪', '流', '清', '碧', '深']
        }

    def generate_poem(self, model, theme_word: str, poem_type: str = '五言绝句', 
                     max_len: int = 32, strict: bool = True, temperature: float = 1.0) -> str:
        model.eval()
        
        # 每句字数（五言/七言）
        chars_per_line = 5 if '五言' in poem_type else 7
        lines_count = 4
        
        # 主题词严格限制在词表内
        if theme_word not in self.token2idx:
            # 如果主题词不在词表，从词表中找近似词（确保不超出词表范围）
            theme_word = random.choice(self.common_words[:5])
        
        # 主题相关词仅从词表中筛选
        theme_related = self.theme_words.get(theme_word, [theme_word])
        theme_related = [w for w in theme_related if w in self.token2idx]  # 过滤词表外的字
        theme_related.extend(self.common_words)
        theme_set = set(theme_related)
        
        # 初始化生成（严格使用词表中的索引）
        start_token = self.token2idx[theme_word]  # 直接从词表取索引
        input_idx = torch.tensor([[start_token]]).to(self.device)
        generated = [self.idx2token[start_token]]  # 从词表取对应字
        
        # 押韵控制（仅用词表内的字）
        rhyme_char = None
        if strict:
            rhyme_pool = [char for char in self.rhyme_dict if char in self.token2idx]
            if rhyme_pool:
                rhyme_char = random.choice(rhyme_pool)
                rhyme_category = self.rhyme_dict[rhyme_char]
        
        # 生成逻辑（全程限制在词表内）
        current_line = 0
        current_char = 0
        
        while len(generated) < max_len and current_line < lines_count:
            with torch.no_grad():
                output, _ = model(input_idx, input_idx)
            
            # 温度调节
            logits = output[0, -1] / temperature
            probs = F.softmax(logits, dim=-1)
            
            # 候选词严格来自词表
            top_k = 50
            top_indices = torch.topk(probs, top_k).indices.tolist()
            valid_indices = [i for i in top_indices if self.idx2token[i] in theme_set]
            
            # 押韵字也必须在词表内
            if strict and rhyme_char and (current_line == 1 or current_line == 3) and current_char == chars_per_line - 1:
                rhyme_candidates = [i for i in valid_indices if self.rhyme_dict.get(self.idx2token[i]) == rhyme_category]
                if rhyme_candidates:
                    valid_indices = rhyme_candidates
            
            if not valid_indices:
                valid_indices = top_indices[:10]
            
            # 采样（确保结果在词表内）
            probs = F.softmax(logits[valid_indices], dim=-1)
            next_idx = torch.multinomial(probs, 1).item()
            next_token_idx = valid_indices[next_idx]
            next_token = self.idx2token[next_token_idx]  # 从词表取字
            
            if next_token in ['<END>', '<PAD>']:
                break
                
            generated.append(next_token)
            input_idx = torch.tensor([[next_token_idx]]).to(self.device)
            
            # 更新计数
            current_char += 1
            if current_char >= chars_per_line:
                current_char = 0
                current_line += 1
        
        # 格式化输出（仅包含词表内的字）
        return self.format_poem(generated, chars_per_line, lines_count)
    
    def format_poem(self, tokens: List[str], chars_per_line: int, lines_count: int) -> str:
        poem = []
        current_line = ""
        
        for token in tokens:
            if token in ['<START>', '<END>', '<PAD>']:
                continue
            current_line += token
            if len(current_line) == chars_per_line:
                poem.append(current_line)
                current_line = ""
            if len(poem) >= lines_count:
                break
        
        # 不足时用词表中的字补充
        while len(poem) < lines_count:
            filler = "".join(random.sample(self.common_words, min(chars_per_line, len(self.common_words))))
            poem.append(filler[:chars_per_line])
        
        return "\n".join(poem)

# ====================== 词表加载（与前面代码完全一致） ======================
def load_vocab(vocab_path: str) -> Tuple[Dict[str, int], Dict[int, str]]:
    """完全复用前面代码的词表加载逻辑，确保使用同一个token2idx.json和idx2token.json"""
    try:
        with open(vocab_path, 'r', encoding='utf-8') as f:
            token2idx = json.load(f)  # 加载前面代码的token2idx.json
        idx2token = {v: k for k, v in token2idx.items()}  # 保持与前面代码一致的反向映射
        print(f"成功加载词表（与前面代码共用），大小: {len(token2idx)}")
        return token2idx, idx2token
    except FileNotFoundError:
        print(f"未找到词表文件，自动生成默认词表（与前面代码格式一致）")
        # 自动生成时，格式与前面代码完全相同（特殊标记索引固定）
        default_chars = "春夏秋冬风雨雪花草树木山石水云日月星天地人"
        token2idx = {
            "<START>": 0,  # 与前面代码一致
            "<END>": 1,    # 与前面代码一致
            "<PAD>": 2     # 与前面代码一致
        }
        for i, char in enumerate(default_chars, 3):
            token2idx[char] = i
        idx2token = {v: k for k, v in token2idx.items()}
        # 自动保存为与前面代码相同的文件，确保后续复用
        with open(vocab_path, 'w', encoding='utf-8') as f:
            json.dump(token2idx, f, ensure_ascii=False)
        return token2idx, idx2token

def load_rhyme_dict(rhyme_path: str) -> Dict[str, str]:
    try:
        with open(rhyme_path, 'r', encoding='utf-8') as f:
            return json.load(f)
    except FileNotFoundError:
        print(f"未找到押韵字典，使用默认值")
        return {}

def load_theme_words(theme_path: str) -> Dict[str, List[str]]:
    try:
        with open(theme_path, 'r', encoding='utf-8') as f:
            return json.load(f)
    except FileNotFoundError:
        print(f"未找到主题词库，使用默认值")
        return {}

# ====================== 主函数（确保词表统一） ======================
def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"使用设备: {device}")
    
    
    token2idx, idx2token = load_vocab('token2idx.json')  
    
    # 加载其他辅助文件（不影响词表统一性）
    rhyme_dict = load_rhyme_dict("rhyme_dict.json")
    theme_words = load_theme_words("theme_words.json")
    
    # 初始化模型（基于统一的词表大小）
    vocab_size = len(token2idx)
    model = PoemTransformer(vocab_size).to(device)
    
    # 初始化生成器（严格传入统一的词表）
    generator = PoemGenerator(
        token2idx=token2idx,
        idx2token=idx2token,
        rhyme_dict=rhyme_dict,
        theme_words=theme_words,
        device=device
    )
    
    # 测试生成（结果完全基于统一词表）
    print("\n===== 生成测试（基于统一词表） =====")
    test_cases = [
        {"theme": "春", "type": "五言绝句", "temp": 0.8},
        {"theme": "月", "type": "七言绝句", "temp": 0.9},
        {"theme": "山", "type": "五言绝句", "temp": 1.0},
        {"theme": "水", "type": "七言绝句", "temp": 1.0}
    ]
    
    for case in test_cases:
        print(f"\n----- 主题: {case['theme']} | 诗体: {case['type']} -----")
        poem = generator.generate_poem(
            model=model,
            theme_word=case['theme'],
            poem_type=case['type'],
            temperature=case['temp']
        )
        print(poem)

if __name__ == "__main__":
    main()

使用设备: cpu
成功加载词表（与前面代码共用），大小: 105569
未找到押韵字典，使用默认值
未找到主题词库，使用默认值

===== 生成测试（基于统一词表） =====

----- 主题: 春 | 诗体: 五言绝句 -----
春望夫双羽
庶买得逆耳
宁独冷凝觌
荻花风喷雪

----- 主题: 月 | 诗体: 七言绝句 -----
谬以湘江祖帐薪
不动携策种豆伴
不得闲吠二十篇
昔初选调剥皮蒙

----- 主题: 山 | 诗体: 五言绝句 -----
气含星分深
小龙咽津莫
外貌囘麦秋
默有浓于当

----- 主题: 水 | 诗体: 七言绝句 -----
侐能荐真细事功
野烧稠叠客于怀
怀今公清日下唯
飞霙逐境听逢逢


####温度+押韵

In [10]:
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import os
import numpy as np
from typing import List, Dict

# ====================== 1. 加载词表 ======================
def load_vocab(file_path: str) -> Dict[str, int]:
    """加载词表映射文件"""
    with open(file_path, 'r', encoding='utf-8') as f:
        return json.load(f)

# ====================== 2. 定义 Transformer 模型 ======================
class PoemTransformer(nn.Module):
    """Transformer模型（保持原结构，确保兼容性）"""
    def __init__(self, vocab_size, embed_size=256, num_heads=8, num_layers=6, ff_dim=512, dropout=0.1):
        super(PoemTransformer, self).__init__()
        
        # 词向量层
        self.embed = nn.Embedding(vocab_size, embed_size, padding_idx=0)  # 主程序中修正padding_idx
        
        # 位置编码
        self.positional_encoding = nn.Parameter(torch.rand(1, 5000, embed_size))
        
        # 完整Transformer结构
        self.transformer = nn.Transformer(
            d_model=embed_size,
            nhead=num_heads,
            num_encoder_layers=num_layers,
            num_decoder_layers=num_layers,
            dim_feedforward=ff_dim,
            dropout=dropout,
            batch_first=True
        )
        
        # 输出层
        self.fc_out = nn.Linear(embed_size, vocab_size)
        self.dropout = nn.Dropout(dropout)
        
    def generate_square_subsequent_mask(self, sz):
        """生成自回归掩码"""
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask
        
    def forward(self, src, tgt):
        # 嵌入和位置编码
        src = self.dropout(self.embed(src) + self.positional_encoding[:, :src.size(1), :])
        tgt = self.dropout(self.embed(tgt) + self.positional_encoding[:, :tgt.size(1), :])
        
        # 生成掩码
        tgt_mask = self.generate_square_subsequent_mask(tgt.size(1)).to(tgt.device)
        src_mask = self.generate_square_subsequent_mask(src.size(1)).to(src.device)
        
        # 通过Transformer
        output = self.transformer(src, tgt, src_mask=src_mask, tgt_mask=tgt_mask)
        
        return self.fc_out(output), None

# ====================== 3. 古诗生成器（支持押韵+五言/七言） ======================
class PoemGenerator:
    def __init__(self,
                 token2idx: Dict[str, int],
                 idx2token: Dict[int, str],
                 model: torch.nn.Module,
                 device: str = "cpu"):

        self.token2idx = token2idx
        self.idx2token = idx2token
        self.model = model.to(device)
        self.device = device
        self.associations = self._load_associations("word_associations.json")
        
        # 押韵字典（简化版平水韵，按韵母分类）
        self.rhyme_dict = self._build_rhyme_dict()
        
        # 初始化分词器
        try:
            import thulac
            self.thulac = thulac.thulac(seg_only=True)
        except ImportError:
            print("警告: 未安装THULAC分词器，将使用简单分词方法")
            self.thulac = None

    def _build_rhyme_dict(self) -> Dict[str, str]:
        """构建简化版押韵字典（按现代拼音韵母分类）"""
        rhyme_dict = {}
        # 基础韵部（覆盖常用字）
        rhyme_categories = {
            'a': ['花', '家', '霞', '沙', '茶', '麻', '涯', '瓜', '华', '芽', '佳', '斜'],
            'o': ['歌', '多', '河', '波', '罗', '柯', '戈', '磨', '蓑', '荷', '婆'],
            'i': ['衣', '期', '池', '知', '时', '丝', '诗', '棋', '词', '啼', '溪', '西'],
            'u': ['无', '图', '湖', '孤', '壶', '途', '苏', '书', '珠', '浮', '奴'],
            'an': ['山', '间', '闲', '还', '颜', '关', '湾', '环', '班', '丹', '残', '天'],
            'ang': ['长', '香', '光', '阳', '堂', '芳', '昌', '央', '刚', '桑', '忙'],
            'eng': ['风', '声', '灯', '僧', '升', '生', '星', '耕', '更', '情', '城'],
            'ong': ['东', '同', '中', '空', '红', '通', '工', '蓬', '浓', '松', '龙']
        }
        for rhyme, chars in rhyme_categories.items():
            for char in chars:
                rhyme_dict[char] = rhyme
        return rhyme_dict

    def _load_associations(self, path: str) -> Dict:
        """加载关联词库"""
        if os.path.exists(path):
            with open(path, "r", encoding="utf-8") as f:
                return json.load(f)
        print(f"警告: 未找到关联词库文件 {path}，将使用空字典。")
        return {}

    def segment_with_thulac(self, text: str) -> List[str]:
        """分词方法"""
        if self.thulac:
            result = self.thulac.cut(text)
            return [word for word, _ in result]
        return list(text)  # 回退到单字分词

    def get_associated_words(self, word: str, top_k: int = 5) -> List[str]:
        """获取输入词的关联词列表"""
        if word not in self.associations:
            return []
        related_words = sorted(
            self.associations[word].items(),
            key=lambda x: x[1],
            reverse=True
        )
        return [w for w, _ in related_words[:top_k]]

    def generate_poem(self, 
                     prompt: str, 
                     style: str = "五言绝句",  # 新增：支持"五言绝句"或"七言绝句"
                     max_repeat: int = 2,
                     temperature: float = 0.6) -> str:
        """
        生成古诗，支持押韵和句式切换
        style: 句式，"五言绝句"（5字/句）或"七言绝句"（7字/句）
        """
        # 1. 确定句式参数
        chars_per_line = 5 if "五言" in style else 7
        lines_count = 4  # 绝句固定4句
        total_chars = chars_per_line * lines_count  # 总字数
        
        # 2. 分词与联想
        segmented_words = self.segment_with_thulac(prompt)
        print(f"分词结果: {segmented_words}")
        
        # 收集候选词
        all_candidates = []
        for word in segmented_words:
            all_candidates.append(word)
            all_candidates.extend(self.get_associated_words(word, top_k=5))
        
        # 过滤有效词，补充主题相关常用字
        valid_words = list(set(w for w in all_candidates if w in self.token2idx))
        if not valid_words:
            valid_words = [w for w in list(self.token2idx.keys())[:10] if w not in ['<PAD>', '<START>']]
        
        # 根据主题补充候选词
        theme_related = {
            '春天': ['春', '风', '花', '柳', '燕', '啼', '暖', '芽'],
            '秋天': ['秋', '霜', '叶', '月', '雁', '寒', '枫', '露'],
            '夏天': ['夏', '荷', '蝉', '雨', '荫', '热', '蛙', '莲'],
            '冬天': ['冬', '雪', '寒', '梅', '冰', '霜', '风', '松']
        }.get(prompt, ['日', '云', '山', '水', '天', '地', '人', '心'])
        
        valid_words.extend([w for w in theme_related if w in self.token2idx and w not in valid_words])
        valid_words = list(set(valid_words))  # 去重
        print(f"有效候选词: {valid_words}")

        # 3. 选择韵部（随机选一个包含候选词的韵部）
        candidate_rhymes = set()
        for word in valid_words:
            if word in self.rhyme_dict:
                candidate_rhymes.add(self.rhyme_dict[word])
        if not candidate_rhymes:
            candidate_rhymes = set(self.rhyme_dict.values())  # 兜底：用所有韵部
        target_rhyme = random.choice(list(candidate_rhymes))  # 目标韵部
        rhyme_words = [w for w in valid_words if self.rhyme_dict.get(w) == target_rhyme]
        if not rhyme_words:
            rhyme_words = [w for w in self.rhyme_dict if self.rhyme_dict[w] == target_rhyme]  # 兜底
        print(f"目标韵部: {target_rhyme}, 押韵候选字: {rhyme_words[:5]}")

        # 4. 初始化输入
        input_ids = [self.token2idx[valid_words[0]]] if valid_words else [self.token2idx['<START>']]
        generated = [self.idx2token[idx] for idx in input_ids]

        # 5. 生成诗句（核心逻辑）
        self.model.eval()
        with torch.no_grad():
            src = torch.tensor([input_ids]).to(self.device)
            tgt = src.clone()
            
            while len(generated) < total_chars:
                # 当前位置信息（第几句、句中第几个字）
                current_total = len(generated)
                current_line = current_total // chars_per_line  # 0-3
                current_pos_in_line = current_total % chars_per_line  # 0到chars_per_line-1
                
                # 模型预测
                output, _ = self.model(src, tgt)
                next_probs = output[:, -1, :]  # 最后一个位置的预测
                
                # 处理重复字符
                last_char = generated[-1] if generated else ''
                repeat_count = 1
                for c in reversed(generated[:-1]):
                    if c == last_char:
                        repeat_count += 1
                    else:
                        break
                if repeat_count >= max_repeat and last_char in self.token2idx:
                    repeat_idx = self.token2idx[last_char]
                    next_probs[0, repeat_idx] *= 0.1  # 惩罚重复字符
                    next_probs = F.softmax(next_probs, dim=-1)
                
                # 押韵控制：第二、四句末尾字必须押韵（绝句规则）
                need_rhyme = (current_line in [1, 3]) and (current_pos_in_line == chars_per_line - 1)
                if need_rhyme:
                    # 只保留押韵候选字
                    rhyme_indices = [self.token2idx[w] for w in rhyme_words if w in self.token2idx]
                    if rhyme_indices:
                        # 过滤非押韵字的概率
                        mask = torch.ones_like(next_probs)
                        for idx in rhyme_indices:
                            mask[0, idx] = 0  # 押韵字保留概率
                        next_probs = next_probs * (1 - mask)  # 非押韵字概率置0
                        next_probs = F.softmax(next_probs, dim=-1)
                
                # 温度采样
                probs = F.softmax(next_probs / temperature, dim=-1)
                next_idx = torch.multinomial(probs, 1).item()
                next_word = self.idx2token.get(next_idx, '')
                
                # 过滤无效字符
                invalid_tokens = ['<END>', '<UNK>', '<PAD>', '<START>']
                if next_word in invalid_tokens:
                    continue
                
                # 重复检查（二次保险）
                if generated and next_word == generated[-1] and repeat_count >= max_repeat:
                    alternatives = [w for w in valid_words if w != next_word and w not in invalid_tokens]
                    if alternatives:
                        next_word = random.choice(alternatives)
                        next_idx = self.token2idx[next_word]
                
                # 添加到结果
                generated.append(next_word)
                input_ids.append(next_idx)
                tgt = torch.tensor([input_ids]).to(self.device)
        
        # 6. 格式化输出（按句式分割）
        poem = "".join(generated)
        for token in invalid_tokens:
            poem = poem.replace(token, "")
        
        # 按句分割
        lines = []
        for i in range(0, min(len(poem), total_chars), chars_per_line):
            line = poem[i:i+chars_per_line]
            if len(line) == chars_per_line:
                lines.append(line)
        
        # 不足4句则补充
        if len(lines) < lines_count and valid_words:
            for _ in range(lines_count - len(lines)):
                filler = "".join(random.sample(valid_words, min(chars_per_line, len(valid_words))))
                lines.append(filler[:chars_per_line])
        
        return "\n".join(lines[:lines_count])

# ====================== 4. 模型加载与主程序 ======================
def load_model(vocab_size, model_path=None, device='cpu'):
    """加载模型并处理权重兼容问题"""
    model = PoemTransformer(vocab_size).to(device)
    if model_path and os.path.exists(model_path):
        try:
            model.load_state_dict(torch.load(model_path, map_location=device), strict=False)
            print(f"成功加载预训练模型: {model_path}")
        except Exception as e:
            print(f"加载模型失败: {e}")
            print("将使用随机初始化权重")
    else:
        print("警告: 未找到预训练模型，将使用随机初始化权重")
    return model

if __name__ == "__main__":
    # 设置设备
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # 加载词表
    print("正在加载词表...")
    token2idx = load_vocab('token2idx.json')
    idx2token = {int(k): v for k, v in load_vocab('idx2token.json').items()}
    vocab_size = len(token2idx)
    print(f"词表加载完成，大小: {vocab_size}")
    
    # 加载模型并修正padding_idx
    print("正在加载模型...")
    model = load_model(vocab_size, 'poem_transformer.pth', device)
    model.embed.padding_idx = token2idx.get('<PAD>', 0)  # 动态修正padding_idx
    print("模型加载完成")
    
    # 初始化生成器
    print("正在初始化生成器...")
    generator = PoemGenerator(
        token2idx=token2idx,
        idx2token=idx2token,
        model=model,
        device=device
    )
    print("生成器初始化完成")
    
    # 生成古诗（测试不同句式和主题）
    print("\n=== 生成五言绝句（春天主题） ===")
    poem = generator.generate_poem("春天", style="五言绝句", temperature=0.6)
    print(poem)
    
    print("\n=== 生成七言绝句（月亮主题） ===")
    poem = generator.generate_poem("月亮", style="七言绝句", temperature=0.7)
    print(poem)

正在加载词表...
词表加载完成，大小: 105569
正在加载模型...
成功加载预训练模型: poem_transformer.pth
模型加载完成
正在初始化生成器...
警告: 未找到关联词库文件 word_associations.json，将使用空字典。
Model loaded succeed
生成器初始化完成

=== 生成五言绝句（春天主题） ===
分词结果: ['春天']
有效候选词: ['风', '春', '柳', '燕', '暖', '花', '芽', '春天', '啼']
目标韵部: eng, 押韵候选字: ['风']
风其语俯仰
之间白登围
振旅有力许
由万象森罗

=== 生成七言绝句（月亮主题） ===
分词结果: ['月亮']
有效候选词: ['云', '天', '花藏', '人', '豪富', '日', '心', '遥隔', '<UNK>', '水', '分丛', '地', '以官', '山', '<END>', '上苍']
目标韵部: an, 押韵候选字: ['天', '山']
云纹生秋向成歌
荆插天空寒啸虎
苒去除起石饥涎
苒一登水塘分别
