In [1]:
import torch
import math
import torch.nn.functional as F
from torch.nn import Parameter

Step 1: Prepare the vocabulary and input data

In [2]:
tokens = ['the', 'sun', 'rises', 'in', 'the', 'east']
vocab = sorted(set(tokens))

In [3]:
word2idx = {w: i for i, w in enumerate(vocab)}
idx2word = {i: w for w, i in word2idx.items()}
vocab_size = len(vocab)

In [4]:
vocab_size

5

In [5]:
word2idx

{'east': 0, 'in': 1, 'rises': 2, 'sun': 3, 'the': 4}

In [6]:
idx2word

{0: 'east', 1: 'in', 2: 'rises', 3: 'sun', 4: 'the'}

In [7]:
inputs = [['the', 'sun', 'rises'], ['sun', 'rises', 'in']]
targets = ['in', 'the']

In [8]:
X = torch.tensor([[word2idx[tok] for tok in seq] for seq in inputs])
Y = torch.tensor([word2idx[tok] for tok in targets])

In [9]:
X

tensor([[4, 3, 2],
        [3, 2, 1]])

In [10]:
Y

tensor([1, 4])

Step 2: Model configuration

In [11]:
embed_dim = 8
num_heads = 2
head_dim = embed_dim // num_heads
seq_len = X.shape[1]
batch_size = X.shape[0]

In [12]:
head_dim

4

In [13]:
seq_len

3

In [14]:
batch_size

2

Learnable parameters

In [15]:
embedding_matrix = Parameter(torch.randn(vocab_size, embed_dim))

In [16]:
embedding_matrix

Parameter containing:
tensor([[-0.8164,  0.3393, -0.1962,  0.0686,  1.8330, -1.5634, -0.6825,  1.2403],
        [-0.4157, -2.8158, -0.4880,  0.9893, -0.2380,  2.2371,  0.6282,  2.2796],
        [ 1.2535,  0.3401, -0.0283, -0.7157, -0.8639, -1.7989, -0.3398,  0.6687],
        [-0.0123, -0.3947, -0.0860,  0.5334,  2.1976, -0.2144, -0.2776, -0.4011],
        [ 0.5143,  1.3045,  1.6402, -0.1149, -0.5363, -0.8710,  1.6064, -0.6619]],
       requires_grad=True)

In [28]:
embedding_matrix.shape

torch.Size([5, 8])

In [17]:
W_q = Parameter(torch.randn(embed_dim, embed_dim))
W_k = Parameter(torch.randn(embed_dim, embed_dim))
W_v = Parameter(torch.randn(embed_dim, embed_dim))

In [18]:
W_q

Parameter containing:
tensor([[-0.0829, -0.1412, -3.1731, -1.4237,  1.5461,  0.6367,  0.0915, -1.0132],
        [-0.0592, -1.1897,  0.1291,  2.0549,  1.4506, -0.6190,  0.3577, -0.3925],
        [ 0.2888, -1.0396, -0.3008,  0.9865,  2.4813, -1.2502,  0.1523, -0.0751],
        [-1.0965,  0.9943, -0.2630,  0.4882,  0.5478,  0.4314,  0.4888,  1.3243],
        [ 1.4615,  0.1642, -0.3472, -2.4003,  1.5152,  0.1648,  1.0022, -0.6238],
        [ 0.3859,  0.4654, -0.1196,  0.8294,  0.3183,  0.5795, -0.4047,  0.3449],
        [ 0.8210, -0.5370, -0.8223,  1.0911,  0.3955, -0.8302, -1.6158,  0.8659],
        [ 0.2091,  0.7154,  1.6497,  2.4070,  1.2231,  0.2212,  0.2300, -0.0381]],
       requires_grad=True)

In [19]:
W_q.shape

torch.Size([8, 8])

In [20]:
W1 = Parameter(torch.randn(embed_dim, embed_dim))
b1 = Parameter(torch.zeros(embed_dim))
W2 = Parameter(torch.randn(embed_dim, embed_dim))
b2 = Parameter(torch.zeros(embed_dim))
W_out = Parameter(torch.randn(embed_dim, vocab_size))
b_out = Parameter(torch.zeros(vocab_size))

In [22]:
pos_embedding = Parameter(torch.randn(seq_len, embed_dim))

In [23]:
optimizer = torch.optim.Adam([
    embedding_matrix, pos_embedding, W_q, W_k, W_v,
    W1, b1, W2, b2, W_out, b_out
], lr=0.01)

Training one row

In [24]:
optimizer.zero_grad()

In [25]:
# Embedding + positional encoding
embedded = embedding_matrix[X] + pos_embedding 

In [26]:
embedded

tensor([[[ 1.4510,  1.0351,  1.9332,  2.2591, -1.1639,  0.0343,  1.8987,
           0.1107],
         [ 0.2814, -2.5231,  0.2681,  1.0828,  1.8583, -0.5441, -0.9493,
          -0.2009],
         [ 0.4394,  1.7338, -0.9409, -2.2016, -0.7979, -3.4903, -0.5249,
           0.4646]],

        [[ 0.9244, -0.6640,  0.2071,  2.9074,  1.5701,  0.6909,  0.0147,
           0.3715],
         [ 1.5472, -1.7883,  0.3258, -0.1663, -1.2033, -2.1286, -1.0114,
           0.8689],
         [-1.2298, -1.4221, -1.4007, -0.4967, -0.1719,  0.5457,  0.4431,
           2.0755]]], grad_fn=<AddBackward0>)

In [27]:
embedded.shape

torch.Size([2, 3, 8])

In [29]:
# Project to Q, K, V
Q = embedded @ W_q
K = embedded @ W_k
V = embedded @ W_v

In [59]:
W_q.shape

torch.Size([8, 8])

In [30]:
Q

tensor([[[ -2.2061,  -2.3157,  -6.6248,   8.2316,   8.9131,  -2.8827,  -2.3213,
            3.3478],
         [  0.7008,   4.1779,  -1.7151, -11.2236,   0.0549,   2.6075,   3.2635,
           -0.0420],
         [ -0.8436,  -4.4766,   1.5845,   0.5003,  -2.3060,  -2.1825,   1.0087,
           -5.1491]],

        [[ -0.5143,   4.1722,  -3.8729,  -3.3425,   5.6317,   2.7242,   2.6552,
            2.4163],
         [ -2.9747,   1.3817,  -2.2572,  -3.5271,  -1.3228,   1.2132,   0.9596,
           -2.0028],
         [  1.0835,   4.3006,   7.3249,   3.5491,  -5.0849,   2.0132,  -1.7091,
            1.8517]]], grad_fn=<UnsafeViewBackward0>)

In [31]:
Q.shape

torch.Size([2, 3, 8])

In [32]:
# Reshape for multi-head attention
def reshape(x):
    return x.view(batch_size, seq_len, num_heads, head_dim).transpose(1, 2)

In [33]:
Qh, Kh, Vh = map(reshape, (Q, K, V))

In [34]:
Qh.shape

torch.Size([2, 2, 3, 4])

In [35]:
Qh

tensor([[[[ -2.2061,  -2.3157,  -6.6248,   8.2316],
          [  0.7008,   4.1779,  -1.7151, -11.2236],
          [ -0.8436,  -4.4766,   1.5845,   0.5003]],

         [[  8.9131,  -2.8827,  -2.3213,   3.3478],
          [  0.0549,   2.6075,   3.2635,  -0.0420],
          [ -2.3060,  -2.1825,   1.0087,  -5.1491]]],


        [[[ -0.5143,   4.1722,  -3.8729,  -3.3425],
          [ -2.9747,   1.3817,  -2.2572,  -3.5271],
          [  1.0835,   4.3006,   7.3249,   3.5491]],

         [[  5.6317,   2.7242,   2.6552,   2.4163],
          [ -1.3228,   1.2132,   0.9596,  -2.0028],
          [ -5.0849,   2.0132,  -1.7091,   1.8517]]]],
       grad_fn=<TransposeBackward0>)

In [36]:
# Causal mask to prevent attending to future tokens
mask = torch.tril(torch.ones(seq_len, seq_len)).unsqueeze(0).unsqueeze(0)

In [37]:
mask

tensor([[[[1., 0., 0.],
          [1., 1., 0.],
          [1., 1., 1.]]]])

In [38]:
# Compute scaled dot-product attention
scores = (Qh @ Kh.transpose(-2, -1)) / math.sqrt(head_dim)

In [39]:
scores

tensor([[[[ 25.2680, -11.2582,   9.3553],
          [  9.8371, -10.7304,  10.1255],
          [-13.9345,  11.5954, -15.3323]],

         [[ -8.7544, -12.1842,  31.9708],
          [-14.7860,   2.7654,   7.0745],
          [ 17.9632,   7.2769, -15.7331]]],


        [[[ -7.3364,   0.1837, -33.9196],
          [ -2.7744,  -8.3566, -22.2177],
          [ 16.7993,  -9.9005,  21.8704]],

         [[-13.9629,   6.5034,  19.9367],
          [  4.2783,   2.4653,  -4.0848],
          [  6.6419, -13.4783,  -6.3278]]]], grad_fn=<DivBackward0>)

In [40]:
scores = scores.masked_fill(mask == 0, float('-inf'))

In [41]:
scores

tensor([[[[ 25.2680,     -inf,     -inf],
          [  9.8371, -10.7304,     -inf],
          [-13.9345,  11.5954, -15.3323]],

         [[ -8.7544,     -inf,     -inf],
          [-14.7860,   2.7654,     -inf],
          [ 17.9632,   7.2769, -15.7331]]],


        [[[ -7.3364,     -inf,     -inf],
          [ -2.7744,  -8.3566,     -inf],
          [ 16.7993,  -9.9005,  21.8704]],

         [[-13.9629,     -inf,     -inf],
          [  4.2783,   2.4653,     -inf],
          [  6.6419, -13.4783,  -6.3278]]]], grad_fn=<MaskedFillBackward0>)

In [42]:
attn_weights = F.softmax(scores, dim=-1)

In [43]:
attn_weights

tensor([[[[1.0000e+00, 0.0000e+00, 0.0000e+00],
          [1.0000e+00, 1.1685e-09, 0.0000e+00],
          [8.1755e-12, 1.0000e+00, 2.0204e-12]],

         [[1.0000e+00, 0.0000e+00, 0.0000e+00],
          [2.3852e-08, 1.0000e+00, 0.0000e+00],
          [9.9998e-01, 2.2855e-05, 2.3220e-15]]],


        [[[1.0000e+00, 0.0000e+00, 0.0000e+00],
          [9.9625e-01, 3.7503e-03, 0.0000e+00],
          [6.2359e-03, 1.5824e-14, 9.9376e-01]],

         [[1.0000e+00, 0.0000e+00, 0.0000e+00],
          [8.5972e-01, 1.4028e-01, 0.0000e+00],
          [1.0000e+00, 1.8276e-09, 2.3298e-06]]]], grad_fn=<SoftmaxBackward0>)

In [44]:
# Apply attention weights to values
attn_output = attn_weights @ Vh

In [45]:
attn_output

tensor([[[[  1.6323,  -7.3257,  -4.4774,   1.5864],
          [  1.6323,  -7.3257,  -4.4774,   1.5864],
          [ -4.4277,  -0.3395,   3.6692,   3.1080]],

         [[ -1.3689,  -1.3359, -10.3603,  -2.7999],
          [  0.4137,  -3.0978,   0.9074,   2.0655],
          [ -1.3689,  -1.3359, -10.3601,  -2.7998]]],


        [[[ -0.8769,  -3.0010,   1.1872,   2.2387],
          [ -0.8901,  -3.0026,   1.1822,   2.2474],
          [ -1.7953,   2.4614,  -1.1720,   1.0276]],

         [[  0.0782,   1.9823,  -4.5615,   0.8238],
          [ -0.7249,   1.3781,  -4.1267,   0.4926],
          [  0.0782,   1.9823,  -4.5615,   0.8238]]]],
       grad_fn=<UnsafeViewBackward0>)

In [46]:
# Concatenate multi-head outputs
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, embed_dim)

In [47]:
attn_output

tensor([[[  1.6323,  -7.3257,  -4.4774,   1.5864,  -1.3689,  -1.3359, -10.3603,
           -2.7999],
         [  1.6323,  -7.3257,  -4.4774,   1.5864,   0.4137,  -3.0978,   0.9074,
            2.0655],
         [ -4.4277,  -0.3395,   3.6692,   3.1080,  -1.3689,  -1.3359, -10.3601,
           -2.7998]],

        [[ -0.8769,  -3.0010,   1.1872,   2.2387,   0.0782,   1.9823,  -4.5615,
            0.8238],
         [ -0.8901,  -3.0026,   1.1822,   2.2474,  -0.7249,   1.3781,  -4.1267,
            0.4926],
         [ -1.7953,   2.4614,  -1.1720,   1.0276,   0.0782,   1.9823,  -4.5615,
            0.8238]]], grad_fn=<ViewBackward0>)

In [48]:
attn_output.shape

torch.Size([2, 3, 8])

In [49]:
# Feedforward network
ffn = torch.relu(attn_output @ W1 + b1)
ffn = ffn @ W2 + b2

In [51]:
ffn

tensor([[[  5.5664,  13.5499,  27.6570,  63.4399, -43.8924,  -0.5320, -20.1907,
           35.3387],
         [  6.3338,  12.6149,  15.1894,  25.7221, -19.6118,  -7.8365,  -9.7928,
            8.9380],
         [ -1.2811, -14.2705,  11.0819,  23.9507, -33.0082,   4.6001,  -9.2054,
           14.9693]],

        [[  1.9444,  -7.8102,  13.5497,  24.8362, -32.8808,  -5.3498, -13.8637,
           10.4085],
         [  4.3597,  -2.8524,  13.8946,  24.0713, -28.3426,  -6.1110,  -9.9676,
           11.2889],
         [-11.5424, -23.9073,  -3.3235,  -0.7025, -18.3985,   9.8123, -11.0779,
            0.4239]]], grad_fn=<AddBackward0>)

In [52]:
ffn.shape

torch.Size([2, 3, 8])

In [50]:
# Take the output for the last token position
final_token = ffn[:, -1, :]

In [53]:
final_token

tensor([[ -1.2811, -14.2705,  11.0819,  23.9507, -33.0082,   4.6001,  -9.2054,
          14.9693],
        [-11.5424, -23.9073,  -3.3235,  -0.7025, -18.3985,   9.8123, -11.0779,
           0.4239]], grad_fn=<SliceBackward0>)

In [54]:
# Output projection to vocab
logits = final_token @ W_out + b_out
print(logits)

loss = F.cross_entropy(logits, Y)
print(loss)

tensor([[ 51.3055, -82.7576, -60.5026, -60.0071,  15.2986],
        [  9.0535, -58.4483,   1.8867, -34.0361,  19.8173]],
       grad_fn=<AddBackward0>)
tensor(67.0316, grad_fn=<NllLossBackward0>)


In [55]:
predicted_indices = torch.argmax(logits, dim=1)

In [56]:
predicted_words = [idx2word[i.item()] for i in predicted_indices]

In [57]:
predicted_words

['east', 'the']

In [58]:
#Query = job seeker’s resume (what they want).
#Keys = job postings (what’s available).
#Values = job descriptions (actual info you’d get).