$$\Large\boxed{\text{AME 5202 Deep Learning, Even Semester 2026}}$$

$$\large\text{Theme}: \underline{\text{Training a simple single-head attention model}}$$

---

Load essential libraries

---

In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import random
torch.manual_seed(0)
import matplotlib.pyplot as plt
plt.style.use('dark_background')
%matplotlib inline
import sys

torch.manual_seed(0)

<torch._C.Generator at 0x24a9fd306f0>

---

Mount Google Drive folder if running Google Colab

---

In [2]:
## Mount Google drive folder if running in Colab
if('google.colab' in sys.modules):
    from google.colab import drive
    drive.mount('/content/drive', force_remount = True)
    DIR = '/content/drive/MyDrive/Colab Notebooks/MAHE/MSIS Coursework/EvenSem2026MAHE'
    DATA_DIR = DIR+'/Data/'
else:
    DATA_DIR = 'Data/'

---

Setup vocabulary including polysemous words (same sound different meaning)

---

In [2]:
# Setup vocabulary
vocab = [
    "[PAD]", "[MASK]", "[UNK]",

    "I", "you", "we", "they",

    # Polysemous words
    "bear",        # animal / tolerate
    "run",         # move / operate
    "bank",        # river / finance
    "charge",      # legal / electrical

    # Nouns
    "river", "road", "field", "court", "battery",
    "money", "load", "power", "side", "shore", "swam",

    # Function words
    "to", "the", "a", "other", "across", "with", "in", "on", "of", "get", "cannot"
]

vocab_size = len(vocab)
print(f'Vocabulary size is {vocab_size}')

Vocabulary size is 33


---

Create word-to-index and index-to-word dictionaries

---

In [3]:
# word-to-index and index-to-word dictionarties
word_to_idx = {w: i for i, w in enumerate(vocab)}
idx_to_word = {i: w for w, i in word_to_idx.items()}
print(f'word-to-index dictionarty:\n {word_to_idx}')
print(f'index-to-word dictionarty:\n {idx_to_word}')

word-to-index dictionarty:
 {'[PAD]': 0, '[MASK]': 1, '[UNK]': 2, 'I': 3, 'you': 4, 'we': 5, 'they': 6, 'bear': 7, 'run': 8, 'bank': 9, 'charge': 10, 'river': 11, 'road': 12, 'field': 13, 'court': 14, 'battery': 15, 'money': 16, 'load': 17, 'power': 18, 'side': 19, 'shore': 20, 'swam': 21, 'to': 22, 'the': 23, 'a': 24, 'other': 25, 'across': 26, 'with': 27, 'in': 28, 'on': 29, 'of': 30, 'get': 31, 'cannot': 32}
index-to-word dictionarty:
 {0: '[PAD]', 1: '[MASK]', 2: '[UNK]', 3: 'I', 4: 'you', 5: 'we', 6: 'they', 7: 'bear', 8: 'run', 9: 'bank', 10: 'charge', 11: 'river', 12: 'road', 13: 'field', 14: 'court', 15: 'battery', 16: 'money', 17: 'load', 18: 'power', 19: 'side', 20: 'shore', 21: 'swam', 22: 'to', 23: 'the', 24: 'a', 25: 'other', 26: 'across', 27: 'with', 28: 'in', 29: 'on', 30: 'of', 31: 'get', 32: 'cannot'}


---

Create training data

---

In [4]:
# Training data
training_data = [
    (["I", "swam", "across", "the", "river", "to", "the", "other", "[MASK]"], "shore"),
    (["they", "went", "to", "the", "bank", "to", "get", "[MASK]"], "money"),
    (["I", "saw", "a", "bear", "in", "the", "[MASK]"], "field"),
    (["I", "cannot", "bear", "the", "[MASK]"], "load"),
    (["they", "run", "across", "the", "[MASK]"], "field"),
    (["the", "battery", "can", "run", "with", "[MASK]"], "power"),
    (["the", "court", "will", "charge", "them", "with", "[MASK]"], "money"),
    (["the", "battery", "has", "a", "charge", "of", "[MASK]"], "power"),
]

---

A simple encoder function

---

In [5]:
# Encoder function
def encode(sentence):
    return torch.tensor([word_to_idx.get(w, word_to_idx["[UNK]"]) for w in sentence])

---

Testing the encoder function

---

In [8]:
sentence = ["I", "swam", "across", "the", "river", "to", "the", "other", "[MASK]"]
input_ids = encode(sentence)
print(input_ids)

tensor([ 3, 21, 26, 23, 11, 22, 23, 25,  1])


---

An example demonstrating how to create embeddings which are random and trainable vector representations of words that are stored in a simple lookup table

---

In [9]:
embedding_size = 8
embed = nn.Embedding(vocab_size, embedding_size)
# The initial random embeddings matrix corresponding to all
# words in the vocabulary
print(embed.weight.data)
#print('----')

# Extract the embeddings of the words in the sentence
print(embed(input_ids))

tensor([[-1.1258e+00, -1.1524e+00, -2.5058e-01, -4.3388e-01,  8.4871e-01,
          6.9201e-01, -3.1601e-01, -2.1152e+00],
        [ 3.2227e-01, -1.2633e+00,  3.4998e-01,  3.0813e-01,  1.1984e-01,
          1.2377e+00,  1.1168e+00, -2.4728e-01],
        [-1.3527e+00, -1.6959e+00,  5.6665e-01,  7.9351e-01,  5.9884e-01,
         -1.5551e+00, -3.4136e-01,  1.8530e+00],
        [ 7.5019e-01, -5.8550e-01, -1.7340e-01,  1.8348e-01,  1.3894e+00,
          1.5863e+00,  9.4630e-01, -8.4368e-01],
        [-6.1358e-01,  3.1593e-02, -4.9268e-01,  2.4841e-01,  4.3970e-01,
          1.1241e-01,  6.4079e-01,  4.4116e-01],
        [-1.0231e-01,  7.9244e-01, -2.8967e-01,  5.2507e-02,  5.2286e-01,
          2.3022e+00, -1.4689e+00, -1.5867e+00],
        [-6.7309e-01,  8.7283e-01,  1.0554e+00,  1.7784e-01, -2.3034e-01,
         -3.9175e-01,  5.4329e-01, -3.9516e-01],
        [-4.4622e-01,  7.4402e-01,  1.5210e+00,  3.4105e+00, -1.5312e+00,
         -1.2341e+00,  1.8197e+00, -5.5153e-01],
        [-5.6925

---

Creating linear operators that are random and trainable (equivalent to weights)

---

In [10]:
W = nn.Linear(8, 33)
#print(W.weight.data)

# Training sample-0
input_data = training_data[0][0]
print(input_data)
print('-----')

# Encoded input data
print(encode(input_data))
print('-----')

# Embeddings corresponding to the tokens in the input data
X = embed(encode(input_data))
print(X)
print('-----')

# The linear operator W applied to the embeddings of the [MASK] token
print(W(X[-1]))

['I', 'swam', 'across', 'the', 'river', 'to', 'the', 'other', '[MASK]']
-----
tensor([ 3, 21, 26, 23, 11, 22, 23, 25,  1])
-----
tensor([[ 0.7502, -0.5855, -0.1734,  0.1835,  1.3894,  1.5863,  0.9463, -0.8437],
        [ 0.3140,  0.2133, -0.1201,  0.3605, -0.3140, -1.0787,  0.2408, -1.3962],
        [ 1.1711,  0.0975,  0.9634,  0.8403, -1.2537,  0.9868, -0.4947, -1.2830],
        [-1.0787, -0.7209,  1.4708,  0.2756,  0.6668, -0.9944, -1.1894, -1.1959],
        [-0.2460,  2.3025, -1.8817, -0.0497, -1.0450, -0.9565,  0.0335,  0.7101],
        [-0.0661, -0.3584, -1.5616, -0.3546,  1.0811,  0.1315,  1.5735,  0.7814],
        [-1.0787, -0.7209,  1.4708,  0.2756,  0.6668, -0.9944, -1.1894, -1.1959],
        [-0.0744, -1.0922,  0.3920,  0.5945,  0.6623, -1.2063,  0.6074, -0.5472],
        [ 0.3223, -1.2633,  0.3500,  0.3081,  0.1198,  1.2377,  1.1168, -0.2473]],
       grad_fn=<EmbeddingBackward0>)
-----
tensor([ 0.5514,  0.2155,  0.0192,  0.0889,  0.3303,  0.8560, -0.2483, -0.3949,
        -

---

A tiny transformer class implementing the single-head self-attention model

---

In [6]:
# Tiny transformer class
class TinyTransformer(nn.Module):
  def __init__(self, vocab_size, d_model):
    super().__init__()

    # Initialize embeddings for the words in the vocabulary
    self.embed = nn.Embedding(vocab_size, d_model)

    # Initialize the query, key, and value linear operators
    self.W_Q = nn.Linear(d_model, d_model, bias = False)
    self.W_K = nn.Linear(d_model, d_model, bias = False)
    self.W_V = nn.Linear(d_model, d_model, bias = False)

    # Initialize the linear operator for the output layer
    self.output = nn.Linear(d_model, vocab_size)
        
  def forward(self, input_ids, mask_index):
    # Extract embeddings for the words in the sentence
    X = self.embed(input_ids)                 

    # Calculate the query, key, and value representations of the words
    Q = self.W_Q(X) # same as torch.matmul(X, W_Q.data.weights)
    K = self.W_K(X)
    V = self.W_V(X)

    # Calculate pairwise scaled self similarities a.k.a. the attention scores
    S = F.softmax((Q @ K.T)/math.sqrt(K.size(-1)), dim = -1)

    # Calculate the updated embeddings of the words weighted using the attention scores
    X_context = S @ V   # Y = Self-Attention(X)            

    # Extract the updated embedding for the missing [MASK] word
    mask_embedding = X_context[mask_index]

    # Calculate the raw scores (also called logits) for the missing [MASK] word
    # w.r.t. all the other words in the vocabulary
    z = self.output(mask_embedding)

    return z, S

---

Define loss function

---

In [7]:
logits = torch.tensor([1.0, 0.5, 6.7, 9.0, -10], dtype = torch.float64)
target_idx = torch.tensor(2, dtype = torch.long)
loss = F.cross_entropy(logits, target_idx)
print(f'Loss = {loss}')
probs = F.softmax(logits)
print(probs)
loss = -torch.log(probs[target_idx])
print(f'Loss = {loss}')

Loss = 2.3960351717727737
tensor([3.0475e-04, 1.8484e-04, 9.1078e-02, 9.0843e-01, 5.0898e-09],
       dtype=torch.float64)
Loss = 2.3960351717727737


  probs = F.softmax(logits)


In [8]:
# Loss function
def loss_fn(logits, target_idx):
  probs = F.softmax(logits, dim = -1) # softmax-activated scores
  eps = 1e-09
  loss = -torch.log(probs[target_idx] + eps)
  #print(f'loss = {loss}')
  # Much quicker way of calculating the CCE loss using PyTorch
  #loss = F.cross_entropy(logits, torch.tensor(target_idx, dtype = torch.long))
  #print(f'loss = {loss}')
  return loss

---

Calculating the loss for a simple training sample

---

In [27]:
data = training_data[0]
print(data)
print('----')

# Call enocder function
input_ids = encode(data[0])
print(input_ids)
print('----')

# Get the position of the missing [MASK] word
mask_index = data[0].index("[MASK]")
print(mask_index)
print('----')

(['I', 'swam', 'across', 'the', 'river', 'to', 'the', 'other', '[MASK]'], 'shore')
----
tensor([ 3, 21, 26, 23, 11, 22, 23, 25,  1])
----
8
----


In [30]:
data = training_data[0]
print(data)
print('----')

# Call enocder function
input_ids = encode(data[0])
print(input_ids)
print('----')

# Get the position of the missing [MASK] word
mask_index = data[0].index("[MASK]")
print(mask_index)
print('----')

# Get the target word index
target_idx = word_to_idx[data[1]]
print(target_idx)
print('----')

# Apply the tiny transformer model
model = TinyTransformer(vocab_size, 8)
logits, _ = model(input_ids, mask_index)
loss = loss_fn(logits, target_idx)
print(loss)


(['I', 'swam', 'across', 'the', 'river', 'to', 'the', 'other', '[MASK]'], 'shore')
----
tensor([ 3, 21, 26, 23, 11, 22, 23, 25,  1])
----
8
----
20
----
tensor(3.6908, grad_fn=<NegBackward0>)


---

Training for the parameters of the model

---

In [9]:
model = TinyTransformer(vocab_size = vocab_size, d_model = 8)
model.parameters()

<generator object Module.parameters at 0x0000024A9B47E180>

In [None]:
# Initialize model
model = TinyTransformer(vocab_size = vocab_size, d_model = 8)

# Initialize optimizer
optimizer = torch.optim.Adam(model.parameters(), lr = 0.05)

# Optimization loop
num_epochs = 500
for epoch in range(num_epochs):
  loss_epoch = 0.0
  random.shuffle(training_data)

  for sentence, target_word in training_data:
    input_ids = encode(sentence)
    mask_index = sentence.index("[MASK]")
    target_id = word_to_idx[target_word]
    
    # Zero out the gradients
    optimizer.zero_grad()

    # Forward propagation
    logits, _ = model(input_ids, mask_index)
    loss = loss_fn(logits, target_id)

    # Backward propagation and optimization
    loss.backward()
    optimizer.step()model = TinyTransformer(vocab_size = vocab_size, d_model = 8)
    loss_epoch += loss.item()
  
  # Print the loss every 10 epochs
  if epoch % 10 == 0:
    print(f"Epoch {epoch:3d} | Loss: {loss_epoch:.4f}")

Epoch   0 | Loss: 27.8455
Epoch  10 | Loss: 4.3334
Epoch  20 | Loss: 4.0115
Epoch  30 | Loss: 0.2210
Epoch  40 | Loss: 0.0207
Epoch  50 | Loss: 0.0107
Epoch  60 | Loss: 0.0068
Epoch  70 | Loss: 0.0048
Epoch  80 | Loss: 0.0036
Epoch  90 | Loss: 0.0028
Epoch 100 | Loss: 0.0022
Epoch 110 | Loss: 0.0018
Epoch 120 | Loss: 0.0015
Epoch 130 | Loss: 0.0013
Epoch 140 | Loss: 0.0011
Epoch 150 | Loss: 0.0010
Epoch 160 | Loss: 0.0008
Epoch 170 | Loss: 0.0007
Epoch 180 | Loss: 0.0007
Epoch 190 | Loss: 0.0006
Epoch 200 | Loss: 0.0005
Epoch 210 | Loss: 0.0005
Epoch 220 | Loss: 0.0004
Epoch 230 | Loss: 0.0004
Epoch 240 | Loss: 0.0004
Epoch 250 | Loss: 0.0003
Epoch 260 | Loss: 0.0003
Epoch 270 | Loss: 0.0003
Epoch 280 | Loss: 0.0003
Epoch 290 | Loss: 0.0002
Epoch 300 | Loss: 0.0002
Epoch 310 | Loss: 0.0002
Epoch 320 | Loss: 0.0002
Epoch 330 | Loss: 0.0002
Epoch 340 | Loss: 0.0002
Epoch 350 | Loss: 0.0002
Epoch 360 | Loss: 0.0001
Epoch 370 | Loss: 0.0001
Epoch 380 | Loss: 0.0001
Epoch 390 | Loss: 0.0001

---

Prediction function

---

In [33]:
def predict(sentence):
  input_ids = encode(sentence)
  mask_index = sentence.index("[MASK]")

  with torch.no_grad():
    logits, attention = model(input_ids, mask_index)
    pred_id = logits.argmax().item()

  return idx_to_word[pred_id], attention[mask_index]

---

Test the model

---

In [34]:
# Model testing
test_data = [
    ["I", "ran", "along", "the", "bank", "to", "the", "[MASK]"],
    ["the", "battery", "cannot", "run", "without", "[MASK]"],
    ["the", "court", "will", "charge", "the", "bank", "with", "[MASK]"],
    ["the", "bear", "on", "the", "bank", "could", "not", "bear", "the", "[MASK]"],
    ["I", "swam", "across", "the", "[MASK]", "to", "get", "the", "battery"]
]

for sentence in test_data:
  prediction, attention = predict(sentence)
  print("\nSentence:", " ".join(sentence))
  print("Prediction:", prediction)
  print("Attention from [MASK]:")
  for w, a in zip(sentence, attention):
    print(f"  {w:10s} → {a:.3f}")


Sentence: I ran along the bank to the [MASK]
Prediction: money
Attention from [MASK]:
  I          → 0.000
  ran        → 0.000
  along      → 0.000
  the        → 0.000
  bank       → 1.000
  to         → 0.000
  the        → 0.000
  [MASK]     → 0.000

Sentence: the battery cannot run without [MASK]
Prediction: load
Attention from [MASK]:
  the        → 0.000
  battery    → 0.000
  cannot     → 1.000
  run        → 0.000
  without    → 0.000
  [MASK]     → 0.000

Sentence: the court will charge the bank with [MASK]
Prediction: money
Attention from [MASK]:
  the        → 0.000
  court      → 0.023
  will       → 0.000
  charge     → 0.000
  the        → 0.000
  bank       → 0.977
  with       → 0.000
  [MASK]     → 0.000

Sentence: the bear on the bank could not bear the [MASK]
Prediction: money
Attention from [MASK]:
  the        → 0.000
  bear       → 0.000
  on         → 0.000
  the        → 0.000
  bank       → 1.000
  could      → 0.000
  not        → 0.000
  bear       → 0.000
