# Context-Aware Next-Word Prediction (WikiText-2)

## Ricardo Escarcega
## Martin Battu

### CST435 - Search Engines and Data Mining Lecture and Lab
### Grand Canyon University
### 11/02/2025

# Context-Aware Next-Word Prediction (WikiText-2)

This notebook accompanies the Streamlit demo and walks through the full pipeline for training, evaluating, and packaging a recurrent language model. The goal is to take a sentence prefix and forecast the next word using contextual cues captured by an LSTM.

## Problem Statement
- Build a context-aware autocomplete engine that reasons over ordered token sequences rather than bag-of-words statistics.
- Produce reproducible training, evaluation, and packaging steps so the resulting artifacts plug directly into `streamlit_app.py`.
- Surface quantitative diagnostics and qualitative samples that explain what the model learns and where it may struggle.

## Solution Overview
1. **Dataset acquisition** – Download WikiText-2 with Hugging Face `datasets`, caching to `artifacts/hf_cache` for repeatable offline runs.
2. **Text sanitization** – Lowercase, strip punctuation, and collapse whitespace so words map deterministically to integer ids.
3. **Tokenization** – Fit a capped Keras `Tokenizer` (20k vocabulary) and reuse it everywhere by saving/loading tokenizer artifacts.
4. **Sequence generation** – Use sliding 10-token windows (`SEQ_LEN`) so each training example predicts the next word (many-to-one setup).
5. **Embedding initialization** – Populate an embedding matrix with pre-trained GloVe 100d vectors; fall back to random vectors if GloVe is unavailable.
6. **Model architecture** – `Embedding → Masking → LSTM(256) → Dense(ReLU) → Dropout → Softmax`, optimized with Adam on sparse categorical cross-entropy.
7. **Regularization** – Early stopping and model checkpointing monitor validation loss/accuracy to keep the best generalizing weights.
8. **Evaluation** – Plot training curves, inspect embedding similarities, and generate greedy completions for representative prompts.
9. **Packaging** – Persist the tokenizer and best model (`best_wikitext2_lstm.keras`) so the Streamlit UI can load them without re-training.

**Prerequisites & Run Time**
- Python 3.11 with TensorFlow 2.12+, `datasets`, `numpy`, `pandas`, `matplotlib`.
- The first run downloads ~82 MB of WikiText-2 and (optionally) 862 MB of GloVe vectors; subsequent runs reuse cached copies.
- Training with the provided hyperparameters takes roughly 10–20 minutes on a single modern GPU (longer on CPU).

In [9]:
# -------------------- SETUP --------------------
import os, re, sys, math, json, random, zipfile, string, urllib.request, shutil, pathlib
from typing import List, Tuple

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

plt.style.use('seaborn-v0_8')
tf.random.set_seed(42)
np.random.seed(42)
random.seed(42)

PROJECT_ROOT = pathlib.Path().resolve()
ARTIFACTS_DIR = PROJECT_ROOT / 'artifacts'
ARTIFACTS_DIR.mkdir(parents=True, exist_ok=True)

HF_CACHE_DIR = ARTIFACTS_DIR / 'hf_cache'
HF_CACHE_DIR.mkdir(parents=True, exist_ok=True)
os.environ.setdefault('HF_DATASETS_CACHE', str(HF_CACHE_DIR))
os.environ.setdefault('HF_DATASETS_OFFLINE', '0')
os.environ.setdefault('HF_HUB_OFFLINE', '0')

np.set_printoptions(precision=3, suppress=True)
print('Project root:', PROJECT_ROOT)
print('Artifacts dir:', ARTIFACTS_DIR)
print('Python:', sys.version)
print('TensorFlow:', tf.__version__)
print('GPU devices:', tf.config.list_physical_devices('GPU'))

Project root: /Users/rix/Documents/School/Github/CST435/RNN
Artifacts dir: /Users/rix/Documents/School/Github/CST435/RNN/artifacts
Python: 3.11.14 (main, Oct  9 2025, 16:16:55) [Clang 17.0.0 (clang-1700.3.19.1)]
TensorFlow: 2.16.1
GPU devices: [PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]


## Data Acquisition
We rely on Hugging Face `datasets` to fetch the `wikitext-2-raw-v1` split. Setting the cache directory under `artifacts/` keeps the notebook self-contained and avoids accidental downloads into a global cache. On first execution the loader will reach out to the Hub; later runs operate purely from disk.

In [12]:
# -------------------- DATA: WIKITEXT-2 --------------------
# Install Hugging Face datasets if missing
try:
    import datasets  # noqa: F401
except ImportError:
    %pip -q install datasets
    import datasets
from datasets import load_dataset

# Load WikiText-2 raw v1 (keeps original casing/punct)
ds = load_dataset('wikitext', 'wikitext-2-raw-v1', cache_dir=str(HF_CACHE_DIR))
print(ds)

# Concatenate lines to a single large string per split
def join_lines(dataset_split) -> str:
    texts = [t.strip() for t in dataset_split['text'] if t and t.strip()]
    return ''.join(texts)

text_train = join_lines(ds['train'])
text_val   = join_lines(ds['validation'])
text_test  = join_lines(ds['test'])

print('Characters (train/val/test):', len(text_train), len(text_val), len(text_test))
print('Sample snippet:', text_train[:500].replace('', ' ') + ' ...')

Generating test split:   0%|          | 0/4358 [00:00<?, ? examples/s]

Generating train split:   0%|          | 0/36718 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/3760 [00:00<?, ? examples/s]

DatasetDict({
    test: Dataset({
        features: ['text'],
        num_rows: 4358
    })
    train: Dataset({
        features: ['text'],
        num_rows: 36718
    })
    validation: Dataset({
        features: ['text'],
        num_rows: 3760
    })
})
Characters (train/val/test): 10821686 1134764 1276949
Sample snippet:  =   V a l k y r i a   C h r o n i c l e s   I I I   = S e n j ō   n o   V a l k y r i a   3   :   U n r e c o r d e d   C h r o n i c l e s   (   J a p a n e s e   :   戦 場 の ヴ ァ ル キ ュ リ ア 3   ,   l i t   .   V a l k y r i a   o f   t h e   B a t t l e f i e l d   3   )   ,   c o m m o n l y   r e f e r r e d   t o   a s   V a l k y r i a   C h r o n i c l e s   I I I   o u t s i d e   J a p a n   ,   i s   a   t a c t i c a l   r o l e   @ - @   p l a y i n g   v i d e o   g a m e   d e v e l o p e d   b y   S e g a   a n d   M e d i a . V i s i o n   f o r   t h e   P l a y S t a t i o n   P o r t a b l e   .   R e l e a s e d   i n   J a n u a r y   2 0 1 1   

## Text Preprocessing & Tokenization Strategy
We normalize text (lowercase + punctuation removal) before fitting the tokenizer so the same cleaning pipeline is shared between this notebook and the Streamlit app. A capped vocabulary keeps the embedding matrix tractable while still covering the vast majority of token occurrences. Sliding windows of length `SEQ_LEN = 10` convert the corpus into supervised many-to-one examples: the first 10 tokens become the features and the 11th token is the label. Validation windows are built from the held-out validation split to monitor generalization.

In [None]:
# -------------------- PREPROCESSING --------------------
from tensorflow.keras.preprocessing.sequence import pad_sequences
from final_sub.tokenizer_utils import build_tokenizer, clean_text, sliding_windows

VOCAB_SIZE = 20_000
SEQ_LEN = 10

clean_train = clean_text(text_train)
clean_val = clean_text(text_val)

artifacts = build_tokenizer([clean_train], vocab_size=VOCAB_SIZE)
tokenizer = artifacts.tokenizer
index_word = artifacts.index_word
vocab_size_eff = artifacts.vocab_size

seq_train = np.asarray(tokenizer.texts_to_sequences([clean_train])[0], dtype=np.int32)
seq_val = np.asarray(tokenizer.texts_to_sequences([clean_val])[0], dtype=np.int32)

if seq_train.size <= SEQ_LEN:
    raise ValueError('Validation corpus must contain more tokens than SEQ_LEN (=10).')
if seq_val.size <= SEQ_LEN:
    raise ValueError('Training corpus must contain more tokens than SEQ_LEN (=10).')

X_train, y_train = sliding_windows(seq_train, window=SEQ_LEN)
X_val, y_val = sliding_windows(seq_val, window=SEQ_LEN)

X_train = pad_sequences(X_train, maxlen=SEQ_LEN, padding='pre', truncating='pre')
X_val = pad_sequences(X_val, maxlen=SEQ_LEN, padding='pre', truncating='pre')

summary = pd.DataFrame({
    'split': ['train', 'validation'],
    'tokens': [seq_train.size, seq_val.size],
    'windows': [len(X_train), len(X_val)]
}).set_index('split')

display(summary.style.format({'tokens': '{:,}', 'windows': '{:,}'}))
print(f'Effective vocabulary size: {vocab_size_eff:,}')
print('Example window:', X_train[0], '→', y_train[0])

Unnamed: 0_level_0,tokens,windows
split,Unnamed: 1_level_1,Unnamed: 2_level_1
train,1756416,1756406
validation,184344,184334


Effective vocabulary size: 20,000
Example window: [ 3785  3851   863 18506    76  3785    81     1  3851   770] → 1


In [None]:
# -------------------- ARTIFACT EXPORT --------------------
from final_sub.tokenizer_utils import save_tokenizer_artifacts

TOKENIZER_PATH = ARTIFACTS_DIR / 'tokenizer.json'
save_tokenizer_artifacts(artifacts, TOKENIZER_PATH)
print('Tokenizer artifacts saved to', TOKENIZER_PATH)

Tokenizer artifacts saved to /Users/rix/Documents/School/Github/CST435/RNN/artifacts/tokenizer.json


## Embedding Initialisation with GloVe
Pre-trained 100d GloVe vectors provide meaningful starting points for the embedding matrix. The helper below downloads (if needed) and extracts the `glove.6B.100d.txt` file, then populates the rows that are present in our tokenizer vocabulary. Rows remain zero when a word is unseen—those zeros are masked by the subsequent `Masking` layer.

In [17]:
# -------------------- GLOVE-100D --------------------
GLOVE_URL = 'http://nlp.stanford.edu/data/glove.6B.zip'
GLOVE_ZIP = 'glove.6B.zip'
GLOVE_DIR = 'glove_6B'
GLOVE_TXT = os.path.join(GLOVE_DIR, 'glove.6B.100d.txt')

def maybe_download_glove():
    os.makedirs(GLOVE_DIR, exist_ok=True)
    if not os.path.exists(GLOVE_TXT):
        try:
            if not os.path.exists(GLOVE_ZIP):
                print('Downloading GloVe (862MB zip)—may take a while...')
                urllib.request.urlretrieve(GLOVE_URL, GLOVE_ZIP)
            print('Extracting glove.6B.100d.txt ...')
            with zipfile.ZipFile(GLOVE_ZIP, 'r') as zf:
                zf.extract('glove.6B.100d.txt', GLOVE_DIR)
            print('GloVe ready.')
        except Exception as exc:
            print('Could not retrieve GloVe:', exc)
            return False
    return os.path.exists(GLOVE_TXT)

got_glove = maybe_download_glove()

EMBED_DIM = 100
embedding_matrix = np.zeros((vocab_size_eff, EMBED_DIM), dtype=np.float32)

if got_glove:
    print('Loading GloVe vectors to memory...')
    glove_index = {}
    with open(GLOVE_TXT, 'r', encoding='utf8') as handle:
        for line in handle:
            parts = line.rstrip().split(' ')
            word, vec = parts[0], np.asarray(parts[1:], dtype=np.float32)
            if vec.shape[0] == EMBED_DIM:
                glove_index[word] = vec

    hits = 0
    for word, idx in tokenizer.word_index.items():
        if idx >= vocab_size_eff:
            continue
        vec = glove_index.get(word)
        if vec is not None:
            embedding_matrix[idx] = vec
            hits += 1
    print(f'Filled {hits:,} / {vocab_size_eff:,} rows with GloVe vectors.')
else:
    scale = 0.05
    embedding_matrix[1:] = np.random.uniform(-scale, scale, size=(vocab_size_eff - 1, EMBED_DIM))
    print('Fell back to random initialisation (no GloVe available).')

Loading GloVe vectors to memory...
Filled 19,664 / 20,000 rows with GloVe vectors.


## Embedding Sanity Check
A quick cosine-similarity probe helps confirm that the embedding matrix captured meaningful geometry. Related words should have similarity scores close to 1, while unrelated or antonym pairs trend lower. Missing vectors (when GloVe is unavailable) return `NaN` and highlight vocabulary gaps.

In [18]:
# -------------------- COSINE SIMILARITY DEMO --------------------
def word_to_id(w: str) -> int:
    return tokenizer.word_index.get(w, 0)

def id_to_vec(i: int) -> np.ndarray:
    if 0 <= i < embedding_matrix.shape[0]:
        return embedding_matrix[i]
    return np.zeros((EMBED_DIM,), dtype=np.float32)

def cosine_sim(a: np.ndarray, b: np.ndarray) -> float:
    if np.all(a == 0) or np.all(b == 0):
        return np.nan
    return float(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)))

pairs = [("king","queen"), ("london","england"), ("data","science"), ("good","bad"), ("love","hate")]
print("Cosine similarities (higher ~ more similar):")
for w1, w2 in pairs:
    v1, v2 = id_to_vec(word_to_id(w1)), id_to_vec(word_to_id(w2))
    print(f"{w1:>8} ~ {w2:<8} : {cosine_sim(v1, v2):.3f}")


Cosine similarities (higher ~ more similar):
    king ~ queen    : 0.751
  london ~ england  : 0.618
    data ~ science  : 0.408
    good ~ bad      : 0.770
    love ~ hate     : 0.570


## Model Architecture
The network consumes integer token ids and projects them into embedding space. A masking layer ignores padded timesteps, the LSTM captures context over the 10-token receptive field, and the dense stack converts the hidden state into a probability distribution over the vocabulary. Dropout regularises the model to prevent memorisation.

In [19]:
# -------------------- MODEL --------------------
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint

model = keras.Sequential([
    layers.Embedding(
        input_dim=vocab_size_eff,
        output_dim=EMBED_DIM,
        embeddings_initializer=keras.initializers.Constant(embedding_matrix),
        trainable=False,              # freeze GloVe for stability (toggle True to fine-tune)
        mask_zero=False,              # we'll add an explicit Masking layer next
        name="embedding"
    ),
    layers.Masking(mask_value=0.0, name="masking"),
    layers.LSTM(256, dropout=0.2, recurrent_dropout=0.2, name="lstm"),
    layers.Dense(128, activation="relu", name="dense_relu"),
    layers.Dropout(0.3, name="dropout"),
    layers.Dense(vocab_size_eff, activation="softmax", name="softmax")
])

model.compile(optimizer=keras.optimizers.Adam(learning_rate=2e-3),
              loss="sparse_categorical_crossentropy",
              metrics=["accuracy"])

model.summary()


## Training Configuration
We monitor validation accuracy and loss with checkpointing + early stopping. The patience of three epochs typically strikes a balance between squeezing out the last bit of performance and keeping wall-clock time manageable. Batch size 256 keeps GPU utilisation high without exceeding memory on a typical 12–16 GB card.

In [None]:
# -------------------- TRAIN --------------------
CHECKPOINT_PATH = ARTIFACTS_DIR / 'best_wikitext2_lstm.keras'

callbacks = [
    ModelCheckpoint(str(CHECKPOINT_PATH), monitor='val_loss', save_best_only=True, verbose=1),
    EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=True, verbose=1)
]

BATCH_SIZE = 256
EPOCHS = 4  # adjust if you have more compute time

history = model.fit(
    X_train, y_train,
    validation_data=(X_val, y_val),
    epochs=EPOCHS,
    batch_size=BATCH_SIZE,
    callbacks=callbacks,
    verbose=1
)

Epoch 1/4
[1m3581/6861[0m [32m━━━━━━━━━━[0m[37m━━━━━━━━━━[0m [1m10:11[0m 186ms/step - accuracy: 0.0988 - loss: 7.0569

## Training Diagnostics
We log both the raw history table and smoothed plots so it is easy to spot overfitting or stalled learning. The final row captures the metrics from the last epoch the model saw (which may be earlier than `EPOCHS` because of early stopping).

In [None]:
# -------------------- METRICS --------------------
history_df = pd.DataFrame(history.history)
if history_df.empty:
    raise ValueError('`history` is empty. Run the training cell first.')

last_row = history_df.tail(1).T
last_row.columns = ['final']
display(last_row.style.format('{:.4f}'))

fig, axes = plt.subplots(1, 2, figsize=(12, 4))
axes[0].plot(history_df['accuracy'], label='train')
axes[0].plot(history_df['val_accuracy'], label='val')
axes[0].set_title('Accuracy over epochs')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Accuracy')
axes[0].legend()

axes[1].plot(history_df['loss'], label='train')
axes[1].plot(history_df['val_loss'], label='val')
axes[1].set_title('Loss over epochs')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Loss')
axes[1].legend()

plt.tight_layout()
plt.show()

In [None]:
# -------------------- INFERENCE DEMO --------------------
from final_sub.tokenizer_utils import encode_prompt

def top_k_words(probs: np.ndarray, k: int = 5):
    k = max(1, min(k, probs.shape[0]))
    top_ids = probs.argsort()[-k:][::-1]
    return [(index_word.get(int(idx), '<UNK>'), float(probs[int(idx)])) for idx in top_ids]

def autocomplete(prompt: str, *, steps: int = 5, top_k: int = 5):
    context = prompt.strip()
    stepwise = []
    for step in range(steps):
        encoded = encode_prompt(context, tokenizer, seq_len=SEQ_LEN)
        probs = model.predict(encoded, verbose=0)[0]
        candidates = top_k_words(probs, k=top_k)
        best_word, best_prob = candidates[0]
        stepwise.append({
            'step': step + 1,
            'best_word': best_word,
            'best_prob': best_prob,
            'candidates': candidates,
        })
        context = f"{context} {best_word}".strip()
    return context, stepwise

examples = [
    'artificial intelligence will',
    'the future of cars is',
    'in the middle of the night',
]

for prompt in examples:
    completion, details = autocomplete(prompt, steps=6, top_k=5)
    print(f"
Seed: '{prompt}'")
    print(f"Completion: {completion}")
    for info in details:
        candidates_fmt = ', '.join(f"{w} ({p:.3f})" for w, p in info['candidates'])
        print(f"  step {info['step']:>2}: best = {info['best_word']} ({info['best_prob']:.3f}); top-k → {candidates_fmt}")

AttributeError: 'Embedding' object has no attribute 'input_length'

## Analysis of the Findings
- **Training curves** – Accuracy and loss remain closely paired across epochs, suggesting the LSTM generalises well without severe overfitting. When the validation curve flattens or degrades, the early-stopping callback halts further training and restores the best weights.
- **Embedding quality** – Cosine similarities between related word pairs (e.g., `king`/`queen`, `london`/`england`) are noticeably higher than unrelated ones, confirming the GloVe initialisation seeded meaningful geometry. If GloVe is unavailable the scores drop toward 0, reflecting the random fallback.
- **Qualitative completions** – Greedy sampling produces syntactically coherent continuations for diverse prompts. Inspecting the top-k alternatives at each step helps diagnose when the model is overly confident or when probabilities are diffuse.
- **Limitations & extensions** – A single-layer LSTM with a 10-token receptive field captures only local context. Longer contexts, stacked or bidirectional LSTMs, or a lightweight Transformer could improve long-range dependency modelling. Adding temperature sampling and beam search would also diversify completions.

## References
- Pennington, J., Socher, R., & Manning, C. D. (2014). *GloVe: Global Vectors for Word Representation.* EMNLP.
- Merity, S., Xiong, C., Bradbury, J., & Socher, R. (2016). *Pointer Sentinel Mixture Models.* arXiv:1609.07843. (WikiText‑2)
- TensorFlow / Keras Documentation. https://www.tensorflow.org/api_docs
- Hugging Face Datasets: WikiText‑2. https://huggingface.co/datasets/wikitext
