In [1]:
from tensorflow.keras.utils import get_file

file = get_file('shakespeare.txt', 'https://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt')

In [2]:
from tensorflow import strings

text = open(file, 'rb').read().decode(encoding='UTF-8')
vocabulary = list(sorted(set(text)))
chars = strings.unicode_split(text, input_encoding='UTF-8')

In [3]:
from tensorflow.keras.layers.experimental.preprocessing import StringLookup

ids_from_chars = StringLookup(vocabulary=vocabulary, mask_token=None)
ids = ids_from_chars(chars)
chars_from_ids = StringLookup(vocabulary=ids_from_chars.get_vocabulary(), invert=True, mask_token=None)
chars = chars_from_ids(ids)

def text_from_ids(ids):
    _chars = chars_from_ids(ids)
    return strings.reduce_join(_chars, axis=-1)
chars

<tf.Tensor: shape=(1115394,), dtype=string, numpy=array([b'F', b'i', b'r', ..., b'g', b'.', b'\n'], dtype=object)>

In [4]:
from tensorflow import data

SEQUENCE_LENGTH = 100
BATCH_SIZE = SEQUENCE_LENGTH + 1
EXAMPLES_PER_EPOCH = len(text)//BATCH_SIZE

def split_dataset(sequence):
    input_text = sequence[:-1]
    output_text = sequence[1:]
    return input_text, output_text

raw_dataset = data.Dataset.from_tensor_slices(ids)
sequences = raw_dataset.batch(BATCH_SIZE, drop_remainder=True).map(split_dataset)
datasets = sequences.shuffle(10_000).batch(64, drop_remainder=True).prefetch(data.AUTOTUNE)
datasets

<PrefetchDataset shapes: ((64, 100), (64, 100)), types: (tf.int64, tf.int64)>

In [5]:
for input_text, output_text in datasets.take(1):
    print(f"""
        input: {text_from_ids(input_text[0]).numpy()}
        vs
        output: {text_from_ids(output_text[0]).numpy()}
    """)


        input: b'ot in the giving vein to-day.\n\nBUCKINGHAM:\nWhy, then resolve me whether you will or no.\n\nKING RICHAR'
        vs
        output: b't in the giving vein to-day.\n\nBUCKINGHAM:\nWhy, then resolve me whether you will or no.\n\nKING RICHARD'
    


In [6]:
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.layers import Dense, Embedding, GRU

class MyModel(Model):
    def __init__(self, vocabulary_size, embedding_dimension=256, rnn_units=1024):
        super().__init__(self)
        self.embedding = Embedding(vocabulary_size, embedding_dimension)
        self.gru = GRU(rnn_units, return_sequences=True, return_state=True)
        self.dense = Dense(vocabulary_size)

    def call(self, inputs, states=None, return_state=False, training=False):
        x = self.embedding(inputs, training=training)
        if states is None:
            states = self.gru.get_initial_state(x)
        x, states = self.gru(x, initial_state=states, training=training)
        x = self.dense(x, training=training)
        
        if return_state:
            return x, states
        else:
            return x

model = MyModel(len(ids_from_chars.get_vocabulary()))

In [7]:
from tensorflow.keras.losses import SparseCategoricalCrossentropy
from tensorflow import random, squeeze, exp

loss = SparseCategoricalCrossentropy(from_logits=True)

for input_batch, expected_batch in datasets.take(1):
    actual_output = model(input_batch)
    example_loss = loss(expected_batch, actual_output)
    actual_output = actual_output[0]
    actual_output = random.categorical(actual_output, num_samples=1)
    actual_output = squeeze(actual_output, axis=-1)
    actual_output = text_from_ids(actual_output)
    expected_output = text_from_ids(expected_batch[0])
    print(f"""
        expected: {expected_output}
        actual: {actual_output}
        loss: {example_loss} {exp(example_loss)}
    """)


        expected: b"n which doth control't.\n\nBRUTUS:\nHas said enough.\n\nSICINIUS:\nHas spoken like a traitor, and shall an"
        actual: b"?H\nfBkt;uwxVu:ttnzKlOC:s,iKoGAxrrT$XViKwCCQd CrUtyDpoe,LofAA'W DZ[UNK]MzpTk3bbpuzR!lBoDWFuJj[UNK]UNP.H&E:?oq"
        loss: 4.189743518829346 66.005859375
    


In [8]:
from tensorflow.keras.callbacks import ModelCheckpoint

import pathlib
import tempfile

checkpoint_root = pathlib.Path(tempfile.mkdtemp() + 'generator-1')
checkpoint_prefix = str(checkpoint_root/'checkpoint-{epoch}')
checkpoint_callback = ModelCheckpoint(checkpoint_prefix, monitor='val_loss', save_weights_only=True)

model.compile(optimizer='adam', loss=loss)

In [9]:
history = model.fit(datasets, epochs=20, callbacks=[checkpoint_callback])

Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20


In [10]:
from tensorflow import SparseTensor, sparse, constant, argmax

mask_ids = ids_from_chars(['[UNK]'])[:, None]
mask = SparseTensor(mask_ids, [-float('inf')] * len(mask_ids), [len(ids_from_chars.get_vocabulary())])
mask = sparse.to_dense(mask)

def generate_one_step(inputs, states=None):
    inputs = strings.unicode_split(inputs, input_encoding='UTF-8')
    inputs = ids_from_chars(inputs).to_tensor()
    predicted, states = model(inputs, states, return_state=True)
    predicted = predicted[:, -1, :]
    predicted = predicted + mask
    predicted = random.categorical(predicted, num_samples=1)
    predicted = squeeze(predicted, axis=-1)
    predicted = chars_from_ids(predicted)

    return predicted, states

In [12]:
next_char = constant(['ROMEO'])
states = None
result = [next_char]

for i in range(1_000):
    next_char, states = generate_one_step(next_char, states)
    result.append(next_char)

print(strings.join(result))

tf.Tensor([b"ROMEO:\nWhy do you remuire? You bring him truel and lig!\n\nJULIET:\n'Tis but possible.\n\nDot him:\nYour prate-pierce loving friends,' in this same perfect with holy\nwhomat, you will purchase good.\nWhy standing courts upon him; 'What bear almost\nGiven in the windsworms slain here to thy daughter.\nThen lups are to beat more prizent and spider'd\nCrevill be out from me a hundred apes,\nYour sleep doth quit it o' the subjects of my virthes;\nAnd spun in love.\n\nRIVERS:\nMadam, I will not do it; yet I'll give you:\nI shall tell no thee from the sea, and his new gown,\nMy fear's son is mine offressed with her behoved.\nMy father Was get a stabbed to her,\nAnd he shall bear them true death with the hallow,\nWould say he-wounds, with any spirit comes;\nHaste you found my ben, i' the chboited heart;\nAnd even himself I teed thee deep trample;\nHe cannot, but die with the officers,\nAnd fetch the usuries act of heavier than a wall\nO mosalties; open with him in that son-fooli