In [None]:
# 诗data的地址
poetry_data_path = "./data/poetry.txt"
# 如果诗词中出现这些词，则将诗舍弃
DISALLOWED_WORDS = ['（', '）', '(', ')', '__', '《', '》', '【', '】', '[', ']']
# 取3000个字作诗,其中包括空格字符
WORD_NUM = 3000
# 将出现少的字使用空格代替
UNKONW_CHAR = " "
# 根据前6个字预测下一个字，比如说根据“寒随穷律变，”预测“春”
TRAIN_NUM = 6

In [None]:
# 保存诗词
poetrys = []
# 保存在诗词中出现的字
all_word = []

with open(poetry_data_path,encoding="utf-8") as f:
    for line in f:
        # 获得诗的内容
        poetry = line.split(":")[1].replace(" ","")
        flag = True
        # 如果在句子中出现'（', '）', '(', ')', '__', '《', '》', '【', '】', '[', ']'则舍弃
        for dis_word in DISALLOWED_WORDS:
            if dis_word in poetry:
                flag = False
                break

        # 只需要5言的诗（两句诗包括标点符号就是12个字），假如少于两句诗则舍弃
        if  len(poetry) < 12 or poetry[5] != '，' or (len(poetry)-1) % 6 != 0:
            flag = False

        if flag:
            # 统计出现的词
            for word in poetry:
                all_word.append(word)
            poetrys.append(poetry)

In [None]:
    print("一共有：{}首诗，一共有{}个字符".format(len(poetrys),len(all_word)))

In [None]:
from collections import Counter
# 对字数进行统计
counter = Counter(all_word)
# 根据出现的次数，进行从大到小的排序
word_count = sorted(counter.items(),key=lambda x : -x[1])
most_num_word,_ = zip(*word_count)
# 取前2999个字，然后在最后加上" "
use_words = most_num_word[:WORD_NUM - 1] + (UNKONW_CHAR,)

In [None]:
print(use_words[-20:])

In [None]:
# word 到 id的映射 {'，': 0,'。': 1,'\n': 2,'不': 3,'人': 4,'山': 5,……}
word_id_dict = {word:index for index,word in enumerate(use_words)}

# id 到 word的映射 {0: '，',1: '。',2: '\n',3: '不',4: '人',5: '山',……}
id_word_dict = {index:word for index,word in enumerate(use_words)}

In [None]:
print(list(word_id_dict.items())[0:10])
print(list(id_word_dict.items())[0:10])

In [None]:
import numpy as np
def word_to_one_hot(word):
    """将一个字转成onehot形式

    :param word: [一个字]
    :type word: [str]
    """
    one_hot_word = np.zeros(WORD_NUM)
    # 假如字是生僻字，则变成空格
    if word not in word_id_dict.keys():
        word = UNKONW_CHAR
    index = word_id_dict[word]
    one_hot_word[index] = 1
    return one_hot_word

def phrase_to_one_hot(phrase):
    """将一个句子转成onehot

    :param phrase: [一个句子]
    :type poetry: [str]
    """
    one_hot_phrase = []
    for word in phrase:
        one_hot_phrase.append(word_to_one_hot(word))
    return one_hot_phrase

In [None]:
word_to_one_hot("，")

In [None]:
phrase_to_one_hot("，。")

In [None]:
np.random.shuffle(poetrys)

In [None]:
X_train_word = []
Y_train_word = []

for poetry in poetrys:
    for i in range(len(poetry)):
        X = poetry[i:i+TRAIN_NUM]
        Y = poetry[i+TRAIN_NUM]
        if "\n" not in X and "\n" not in Y:
            X_train_word.append(X)
            Y_train_word.append(Y)
        else:
            break

In [None]:
len(X_train_word)

In [None]:
X_train_word[:5]

In [None]:
Y_train_word[:5]

In [None]:
import keras
from keras.callbacks import LambdaCallback,ModelCheckpoint
from keras.models import Input, Model
from keras.layers import  Dropout, Dense,SimpleRNN 
from keras.optimizers import Adam
from keras.utils import plot_model

def build_model():
    print('building model')
    # 输入的dimension
    input_tensor = Input(shape=(TRAIN_NUM,WORD_NUM))
    rnn = SimpleRNN(512,return_sequences=True)(input_tensor)
    dropout = Dropout(0.6)(rnn)

    rnn = SimpleRNN(256)(dropout)
    dropout = Dropout(0.6)(rnn)
    dense = Dense(WORD_NUM, activation='softmax')(dropout)

    model = Model(inputs=input_tensor, outputs=dense)
    optimizer = Adam(lr=0.001)
    model.compile(loss='categorical_crossentropy', optimizer=optimizer, metrics=['accuracy'])
    model.summary()
    # 画出模型图
    # plot_model(model, to_file='model.png', show_shapes=True, expand_nested=True, dpi=500)
    return  model

In [None]:
model = build_model()

In [None]:
import math
def get_batch(batch_size = 32):
    """源源不断产生产生one-hot编码的训练数据

    :param batch_size: [一次产生训练数据的大小], defaults to 32
    :type batch_size: int, optional
    :yield: [返回X（np.array(X_train_batch)）和Y（np.array(Y_train_batch)）]
    :rtype: [X.shape为(batch_size, 6, 3000) , Y.shape数据的shape(batch_size, 3000)]
    """
    # 确定每轮有多少个batch
    steps = math.ceil(len(X_train_word) / batch_size)
    while True:
        for i in range(steps):
            X_train_batch = []
            Y_train_batch = []
            X_batch_datas = X_train_word[i*batch_size:(i+1)*batch_size]
            Y_batch_datas = Y_train_word[i*batch_size:(i+1)*batch_size]

            for x,y in zip(X_batch_datas,Y_batch_datas):
                X_train_batch.append(phrase_to_one_hot(x))
                Y_train_batch.append(word_to_one_hot(y))
            yield np.array(X_train_batch),np.array(Y_train_batch)

In [None]:
def predict_next(x):
    """ 根据X预测下一个字符

    :param x: [输入数据]
    :type x: [x的shape为(1,TRAIN_NUM,WORD_NUM)]
    :return: [最大概率字符的索引，有可能为为2999，也就是预测的字符可能为“ ”]
    :rtype: [int]
    """
    predict_y = model.predict(x)[0]
    # 获得最大概率的索引
    index = np.argmax(predict_y)
    return index

def generate_sample_result(epoch, logs):
    """生成五言诗

    :param epoch: [目前模型训练的epoch]
    :type epoch: [int]
    :param logs: [模型训练日志]
    :type logs: [list]
    """
    # 每个epoch都产生输出
    if epoch % 1 == 0:
        predict_sen = "一朝春夏改，"
        predict_data = predict_sen
        # 生成的4句五言诗（4 * 6 = 24）
        while len(predict_sen) < 24:
            X_data = np.array(phrase_to_one_hot(predict_data)).reshape(1,TRAIN_NUM,WORD_NUM)
            # 根据6个字符预测下一个字符
            y = predict_next(X_data)
            predict_sen = predict_sen+ id_word_dict[y]
            # “寒随穷律变，” ——> “随穷律变，春”
            predict_data = predict_data[1:]+id_word_dict[y]
        # 将数据写入文件    
        with open('out/out.txt', 'a',encoding='utf-8') as f:
            f.write(write_data+'\n')

In [None]:
batch_size = 2048
model.fit_generator(
            generator=get_batch(batch_size),
            verbose=True,
            steps_per_epoch=math.ceil(len(X_train_word) / batch_size),
            epochs=30,
            callbacks=[
                ModelCheckpoint("poetry_model.hdf5",verbose=1,monitor='val_loss',period=1),
                # 每次完成一个epoch会调用generate_sample_result产生五言诗
                LambdaCallback(on_epoch_end=generate_sample_result)
            ]
    )