See https://www.tensorflow.org/text/tutorials/text_generation

In [None]:
import tensorflow as tf

import numpy as np
import os
import time
import sys
sys.path.append('../train')
from src.shakespeare_model import (
    OneStepModel, ShakespeareModel, split_input_target, generate_batch_dataset
)

# 1. Get Shakespeare data

In [None]:
path_to_file = tf.keras.utils.get_file(
    'shakespeare.txt',
    'https://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt'
)
# Read, then decode for py2 compat.
text = open(path_to_file, 'rb').read().decode(encoding='utf-8')
# length of text is the number of characters in it
print(f'Length of text: {len(text)} characters')

# Unique characters in the file: define vocabulary
vocab = sorted(set(text))
print(f'{len(vocab)} unique characters')

In [None]:
# Split between train and validation sets
split_index = 1003857  # Approx 90% ratio between train/validation
text_train = text[:split_index]
text_valid = text[split_index:]

# Reduce train text smoke training
text_train = text_train[:20000]
print(f'{len(sorted(set(text_train)))} unique characters')

# 2. Process text
Tokenize and vectorize: convert strings into numerical representation

In [None]:
ids_from_chars = tf.keras.layers.StringLookup(vocabulary=list(vocab), mask_token=None)

example_text = ['this', 'giggle', 'is', 'so', 'silver']
example_chars = tf.strings.unicode_split(example_text, input_encoding='UTF-8')

ids_from_chars = tf.keras.layers.StringLookup(vocabulary=list(vocab), mask_token=None)
example_ids = ids_from_chars(example_chars)
print(example_ids)

This is the numerical representation of our example text. We need a way to convert it back into text :

In [None]:
chars_from_ids = tf.keras.layers.StringLookup(
    vocabulary=ids_from_chars.get_vocabulary(), invert=True, mask_token=None)

print(chars_from_ids(example_ids))

In [None]:
tf.strings.reduce_join(chars_from_ids(example_ids), axis=-1).numpy()

In [None]:
def text_from_ids(ids):
    return tf.strings.reduce_join(chars_from_ids(ids), axis=-1)

# 3. Prediction task: train on corpus and predict what's next
`tf.data.Dataset.from_tensor_slices` function to convert the text vector into a stream of character indices

In [None]:
SEQ_LENGTH = 100
# Batch size
BATCH_SIZE = 64
# Buffer size to shuffle the dataset
# (TF data is designed to work with possibly infinite sequences,
# so it doesn't attempt to shuffle the entire sequence in memory. Instead,
# it maintains a buffer in which it shuffles elements).
BUFFER_SIZE = 10000
train_dataset_batch = generate_batch_dataset(
    text=text_train,
    ids_from_chars=ids_from_chars,
    seq_length=SEQ_LENGTH,
    buffer_size=BUFFER_SIZE,
    batch_size=BATCH_SIZE
)

In [None]:
# Length of the vocabulary in StringLookup Layer
vocab_size = len(ids_from_chars.get_vocabulary())
# The embedding dimension
embedding_dim = 256
# Number of RNN units
rnn_units = 1024

model = ShakespeareModel(
    vocab_size=vocab_size,
    embedding_dim=embedding_dim,
    rnn_units=rnn_units
)

In [None]:
for input_example_batch, target_example_batch in train_dataset_batch.take(1):
    print(input_example_batch)
    example_batch_predictions = model(input_example_batch)
    print(f'{example_batch_predictions.shape}  # (batch_size, sequence_length, vocab_size)')

In [None]:
model.summary()

In [None]:
sampled_indices = tf.random.categorical(example_batch_predictions[0], num_samples=1)
sampled_indices = tf.squeeze(sampled_indices, axis=-1).numpy()

print("Input:\n", text_from_ids(input_example_batch[0]).numpy())
print()
print("Next Char Predictions:\n", text_from_ids(sampled_indices).numpy())

### Train

In [None]:
loss = tf.losses.SparseCategoricalCrossentropy(from_logits=True)
example_batch_mean_loss = loss(target_example_batch, example_batch_predictions)
print(f'Prediction shape: {example_batch_predictions.shape}  # (batch_size, sequence_length, vocab_size)')
print(f'Mean loss: {example_batch_mean_loss}')

model.compile(optimizer='adam', loss=loss)

# Directory where the checkpoints will be saved
checkpoint_dir = './training_checkpoints'
# Name of the checkpoint files
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}")

checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_prefix,
    save_weights_only=True
)

EPOCHS = 15
history = model.fit(train_dataset_batch, epochs=EPOCHS, callbacks=[checkpoint_callback])

In [None]:
# Create a OneStepModel generator
one_step_model = OneStepModel(model, chars_from_ids, ids_from_chars)
start = time.time()
states = None
next_char = tf.constant(['ROMEO:'])
result = [next_char]

for n in range(1000):
    next_char, states = one_step_model.generate_one_step(next_char, states=states)
    result.append(next_char)

result = tf.strings.join(result)
end = time.time()
print(result[0].numpy().decode('utf-8'), '\n\n' + '_'*80)
print('\nRun time:', end - start)

In [None]:
# Create a OneStepModel generator
one_step_model = OneStepModel(model, chars_from_ids, ids_from_chars)
start = time.time()
states = None
next_char = tf.constant(['ROMEO:'])
result = [next_char]

for n in range(1000):
    next_char, states = one_step_model.generate_one_step(next_char, states=states)
    result.append(next_char)

result = tf.strings.join(result)
end = time.time()
print(result[0].numpy().decode('utf-8'), '\n\n' + '_'*80)
print('\nRun time:', end - start)

In [None]:
tf.config.run_functions_eagerly(True)

In [None]:
one_step_model.generate_one_sentence('GREMIO:\n')

In [None]:
tf.saved_model.save(one_step_model, 'one_step')
#one_step_reloaded = tf.saved_model.load('one_step')

## Validation process
### Metrics

In [None]:
import nltk
nltk.download('punkt')
from src.metrics import calculate_bleu_score, calculate_rouge_score

In [None]:
calculate_bleu('GREMIO:\nYNO3hUCnGpSdRC;aW?', 'GREMIO:\nGood morrow, neighbour Baptista.')

In [None]:
calculate_bleu('GREMIO:\nYNO3hUCnGpSdRC;aW?', 'GREMIO:\n&-;npwld tat wat e, the tchero ilthe bee nrecurmwnuree dihet, sut\n')

In [None]:
valid_dataset_batch = generate_batch_dataset(
    text=text_valid,
    ids_from_chars=ids_from_chars,
    seq_length=SEQ_LENGTH,
    buffer_size=BUFFER_SIZE,
    batch_size=BATCH_SIZE
)