# Next Word Prediction with RNN, LSTM, GRU, Transformer, and BERT
This notebook demonstrates how different sequence models predict the **next word** in text.
- RNN
- LSTM
- GRU
- Transformer
- BERT (masked word prediction)


In [1]:
!pip install tensorflow transformers datasets --quiet

In [2]:

import tensorflow as tf
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Embedding, SimpleRNN, LSTM, GRU, Dense, Input, Dropout
from tensorflow.keras import Model
from tensorflow.keras.layers import MultiHeadAttention, LayerNormalization

import numpy as np
from transformers import BertTokenizer, TFBertForMaskedLM


## Data Preparation

In [3]:

# Sample small corpus for quick demo
corpus = [
    "the sun rises in the east",
    "the moon shines at night",
    "the stars twinkle in the sky",
    "the earth revolves around the sun",
    "the wind blows softly"
]

tokenizer = Tokenizer()
tokenizer.fit_on_texts(corpus)
total_words = len(tokenizer.word_index) + 1

# Create input sequences
input_sequences = []
for line in corpus:
    token_list = tokenizer.texts_to_sequences([line])[0]
    for i in range(1, len(token_list)):
        n_gram_sequence = token_list[:i+1]
        input_sequences.append(n_gram_sequence)

# Pad sequences
max_seq_len = max([len(x) for x in input_sequences])
input_sequences = np.array(pad_sequences(input_sequences, maxlen=max_seq_len, padding="pre"))

X, y = input_sequences[:,:-1], input_sequences[:,-1]
y = tf.keras.utils.to_categorical(y, num_classes=total_words)
print("Vocabulary size:", total_words)
print("Max sequence length:", max_seq_len)


Vocabulary size: 19
Max sequence length: 6


In [4]:

reverse_word_index = {v:k for k,v in tokenizer.word_index.items()}

def predict_next_word(model, seed_text, max_len):
    token_list = tokenizer.texts_to_sequences([seed_text])[0]
    token_list = pad_sequences([token_list], maxlen=max_len-1, padding="pre")
    predicted = model.predict(token_list, verbose=0)
    next_index = np.argmax(predicted, axis=1)[0]
    return reverse_word_index.get(next_index, "?")


## RNN Model

In [5]:

rnn_model = Sequential([
    Embedding(total_words, 64, input_length=max_seq_len-1),
    SimpleRNN(32),
    Dense(total_words, activation='softmax')
])
rnn_model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
rnn_model.fit(X, y, epochs=200, verbose=0)

print("Prediction (RNN):", predict_next_word(rnn_model, "the sun rises", max_seq_len))




Prediction (RNN): in


## LSTM Model

In [6]:

lstm_model = Sequential([
    Embedding(total_words, 64, input_length=max_seq_len-1),
    LSTM(64),
    Dense(total_words, activation='softmax')
])
lstm_model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
lstm_model.fit(X, y, epochs=200, verbose=0)

print("Prediction (LSTM):", predict_next_word(lstm_model, "the sun rises", max_seq_len))


Prediction (LSTM): in


## GRU Model

In [7]:

gru_model = Sequential([
    Embedding(total_words, 64, input_length=max_seq_len-1),
    GRU(64),
    Dense(total_words, activation='softmax')
])
gru_model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
gru_model.fit(X, y, epochs=200, verbose=0)

print("Prediction (GRU):", predict_next_word(gru_model, "the sun rises", max_seq_len))


Prediction (GRU): in


## Transformer Model (Minimal Encoder)

In [11]:
class TransformerBlock(tf.keras.layers.Layer):
    def __init__(self, embed_dim, num_heads, ff_dim, rate=0.1):
        super(TransformerBlock, self).__init__()
        self.att = MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)
        self.ffn = Sequential([Dense(ff_dim, activation="relu"), Dense(embed_dim)])
        self.layernorm1 = LayerNormalization(epsilon=1e-6)
        self.layernorm2 = LayerNormalization(epsilon=1e-6)
        self.dropout1 = Dropout(rate)
        self.dropout2 = Dropout(rate)
    def call(self, inputs, training):
        attn_output = self.att(inputs, inputs)
        attn_output = self.dropout1(attn_output, training=training)
        out1 = self.layernorm1(inputs + attn_output)
        ffn_output = self.ffn(out1)
        ffn_output = self.dropout2(ffn_output, training=training)
        return self.layernorm2(out1 + ffn_output)

class MeanPooling(tf.keras.layers.Layer):
    def call(self, inputs):
        return tf.reduce_mean(inputs, axis=1)

inputs = Input(shape=(max_seq_len-1,))
embedding_layer = Embedding(total_words, 64)(inputs)
transformer_block = TransformerBlock(64, 2, 64)
x = transformer_block(embedding_layer, training=False) # Pass training argument
x = MeanPooling()(x) # Use the custom MeanPooling layer
outputs = Dense(total_words, activation="softmax")(x)

transformer_model = Model(inputs=inputs, outputs=outputs)
transformer_model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
transformer_model.fit(X, y, epochs=200, verbose=0)

print("Prediction (Transformer):", predict_next_word(transformer_model, "the sun rises", max_seq_len))

Prediction (Transformer): in


## BERT (Masked Language Model)

In [13]:
bert_model = TFBertForMaskedLM.from_pretrained("bert-base-uncased", from_pt=True)
bert_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

text = "the sun rises in the [MASK]"
inputs = bert_tokenizer(text, return_tensors="tf")
outputs = bert_model(**inputs)
predictions = outputs.logits
mask_token_index = tf.where(inputs["input_ids"] == bert_tokenizer.mask_token_id)[0, 1]
predicted_id = tf.argmax(predictions[0, mask_token_index]).numpy()
predicted_token = bert_tokenizer.decode([predicted_id])
print("Prediction (BERT):", predicted_token)

pytorch_model.bin:   0%|          | 0.00/440M [00:00<?, ?B/s]

All PyTorch model weights were used when initializing TFBertForMaskedLM.

All the weights of TFBertForMaskedLM were initialized from the PyTorch model.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFBertForMaskedLM for predictions without further training.


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

TensorFlow and JAX classes are deprecated and will be removed in Transformers v5. We recommend migrating to PyTorch classes or pinning your version of Transformers.


Prediction (BERT): sky
