In [8]:
import tensorflow as tf
import tensorflow_datasets as tfds
import numpy as np

dataset, info = tfds.load('tiny_shakespeare', with_info=True, as_supervised=False)

In [9]:
text = next(iter(dataset['train']))['text'].numpy().decode('utf-8')

vocab = sorted(set(text))
char2idx = {char: idx for idx, char in enumerate(vocab)}
idx2char = np.array(vocab)

text_as_int = np.array([char2idx[c] for c in text])

seq_length = 100
examples_per_epoch = len(text) // (seq_length + 1)

char_dataset = tf.data.Dataset.from_tensor_slices(text_as_int)

sequences = char_dataset.batch(seq_length + 1, drop_remainder=True)

In [10]:
def split_input_target(chunk):
    input_text = chunk[:-1]
    target_text = chunk[1:]
    return input_text, target_text

dataset = sequences.map(split_input_target)

In [11]:
BATCH_SIZE = 64
BUFFER_SIZE = 10000

dataset = (
    dataset
    .shuffle(BUFFER_SIZE)
    .batch(BATCH_SIZE, drop_remainder=True)
    .prefetch(tf.data.experimental.AUTOTUNE)
)

In [12]:
import tensorflow as tf

vocab_size = len(vocab)

embedding_dim = 256

rnn_units = 1024

def build_model(vocab_size, embedding_dim, rnn_units, batch_size):
    inputs = tf.keras.Input(batch_shape=(batch_size, None))
    x = tf.keras.layers.Embedding(vocab_size, embedding_dim)(inputs)
    x = tf.keras.layers.LSTM(rnn_units, return_sequences=True, stateful=True, recurrent_initializer='glorot_uniform')(x)
    outputs = tf.keras.layers.Dense(vocab_size)(x)
    return tf.keras.Model(inputs, outputs)

model = build_model(vocab_size, embedding_dim, rnn_units, BATCH_SIZE)

In [13]:
def loss(labels, logits):
    return tf.keras.losses.sparse_categorical_crossentropy(labels, logits, from_logits=True)

model.compile(optimizer='adam', loss=loss)

In [14]:
import os

checkpoint_dir = './training_checkpoints'
os.makedirs(checkpoint_dir, exist_ok=True)

checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}.weights.h5")

checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_prefix,
    save_weights_only=True  # Only save weights
)

EPOCHS = 10
history = model.fit(dataset, epochs=EPOCHS, callbacks=[checkpoint_callback])

Epoch 1/10
[1m155/155[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m200s[0m 1s/step - loss: 2.8897
Epoch 2/10
[1m155/155[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m214s[0m 1s/step - loss: 1.8758
Epoch 3/10
[1m155/155[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m213s[0m 1s/step - loss: 1.6164
Epoch 4/10
[1m155/155[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m216s[0m 1s/step - loss: 1.4861
Epoch 5/10
[1m155/155[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m216s[0m 1s/step - loss: 1.4078
Epoch 6/10
[1m155/155[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m222s[0m 1s/step - loss: 1.3530
Epoch 7/10
[1m155/155[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m218s[0m 1s/step - loss: 1.3073
Epoch 8/10
[1m155/155[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m214s[0m 1s/step - loss: 1.2720
Epoch 9/10
[1m155/155[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2977s[0m 19s/step - loss: 1.2394
Epoch 10/10
[1m155/155[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m194s[0m 

In [18]:
import os
import re

# Manually find the latest checkpoint file
checkpoint_files = os.listdir(checkpoint_dir)
checkpoint_files = [f for f in checkpoint_files if f.endswith(".weights.h5")]
checkpoint_files.sort(key=lambda x: int(re.search(r'\d+', x).group()))
latest_checkpoint = checkpoint_files[-1] if checkpoint_files else None

# Build the model
model = build_model(vocab_size, embedding_dim, rnn_units, batch_size=1)

# Load weights if a checkpoint is found
if latest_checkpoint:
    latest_checkpoint_path = os.path.join(checkpoint_dir, latest_checkpoint)
    model.load_weights(latest_checkpoint_path)
    print(f"Loaded weights from {latest_checkpoint_path}")
else:
    print("No valid checkpoint found.")

# Build the model for inference
model.build(tf.TensorShape([1, None]))

Loaded weights from ./training_checkpoints/ckpt_10.weights.h5


In [22]:
def generate_text(model, start_string):
    num_generate = 1000

    input_eval = [char2idx[s] for s in start_string]
    input_eval = tf.expand_dims(input_eval, 0)

    text_generated = []

    for i in range(num_generate):
        predictions = model(input_eval)
        predictions = tf.squeeze(predictions, 0)

        predicted_id = tf.random.categorical(predictions, num_samples=1)[-1, 0].numpy()
        input_eval = tf.expand_dims([predicted_id], 0)

        text_generated.append(idx2char[predicted_id])

    return (start_string + ''.join(text_generated))

print(generate_text(model, start_string=u"QUEEN: So, lets end this"))

QUEEN: So, lets end this to ever more of you;
And she is so I, no. How oway's good deess!
Emer, when thou broke! Darkncif Milerey's liquoth.
Mark him, we his eyes was too much aprain?
In play apparelly impress'd by me; but thou shalt be
In pain and woman tyrant madam in peace.

Second Murderer:
O my Lord of Wariame.

Stird Augard:
In the Warwick where thou werp death it seems our king,
And even so noble Romeo!' thy sun
Ere now grief to their event, look to know
you rohe many flant and than that places
To meer myself or each other thee decline,
With that can hear you born him hither; before thy pry God in heaven,
I heard you see good Blancaster; they have too Norfolk'd business
From such apputy to his uncle woeld so till
Than climate youe dare only their conversation,
Resterbert my names and high deserts, past Lord Angelo,
Who had so made it first? Where he and unknessed change
Was weap's revoil upon thee. But camal addly be gone.

HENRY BORINGBROKE:
Good, worrior is a respect pay the n