In [80]:
from __future__ import absolute_import, division, print_function, unicode_literals

import tensorflow as tf

import numpy as np
import os
import time


来自官方教程： https://www.tensorflow.org/tutorials/text/text_generation?hl=en

差别是换了一下数据



In [81]:
#Const
SEQ_LEN = 13
BATCH_SIZE = 64
BUFFER_SIZE = 10000
VOCAB_SIZE = 5000
EMBEDDING_DIM = 256
RNN_UNITS = 1024
EPOCHS = 10

In [82]:
class NLG_RNN_TF(object):
    def __init__(self, vocab_size, embedding_dim, rnn_units):
        self.vocab_size = vocab_size
        self.embedding_dim = embedding_dim
        self.rnn_units = rnn_units
        self.model = None
        
    
    def init_model(self, batch_size):
        self.model = tf.keras.Sequential([
            tf.keras.layers.Embedding(self.vocab_size, self.embedding_dim,
                                      batch_input_shape=[batch_size, None]),
            tf.keras.layers.GRU(self.rnn_units,
                                return_sequences=True,
                                stateful=True,
                                recurrent_initializer='glorot_uniform'),
            tf.keras.layers.Dense(self.vocab_size)
          ])
        
    
    def train(self, batch_size, dataset, epochs =  1, checkpoint_dir = './rnn_tf_checkpoint'):
        self.init_model(batch_size)
        def loss(labels, logits):
            return tf.keras.losses.sparse_categorical_crossentropy(labels, logits, from_logits=True)
        self.model.compile(optimizer='adam', loss=loss)
        
        checkpoint_callback=tf.keras.callbacks.ModelCheckpoint(
            filepath=os.path.join(checkpoint_dir, "ckpt_{epoch}"),
            save_weights_only=True)
        
        history = self.model.fit(dataset, epochs=epochs, callbacks=[checkpoint_callback])
        
    def load_model(self, checkpoint_dir = './rnn_tf_checkpoint' ):
        self.init_model(1)
        self.model.load_weights(tf.train.latest_checkpoint(checkpoint_dir))
        self.model.build(tf.TensorShape([1, None]))
        self.model.summary()
        
    def generate_text(self, start_string, seq_len = SEQ_LEN, word_to_index = None, index_to_word = None):
        
        input_eval = [word_to_index[s] for s in start_string]
        input_eval = tf.expand_dims(input_eval, 0)
        text_generated = []

        # Low temperatures results in more predictable text.
        # Higher temperatures results in more surprising text.
        # Experiment to find the best setting.
        temperature = 1.0

        self.model.reset_states()
        for i in range(seq_len):
            predictions = self.model(input_eval)
            # remove the batch dimension
            predictions = tf.squeeze(predictions, 0)

            # using a categorical distribution to predict the character returned by the model
            predictions = predictions / temperature
            predicted_id = tf.random.categorical(predictions, num_samples=1)[-1,0].numpy()

            # We pass the predicted character as the next input to the model
            # along with the previous hidden state
            input_eval = tf.expand_dims([predicted_id], 0)

            text_generated.append(index_to_word[predicted_id])

        return (start_string + ''.join(text_generated))

In [83]:
def prepare_data():
    from data import TextDataLoader
    loader = TextDataLoader()
    loader.load_data("data/poetry.txt", VOCAB_SIZE)
    all_word_index = np.array([ i for line in loader.data for i in line ])
    
    char_dataset = tf.data.Dataset.from_tensor_slices(all_word_index)
    sequences = char_dataset.batch(SEQ_LEN, drop_remainder=True)
    def split_input_target(chunk):
        input_text = chunk[:-1]
        target_text = chunk[1:]
        return input_text, target_text

    dataset = sequences.map(split_input_target)
    dataset = dataset.shuffle(BUFFER_SIZE).batch(BATCH_SIZE, drop_remainder=True)
    return dataset
    
    

In [84]:
def _test_train():
    dataset = prepare_data()
    
    model = NLG_RNN_TF(VOCAB_SIZE, EMBEDDING_DIM, RNN_UNITS)
    model.train(BATCH_SIZE, dataset, epochs=1)
    
def _test_gen():
    from data import TextDataLoader
    loader = TextDataLoader()
    loader.load_data("data/poetry.txt", VOCAB_SIZE)
    
    model = NLG_RNN_TF(VOCAB_SIZE, EMBEDDING_DIM, RNN_UNITS)
    model.load_model()
    for i in range(10):
        ret = model.generate_text("春", seq_len = SEQ_LEN, word_to_index = loader.word_to_index, index_to_word = loader.index_to_word)
        print(ret)
    

In [85]:
if __name__ == "__main__":
    #_test_train()
    _test_gen()

























































































































































Model: "sequential_9"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
embedding_9 (Embedding)      (1, None, 256)            1280000   
_________________________________________________________________
gru_9 (GRU)                  (1, None, 1024)           3938304   
_________________________________________________________________
dense_9 (Dense)              (1, None, 5000)           5125000   
Total params: 10,343,304
Trainable params: 10,343,304
Non-trainable params: 0
_________________________________________________________________
春界披锦，汗漫烧。
相将旧咸
春簴马，非尔怯所贤。
故人复
春灵滤池塘。
何须成贾傅，怪
春乌握栏巾起，光华露气呵。

春初暖，窗外月犹重，南风燕一
春砝红尘。
闻浑吾山上，，此
春色临坛。
瓶沽孤片月，助息
春香虽甘食野累自常。
何曾欢
春鸟风来晓蝶飞。
小里春山里
春蕊文兰带开垂玉轴，落照玉皇
