In [1]:
import tensorflow as tf
from tensorflow import keras
import numpy as np

In [2]:
file_path = keras.utils.get_file('shakespeare.txt', 'https://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt')

In [5]:
text = open(file_path, 'rb').read().decode(encoding='utf-8')
len(text)

1115394

In [12]:
# 有 100 多万个字母的数据集，打印前 1000 个。
print(text[:1000])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor citizens, the patricians good.
What authority surfeits on would relieve us: if they
would yield us but the superfluity, while it were
wholesome, we might guess they relieved us humanely;
but they think we are too dear: the leanness that
afflicts us, the object of our misery, is as an
inventory to particularise their abundance; our
sufferance is a gain to them Let us revenge this with
our pikes, ere we become rakes: for the gods know I
speak this in hunger for bread, not in thirst for revenge.



In [14]:
# 取出不重复的字母
vocab = sorted(set(text))
print(vocab)
len(vocab)

['\n', ' ', '!', '$', '&', "'", ',', '-', '.', '3', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']


65

In [15]:
char2index = {u:i for i, u in enumerate(vocab)}
index2char = np.array(vocab)

# 把字符都转成索引
text_as_int = np.array([char2index[c] for c in text])
text_as_int

array([18, 47, 56, ..., 45,  8,  0])

### 训练样本的创建

将整个文本拆分成 `seq_length+1` 个字符的文本块，每个输入序列和目标序列均有 `seq_length` 个字符。

目标序列为输入序列向右顺移一个字符。

文本块 "Hello",
输入序列 "Hell",
目标序列 "ello".

In [16]:
seq_length = 100

text_as_int_dataset = tf.data.Dataset.from_tensor_slices(text_as_int)
text_as_int_dataset = text_as_int_dataset.batch(seq_length+1, drop_remainder=True)
text_as_int_dataset

<BatchDataset shapes: (101,), types: tf.int64>

In [17]:
def split_input_target(chunk):
    """
    ("Hello") -> ("Hell", "ello")
    """
    return chunk[:-1], chunk[1:]

dataset = text_as_int_dataset.map(split_input_target)
dataset

<MapDataset shapes: ((100,), (100,)), types: (tf.int64, tf.int64)>

In [18]:
dataset = dataset.shuffle(10000).batch(64, drop_remainder=True)
dataset

<BatchDataset shapes: ((64, 100), (64, 100)), types: (tf.int64, tf.int64)>

In [19]:
# 嵌入的维度
embedding_dimension = 256

# RNN 单元数量
rnn_units = 1024

model = keras.Sequential([
    keras.layers.Embedding(len(vocab), embedding_dimension,
                           batch_input_shape=[64, None]),
    keras.layers.GRU(rnn_units,
                     return_sequences=True,
                     stateful=True,
                     recurrent_initializer='glorot_uniform'),
    keras.layers.Dense(len(vocab))
])

model.summary()

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
embedding (Embedding)        (64, None, 256)           16640     
_________________________________________________________________
gru (GRU)                    (64, None, 1024)          3938304   
_________________________________________________________________
dense (Dense)                (64, None, 65)            66625     
Total params: 4,021,569
Trainable params: 4,021,569
Non-trainable params: 0
_________________________________________________________________


![text generation RNN](text_generation_training.png)

In [20]:
def loss(labels, logits):
    # 因为我们的模型返回逻辑回归，所以我们需要设定命令行参数 from_logits
    return keras.losses.sparse_categorical_crossentropy(labels, logits, from_logits=True)

model.compile(optimizer='adam', loss=loss)

In [21]:
history = model.fit(dataset, epochs=10)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


In [22]:
model.save('4-14-00-35.h5')