In [1]:
from poem.genre import Genre

dataset_directory = 'data/Poetry/诗歌数据集'

current_genre = Genre.WUJUE
# current_genre = Genre.QIJUE
# current_genre = Genre.WULV
# current_genre = Genre.QILV

In [2]:
# Get all file nams
import os

def get_all_files(base_dir):
    all_files = []
    for root, dirs, files in os.walk(base_dir):
        for file in files:
            all_files.append(os.path.join(root, file))
    return all_files

base_dir = os.path.expanduser(dataset_directory)
poem_files = get_all_files(base_dir) # format #{dataset_directory}/XXX.txt

### Define the checking functions for single poem file

In [3]:
# Select one file to demonstrate
demo_file = poem_files[16]

In [4]:
# Read content for one file
#
# Read it to pd.Series and only extract the content

import pandas as pd

def read_file_to_pandas(file_path: str, genre_name: str):
    if not os.path.exists(file_path):
        raise FileNotFoundError(f"File {file_path} does not exist.")

    df = pd.read_csv(file_path)
    filter_by_genre = df[df["体裁"].astype(str).str.contains(genre_name, na=False)].copy()
    poems = filter_by_genre['内容']
    return poems

one_dynasty_poems = read_file_to_pandas(demo_file, current_genre.genre_name)
one_dynasty_poems.shape

(2711,)

In [5]:
# Define some checking functions for the single poem text

DEFAULT_VALID_PUNCTATIONS = set("！，？。")

def check_poem_punctuation(text: str, positions: list[int], valid_punctuations=None) -> bool:
    """
    Check the punctuation in the fixed positions of the single poem text.

    :param text: string, the single poem text
    :param positions: list, the fixed positions to check
    :param valid_punctuations: set, optional, the valid punctuation characters

    :return: bool, whether the text is valid. If False, it means there are other characters
             found in the fixed positions.
    """
    if valid_punctuations is None:
        valid_punctuations = DEFAULT_VALID_PUNCTATIONS

    chars_at_fixed_positions = set(text[i] for i in positions if i < len(text))
    invalid_chars = chars_at_fixed_positions - valid_punctuations
    return not invalid_chars

def check_poem_length(text: str, expected_length=24) -> bool:
    """
    Check if the poem text has the expected length.

    :param text: string, the single poem text
    :param expected_length: int, the expected length of the poem text

    :return: bool, whether the text has the expected length
    """
    return len(text) == expected_length

def check_poem(text: str, rows: int, cols: int) -> bool:
    """
    Check if the poem text is valid by checking both punctuation and length.

    :param text: string, the single poem text
    :param rows: int, number of rows in the poem.
                 E.g. 4 for 绝句
    :param cols: int, number of columns in the single poem row，not including the punctuation.
                 E.g. 5 for 五言

    :return: bool, whether the text is valid
    """

    # At an unknown scene, some items are not string, e.g. float('nan')
    if type(text) is not str:
        return False

    punctuation_positions = [(i + 1) * (cols + 1) - 1 for i in range(rows)]
    poem_length = rows * (cols + 1)

    return len(text) == poem_length and check_poem_punctuation(text, punctuation_positions)

In [6]:
# Demonstrate the checking functions on one poem text
check_poem(one_dynasty_poems.iloc[0], current_genre.rows, current_genre.cols)

True

In [7]:
# Demonstrate the checking functions
from functools import partial

def check_poems(poem_texts: pd.Series, genre: Genre):
    """
    Check the poems Series using the above checking functions and return the mask values (True/False) for each poem text.

    :param poem_texts: pd.Series, the series of poem texts, it supports MultiIndex
    :param genre: Genre(Enum), the genre rule applied to check
    :return: pd.Series of bool, the mask values for each poem text, with the same index as input
    """
    check_for_current_genre = partial(check_poem, rows=genre.rows, cols=genre.cols)
    mask = poem_texts.str[:genre.length].apply(check_for_current_genre)
    return mask

def report_check_results(mask: pd.Series):
    return len(mask), mask.sum(), f"{mask.mean() * 100:.2f}%"

mask = check_poems(one_dynasty_poems, current_genre)
report_check_results(mask)

(2711, 2710, '99.96%')

### Apply the checking functions to all poem files

In [8]:
# Read all files into pandas Series, with MultiIndex

def extract_dynasty_from_filename(file_path: str) -> str:
    """
    Extract the dynasty from the file name.

    :param file_path: string, the file path

    :return: string, the dynasty extracted from the file name
    """
    base_name = os.path.basename(file_path)
    dynasty = base_name.split('.')[0]
    return dynasty

# poem_files: list[str], format #{dataset_directory}/XXX.txt

# Read all files
list_of_poems = [read_file_to_pandas(file_path, current_genre.genre_name)  # with default genre
                 for file_path in poem_files]

# Extract all dynasties from file names
dynasty_list = [extract_dynasty_from_filename(file) for file in poem_files]

# Merge them into MultiIndex DataFrame
all_dynasty_poems = pd.concat(list_of_poems, keys=dynasty_list)

In [9]:
# Apply the checking functions to all dynasty poems
mask = check_poems(all_dynasty_poems, current_genre)
report_check_results(mask)

(37081, 37012, '99.81%')

In [10]:
# Get the cleaned dataset using mask values
cleaned_poems = all_dynasty_poems[mask].str[:current_genre.length]
cleaned_poems.shape

(37012,)

### Build the text_vectorization

In [11]:
from keras import layers

text_vectorization = layers.TextVectorization(
    standardize=None,
    split='character',
    output_mode="int",
    output_sequence_length=current_genre.length
)
text_vectorization.adapt(cleaned_poems)

2025-09-02 11:11:52.571937: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [12]:
# Demonstrate the text_vectorization

# Print the vocabulary
print('Vocabulary size:', text_vectorization.vocabulary_size())
print('Vocabulary samples:', ''.join(text_vectorization.get_vocabulary()[:20]))

# Encode
encoded = text_vectorization(cleaned_poems.iloc[0])

# Decode
vocabulary = text_vectorization.get_vocabulary()
decoded =[vocabulary[i] for i in encoded]
print('Encoded:', encoded.numpy())
print('Decoded:', ''.join(decoded))

Vocabulary size: 6350
Vocabulary samples: [UNK]，。不人山风一花无来何云有日月春水中
Encoded: [1152 1152  948  466   65    2 1074  177  740   41  604    3 1103   64
    6   16 1230    2  945 2176  183    7   23    3]
Decoded: 脉脉广川流，驱马历长洲。鹊飞山月曙，蝉噪野风秋。


In [13]:
# Encode all the poems
train_token_ids = text_vectorization(cleaned_poems)
print('shape of train dataset:', train_token_ids.shape)

shape of train dataset: (37012, 24)


In [14]:
# Now we can save the vocabulary of text_vectorization
!mkdir -p models

def save_vocabulary(vocabulary: list[str], genre: Genre):
    vocab_path = f'models/{genre.genre_name}_vocabulary.txt'
    with open(vocab_path, 'w', encoding='utf-8') as f:
        for token in vocabulary:
            f.write(f"{token}\n")

save_vocabulary(vocabulary, current_genre)

### Generate the model using LSTM layer

In [15]:
# Prepare the train dataset
train_sequences = train_token_ids[:, :-1]
target_sequences = train_token_ids[:, 1:]

train_sequences.shape, target_sequences.shape

(TensorShape([37012, 23]), TensorShape([37012, 23]))

In [16]:
# Build a simple LSTM Decoder model

import keras
from keras import models, layers

class Config:
    batch_size = 256
    epochs = 50
    vocab_size = len(vocabulary)
    embedding_dim = 100
    lstm_units = 512
    dropout_rate = 0.1

def build_model(config: Config) -> keras.Model:
    inputs = keras.Input(shape=(None,), dtype="int32", name="inputs")
    x = layers.Embedding(input_dim=config.vocab_size, output_dim=config.embedding_dim, name="embedding")(inputs)
    x = layers.LSTM(config.lstm_units, return_sequences=True, name="lstm")(x)
    x = layers.Dropout(config.dropout_rate, name="dropout")(x)
    outputs = layers.Dense(config.vocab_size, activation="softmax", name="output")(x)

    return models.Model(inputs=inputs, outputs=outputs, name="lstm_decoder")

config = Config()
model = build_model(config)
model.summary()

In [17]:
model.compile(
    loss="sparse_categorical_crossentropy",
    optimizer="adam",
    metrics=["accuracy"]
)
model.fit(
    train_sequences,
    target_sequences,
    batch_size=config.batch_size,
    epochs=config.epochs,
    callbacks=[],
    verbose=2
)

Epoch 1/50


KeyboardInterrupt: 

In [80]:
# Save the trained model
model_path = f'models/{current_genre.genre_name}_lstm_model-epoch{config.epochs}.keras'
model.save(model_path)

### Generate the text using trained model

In [19]:
# Demonstrate text generation using PoemGenerator class

from poem.generator import PoemGenerator

poem_generator = PoemGenerator(
    vectorization_model=text_vectorization,
    generation_model=model,
    genre=current_genre
)

poem_generator.generate("海外", temperature=0)

'海外。。。。。。。。。。。。。。。。。。。。。。'