# Predicting the Next Word: Deep dive into RNN

Data preparation

In [1]:
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

# Sample token list
rnn_tokens = ['the', 'sun', 'rises', 'in', 'the', 'east', 'the', 'sun', 'sets', 'in', 'the', 'west']
word2idx = {word: i for i, word in enumerate(set(rnn_tokens))}
idx2word = {i: word for word, i in word2idx.items()}

# Create training sequences
sequences = [(word2idx[rnn_tokens[i]], word2idx[rnn_tokens[i+1]]) for i in range(len(rnn_tokens) - 1)]

class WordDataset(Dataset):
    def __init__(self, pairs):
        self.pairs = pairs

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        x, y = self.pairs[idx]
        return torch.tensor(x), torch.tensor(y)

dataset = WordDataset(sequences)
loader = DataLoader(dataset, batch_size=2, shuffle=True)


Implementing RNN

In [2]:
class SimpleRNN:
    def __init__(self, vocab_size, embedding_dim=10, hidden_dim=16):
        self.vocab_size = vocab_size
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim

        # Embedding matrix (vocab_size x embedding_dim)
        self.embed = torch.randn(vocab_size, embedding_dim, requires_grad=True)

        # RNN weights
        self.Wxh = torch.randn(embedding_dim, hidden_dim, requires_grad=True)
        self.Whh = torch.randn(hidden_dim, hidden_dim, requires_grad=True)
        self.bh = torch.zeros(hidden_dim, requires_grad=True)

        # Output layer
        self.Why = torch.randn(hidden_dim, vocab_size, requires_grad=True)
        self.by = torch.zeros(vocab_size, requires_grad=True)

        # Track all parameters
        self.params = [self.embed, self.Wxh, self.Whh, self.bh, self.Why, self.by]

    def forward(self, x_idx, h_prev):
        # Get embedding for input word
        x_embed = self.embed[x_idx]  # (batch, embedding_dim)

        # RNN cell: h_t = tanh(Wxh * x + Whh * h + b)
        h_new = torch.tanh(x_embed @ self.Wxh + h_prev @ self.Whh + self.bh)

        # Output logits
        logits = h_new @ self.Why + self.by

        return logits, h_new

Training

In [3]:
# Initialize model
vocab_size = len(word2idx)
model = SimpleRNN(vocab_size)

# Optimizer
optimizer = torch.optim.SGD(model.params, lr=0.01)

# Loss function
loss_fn = torch.nn.CrossEntropyLoss()

# One training loop
for epoch in range(5):
    total_loss = 0
    for x_batch, y_batch in loader:
        h = torch.zeros(x_batch.size(0), model.hidden_dim)  # Initial hidden state

        logits, h_new = model.forward(x_batch, h)  # One step RNN

        loss = loss_fn(logits, y_batch)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    print(f"Epoch {epoch+1}, Loss: {total_loss:.4f}")

Epoch 1, Loss: 39.5475
Epoch 2, Loss: 38.8811
Epoch 3, Loss: 34.5613
Epoch 4, Loss: 30.8859
Epoch 5, Loss: 27.9686


In [37]:
model.vocab_size

7

In [39]:
model.hidden_dim

16

In [40]:
model.Whh

tensor([[ 0.8674,  1.3481,  0.5546, -0.6112,  1.0983,  0.4165,  1.5105,  0.8281,
          0.9538, -0.5975, -1.6062,  0.9321, -1.7609, -0.3320,  0.0941, -1.4093],
        [ 0.5549, -1.2707,  0.7994, -0.1094, -0.8315, -0.8924, -1.4619, -0.9835,
          0.8612, -0.6127, -0.6205,  2.7250, -1.0205,  2.9810, -1.0786, -0.2745],
        [ 1.4266, -0.9023, -0.6765, -0.7307,  0.1673, -1.2404, -0.9419, -1.2360,
         -1.5414,  2.1274,  1.9435, -1.6444, -0.6260, -0.8259,  2.3127,  0.9850],
        [ 0.9234, -0.4823, -0.9553, -1.3625, -1.5849, -2.3466, -0.3243,  0.3011,
         -0.2848, -1.4307, -0.4501, -0.2041, -0.1597, -0.8758,  0.0830,  0.1121],
        [-0.7074, -0.5669, -0.3143, -0.1969, -0.8108,  0.7616,  0.0605,  1.0848,
          2.1813, -1.6685,  0.4222,  1.0246,  1.0018,  0.8114,  1.7322,  0.7136],
        [ 0.1161, -1.0974,  0.9627,  0.0634,  0.4236,  1.9618, -0.0185, -0.7070,
         -1.4340,  1.3353, -2.8229,  0.5834,  0.0860, -0.1003, -0.7557, -0.7512],
        [ 0.2721, -0.2

generate text with step by step execution

In [4]:
def generate_text(start_word, model, word2idx, idx2word, max_len=10):
    model.eval = lambda: None  # dummy to simulate eval mode
    idx = word2idx[start_word]
    words = [start_word]
    h = torch.zeros(1, model.hidden_dim)

    for _ in range(max_len):
        logits, h = model.forward(torch.tensor([idx]), h)
        idx = torch.argmax(logits, dim=1).item()
        words.append(idx2word[idx])

    return ' '.join(words)

print(generate_text("the", model, word2idx, idx2word, max_len=10))

the rises west sets in west the rises west sets rises


## Step by step execution for deeper understanding

In [5]:
import torch
import torch.nn.functional as F

In [6]:
tokens = ['the', 'sun', 'rises', 'in', 'the', 'east']
word2idx = {word: i for i, word in enumerate(set(tokens))}
idx2word = {i: word for word, i in word2idx.items()}
vocab_size = len(word2idx)

In [8]:
print(word2idx)
print(idx2word)
print(vocab_size)

{'in': 0, 'the': 1, 'sun': 2, 'rises': 3, 'east': 4}
{0: 'in', 1: 'the', 2: 'sun', 3: 'rises', 4: 'east'}
5


In [9]:
sequences = [(word2idx[tokens[i]], word2idx[tokens[i + 1]]) for i in range(len(tokens) - 1)]
print("Training pairs (indices):", sequences)

Training pairs (indices): [(1, 2), (2, 3), (3, 0), (0, 1), (1, 4)]


In [10]:
x_batch = torch.tensor([sequences[0][0], sequences[1][0]])  # inputs: 'the', 'sun'
y_batch = torch.tensor([sequences[0][1], sequences[1][1]])  # targets: 'sun', 'rises'
print("\nx_batch (word indices):", x_batch)
print("y_batch (target indices):", y_batch)


x_batch (word indices): tensor([1, 2])
y_batch (target indices): tensor([2, 3])


In [11]:
embedding_dim = 5
hidden_dim = 4

In [12]:
embedding = torch.randn(vocab_size, embedding_dim)

In [13]:
embedding

tensor([[ 1.6945, -0.7766, -0.9094, -0.7208, -0.3058],
        [ 0.4009,  1.3930,  1.3562, -0.4178,  0.1735],
        [-0.2296,  1.0821,  0.8929,  1.3625, -0.9358],
        [ 1.2505, -0.3129,  0.5570,  0.9088, -0.1406],
        [ 0.7679, -0.7273, -0.5810,  0.8826, -0.7667]])

In [14]:
embedding.shape

torch.Size([5, 5])

In [15]:
# RNN weights
Wxh = torch.randn(embedding_dim, hidden_dim)
Whh = torch.randn(hidden_dim, hidden_dim)
bh = torch.zeros(hidden_dim)

In [16]:
Wxh

tensor([[ 0.4577, -0.8893,  0.4595, -0.5254],
        [-0.5251, -1.9444,  0.6279, -1.8165],
        [-0.0880, -0.3282,  0.5785, -2.1097],
        [ 0.3468, -0.2600,  0.6679,  0.9620],
        [-1.4054,  0.7063,  1.2270,  2.1211]])

In [17]:
Whh

tensor([[ 0.2006,  0.6899, -0.8724,  0.6135],
        [-1.2007,  0.1932,  0.3728,  0.0152],
        [ 0.7185,  0.0255,  1.3050, -0.7606],
        [-0.4692,  1.4785,  0.7028,  0.9181]])

In [18]:
bh

tensor([0., 0., 0., 0.])

In [19]:
# Output layer weights
Why = torch.randn(hidden_dim, vocab_size)
by = torch.zeros(vocab_size)

In [20]:
Why

tensor([[ 1.7660, -0.4987,  1.3216,  0.0227, -0.0946],
        [-0.2630, -0.2602, -0.5839, -0.3715,  0.1578],
        [ 1.9842, -0.6254,  0.9436,  0.3605,  0.7815],
        [ 2.0140,  0.0899, -0.9838, -1.4917, -0.6184]])

In [21]:
by

tensor([0., 0., 0., 0., 0.])

In [22]:
x_embed = embedding[x_batch]  # shape: (batch_size, embedding_dim)
print("\nEmbeddings:\n", x_embed)
print("Shape of embeddings:", x_embed.shape)


Embeddings:
 tensor([[ 0.4009,  1.3930,  1.3562, -0.4178,  0.1735],
        [-0.2296,  1.0821,  0.8929,  1.3625, -0.9358]])
Shape of embeddings: torch.Size([2, 5])


In [24]:
batch_size = x_batch.shape[0]
print("\nInitial batch size:\n", batch_size)
h_prev = torch.zeros(batch_size, hidden_dim)
print("\nInitial hidden state:\n", h_prev)


Initial batch size:
 2

Initial hidden state:
 tensor([[0., 0., 0., 0.],
        [0., 0., 0., 0.]])


In [25]:
h_t = torch.tanh(x_embed @ Wxh + h_prev @ Whh + bh)
print("\nUpdated hidden state:\n", h_t)
print("Shape of hidden state:", h_t.shape)


Updated hidden state:
 tensor([[-0.7841, -0.9972,  0.9444, -1.0000],
        [ 0.7763, -0.9967,  0.6923, -0.9997]])
Shape of hidden state: torch.Size([2, 4])


In [26]:
logits = h_t @ Why + by
print("\nLogits:\n", logits)
print("Shape of logits:", logits.shape)


Logits:
 tensor([[-1.2626, -0.0301,  1.4209,  2.1846,  1.2733],
        [ 0.9934, -0.6506,  3.2447,  2.1287,  0.9286]])
Shape of logits: torch.Size([2, 5])


In [27]:
loss = F.cross_entropy(logits, y_batch)
print("\nCross entropy loss:", loss.item())


Cross entropy loss: 1.5084009170532227


In [28]:
predicted_indices = torch.argmax(logits, dim=1)
predicted_words = [idx2word[i.item()] for i in predicted_indices]
print("\nPredicted next word indices:", predicted_indices)
print("Predicted words:", predicted_words)


Predicted next word indices: tensor([3, 2])
Predicted words: ['rises', 'sun']


In [31]:
print("Wxh (input to hidden weights):", Wxh.shape)
print("Whh (hidden to hidden weights):", Whh.shape)
print("bh (hidden bias):", bh.shape)
print("Why (hidden to output weights):", Why.shape)
print("by (output bias):", by.shape)

Wxh (input to hidden weights): torch.Size([5, 4])
Whh (hidden to hidden weights): torch.Size([4, 4])
bh (hidden bias): torch.Size([4])
Why (hidden to output weights): torch.Size([4, 5])
by (output bias): torch.Size([5])


## issues with RNNs

- In an RNN, gradients are propagated backward through time during training.
- This involves **repeated multiplications** of gradient values at each time step.

- If the recurrent weights are **small (<1)**, these repeated multiplications cause gradients to:
  → **Shrink exponentially** with each time step.
  → Eventually become **very close to zero** — this is known as the **vanishing gradient problem**.

- What this means:
  → **Earlier time steps receive almost no learning signal**.
  → The model struggles to **retain and learn long-term dependencies** in sequences.
  → RNNs end up relying mostly on **recent inputs**, forgetting information from earlier in the sequence.

## LSTM
How Do LSTMs Help?
LSTMs have cell states and gating mechanisms (forget, input, output) that control the flow of information.

This allows the network to retain important information and ignore unimportant details.

As a result, gradients don’t vanish/explode as easily, making LSTMs better at long-term memory.