In [39]:
import tensorflow as tf
import numpy as np
import os
import time
import re
import re
import nltk
from nltk.corpus import stopwords

In [41]:
path_to_file = r"C:\Users\relee\OneDrive\Desktop\science-textbook-grade-5.txt"
nltk.download('stopwords')
stop_words = set(stopwords.words('english'))

[nltk_data] Downloading package stopwords to
[nltk_data]     C:\Users\relee\AppData\Roaming\nltk_data...
[nltk_data]   Unzipping corpora\stopwords.zip.


In [43]:
import os

# Directory where checkpoints will be saved
checkpoint_dir = './training_checkpoints'
os.makedirs(checkpoint_dir, exist_ok=True)

# Checkpoint file name pattern
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}.weights.h5")

# Define callback
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_prefix,
    save_weights_only=True
)

In [45]:
# Read the text file
try:
    with open(path_to_file, 'r', encoding='utf-8') as f:
        text = f.read()
except UnicodeDecodeError:
    with open(path_to_file, 'r', encoding='latin-1') as f:
        text = f.read()

text = re.sub(r'\n+', ' ', text)             # remove newlines
text = re.sub(r'[^a-zA-Z0-9 .,!?\']+', ' ', text)  # keep only basic punctuation
words = text.split()
filtered_words = [word for word in words if word.lower() not in stop_words]
text = ' '.join(filtered_words)
text = re.sub(r'\s+', ' ', text).strip()     # normalize whitespace

# Tokenize by whitespace
words = text.split()

# Create vocabulary of unique words
vocab = sorted(set(words))
print(f'Total words: {len(words)}')
print(f'Unique words: {len(vocab)}')

Total words: 22815
Unique words: 5238


In [47]:
# StringLookup layers for word IDs
ids_from_words = tf.keras.layers.StringLookup(vocabulary=vocab, mask_token=None)
words_from_ids = tf.keras.layers.StringLookup(vocabulary=ids_from_words.get_vocabulary(), invert=True, mask_token=None)

# Convert words to IDs
all_word_ids = ids_from_words(words)
ids_dataset = tf.data.Dataset.from_tensor_slices(all_word_ids)

# Utility for decoding predictions
def text_from_ids(ids):
    return tf.strings.reduce_join(words_from_ids(ids), separator=' ', axis=-1)

In [49]:
seq_length = 20  # word-based sequence length
sequences = ids_dataset.batch(seq_length + 1, drop_remainder=True)

def split_input_target(sequence):
    input_seq = sequence[:-1]
    target_seq = sequence[1:]
    return input_seq, target_seq

dataset = sequences.map(split_input_target)

In [51]:
BATCH_SIZE = 64
BUFFER_SIZE = 10000

dataset = (
    dataset
    .shuffle(BUFFER_SIZE)
    .batch(BATCH_SIZE, drop_remainder=True)
    .prefetch(tf.data.experimental.AUTOTUNE)
)

In [53]:
class MyModel(tf.keras.Model):
    def __init__(self, vocab_size, embedding_dim, rnn_units):
        super().__init__()
        self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
        self.gru = tf.keras.layers.GRU(rnn_units, return_sequences=True, return_state=True)
        self.dense = tf.keras.layers.Dense(vocab_size)

    def call(self, inputs, states=None, return_state=False, training=False):
        x = self.embedding(inputs, training=training)
        if states is None:
            # Grab the batch size from the input tensor dynamically
            batch_size = tf.shape(inputs)[0]
            states = [tf.zeros((batch_size, self.gru.units))]
    
        x, new_states = self.gru(x, initial_state=states, training=training)
        x = self.dense(x, training=training)
        return (x, new_states) if return_state else x
        
vocab_size = len(ids_from_words.get_vocabulary())
embedding_dim = 256
rnn_units = 1024

model = MyModel(vocab_size, embedding_dim, rnn_units)
loss = tf.losses.SparseCategoricalCrossentropy(from_logits=True)

model.compile(optimizer='adam', loss=loss)

In [55]:
for input_example_batch, target_example_batch in dataset.take(1):
    example_batch_predictions = model(input_example_batch)
    print(example_batch_predictions.shape)

(64, 20, 5239)


In [56]:
model.summary()

In [59]:
history = model.fit(dataset, epochs=1, callbacks=[checkpoint_callback])

[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m44s[0m 3s/step - loss: 8.3872


In [61]:
class OneStep(tf.keras.Model):
    def __init__(self, model, words_from_ids, ids_from_words, temperature=1.0):
        super().__init__()
        self.temperature = temperature
        self.model = model
        self.words_from_ids = words_from_ids
        self.ids_from_words = ids_from_words

        # Prevent "[UNK]" from being generated
        skip_ids = self.ids_from_words(['[UNK]'])[:, None]
        sparse_mask = tf.SparseTensor(
            values=[-float('inf')] * len(skip_ids),
            indices=skip_ids,
            dense_shape=[len(ids_from_words.get_vocabulary())]
        )
        self.prediction_mask = tf.sparse.to_dense(sparse_mask)

    # REMOVE THIS:
    # @tf.function
    
    def generate_one_step(self, inputs, states=None):
        input_words = tf.strings.split(inputs)
        input_ids = self.ids_from_words(input_words).to_tensor()
    
        predicted_logits, states = self.model(inputs=input_ids, states=states, return_state=True)
        predicted_logits = predicted_logits[:, -1, :] / self.temperature
        predicted_logits += self.prediction_mask
    
        predicted_ids = tf.random.categorical(predicted_logits, num_samples=1)
        predicted_ids = tf.squeeze(predicted_ids, axis=-1)
        predicted_words = self.words_from_ids(predicted_ids)
    
        return predicted_words, states

In [63]:
one_step_model = OneStep(model, words_from_ids, ids_from_words)

start = time.time()
states = None
next_input = tf.constant(['the mitochondria is the'])
result = [next_input]

for _ in range(10):  # generate 100 words
    next_input, states = one_step_model.generate_one_step(next_input, states=states)
    result.append(next_input)

result = tf.strings.join(result, separator=' ')
end = time.time()

print(result.numpy()[0].decode('utf-8'))
print('\n' + '_'*80)
print(f"Run time: {end - start}")

the mitochondria is the pumps sailors AO ml removes world. K soaks sugar. satellite

________________________________________________________________________________
Run time: 0.43857359886169434
