## Use LSTM to generate the poem

The procedure:

1. Embedding Layer;
2. LSTM Decoder-only;
3. Sample for generation;
4. Use all data to train, which overfits the training data.

In [1]:
%pip install pandas

Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
Note: you may need to restart the kernel to use updated packages.


### 1. Download The Dataset

In [2]:
%%script echo skipping

# Optional: set the proxy
%env all_proxy=socks5://127.0.0.1:7897

!mkdir -p data
!git clone https://github.com/xiu-ze/Poetry.git data/Poetry

skipping


In [3]:
# Get all files
import os

def get_all_files(base_dir):
    all_files = []
    for root, dirs, files in os.walk(base_dir):
        for file in files:
            all_files.append(os.path.join(root, file))
    return all_files

base_dir = os.path.expanduser('data/Poetry/诗歌数据集')
poem_files = get_all_files(base_dir)
poem_files

['data/Poetry/诗歌数据集/秦.csv',
 'data/Poetry/诗歌数据集/先秦.csv',
 'data/Poetry/诗歌数据集/隋.csv',
 'data/Poetry/诗歌数据集/辽.csv',
 'data/Poetry/诗歌数据集/当代.csv',
 'data/Poetry/诗歌数据集/明_1.csv',
 'data/Poetry/诗歌数据集/明_2.csv',
 'data/Poetry/诗歌数据集/明_3.csv',
 'data/Poetry/诗歌数据集/清_3.csv',
 'data/Poetry/诗歌数据集/清_2.csv',
 'data/Poetry/诗歌数据集/明_4.csv',
 'data/Poetry/诗歌数据集/元.csv',
 'data/Poetry/诗歌数据集/清_1.csv',
 'data/Poetry/诗歌数据集/南北朝.csv',
 'data/Poetry/诗歌数据集/宋_1.csv',
 'data/Poetry/诗歌数据集/宋_2.csv',
 'data/Poetry/诗歌数据集/唐.csv',
 'data/Poetry/诗歌数据集/近现代.csv',
 'data/Poetry/诗歌数据集/宋_3.csv',
 'data/Poetry/诗歌数据集/汉.csv',
 'data/Poetry/诗歌数据集/金.csv',
 'data/Poetry/诗歌数据集/魏晋.csv']

### 2. Clean the dataset

1. Truncate the poems to 24 characters;
2. Check the invalid signs in the fixed positions and remove the abnormal items.

In [4]:
# Read dataset from one file
import pandas as pd

project_root = os.path.abspath('.')

def read_file_to_df(file_path):
    file = os.path.join(project_root, file_path)
    if not os.path.exists(file):
        raise FileNotFoundError(f"File {file} does not exist.")

    df = pd.read_csv(file)
    filter_by_wujue = df[df["体裁"].astype(str).str.contains("五言绝句", na=False)].copy()
    wujue_content = filter_by_wujue['内容']
    return wujue_content

In [5]:
# Transform the pandas Series to numpy array and check
import numpy as np

def check_punctuation(poem_texts, positions=[5, 11, 17, 23]):
    """
    Check the punctuation in the fixed positions of the poem texts.
    It prints the found characters in the fixed positions and identifies any invalid characters.

    :param poem_texts: np.ndarray, the array of poem texts
    :return: set, invalid punctuations
    """

    # Check the fixed location values
    chars_at_fixed_positions = set(poem_texts[:, positions].reshape(-1))
    print('Characters found at fixed positions:', ''.join(chars_at_fixed_positions))

    # Check the invalid characters
    valid_chars = set("！，？。")
    invalid_chars = chars_at_fixed_positions - valid_chars
    print('Find invalid characters:', ''.join(invalid_chars))

    return invalid_chars

def clean_poem_texts(poem_texts):
    """
    Clean the poem texts by removing invalid characters and truncating to 24 characters.

    :param poem_texts: pd.Series, the series of poem texts
    :return: ndarray, cleaned poem texts with shape (n, 24)
    """
    # Adjust the size of every poem item to 24
    poems_truncated = poem_texts.map(lambda x: x[:24])

    # Check the size less than 24
    poems_small = poems_truncated[poems_truncated.str.len() < 24]
    if not poems_small.empty:
        print('Count of poems with size less than 24:', len(poems_small))
        poems_truncated = poems_truncated[poems_truncated.str.len() == 24]

    # Transform the pandas Series to numpy array
    poem_numpy = np.array(
        poems_truncated.map(list).to_list()
    )
    print('shape of numpy array', poem_numpy.shape)

    # Check the signs in the wujue array
    invalid_chars = check_punctuation(poem_numpy)

    # Find the abnormal items
    abnormal_items = poem_numpy[
        np.isin(poem_numpy[:, [5, 11, 17, 23]], list(invalid_chars)).any(axis=1)
    ]
    abnormal_count = len(abnormal_items)
    print('abnormal count:', abnormal_count)
    if abnormal_count > 0:
        print('abnormal item: ', ''.join(abnormal_items[0]))

    # Remove the abnormal item
    poems_removed_invalid = poem_numpy[
        ~np.isin(poem_numpy[:, [5, 11, 17, 23]], list(invalid_chars)).any(axis=1)
    ]

    print('===== After removing the abnormal items =====')
    print('shape of numpy array', poems_removed_invalid.shape)
    check_punctuation(poems_removed_invalid)

    # Convert back to pandas Series
    return poems_removed_invalid

In [6]:
# Read all files and clean the dataset

all_poems = []

for file in poem_files:
    print(f"Processing file: {file}")
    poem_texts = read_file_to_df(file)
    if poem_texts.empty:
        print("No valid poems found in this file.\n")
        continue

    result = clean_poem_texts(poem_texts)
    all_poems.append(result)
    print()

# Concatenate all cleaned poems into a single array
train_poems_numpy = np.concatenate(all_poems, axis=0)
train_poems_numpy.shape

Processing file: data/Poetry/诗歌数据集/秦.csv
No valid poems found in this file.

Processing file: data/Poetry/诗歌数据集/先秦.csv
No valid poems found in this file.

Processing file: data/Poetry/诗歌数据集/隋.csv
shape of numpy array (71, 24)
Characters found at fixed positions: ，？。
Find invalid characters: 
abnormal count: 0
===== After removing the abnormal items =====
shape of numpy array (71, 24)
Characters found at fixed positions: ，？。
Find invalid characters: 

Processing file: data/Poetry/诗歌数据集/辽.csv
No valid poems found in this file.

Processing file: data/Poetry/诗歌数据集/当代.csv
shape of numpy array (1109, 24)
Characters found at fixed positions: 华文宵帽冕流？山，栏！鼠大歌边」轻对不。闲传：；
Find invalid characters: 华文宵帽鼠大冕流歌边」轻对山不栏闲传：；
abnormal count: 13
abnormal item:  琅玕经雨青，染绿山溪水。“萧萧我凭栏”，袅袅清歌
===== After removing the abnormal items =====
shape of numpy array (1096, 24)
Characters found at fixed positions: ！，？。
Find invalid characters: 

Processing file: data/Poetry/诗歌数据集/明_1.csv
shape of numpy array (3712, 24)
Char

(37012, 24)

### 3. Dataset to token id sequences

In [7]:
from keras import layers

tv = layers.TextVectorization(
    max_tokens=10000,
    standardize=None,
    split=None, # 直接喂二维数组
    output_mode="int",
    output_sequence_length=24
)
tv.adapt(train_poems_numpy)

2025-08-27 14:56:55.713320: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-08-27 14:56:55.720556: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1756277815.728456   12689 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1756277815.731286   12689 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1756277815.738430   12689 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking 

In [8]:
# Demo usage of tv

# Print the vocabulary
print('Vocabulary size:', tv.vocabulary_size())
print('Vocabulary:', ''.join(tv.get_vocabulary()[:20]))

# Encode
encoded = tv(train_poems_numpy[0])

# Decode
vocab = tv.get_vocabulary()
decoded =[vocab[i] for i in encoded]
print('Encoded:', encoded.numpy())
print('Decoded:', ''.join(decoded))

Vocabulary size: 6350
Vocabulary: [UNK]，。不人山风一花无来何云有日月春水中
Encoded: [1152 1152  948  466   65    2 1074  177  740   41  604    3 1103   64
    6   16 1230    2  945 2176  183    7   23    3]
Decoded: 脉脉广川流，驱马历长洲。鹊飞山月曙，蝉噪野风秋。


In [9]:
# Encode all the poems
train_token_ids = tv(train_poems_numpy)
print('shape of wujue_token_ids:', train_token_ids.shape)

shape of wujue_token_ids: (37012, 24)


### 4. Build the LSTM Decoder model

In [10]:
# Prepare the train dataset
train_sequences = train_token_ids[:, :-1]
target_sequences = train_token_ids[:, 1:]

train_sequences.shape, target_sequences.shape

(TensorShape([37012, 23]), TensorShape([37012, 23]))

In [11]:
# Build a simple LSTM Decoder model

import keras
from keras import models, layers

def build_model(vocab_size):
    inputs = keras.Input(shape=(None,), dtype="int32", name="inputs")
    x = layers.Embedding(input_dim=vocab_size, output_dim=100, name="embedding")(inputs)
    x = layers.LSTM(512, return_sequences=True, name="lstm")(x)
    x = layers.Dropout(0.1, name="dropout")(x)
    outputs = layers.Dense(vocab_size, activation="softmax", name="output")(x)

    return models.Model(inputs=inputs, outputs=outputs, name="lstm_decoder")

model = build_model(tv.vocabulary_size())
model.summary()

In [12]:
# Define the sample generate function
def generate(prompt, max_length=24, temperature=1.0):
    """
    Generate a poem based on the start prompt

    Returns:
        A generated poem as a string.
    """
    prompt_inputs = list(prompt)
    generated = tv(prompt_inputs)[:len(prompt)].numpy().tolist()
    while len(generated) < max_length:
        input_sequence = np.array(generated).reshape(1, -1)
        predictions = model.predict(input_sequence, verbose=0)[0]
        next_token_id = sample(predictions[-1], temperature)
        generated.append(next_token_id)
    return ''.join(tv.get_vocabulary()[token_id] for token_id in generated)

def sample(predictions, temperature=1.0, eps1=1e-20, eps2=1e-9):
    p = np.asarray(predictions, dtype=np.float64)

    # The two key points: log(p + eps1) divide by (T + eps2)
    logits = np.log(p + eps1) / (float(temperature) + eps2)

    # Subtract the max logit to prevent overflow
    logits -= np.max(logits)

    q = np.exp(logits)
    q /= q.sum()
    return int(np.random.choice(len(q), p=q))


generate("海外", temperature=0)

I0000 00:00:1756277818.662833   12767 cuda_dnn.cc:529] Loaded cuDNN version 90300


'海外薛歃嗅怏踵七毂吊枑穸臬痼埭多多受豁薾豁薾蕨噀'

In [13]:
# Define callback to print the sample generative poem every 10 epochs
class PoetryGenerateCallback(keras.callbacks.Callback):
    def __init__(self, epochs):
        super().__init__()
        self.generating_epochs = self._get_generating_epochs(epochs)
        self.generated_poems = {}

    def on_epoch_end(self, epoch, logs=None):
        epoch += 1
        if epoch not in self.generating_epochs:
            return
        poems = self.generate_poems()
        self.generated_poems[epoch] = {
            'poems': poems,
            'logs': logs
        }

        self.print_poems(poems)

    @staticmethod
    def generate_poems():
        temperatures = [0, 0.5, 1.0, 1.5]
        generated_texts = [
            generate('海外', max_length=24, temperature=temp)
            for temp in temperatures
        ]
        return [{ 'temperature': temperature, 'text': text } for temperature, text in zip(temperatures, generated_texts)]

    @staticmethod
    def print_poems(poems):
        for item in poems:
            print(f"temperature {item['temperature']:.1f}: {item['text']}")

    @staticmethod
    def _get_generating_epochs(epochs):
        if epochs % 2 != 0:
            print("Warning: epochs should be even number.")

        mid_epoch = epochs // 2
        left_generating_epochs = [2**i for i in range(0, int(np.log2(mid_epoch)) + 1)]
        if mid_epoch not in left_generating_epochs:
            left_generating_epochs.append(mid_epoch)

        right_generating_epochs = [1 + epochs - e for e in left_generating_epochs][::-1]
        return left_generating_epochs + right_generating_epochs

In [14]:
epochs = 50
poetry_callback = PoetryGenerateCallback(epochs)

model.compile(
    loss="sparse_categorical_crossentropy",
    optimizer="adam",
    metrics=["accuracy"]
)
model.fit(
    train_sequences,
    target_sequences,
    batch_size=256,
    epochs=epochs,
    callbacks=[],
    verbose=2
)

Epoch 1/50
145/145 - 15s - 103ms/step - accuracy: 0.0961 - loss: 6.5784
Epoch 2/50
145/145 - 13s - 93ms/step - accuracy: 0.1420 - loss: 5.8895
Epoch 3/50
145/145 - 14s - 93ms/step - accuracy: 0.1775 - loss: 5.7291
Epoch 4/50
145/145 - 14s - 93ms/step - accuracy: 0.1918 - loss: 5.6292
Epoch 5/50
145/145 - 13s - 93ms/step - accuracy: 0.1953 - loss: 5.5501
Epoch 6/50
145/145 - 13s - 93ms/step - accuracy: 0.1980 - loss: 5.4718
Epoch 7/50
145/145 - 14s - 93ms/step - accuracy: 0.2021 - loss: 5.3731
Epoch 8/50
145/145 - 14s - 93ms/step - accuracy: 0.2073 - loss: 5.2727
Epoch 9/50
145/145 - 13s - 93ms/step - accuracy: 0.2118 - loss: 5.1933
Epoch 10/50
145/145 - 13s - 93ms/step - accuracy: 0.2155 - loss: 5.1308
Epoch 11/50
145/145 - 13s - 93ms/step - accuracy: 0.2193 - loss: 5.0736
Epoch 12/50
145/145 - 14s - 93ms/step - accuracy: 0.2235 - loss: 5.0187
Epoch 13/50
145/145 - 14s - 93ms/step - accuracy: 0.2271 - loss: 4.9678
Epoch 14/50
145/145 - 13s - 93ms/step - accuracy: 0.2310 - loss: 4.9187


<keras.src.callbacks.history.History at 0x7f4d0b96fef0>

In [15]:
# Print the generated poems for analysis
# poetry_callback.generated_poems

# Or save the model
model.save('lstm_poetry_model.keras')

### 5. Using the model to generate poems

In [None]:
# Reload the model
model = keras.models.load_model('lstm_poetry_model.keras', compile=False)

In [18]:
poems = PoetryGenerateCallback.generate_poems()
PoetryGenerateCallback.print_poems(poems)

temperature 0.0: 海外一天地，江南万里来。西风吹不尽，吹笛一枝秋。
temperature 0.5: 海外草云深，山空云雾重。不知天外景，风雨暗中流。
temperature 1.0: 海外春气晚，夕阳春水发。途人骨束堆，竹间江头绿。
temperature 1.5: 海外负银川，较游穷怨望。楚江海轮行，呼綵七茅屋。
