In [14]:
import numpy as np
import collections
import torch
from torch.autograd import Variable
import torch.optim as optim

import rnn

start_token = 'G'
end_token = 'E'
batch_size = 64

device = torch.device("cuda")
device

device(type='cuda')

In [15]:
def process_poems1(file_name):
    """

    :param file_name:
    :return: poems_vector  have tow dimmention ,first is the poem, the second is the word_index
    e.g. [[1,2,3,4,5,6,7,8,9,10],[9,6,3,8,5,2,7,4,1]]

    """
    poems = []
    with open(file_name, "r", encoding='utf-8', ) as f:
        for line in f.readlines():
            try:
                title, content = line.strip().split(':')
                # content = content.replace(' ', '').replace('，','').replace('。','')
                content = content.replace(' ', '')
                if '_' in content or '(' in content or '（' in content or '《' in content or '[' in content or \
                                start_token in content or end_token in content:
                    continue
                if len(content) < 5 or len(content) > 80:
                    continue
                content = start_token + content + end_token
                poems.append(content)
            except ValueError as e:
                print("error")
                pass
    # 按诗的字数排序
    poems = sorted(poems, key=lambda line: len(line))
    # print(poems)
    # 统计每个字出现次数
    all_words = []
    for poem in poems:
        all_words += [word for word in poem]
    counter = collections.Counter(all_words)  # 统计词和词频。
    count_pairs = sorted(counter.items(), key=lambda x: -x[1])  # 排序
    words, _ = zip(*count_pairs)
    words = words[:len(words)] + (' ',)
    word_int_map = dict(zip(words, range(len(words))))
    poems_vector = [list(map(word_int_map.get, poem)) for poem in poems]
    return poems_vector, word_int_map, words

def process_poems2(file_name):
    """
    :param file_name:
    :return: poems_vector  have tow dimmention ,first is the poem, the second is the word_index
    e.g. [[1,2,3,4,5,6,7,8,9,10],[9,6,3,8,5,2,7,4,1]]

    """
    poems = []
    with open(file_name, "r", encoding='utf-8', ) as f:
        # content = ''
        for line in f.readlines():
            try:
                line = line.strip()
                if line:
                    content = line.replace(' '' ', '').replace('，','').replace('。','')
                    if '_' in content or '(' in content or '（' in content or '《' in content or '[' in content or \
                                    start_token in content or end_token in content:
                        continue
                    if len(content) < 5 or len(content) > 80:
                        continue
                    # print(content)
                    content = start_token + content + end_token
                    poems.append(content)
                    # content = ''
            except ValueError as e:
                # print("error")
                pass
    # 按诗的字数排序
    poems = sorted(poems, key=lambda line: len(line))
    # print(poems)
    # 统计每个字出现次数
    all_words = []
    for poem in poems:
        all_words += [word for word in poem]
    counter = collections.Counter(all_words)  # 统计词和词频。
    count_pairs = sorted(counter.items(), key=lambda x: -x[1])  # 排序
    words, _ = zip(*count_pairs)
    words = words[:len(words)] + (' ',)
    word_int_map = dict(zip(words, range(len(words))))
    poems_vector = [list(map(word_int_map.get, poem)) for poem in poems]
    return poems_vector, word_int_map, words

In [16]:
def generate_batch(batch_size, poems_vec, word_to_int):
    n_chunk = len(poems_vec) // batch_size
    x_batches = []
    y_batches = []
    for i in range(n_chunk):
        start_index = i * batch_size
        end_index = start_index + batch_size
        x_data = poems_vec[start_index:end_index]
        y_data = []
        for row in x_data:
            y  = row[1:]
            y.append(row[-1])
            y_data.append(y)
        """
        x_data             y_data
        [6,2,4,6,9]       [2,4,6,9,9]
        [1,4,2,8,5]       [4,2,8,5,5]
        """
        # print(x_data[0])
        # print(y_data[0])
        # exit(0)
        x_batches.append(x_data)
        y_batches.append(y_data)
    return x_batches, y_batches

In [17]:
def run_training():
    # 处理数据集
    # poems_vector, word_to_int, vocabularies = process_poems2('./tangshi.txt')
    poems_vector, word_to_int, vocabularies = process_poems1('./poems.txt')
    # 生成batch
    print("finish  loadding data")
    BATCH_SIZE = 100

    torch.manual_seed(5)
    word_embedding = rnn.word_embedding( vocab_length= len(word_to_int) + 1 , embedding_dim= 100).to(device)
    rnn_model = rnn.RNN_model(
        batch_sz = BATCH_SIZE,
        vocab_len = len(word_to_int) + 1,
        word_embedding = word_embedding,
        embedding_dim= 100, 
        lstm_hidden_dim=128
    ).to(device)

    # optimizer = optim.Adam(rnn_model.parameters(), lr= 0.001)
    optimizer=optim.RMSprop(rnn_model.parameters(), lr=0.01)

    loss_fun = torch.nn.NLLLoss()
    # rnn_model.load_state_dict(torch.load('./poem_generator_rnn'))  # if you have already trained your model you can load it by this line.

    for epoch in range(30):
        batches_inputs, batches_outputs = generate_batch(BATCH_SIZE, poems_vector, word_to_int)
        n_chunk = len(batches_inputs)
        for batch in range(n_chunk):
            batch_x = batches_inputs[batch]
            batch_y = batches_outputs[batch] # (batch , time_step)

            loss = 0
            for index in range(BATCH_SIZE):
                x = np.array(batch_x[index], dtype = np.int64)
                y = np.array(batch_y[index], dtype = np.int64)
                x = Variable(torch.from_numpy(np.expand_dims(x,axis=1))).to(device)
                y = Variable(torch.from_numpy(y )).to(device)
                pre = rnn_model(x)
                loss += loss_fun(pre , y)
                if index == 0:
                    _, pre = torch.max(pre, dim=1)
                    print('prediction', pre.data.tolist()) # the following three lines can print the output and the prediction
                    print('b_y       ', y.data.tolist())   # And you need to take a screenshot and then past it to your homework paper.
                    print('*' * 30)
            loss  = loss  / BATCH_SIZE
            if batch % 50 == 0:
                print("epoch  ",epoch,'batch number',batch,"loss is: ", loss.data.tolist())
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm(rnn_model.parameters(), 1)
            optimizer.step()

            if batch % 30 ==0:
                torch.save(rnn_model.state_dict(), './poem_generator_rnn')
                print("finish  save model")

In [18]:
def to_word(predict, vocabs):  # 预测的结果转化成汉字
    sample = np.argmax(predict)

    if sample >= len(vocabs):
        sample = len(vocabs) - 1

    return vocabs[sample]


def pretty_print_poem(poem):  # 令打印的结果更工整
    shige=[]
    for w in poem:
        if w == start_token or w == end_token:
            break
        shige.append(w)
    poem_sentences = poem.split('。')
    for s in poem_sentences:
        if s != '' and len(s) > 10:
            print(s + '。')


def gen_poem(begin_word):
    # poems_vector, word_int_map, vocabularies = process_poems2('./tangshi.txt')  #  use the other dataset to train the network
    poems_vector, word_int_map, vocabularies = process_poems1('./poems.txt')
    word_embedding = rnn.word_embedding(vocab_length=len(word_int_map) + 1, embedding_dim=100).to(device)
    rnn_model = rnn.RNN_model(batch_sz=64, vocab_len=len(word_int_map) + 1, word_embedding=word_embedding,
                                   embedding_dim=100, lstm_hidden_dim=128).to(device)

    rnn_model.load_state_dict(torch.load('./poem_generator_rnn', map_location=device, weights_only=True))

    # 指定开始的字

    poem = start_token+begin_word
    word = begin_word
    while word != end_token:
        input = np.array([word_int_map[w] for w in poem],dtype= np.int64)
        input = Variable(torch.from_numpy(input)).to(device)
        output = rnn_model(input, is_test=True)
        word = to_word(output.data.tolist()[-1], vocabularies)
        poem += word
        # print(word)
        # print(poem)
        if len(poem) > 100:
            break
    return poem[1:]

In [19]:
run_training()  # 如果不是训练阶段 ，请注销这一行 。 网络训练时间很长。

finish  loadding data
inital  linear weight 
prediction [1847, 2057, 2057, 2237, 5930, 3236, 3236]
b_y        [28, 546, 104, 718, 1, 3, 3]
******************************
epoch   0 batch number 0 loss is:  8.717957496643066
finish  save model
prediction [1, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3]
b_y        [267, 2888, 267, 2045, 0, 26, 1067, 26, 349, 1, 3, 3]
******************************


  torch.nn.utils.clip_grad_norm(rnn_model.parameters(), 1)


prediction [950, 66, 66, 99, 68, 99, 99, 99, 66, 99, 99, 99, 25, 69]
b_y        [617, 564, 80, 601, 132, 0, 9, 184, 2007, 373, 305, 1, 3, 3]
******************************
prediction [215, 268, 3344, 3344, 3344, 3344, 3344, 3344, 3344, 3344, 3344, 3344, 3344, 3344, 3344, 3344, 3344, 3344]
b_y        [125, 9, 3862, 695, 16, 1721, 1721, 0, 373, 591, 246, 695, 900, 80, 431, 1, 3, 3]
******************************
prediction [99, 99, 99, 99, 99, 99, 3, 3, 257, 257, 57, 15, 15, 15, 99, 808, 99, 99, 99, 80, 0, 3]
b_y        [1539, 1058, 3996, 64, 0, 793, 977, 550, 381, 1, 3867, 3285, 972, 1287, 0, 1541, 1541, 1779, 1779, 1, 3, 3]
******************************
prediction [47, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 3, 3]
b_y        [194, 348, 66, 89, 73, 0, 574, 21, 855, 101, 1013, 1, 293, 162, 1283, 117, 383, 0, 25, 74, 4, 253, 611, 1, 3, 3]
******************************
prediction [79, 23, 183, 183, 183, 82, 82, 82, 82, 183, 82, 82, 183, 183, 82, 82, 183, 82, 

In [20]:
pretty_print_poem(gen_poem("日"))
pretty_print_poem(gen_poem("红"))
pretty_print_poem(gen_poem("山"))
pretty_print_poem(gen_poem("夜"))
pretty_print_poem(gen_poem("湖"))
pretty_print_poem(gen_poem("海"))
pretty_print_poem(gen_poem("月"))
pretty_print_poem(gen_poem("君"))

inital  linear weight 
日暮清明月，江山望月明。
云山千里外，山色一庭秋。
日月连天雨，春风入洞庭。
云山千里外，山色一庭秋。
不见东风起，应知不见人。
何当一片石，不是有清风。
inital  linear weight 
inital  linear weight 
山川不可见，不见白云归。
一片云中在，东风吹不开。
云山千里外，山色一庭秋。
日月连天雨，春风入洞庭。
山川如有意，山水更何如。
inital  linear weight 
夜雨无人见，春风不可闻。
不知青草里，不见白云归。
日月连天雨，春风入夜闻。
云山千里外，山色一庭春。
日月无人识，山川不可闻。
何当一片石，不是在南州。
inital  linear weight 
湖上春风起，春风入洞天。
云山千里外，山色万人归。
日月连天雨，春风入洞庭。
云山千里外，山色一庭春。
日月连天外，春风吹水流。
何当一片石，不是在南山。
inital  linear weight 
海上东山下，春风入洞庭。
云山千里外，山色一庭秋。
日月连天雨，春风入洞庭。
云山千里外，山色一庭秋。
不见东风起，应知不见人。
何当一片石，不是有清风。
inital  linear weight 
inital  linear weight 
君王不可见，不复更相违。
云里山川外，山深洞里秋。
云山千里外，山色一庭秋。
日月连天地，春风入夜闻。
何当一片石，一片一相逢。


In [21]:
def check_poem_format(file_name):
    """
    检查文件中每一行是否符合“标题:诗句”的格式，
    并输出有问题的行号和原因。
    """
    forbidden_chars = ["_", "(", "（", "《", "[", "G", "E"]  # 可根据需要调整

    with open(file_name, "r", encoding="utf-8") as f:
        line_number = 0
        error_found = False
        for line in f:
            line_number += 1
            line = line.strip()
            if not line:  # 跳过空行
                continue

            # 检查是否有且仅有一个冒号
            parts = line.split(":")
            if len(parts) != 2:
                print(f"Line {line_number}: 格式错误，冒号数量不为1 -> {line}")
                error_found = True
                continue

            title, content = parts
            if not title:
                print(f"Line {line_number}: 标题为空 -> {line}")
                error_found = True
            if not content:
                print(f"Line {line_number}: 内容为空 -> {line}")
                error_found = True

        if not error_found:
            print("所有行格式均符合要求！")

# 示例调用：
check_poem_format("poems.txt")

所有行格式均符合要求！
