# Next-Word Prediction with a Small Transformer (Keras/TensorFlow)

In [1]:

!pip -q install tensorflow


In [2]:

import numpy as np, tensorflow as tf
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras import layers, Model

tf.random.set_seed(42)
np.random.seed(42)
print("TF:", tf.__version__)


TF: 2.19.0


In [8]:

corpus_text = "Alice was beginning to get very tired of sitting by her sister on the bank.".lower()
tokenizer = Tokenizer(oov_token="<oov>")
tokenizer.fit_on_texts([corpus_text])
word_index = tokenizer.word_index
word_index


{'<oov>': 1,
 'alice': 2,
 'was': 3,
 'beginning': 4,
 'to': 5,
 'get': 6,
 'very': 7,
 'tired': 8,
 'of': 9,
 'sitting': 10,
 'by': 11,
 'her': 12,
 'sister': 13,
 'on': 14,
 'the': 15,
 'bank': 16}

In [9]:
id_to_word = {i:w for w,i in word_index.items()}
id_to_word

{1: '<oov>',
 2: 'alice',
 3: 'was',
 4: 'beginning',
 5: 'to',
 6: 'get',
 7: 'very',
 8: 'tired',
 9: 'of',
 10: 'sitting',
 11: 'by',
 12: 'her',
 13: 'sister',
 14: 'on',
 15: 'the',
 16: 'bank'}

In [12]:
vocab_size = len(word_index) + 1
tokens = tokenizer.texts_to_sequences([corpus_text])[0]
tokens


[2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]

In [16]:
sequences = [tokens[:i] for i in range(2, len(tokens))]
sequences

[[2, 3],
 [2, 3, 4],
 [2, 3, 4, 5],
 [2, 3, 4, 5, 6],
 [2, 3, 4, 5, 6, 7],
 [2, 3, 4, 5, 6, 7, 8],
 [2, 3, 4, 5, 6, 7, 8, 9],
 [2, 3, 4, 5, 6, 7, 8, 9, 10],
 [2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
 [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
 [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13],
 [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14],
 [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]]

In [17]:
max_len = max(len(s) for s in sequences)
max_len

14

In [18]:
padded = pad_sequences(sequences, maxlen=max_len, padding='pre')
padded

array([[ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  2,  3],
       [ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  2,  3,  4],
       [ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  2,  3,  4,  5],
       [ 0,  0,  0,  0,  0,  0,  0,  0,  0,  2,  3,  4,  5,  6],
       [ 0,  0,  0,  0,  0,  0,  0,  0,  2,  3,  4,  5,  6,  7],
       [ 0,  0,  0,  0,  0,  0,  0,  2,  3,  4,  5,  6,  7,  8],
       [ 0,  0,  0,  0,  0,  0,  2,  3,  4,  5,  6,  7,  8,  9],
       [ 0,  0,  0,  0,  0,  2,  3,  4,  5,  6,  7,  8,  9, 10],
       [ 0,  0,  0,  0,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11],
       [ 0,  0,  0,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12],
       [ 0,  0,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13],
       [ 0,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14],
       [ 2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15]],
      dtype=int32)

In [19]:
X = padded[:, :-1]
y = padded[:, -1]

In [20]:
X

array([[ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  2],
       [ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  2,  3],
       [ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  2,  3,  4],
       [ 0,  0,  0,  0,  0,  0,  0,  0,  0,  2,  3,  4,  5],
       [ 0,  0,  0,  0,  0,  0,  0,  0,  2,  3,  4,  5,  6],
       [ 0,  0,  0,  0,  0,  0,  0,  2,  3,  4,  5,  6,  7],
       [ 0,  0,  0,  0,  0,  0,  2,  3,  4,  5,  6,  7,  8],
       [ 0,  0,  0,  0,  0,  2,  3,  4,  5,  6,  7,  8,  9],
       [ 0,  0,  0,  0,  2,  3,  4,  5,  6,  7,  8,  9, 10],
       [ 0,  0,  0,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11],
       [ 0,  0,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12],
       [ 0,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13],
       [ 2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14]], dtype=int32)

In [21]:
y

array([ 3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15], dtype=int32)

In [22]:

class PositionalEmbedding(layers.Layer):
    def __init__(self, vocab_size, d_model, max_len):
        super().__init__()
        self.token_emb = layers.Embedding(vocab_size, d_model)
        self.pos_emb = layers.Embedding(max_len, d_model)
    def call(self, x):
        seq_len = tf.shape(x)[1]
        positions = tf.range(start=0, limit=seq_len, delta=1)
        pos = self.pos_emb(positions)
        return self.token_emb(x) + pos[tf.newaxis, :, :]

def causal_mask(seq_len):
    i = tf.range(seq_len)[:, None]
    j = tf.range(seq_len)[None, :]
    return tf.cast(i >= j, dtype=tf.bool)

class TransformerDecoderBlock(layers.Layer):
    def __init__(self, d_model, num_heads, dff):
        super().__init__()
        self.ln1 = layers.LayerNormalization(epsilon=1e-6)
        self.mha = layers.MultiHeadAttention(num_heads=num_heads, key_dim=d_model)
        self.ln2 = layers.LayerNormalization(epsilon=1e-6)
        self.ffn = tf.keras.Sequential([layers.Dense(dff, activation='relu'),
                                        layers.Dense(d_model)])
    def call(self, x, mask=None):
        h = self.ln1(x)
        attn = self.mha(h, h, attention_mask=mask)
        x = x + attn
        h2 = self.ln2(x)
        x = x + self.ffn(h2)
        return x


In [24]:

d_model=64; num_heads=2; dff=128; num_layers=2
seq_len = X.shape[1]
inputs = layers.Input(shape=(seq_len,), dtype=tf.int32)
x = PositionalEmbedding(vocab_size, d_model, max_len)(inputs)
mask = causal_mask(seq_len)[tf.newaxis, ...]
for _ in range(num_layers):
    x = TransformerDecoderBlock(d_model,num_heads,dff)(x,mask=mask)
x = layers.LayerNormalization(epsilon=1e-6)(x)
x_last = x[:, -1, :]
outputs = layers.Dense(vocab_size, activation='softmax')(x_last)
model = Model(inputs, outputs)
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
model.summary()


In [25]:

model.fit(X, y, epochs=10, verbose=1)


Epoch 1/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 9s/step - accuracy: 0.0769 - loss: 3.2079
Epoch 2/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 1s/step - accuracy: 0.0769 - loss: 2.6853
Epoch 3/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 56ms/step - accuracy: 0.1538 - loss: 2.2201
Epoch 4/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 58ms/step - accuracy: 0.3077 - loss: 1.9785
Epoch 5/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 38ms/step - accuracy: 0.5385 - loss: 1.7008
Epoch 6/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 60ms/step - accuracy: 0.5385 - loss: 1.4928
Epoch 7/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 42ms/step - accuracy: 0.6154 - loss: 1.3020
Epoch 8/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 43ms/step - accuracy: 0.8462 - loss: 1.1096
Epoch 9/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s

<keras.src.callbacks.history.History at 0x7b5c60a135c0>

In [26]:

def generate_next_words(seed, num_words=5):
    text = seed.lower()
    for _ in range(num_words):
        seq = tokenizer.texts_to_sequences([text])[0]
        seq = seq[-seq_len:]
        seq = pad_sequences([seq], maxlen=seq_len, padding='pre')
        probs = model.predict(seq, verbose=0)[0]
        next_id = int(np.argmax(probs))
        next_word = id_to_word.get(next_id, "<unk>")
        text += " " + next_word
    return text

print(generate_next_words("alice was", num_words=5))


alice was beginning was beginning get very
