## Use LSTM to generate the poem

The procedure:

1. Embedding Layer;
2. LSTM Decoder-only;
3. Sample for generation;
4. Use all data to train, which overfits the training data.

### 1. Download The Dataset

In [None]:
# Optional: set the proxy
%env all_proxy=socks5://127.0.0.1:7897

!mkdir -p data
!git clone https://github.com/xiu-ze/Poetry.git data/Poetry

In [6]:
# Get all files
import os

def get_all_files(base_dir):
    all_files = []
    for root, dirs, files in os.walk(base_dir):
        for file in files:
            all_files.append(os.path.join(root, file))
    return all_files

base_dir = os.path.expanduser('data/Poetry/诗歌数据集')
poem_files = get_all_files(base_dir)
poem_files

['data/Poetry/诗歌数据集/秦.csv',
 'data/Poetry/诗歌数据集/先秦.csv',
 'data/Poetry/诗歌数据集/隋.csv',
 'data/Poetry/诗歌数据集/辽.csv',
 'data/Poetry/诗歌数据集/当代.csv',
 'data/Poetry/诗歌数据集/明_1.csv',
 'data/Poetry/诗歌数据集/明_2.csv',
 'data/Poetry/诗歌数据集/明_3.csv',
 'data/Poetry/诗歌数据集/清_3.csv',
 'data/Poetry/诗歌数据集/清_2.csv',
 'data/Poetry/诗歌数据集/明_4.csv',
 'data/Poetry/诗歌数据集/元.csv',
 'data/Poetry/诗歌数据集/清_1.csv',
 'data/Poetry/诗歌数据集/南北朝.csv',
 'data/Poetry/诗歌数据集/宋_1.csv',
 'data/Poetry/诗歌数据集/宋_2.csv',
 'data/Poetry/诗歌数据集/唐.csv',
 'data/Poetry/诗歌数据集/近现代.csv',
 'data/Poetry/诗歌数据集/宋_3.csv',
 'data/Poetry/诗歌数据集/汉.csv',
 'data/Poetry/诗歌数据集/金.csv',
 'data/Poetry/诗歌数据集/魏晋.csv']

### 2. Clean the dataset

1. Truncate the poems to 24 characters;
2. Check the invalid signs in the fixed positions and remove the abnormal items.

In [16]:
# Read dataset from one file
import pandas as pd

project_root = os.path.abspath('.')

def read_file_to_df(file_path):
    file = os.path.join(project_root, file_path)
    if not os.path.exists(file):
        raise FileNotFoundError(f"File {file} does not exist.")

    df = pd.read_csv(file)
    filter_by_wujue = df[df["体裁"].astype(str).str.contains("五言绝句", na=False)].copy()
    wujue_content = filter_by_wujue['内容']
    return wujue_content

In [31]:
# Transform the pandas Series to numpy array and check
import numpy as np

def check_punctuation(poem_texts, positions=[5, 11, 17, 23]):
    """
    Check the punctuation in the fixed positions of the poem texts.
    It prints the found characters in the fixed positions and identifies any invalid characters.

    :param poem_texts: np.ndarray, the array of poem texts
    :return: set, invalid punctuations
    """

    # Check the fixed location values
    chars_at_fixed_positions = set(poem_texts[:, positions].reshape(-1))
    print('Characters found at fixed positions:', ''.join(chars_at_fixed_positions))

    # Check the invalid characters
    valid_chars = set("！，？。")
    invalid_chars = chars_at_fixed_positions - valid_chars
    print('Find invalid characters:', ''.join(invalid_chars))

    return invalid_chars

def clean_poem_texts(poem_texts):
    """
    Clean the poem texts by removing invalid characters and truncating to 24 characters.

    :param poem_texts: pd.Series, the series of poem texts
    :return: ndarray, cleaned poem texts with shape (n, 24)
    """
    # Adjust the size of every poem item to 24
    poems_truncated = poem_texts.map(lambda x: x[:24])

    # Check the size less than 24
    poems_small = poems_truncated[poems_truncated.str.len() < 24]
    if not poems_small.empty:
        print('Count of poems with size less than 24:', len(poems_small))
        poems_truncated = poems_truncated[poems_truncated.str.len() == 24]

    # Transform the pandas Series to numpy array
    poem_numpy = np.array(
        poems_truncated.map(list).to_list()
    )
    print('shape of numpy array', poem_numpy.shape)

    # Check the signs in the wujue array
    invalid_chars = check_punctuation(poem_numpy)

    # Find the abnormal items
    abnormal_items = poem_numpy[
        np.isin(poem_numpy[:, [5, 11, 17, 23]], list(invalid_chars)).any(axis=1)
    ]
    abnormal_count = len(abnormal_items)
    print('abnormal count:', abnormal_count)
    if abnormal_count > 0:
        print('abnormal item: ', ''.join(abnormal_items[0]))

    # Remove the abnormal item
    poems_removed_invalid = poem_numpy[
        ~np.isin(poem_numpy[:, [5, 11, 17, 23]], list(invalid_chars)).any(axis=1)
    ]

    print('===== After removing the abnormal items =====')
    print('shape of numpy array', poems_removed_invalid.shape)
    check_punctuation(poems_removed_invalid)

    # Convert back to pandas Series
    return poems_removed_invalid

In [36]:
# Read all files and clean the dataset

all_poems = []

for file in poem_files:
    print(f"Processing file: {file}")
    poem_texts = read_file_to_df(file)
    if poem_texts.empty:
        print("No valid poems found in this file.\n")
        continue

    result = clean_poem_texts(poem_texts)
    all_poems.append(result)
    print()

# Concatenate all cleaned poems into a single array
train_poems_numpy = np.concatenate(all_poems, axis=0)
train_poems_numpy.shape

Processing file: data/Poetry/诗歌数据集/秦.csv
No valid poems found in this file.

Processing file: data/Poetry/诗歌数据集/先秦.csv
No valid poems found in this file.

Processing file: data/Poetry/诗歌数据集/隋.csv
shape of numpy array (71, 24)
Characters found at fixed positions: 。，？
Find invalid characters: 
abnormal count: 0
===== After removing the abnormal items =====
shape of numpy array (71, 24)
Characters found at fixed positions: 。，？
Find invalid characters: 

Processing file: data/Poetry/诗歌数据集/辽.csv
No valid poems found in this file.

Processing file: data/Poetry/诗歌数据集/当代.csv
shape of numpy array (1109, 24)
Characters found at fixed positions: 」？轻鼠冕；帽闲宵不文对边山大华歌！栏流。，：传
Find invalid characters: 传大轻鼠华冕歌；帽闲宵栏流不文对边：山」
abnormal count: 13
abnormal item:  琅玕经雨青，染绿山溪水。“萧萧我凭栏”，袅袅清歌
===== After removing the abnormal items =====
shape of numpy array (1096, 24)
Characters found at fixed positions: ！。，？
Find invalid characters: 

Processing file: data/Poetry/诗歌数据集/明_1.csv
shape of numpy array (3712, 24)
Char

(37012, 24)

### 3. Dataset to token id sequences

In [56]:
from keras import layers

tv = layers.TextVectorization(
    max_tokens=10000,
    standardize=None,
    split=None, # 直接喂二维数组
    output_mode="int",
    output_sequence_length=24
)
tv.adapt(train_poems_numpy)

In [57]:
# Demo usage of tv

# Print the vocabulary
print('Vocabulary size:', tv.vocabulary_size())
print('Vocabulary:', ''.join(tv.get_vocabulary()[:20]))

# Encode
encoded = tv(train_poems_numpy[0])

# Decode
vocab = tv.get_vocabulary()
decoded =[vocab[i] for i in encoded]
print('Encoded:', encoded.numpy())
print('Decoded:', ''.join(decoded))

Vocabulary size: 6350
Vocabulary: [UNK]，。不人山风一花无来何云有日月春水中
Encoded: [1152 1152  948  466   65    2 1074  177  740   41  604    3 1103   64
    6   16 1230    2  945 2176  183    7   23    3]
Decoded: 脉脉广川流，驱马历长洲。鹊飞山月曙，蝉噪野风秋。


In [58]:
# Encode all the poems
train_token_ids = tv(train_poems_numpy)
print('shape of wujue_token_ids:', train_token_ids.shape)

shape of wujue_token_ids: (37012, 24)


### 4. Build the LSTM Decoder model

In [44]:
# Prepare the train dataset
train_sequences = train_token_ids[:, :-1]
target_sequences = train_token_ids[:, 1:]

train_sequences.shape, target_sequences.shape

(TensorShape([37012, 23]), TensorShape([37012, 23]))

In [45]:
# Build a simple LSTM Decoder model

import keras
from keras import models, layers

def build_model(vocab_size):
    inputs = keras.Input(shape=(None,), dtype="int32", name="inputs")
    x_embedded = layers.Embedding(
        input_dim=vocab_size, output_dim=100, name="embedding"
    )(inputs)
    x_lstm_output = layers.LSTM(
        128, return_sequences=True, name="lstm"
    )(x_embedded)
    outputs = layers.Dense(
        vocab_size, activation="softmax", name="output"
    )(x_lstm_output)

    return models.Model(inputs=inputs, outputs=outputs, name="lstm_decoder")

model = build_model(tv.vocabulary_size())
model.summary()


In [68]:
# Define the sample generate function
def generate(prompt, max_length=24, temperature=1.0):
    """
    Generate a poem based on the start prompt

    Returns:
        A generated poem as a string.
    """
    prompt_inputs = list(prompt)
    generated = tv(prompt_inputs)[:len(prompt)].numpy().tolist()
    while len(generated) < max_length:
        input_sequence = np.array(generated).reshape(1, -1)
        predictions = model.predict(input_sequence, verbose=0)[0]
        next_token_id = sample(predictions[-1], temperature)
        generated.append(next_token_id)
    return ''.join(tv.get_vocabulary()[token_id] for token_id in generated)

def sample(predictions, temperature=1.0, eps1=1e-20, eps2=1e-9):
    p = np.asarray(predictions, dtype=np.float64)

    # The two key points: log(p + eps1) divide by (T + eps2)
    logits = np.log(p + eps1) / (float(temperature) + eps2)

    # Subtract the max logit to prevent overflow
    logits -= np.max(logits)

    q = np.exp(logits)
    q /= q.sum()
    return int(np.random.choice(len(q), p=q))


generate("海外", temperature=0)

'海外春风雨，山花一叶开。不知春水上，不见一枝花。'

In [69]:
# Define callback to print the sample generative poem every 10 epochs
end_epoch = 10

class PoetryGenerateCallback(keras.callbacks.Callback):
    def __init__(self):
        super().__init__()

        self.next_print_epoch = 1

    def on_epoch_end(self, epoch, logs=None):
        epoch += 1
        if epoch != self.next_print_epoch and epoch != end_epoch:
            return

        print(f"Generating poems at epoch {epoch}:\n")
        self._print_generated_poems()
        self.next_print_epoch *= 2

    @staticmethod
    def _print_generated_poems():
        temperatures = [0, 0.5, 1.0, 1.5]
        generated_texts = [
            generate('海外', max_length=24, temperature=temp)
            for temp in temperatures
        ]

        for temp, text in zip(temperatures, generated_texts):
            print(f"temperature {temp}:{text}\n")

In [70]:
model.compile(
    loss="sparse_categorical_crossentropy",
    optimizer="adam",
    metrics=["accuracy"]
)
model.fit(
    train_sequences,
    target_sequences,
    batch_size=64,
    epochs=10,
    callbacks=[PoetryGenerateCallback()],
    verbose=2
)

Generating poems at epoch 1:

temperature 0:海外春风起，山花一叶开。不知春色里，不见一枝花。

temperature 0.5:海外青江水，山舟雁已多。青风吹不尽，不得一枝红。

temperature 1.0:海外采嶂冷，濯旗苍草流。非织彫陈在，佛下春亭城。

temperature 1.5:海外甘污拙，计随曹荐琴。兼杰卢沟始，羞橙九月中。

579/579 - 156s - 270ms/step - accuracy: 0.2378 - loss: 4.9476


<keras.src.callbacks.history.History at 0x146f78440>

### 5. Using the model to generate poems

In [71]:
generate('海外')

'海外安来日，春风伴菜间。水茂宋鞭发，多久坐钓香。'