In [60]:
# Generating Shakespeare Text Using a Character RNN

import numpy as np
import tensorflow as tf
from tensorflow import keras

# step 1: load the data
shakespeare_url = "https://homl.info/shakespeare" # shortcut URL
filepath = keras.utils.get_file("shakespeare.txt", shakespeare_url)
with open(filepath) as f:
    shakespeare_text = f.read()

# step 2: tokenization
tokenizer = keras.preprocessing.text.Tokenizer(char_level=True)
#tokenizer.fit_on_texts(shakespeare_text) old line with text input as a list.
tokenizer.fit_on_texts(shakespeare_text)

In [None]:
max_id = len(tokenizer.word_index) # number of distinct characters
dataset_size = tokenizer.document_count # total number of characters

print(max_id, dataset_size)

In [None]:
[encoded] = np.array(tokenizer.texts_to_sequences([shakespeare_text])) - 1
train_size = dataset_size * 90 // 100
dataset = tf.data.Dataset.from_tensor_slices(encoded[:train_size])
n_steps = 100
window_length = n_steps + 1 # target = input shifted 1 character ahead
dataset = dataset.window(window_length, shift=1, drop_remainder=True)

dataset = dataset.flat_map(lambda window: window.batch(window_length))
batch_size = 32
dataset = dataset.shuffle(10000).batch(batch_size)
dataset = dataset.map(lambda windows: (windows[:, :-1], windows[:, 1:]))

dataset = dataset.map(
    lambda X_batch, Y_batch: (tf.one_hot(X_batch, depth=max_id), Y_batch))
dataset = dataset.prefetch(1)

In [59]:
model = keras.models.Sequential([
    keras.layers.GRU(128, return_sequences=True, input_shape=[None, max_id],
                     dropout=0.2, recurrent_dropout=0.2),
    keras.layers.GRU(128, return_sequences=True,
                     dropout=0.2, recurrent_dropout=0.2),
    keras.layers.TimeDistributed(keras.layers.Dense(max_id,
                                                    activation="softmax"))
])
model.compile(loss="sparse_categorical_crossentropy", optimizer="adam")

history = model.fit(dataset,steps_per_epoch=train_size // batch_size, epochs=3)

Epoch 1/3
   20/31370 [..............................] - ETA: 1:28:33 - loss: 3.5134 

KeyboardInterrupt: 

<FlatMapDataset shapes: (None,), types: tf.int64>
