Imports


In [None]:
import os
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM, Dense, Embedding
import numpy as np
import glob

Get text and create sequences

In [None]:
path_text = tf.keras.utils.get_file('shakespeare.txt', 'https://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt')
text = open(path_text, 'rb').read().decode(encoding='utf-8')
chars = sorted(set(text))
vocab_size = len(chars)
char2idx = {c: i for i, c in enumerate(chars)}
idx2char = np.array(chars)

text_as_int = np.array([char2idx[c] for c in text], dtype=np.int32)

seq_length = 100
examples_per_epoch = len(text_as_int) // (seq_length + 1)

char_dataset = tf.data.Dataset.from_tensor_slices(text_as_int)
sequences = char_dataset.batch(seq_length + 1, drop_remainder=True)

Create splits based on sequence, then prepare for training using batch & buffer size

In [None]:
def split_input_target(chunk):
    input_seq = chunk[:-1]
    target_char = chunk[-1]
    return input_seq, target_char

dataset = sequences.map(split_input_target)

BATCH_SIZE = 64
BUFFER_SIZE = 10000

dataset = dataset.shuffle(BUFFER_SIZE).batch(BATCH_SIZE, drop_remainder=True)

Create model

In [None]:
embedding_dim = 256
lstm_units = 512

model = Sequential([
    Embedding(vocab_size, embedding_dim),
    LSTM(lstm_units, return_sequences=True),
    LSTM(lstm_units),
    Dense(vocab_size, activation='softmax')
])

Train model, store weights as checkpoint file

In [None]:
model.compile(loss='sparse_categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

checkpoint_dir = './training_checkpoints'
os.makedirs(checkpoint_dir, exist_ok=True)
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch:02d}.weights.h5")

checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_prefix,
    save_weights_only=True,
    save_best_only=False,
    save_freq='epoch'
)


In [None]:
EPOCH = 10
model.fit(dataset, epochs=EPOCH, callbacks=[checkpoint_callback])

In [None]:
model.summary()

Generate text

In [None]:
def latest_checkpoint():
    weight_files = sorted(glob.glob(os.path.join(checkpoint_dir, "*.weights.h5")))
    return weight_files[-1] if weight_files else None
weights_to_load = latest_checkpoint()
model.load_weights(weights_to_load)

def generate_text(seed_text, num_generate=500, temperature=0.7):

    pad_idx = 0
    input_ids = [char2idx.get(c, pad_idx) for c in seed_text]
    if len(input_ids) < seq_length:
        input_ids = [pad_idx] * (seq_length - len(input_ids)) + input_ids
    else:
        input_ids = input_ids[-seq_length:]

    generated = []
    for _ in range(num_generate):
        x = np.array([input_ids], dtype=np.int32)  # shape (1, seq_length)
        preds = model.predict(x, verbose=0)        # shape (1, vocab_size)
        preds = preds[0].astype(np.float64)

        # Temperature scaling and sampling
        preds = np.log(preds + 1e-8) / max(1e-8, temperature)
        exp_preds = np.exp(preds)
        probs = exp_preds / np.sum(exp_preds)

        next_id = np.random.choice(range(vocab_size), p=probs)
        next_char = idx2char[next_id]
        generated.append(next_char)

        input_ids = input_ids[1:] + [next_id]
    return seed_text + ''.join(generated)

In [None]:
SEED = "MENENIUS:"
NUM_GENERATE = 500
TEMPERATURE = 0.7

sample = generate_text(SEED, NUM_GENERATE, TEMPERATURE)
print("\n--- Generated sample ---\n")
print(sample)
print("\n--- End sample ---\n")