In [20]:
import torch 
import torch.nn as nn 
import torch.nn.functional as F 
from tqdm import tqdm 
from pathlib import Path

In [21]:
# Hyperparams 
device = torch.device("cpu")
block_size = 64
batch_size = 12


In [22]:
# Data loader 
data_path = Path("data/tiny.txt")
assert data_path.exists(), "Create data/text.txt with some text"
text = data_path.read_text(encoding="utf-8")
chars = sorted(list(set(text)))
vocab_size = len(chars)
stoi = {ch:i for i,ch in enumerate(chars)}  #This builds a dictionary (mapping) from each character (ch) to a number (i).
itos = {i:ch for ch,i in stoi.items()} # This does the reverse of stoi: creates a dictionary that maps each number back to a character.

In [23]:
def encode(s):
    """
    This function takes in a string as an input and returns its numerical representation based on the stoi mappings 
    """
    return [stoi[c] for c in s]

def decode(l): 
    """
    Takes in a list of numbers and returns a string based in itos mappings 
    """
    return "".join(itos[i] for i in l)

In [24]:
pairs = encode(text)

pairs[:20]

[25, 1, 29, 22, 33, 21, 28, 1, 35, 21, 28, 33, 1, 42, 45, 55, 44, 45, 50, 43]

In [26]:
decode_pairs = decode(pairs)

decode_pairs[:20]

'I OFTEN WENT fishing'

In [5]:
data = torch.tensor(encode(text), dtype=torch.long)
n = len(data)
train_data = data[: int(0.9 * n)]
val_data = data[int(0.9 * n):]

In [28]:
def get_batch(split):
    src = train_data if split == "train" else val_data
    ix = torch.randint(len(src) - block_size, (batch_size,))
    x = torch.stack([src[i:i+block_size] for i in ix])
    y = torch.stack([src[i+1:i+block_size+1] for i in ix])
    return x.to(device), y.to(device)

In [29]:
batches = get_batch(split="train")

batches[1]

tensor([[40,  1, 55, 51, 49, 41, 56, 44, 45, 50, 43,  8,  1, 24, 51, 59,  1, 37,
         38, 51, 57, 56,  1, 56, 44, 41,  1, 56, 45, 49, 41,  1, 61, 51, 57,  1,
         43, 37, 58, 41,  1, 37,  0,  1, 48, 37, 54, 43, 41,  1, 39, 51, 50, 56,
         54, 45, 38, 57, 56, 45, 51, 50,  1, 56],
        [ 1, 61, 51, 57,  1, 59, 41, 54, 41,  1, 38, 51, 54, 50,  1, 59, 37, 55,
          0,  1, 52, 41, 54, 42, 51, 54, 49, 41, 40,  1, 38, 41, 39, 37, 57, 55,
         41,  1, 61, 51, 57,  1, 59, 37, 50, 56, 41, 40,  1, 55, 51, 49, 41, 56,
         44, 45, 50, 43,  8,  1, 24, 51, 59,  1],
        [54, 56, 44,  1, 37,  1, 48, 51, 56,  1, 56, 51,  1, 37,  1, 44, 51, 56,
         41, 48,  6,  1, 45, 55, 50, 65, 56,  1, 45, 56, 16, 65,  0,  1, 17, 55,
          1, 25,  1, 56, 37, 48, 47, 41, 40,  6,  1, 25,  1, 59, 54, 51, 56, 41,
          1, 56, 44, 41, 55, 41,  1, 56, 59, 51],
        [51,  1, 48, 51, 51, 47,  1, 37, 56,  1, 61, 51, 57, 54,  1, 44, 51, 56,
         41, 48,  1, 37, 55,  1, 25,  0,

In [1]:
import torch
import torch.nn.functional as F

In [2]:
# --- Step 1: Setup ---
# Suppose we have 4 tokens ("I", "love", "machine", "learning")
# and each token is represented by an embedding vector of size 8.
torch.manual_seed(0)
x = torch.randn(1, 4, 8)  # (batch=1, seq_len=4, embed_dim=8)

In [3]:
x

tensor([[[-1.1258, -1.1524, -0.2506, -0.4339,  0.8487,  0.6920, -0.3160,
          -2.1152],
         [ 0.3223, -1.2633,  0.3500,  0.3081,  0.1198,  1.2377,  1.1168,
          -0.2473],
         [-1.3527, -1.6959,  0.5667,  0.7935,  0.5988, -1.5551, -0.3414,
           1.8530],
         [ 0.7502, -0.5855, -0.1734,  0.1835,  1.3894,  1.5863,  0.9463,
          -0.8437]]])

In [4]:
# We'll use 2 heads
n_heads = 2
head_dim = x.size(-1) // n_heads  # 8 / 2 = 4

In [6]:
# Number of dime|nsions in each head 
head_dim 

4

In [7]:
# Define learnable weights for Q, K, V ---
W_q = torch.randn(8, 8)
W_k = torch.randn(8, 8)
W_v = torch.randn(8, 8)


In [8]:
W_q

tensor([[-6.1358e-01,  3.1593e-02, -4.9268e-01,  2.4841e-01,  4.3970e-01,
          1.1241e-01,  6.4079e-01,  4.4116e-01],
        [-1.0231e-01,  7.9244e-01, -2.8967e-01,  5.2507e-02,  5.2286e-01,
          2.3022e+00, -1.4689e+00, -1.5867e+00],
        [-6.7309e-01,  8.7283e-01,  1.0554e+00,  1.7784e-01, -2.3034e-01,
         -3.9175e-01,  5.4329e-01, -3.9516e-01],
        [-4.4622e-01,  7.4402e-01,  1.5210e+00,  3.4105e+00, -1.5312e+00,
         -1.2341e+00,  1.8197e+00, -5.5153e-01],
        [-5.6925e-01,  9.1997e-01,  1.1108e+00,  1.2899e+00, -1.4782e+00,
          2.5672e+00, -4.7312e-01,  3.3555e-01],
        [-1.6293e+00, -5.4974e-01, -4.7983e-01, -4.9968e-01, -1.0670e+00,
          1.1149e+00, -1.4067e-01,  8.0575e-01],
        [-9.3348e-02,  6.8705e-01, -8.3832e-01,  8.9182e-04,  8.4189e-01,
         -4.0003e-01,  1.0395e+00,  3.5815e-01],
        [-2.4600e-01,  2.3025e+00, -1.8817e+00, -4.9727e-02, -1.0450e+00,
         -9.5650e-01,  3.3532e-02,  7.1009e-01]])

In [9]:
W_k

tensor([[ 1.6459, -1.3602,  0.3446,  0.5199, -2.6133, -1.6965, -0.2282,  0.2800],
        [ 0.2469,  0.0769,  0.3380,  0.4544,  0.4569, -0.8654,  0.7813, -0.9268],
        [-0.2188, -2.4351, -0.0729, -0.0340,  0.9625,  0.3492, -0.9215, -0.0562],
        [-0.6227, -0.4637,  1.9218, -0.4025,  0.1239,  1.1648,  0.9234,  1.3873],
        [-0.8834, -0.4189, -0.8048,  0.5656,  0.6104,  0.4669,  1.9507, -1.0631],
        [-0.0773,  0.1164, -0.5940, -1.2439, -0.1021, -1.0335, -0.3126,  0.2458],
        [-0.2596,  0.1183,  0.2440,  1.1646,  0.2886,  0.3866, -0.2011, -0.1179],
        [ 0.1922, -0.7722, -1.9003,  0.1307, -0.7043,  0.3147,  0.1574,  0.3854]])

In [10]:
W_v

tensor([[ 0.9671, -0.9911,  0.3016, -0.1073,  0.9985, -0.4987,  0.7611,  0.6183],
        [ 0.3140,  0.2133, -0.1201,  0.3605, -0.3140, -1.0787,  0.2408, -1.3962],
        [-0.0661, -0.3584, -1.5616, -0.3546,  1.0811,  0.1315,  1.5735,  0.7814],
        [-1.0787, -0.7209,  1.4708,  0.2756,  0.6668, -0.9944, -1.1894, -1.1959],
        [-0.5596,  0.5335,  0.4069,  0.3946,  0.1715,  0.8760, -0.2871,  1.0216],
        [-0.0744, -1.0922,  0.3920,  0.5945,  0.6623, -1.2063,  0.6074, -0.5472],
        [ 1.1711,  0.0975,  0.9634,  0.8403, -1.2537,  0.9868, -0.4947, -1.2830],
        [ 0.9552,  1.2836, -0.6659,  0.5651,  0.2877, -0.0334, -1.0619, -0.1144]])

In [13]:
# Compute Q, K, V ---
Q = x @ W_q    # (1, 4, 8) x.W_q
K = x @ W_k    # (1, 4, 8) x.W_k
V = x @ W_v    # (1, 4, 8) x.W_v

In [14]:
Q

tensor([[[ 0.1102, -6.1774,  4.8199, -1.0106, -0.4241,  2.9541, -0.8527,
           0.8973],
         [-2.5697, -0.8284,  0.1135,  0.6763, -1.3704, -1.9123,  3.7348,
           3.1003],
         [ 2.0369,  5.1362,  1.1734,  3.8390, -4.2767, -7.0900,  3.0189,
           1.5741],
         [-3.6219, -1.3414,  1.4725,  1.7928, -2.2853,  4.3418,  1.6553,
           2.7115]]])

In [15]:
K

tensor([[[-2.9403,  3.5750,  1.2555, -1.9510,  3.9666,  1.2075,  0.3567,
          -1.3451],
         [-0.5890, -1.1137,  0.1614, -0.7459, -0.6012,  0.1583, -1.5152,
           1.6187],
         [-3.2272, -1.9414, -2.7186,  0.3052,  2.5244,  7.2225,  1.2088,
           1.9976],
         [-0.7440, -0.3622,  0.1995, -0.1397, -0.8188, -1.5034,  1.5917,
          -0.5069]]])

In [16]:
V

tensor([[[-3.8832, -1.7764,  1.2725, -1.0399, -0.9309,  1.8705,  1.5665,
           2.3719],
         [ 0.4719, -2.4329,  1.9300,  1.0527,  0.6714,  0.6640,  0.5529,
          -0.0912],
         [-1.5834,  4.5669, -1.8508, -0.3763,  0.3576,  3.7913, -4.4054,
           2.7140],
         [-0.2380, -2.9206,  3.4978,  1.6303,  0.7276,  0.3178,  0.9314,
           0.3603]]])

In [17]:
print("Q shape:", Q.shape)
print("K shape:", K.shape)
print("V shape:", V.shape)

Q shape: torch.Size([1, 4, 8])
K shape: torch.Size([1, 4, 8])
V shape: torch.Size([1, 4, 8])


In [None]:
# Split into heads 
# We reshape (batch, seq_len, embed_dim) → (batch, n_heads, seq_len, head_dim)
def split_heads(tensor):
    b, seq_len, embed_dim = tensor.size()
    return tensor.view(b, seq_len, n_heads, head_dim).transpose(1, 2)
    # Transpose so shape = (batch, heads, seq_len, head_dim)

In [19]:
Q_heads = split_heads(Q)
K_heads = split_heads(K)
V_heads = split_heads(V)

In [20]:
Q_heads

tensor([[[[ 0.1102, -6.1774,  4.8199, -1.0106],
          [-2.5697, -0.8284,  0.1135,  0.6763],
          [ 2.0369,  5.1362,  1.1734,  3.8390],
          [-3.6219, -1.3414,  1.4725,  1.7928]],

         [[-0.4241,  2.9541, -0.8527,  0.8973],
          [-1.3704, -1.9123,  3.7348,  3.1003],
          [-4.2767, -7.0900,  3.0189,  1.5741],
          [-2.2853,  4.3418,  1.6553,  2.7115]]]])

In [21]:
K_heads

tensor([[[[-2.9403,  3.5750,  1.2555, -1.9510],
          [-0.5890, -1.1137,  0.1614, -0.7459],
          [-3.2272, -1.9414, -2.7186,  0.3052],
          [-0.7440, -0.3622,  0.1995, -0.1397]],

         [[ 3.9666,  1.2075,  0.3567, -1.3451],
          [-0.6012,  0.1583, -1.5152,  1.6187],
          [ 2.5244,  7.2225,  1.2088,  1.9976],
          [-0.8188, -1.5034,  1.5917, -0.5069]]]])

In [22]:
V_heads

tensor([[[[-3.8832, -1.7764,  1.2725, -1.0399],
          [ 0.4719, -2.4329,  1.9300,  1.0527],
          [-1.5834,  4.5669, -1.8508, -0.3763],
          [-0.2380, -2.9206,  3.4978,  1.6303]],

         [[-0.9309,  1.8705,  1.5665,  2.3719],
          [ 0.6714,  0.6640,  0.5529, -0.0912],
          [ 0.3576,  3.7913, -4.4054,  2.7140],
          [ 0.7276,  0.3178,  0.9314,  0.3603]]]])

In [None]:
# --- Step 5: Compute attention scores for each head ---
scores = torch.matmul(Q_heads, K_heads.transpose(-2, -1)) / (head_dim ** 0.5) # Simply QKt 
# Q = 4x4 matrix 
# K = 4x4 matrix 
# Kt = 4X4 
# Therefore QKt will result to a 4x4 matrix output 

In [29]:

scores

tensor([[[[ -7.1925,   4.1732,  -0.8871,   1.6291],
          [  1.7086,   0.9751,   4.8995,   1.0701],
          [  3.1781,  -4.7970,  -9.2815,  -1.8389],
          [  2.1023,   1.2639,   5.4185,   1.6120]],

         [[  0.1869,   1.7336,  10.5136,  -2.9530],
          [ -5.2913,  -0.0598,  -3.2815,   4.1851],
          [-13.2826,  -0.2889, -27.6053,   9.0842],
          [ -3.4395,   1.9712,  16.5036,  -1.6980]]]])

In [30]:
# Apply softmax to get attention weights ---
attn_weights = F.softmax(scores, dim=-1)

In [31]:
attn_weights

tensor([[[[1.0679e-05, 9.2175e-01, 5.8467e-03, 7.2394e-02],
          [3.7994e-02, 1.8246e-02, 9.2370e-01, 2.0064e-02],
          [9.9308e-01, 3.4153e-04, 3.8531e-06, 6.5782e-03],
          [3.3786e-02, 1.4608e-02, 9.3091e-01, 2.0691e-02]],

         [[3.2739e-05, 1.5374e-04, 9.9981e-01, 1.4172e-06],
          [7.5512e-05, 1.4126e-02, 5.6347e-04, 9.8523e-01],
          [1.9327e-10, 8.4972e-05, 1.1639e-16, 9.9992e-01],
          [2.1818e-09, 4.8823e-07, 1.0000e+00, 1.2449e-08]]]])

In [32]:
# Multiply weights by V to get attention outputs ---
head_outputs = torch.matmul(attn_weights, V_heads)

In [33]:
head_outputs

tensor([[[[ 0.4085, -2.4272,  2.0214,  1.0861],
          [-1.6063,  4.0479, -1.5559, -0.3352],
          [-3.8577, -1.7841,  1.2874, -1.0216],
          [-1.6032,  4.0954, -1.5794, -0.3363]],

         [[ 0.3576,  3.7907, -4.4045,  2.7135],
          [ 0.7265,  0.3247,  0.9231,  0.3554],
          [ 0.7276,  0.3178,  0.9314,  0.3602],
          [ 0.3576,  3.7913, -4.4054,  2.7140]]]])

In [34]:
# Concatenate heads back together ---
# (batch, heads, seq_len, head_dim) → (batch, seq_len, embed_dim)
combined = head_outputs.transpose(1, 2).contiguous().view(1, 4, 8)

In [36]:
combined # 4x8 matrix 

tensor([[[ 0.4085, -2.4272,  2.0214,  1.0861,  0.3576,  3.7907, -4.4045,
           2.7135],
         [-1.6063,  4.0479, -1.5559, -0.3352,  0.7265,  0.3247,  0.9231,
           0.3554],
         [-3.8577, -1.7841,  1.2874, -1.0216,  0.7276,  0.3178,  0.9314,
           0.3602],
         [-1.6032,  4.0954, -1.5794, -0.3363,  0.3576,  3.7913, -4.4054,
           2.7140]]])

In [37]:
W_o = torch.randn(8, 8)
out = combined @ W_o

In [38]:
out

tensor([[[ -2.2239,   4.5809,   3.1000,  -8.3132,   8.8128, -10.4970,  11.6042,
           12.2652],
         [  2.0165,  -1.3627,  -2.3379,   4.4722,  -7.3122,   1.6044,   3.8696,
           -5.5359],
         [ -0.5414,  -7.4888,  -0.1208,  -4.1437,   2.8600,  -5.1260,  -1.8987,
            6.9527],
         [  2.9528,   2.7813,  -2.2634,   1.1013,  -0.6279,  -6.1177,  14.9080,
            3.5088]]])

In [39]:
W_o

tensor([[-0.3433,  1.5713,  0.1916,  0.3799, -0.1448,  0.6376, -0.2813, -1.3299],
        [-0.1420, -0.5341, -0.5234,  0.8615, -0.8870,  0.8388,  1.1529, -1.7611],
        [-1.4777, -1.7557,  0.0762, -1.0786,  1.4403, -0.1106,  0.5769, -0.1692],
        [-0.0640,  1.0384,  0.9068, -0.4755, -0.8707,  0.1447,  1.9029,  0.3904],
        [-0.0394, -0.8015, -0.4955, -0.3615,  0.5851, -1.1560, -0.1434, -0.1947],
        [-0.0856,  1.3945,  0.5969, -0.4828, -0.3661, -1.3271,  1.6953,  2.0655],
        [-0.2340,  0.7073,  0.5800,  0.2683, -2.0589,  0.5340, -0.5354, -0.8637],
        [-0.0235,  1.1717,  0.3987, -0.1987, -1.1559, -0.3167,  0.9403, -1.1470]])

## How to compute Q, K, V with one combined linear layer (step-by-step)

In [47]:
import torch.nn.functional as F
import math
import torch.nn as nn


In [41]:
torch.manual_seed(0)

<torch._C.Generator at 0x20abb15e930>

In [42]:
# --- toy settings ---
B = 1           # batch size
T = 4           # sequence length (tokens)
D = 8           # model dimension (embedding size)
H = 2           # number of heads
assert D % H == 0
d_head = D // H

In [44]:
# example input embeddings (B, T, D)
x = torch.randn(B, T, D)

In [45]:
x

tensor([[[-0.6136,  0.0316, -0.4927,  0.2484,  0.4397,  0.1124,  0.6408,
           0.4412],
         [-0.1023,  0.7924, -0.2897,  0.0525,  0.5229,  2.3022, -1.4689,
          -1.5867],
         [-0.6731,  0.8728,  1.0554,  0.1778, -0.2303, -0.3918,  0.5433,
          -0.3952],
         [-0.4462,  0.7440,  1.5210,  3.4105, -1.5312, -1.2341,  1.8197,
          -0.5515]]])

In [54]:
# Compute one combined linear of Q, K, V that outputs 3*D at once

# This means we have 
# Q = 4x8 matrix 
# K + 4x8 matrix 
# V = 4x8 matrix 

# We stack them together to get ONE BIG MATRIX of size = 4x24 

qkv_proj = nn.Linear(D, 3 * D, bias=False)

In [55]:
qkv_proj 

Linear(in_features=8, out_features=24, bias=False)

In [50]:
qkv = qkv_proj(x) 

In [None]:
qkv # This matrix is of size 4x24 

tensor([[[-0.2039, -0.2304, -0.0942, -0.3087,  0.0143,  0.2670, -0.0181,
          -0.2633,  0.1654,  0.3899, -0.3810, -0.1740,  0.0521,  0.0545,
           0.0235,  0.1031,  0.0125,  0.0571, -0.0644,  0.0663,  0.3658,
          -0.1877,  0.0694,  0.0232],
         [ 1.4472, -0.2227,  0.4260,  1.1066,  1.1998,  0.0841, -0.0277,
          -0.1952, -0.6675,  0.2054,  0.0737, -0.0966,  0.0340,  0.4915,
          -0.8958, -0.4052,  0.5921,  0.4196, -0.2111,  0.3721, -0.4131,
          -0.2653, -0.3723,  0.4390],
         [-0.4211,  0.7516,  0.7574, -0.3153, -0.6000, -0.2862,  0.6364,
          -0.0470, -0.0236,  0.1096,  0.4059,  0.4322, -0.5328,  0.1104,
           0.2557, -0.0593, -0.2273, -0.3729, -0.0600,  0.2089,  0.0102,
           0.3741,  0.3793, -0.3332],
         [-0.9990,  2.1407, -0.4576, -0.4705, -2.1723, -1.6044, -0.1510,
           0.9739,  0.8386,  0.6835,  0.5413, -0.7491, -1.8316, -1.2455,
          -0.2053,  0.2213, -1.1326, -1.4582, -0.2638, -0.7247, -0.3658,
          

In [52]:
print("qkv shape (combined):", qkv.shape)

qkv shape (combined): torch.Size([1, 4, 24])


In [57]:
# Split combined tensor into Q, K, V
Q_chunk, K_chunk, V_chunk = qkv.chunk(3, dim=-1)

In [58]:
Q_chunk

tensor([[[-0.2039, -0.2304, -0.0942, -0.3087,  0.0143,  0.2670, -0.0181,
          -0.2633],
         [ 1.4472, -0.2227,  0.4260,  1.1066,  1.1998,  0.0841, -0.0277,
          -0.1952],
         [-0.4211,  0.7516,  0.7574, -0.3153, -0.6000, -0.2862,  0.6364,
          -0.0470],
         [-0.9990,  2.1407, -0.4576, -0.4705, -2.1723, -1.6044, -0.1510,
           0.9739]]], grad_fn=<SplitBackward0>)

In [59]:
K_chunk

tensor([[[ 0.1654,  0.3899, -0.3810, -0.1740,  0.0521,  0.0545,  0.0235,
           0.1031],
         [-0.6675,  0.2054,  0.0737, -0.0966,  0.0340,  0.4915, -0.8958,
          -0.4052],
         [-0.0236,  0.1096,  0.4059,  0.4322, -0.5328,  0.1104,  0.2557,
          -0.0593],
         [ 0.8386,  0.6835,  0.5413, -0.7491, -1.8316, -1.2455, -0.2053,
           0.2213]]], grad_fn=<SplitBackward0>)

In [60]:
V_chunk

tensor([[[ 0.0125,  0.0571, -0.0644,  0.0663,  0.3658, -0.1877,  0.0694,
           0.0232],
         [ 0.5921,  0.4196, -0.2111,  0.3721, -0.4131, -0.2653, -0.3723,
           0.4390],
         [-0.2273, -0.3729, -0.0600,  0.2089,  0.0102,  0.3741,  0.3793,
          -0.3332],
         [-1.1326, -1.4582, -0.2638, -0.7247, -0.3658, -0.4818,  0.7437,
          -0.8007]]], grad_fn=<SplitBackward0>)

In [65]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class MultiHeadSelfAttention(nn.Module):
    def __init__(self, d_model: int, n_heads: int, dropout=0.0):
        super().__init__()
        assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_head = d_model // n_heads

        # Combined projection producing [Q | K | V] in one matmul:
        self.qkv_proj = nn.Linear(d_model, 3 * d_model, bias=False)
        # Final output projection (after concatenating heads)
        self.out_proj = nn.Linear(d_model, d_model, bias=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        """
        x: (B, T, D)
        mask: optional attention mask broadcastable to (B, 1, T, T) or (B, H, T, T)
        returns: out (B, T, D), and (optionally) attention probs (B, H, T, T)
        """
        B, T, D = x.shape
        assert D == self.d_model

        # 1) Combined linear -> (B, T, 3*D)
        qkv = self.qkv_proj(x)

        # 2) Split into Q, K, V -> each (B, T, D)
        q, k, v = qkv.chunk(3, dim=-1)

        # 3) Reshape to split heads:
        # from (B, T, D) -> (B, T, H, d_head) then permute -> (B, H, T, d_head)
        def split_heads(tensor):
            return tensor.view(B, T, self.n_heads, self.d_head).permute(0, 2, 1, 3)

        q = split_heads(q)   # (B, H, T, d_head)
        k = split_heads(k)   # (B, H, T, d_head)
        v = split_heads(v)   # (B, H, T, d_head)

        # 4) Scaled dot-product: scores = Q @ K^T / sqrt(d_head) -> (B, H, T, T)
        # K^T: (B, H, d_head, T)
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_head)

        # 5) Apply mask (if provided). mask==0 -> set to -inf so softmax yields 0.
        if mask is not None:
            # mask should be broadcastable to (B, H, T, T)
            scores = scores.masked_fill(mask == 0, float('-inf'))

        # 6) Softmax to get attention probabilities, then optional dropout
        attn_probs = F.softmax(scores, dim=-1)  # (B, H, T, T)
        attn_probs = self.dropout(attn_probs)

        # 7) Weighted sum of values: (B, H, T, T) @ (B, H, T, d_head) -> (B, H, T, d_head)
        attn_out = torch.matmul(attn_probs, v)

        # 8) Recombine heads: (B, H, T, d_head) -> (B, T, H, d_head) -> (B, T, D)
        attn_out = attn_out.permute(0, 2, 1, 3).contiguous().view(B, T, D)

        # 9) Final linear projection
        out = self.out_proj(attn_out)  # (B, T, D)

        # return both the output and the attention probs for inspection/debugging
        return out, attn_probs



In [67]:
B = 1
T = 4          # 4 tokens: "I love machine learning"
D = 8          # embedding dimension
H = 2          # heads
x = torch.randn(B, T, D)

In [68]:
mha = MultiHeadSelfAttention(d_model=D, n_heads=H, dropout=0.0)

mha

MultiHeadSelfAttention(
  (qkv_proj): Linear(in_features=8, out_features=24, bias=False)
  (out_proj): Linear(in_features=8, out_features=8, bias=False)
  (dropout): Dropout(p=0.0, inplace=False)
)

In [69]:
causal = torch.tril(torch.ones(T, T)).bool()
mask = causal.unsqueeze(0).unsqueeze(0) 

In [70]:
causal

tensor([[ True, False, False, False],
        [ True,  True, False, False],
        [ True,  True,  True, False],
        [ True,  True,  True,  True]])

In [71]:
mask

tensor([[[[ True, False, False, False],
          [ True,  True, False, False],
          [ True,  True,  True, False],
          [ True,  True,  True,  True]]]])

In [72]:
out, attn_probs = mha(x, mask=mask)

out

tensor([[[ 0.3435,  0.6913,  0.1346,  0.0889,  0.1877, -0.1585,  0.6864,
          -0.4301],
         [ 0.2950,  0.3847, -0.0949, -0.0579,  0.1814, -0.2265,  0.3393,
          -0.2377],
         [ 0.2599,  0.5260,  0.0767, -0.0978,  0.1644, -0.2744,  0.3204,
          -0.4330],
         [ 0.2502,  0.3490, -0.0122, -0.0971,  0.1564, -0.2560,  0.2445,
          -0.2967]]], grad_fn=<UnsafeViewBackward0>)

In [73]:
attn_probs

tensor([[[[1.0000, 0.0000, 0.0000, 0.0000],
          [0.5478, 0.4522, 0.0000, 0.0000],
          [0.3274, 0.2866, 0.3861, 0.0000],
          [0.2485, 0.2721, 0.2112, 0.2682]],

         [[1.0000, 0.0000, 0.0000, 0.0000],
          [0.4438, 0.5562, 0.0000, 0.0000],
          [0.3240, 0.3522, 0.3239, 0.0000],
          [0.2426, 0.2616, 0.2515, 0.2442]]]], grad_fn=<SoftmaxBackward0>)