<a href="https://colab.research.google.com/github/syedmahmoodiagents/transformers/blob/main/Attention_for_next_word.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
sentence = ["The", "cat", "sat", "on", "the", "mat"]
vocab = {word.lower(): i for i, word in enumerate(set([w.lower() for w in sentence] + ["is", "sleeping"]))}


In [None]:
vocab

{'is': 0, 'sat': 1, 'the': 2, 'cat': 3, 'on': 4, 'sleeping': 5, 'mat': 6}

In [None]:
inv_vocab = {i: w for w, i in vocab.items()}

In [None]:
indices = torch.tensor([vocab[word.lower()] for word in sentence]).unsqueeze(0)  # (1, seq_len)

In [None]:
class NextWordModel(nn.Module):
    def __init__(self, vocab_size, embed_dim=16, num_heads=1):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.attn = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, batch_first=True)
        self.fc = nn.Linear(embed_dim, vocab_size)

    def forward(self, x):
        embeds = self.embedding(x)  # (batch, seq_len, embed_dim)
        Q = K = V = embeds
        attn_output, attn_weights = self.attn(Q, K, V)
        logits = self.fc(attn_output[:, -1, :])  # only last token predicts next word
        return logits, attn_weights


In [None]:
model = NextWordModel(len(vocab))

In [None]:
indices

tensor([[2, 3, 1, 4, 2, 6]])

In [None]:
logits, attn_weights = model(indices)

In [None]:
probs = F.softmax(logits, dim=-1)

In [None]:
pred_idx = torch.argmax(probs, dim=-1).item()

In [None]:
pred_idx

3

In [None]:
inv_vocab[pred_idx]

'cat'

# Using the Training loops

In [None]:
sentence = ["The", "cat", "sat", "on", "the", "mat"]

# Add possible next words to vocab
vocab = {word.lower(): i for i, word in enumerate(set([w.lower() for w in sentence] + ["is", "sleeping"]))}
inv_vocab = {i: w for w, i in vocab.items()}

# let's predict "is" after the sentence
target_word = "is"

# Encode inputs and targets
indices = torch.tensor([vocab[word.lower()] for word in sentence]).unsqueeze(0)  # (1, seq_len)
target_idx = torch.tensor([vocab[target_word]])  # (1,)

In [None]:
class TinyNextWordModel(nn.Module):
    def __init__(self, vocab_size, embed_dim=16, num_heads=1):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.attn = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, batch_first=True)
        self.fc = nn.Linear(embed_dim, vocab_size)

    def forward(self, x):
        embeds = self.embedding(x)  # (batch, seq_len, embed_dim)
        Q = K = V = embeds
        attn_output, attn_weights = self.attn(Q, K, V)
        logits = self.fc(attn_output[:, -1, :])  # predict next word from last token
        return logits, attn_weights



In [None]:
model = TinyNextWordModel(len(vocab))
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
loss_fn = nn.CrossEntropyLoss()

In [None]:
epochs = 20
for epoch in range(epochs):
    logits, _ = model(indices)
    loss = loss_fn(logits, target_idx)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if (epoch + 1) % 5 == 0:
        print(f"Epoch {epoch+1}/{epochs} - Loss: {loss.item():.4f}")

Epoch 5/20 - Loss: 0.8972
Epoch 10/20 - Loss: 0.0079
Epoch 15/20 - Loss: 0.0000
Epoch 20/20 - Loss: 0.0000


In [None]:
logits, attn_weights = model(indices)

In [None]:
probs = F.softmax(logits, dim=-1)

In [None]:
pred_idx = torch.argmax(probs, dim=-1).item()

In [None]:
print("Predicted next word:", inv_vocab[pred_idx])

Predicted next word: is
