In [1]:
import torch
import torch.nn as nn
import numpy as np

In [2]:
# Set numpy print mode to 2 decimal places and turn off scientific notation
np.set_printoptions(precision=2, suppress=True)
# Set torch mode to no grad
torch.set_grad_enabled(False)
# Set seed for reproducibility
torch.manual_seed(79)

<torch._C.Generator at 0x1d1a80bf270>

In [3]:
# We want to go through one layer of a transformer, starting from the input embeddings
# Make random tensor with size (seq_len, emb_dim)
seq_len = 5
emb_dim = 4
emb = torch.randn(seq_len, emb_dim)
print(emb.numpy())

[[ 0.61  0.23 -0.22  0.59]
 [-0.67  1.13  0.3   1.35]
 [-0.43 -1.38  1.01  1.45]
 [-0.03 -0.56 -0.61  2.2 ]
 [ 0.07  0.72  1.4   1.05]]


In [4]:
# Next is positional encoding
# Make a tensor with size (seq_len, emb_dim) with columns with 1 to 5 and rows that are all the same
pos = torch.arange(seq_len).unsqueeze(1).expand(seq_len, emb_dim).float()
print(pos.numpy())

[[0. 0. 0. 0.]
 [1. 1. 1. 1.]
 [2. 2. 2. 2.]
 [3. 3. 3. 3.]
 [4. 4. 4. 4.]]


In [5]:
# Add the positional encoding to the embeddings
emb_pos = emb + pos
print(emb_pos.numpy())

[[ 0.61  0.23 -0.22  0.59]
 [ 0.33  2.13  1.3   2.35]
 [ 1.57  0.62  3.01  3.45]
 [ 2.97  2.44  2.39  5.2 ]
 [ 4.07  4.72  5.4   5.05]]


In [6]:
# Next is layernorm
# Make a layernorm module
ln = nn.LayerNorm(emb_dim)
emb_pos_ln = ln(emb_pos)
print(emb_pos_ln.numpy())

[[ 0.91 -0.22 -1.55  0.85]
 [-1.51  0.76 -0.29  1.03]
 [-0.52 -1.37  0.75  1.14]
 [-0.24 -0.71 -0.75  1.7 ]
 [-1.51 -0.19  1.21  0.49]]


In [7]:
# Next up is multi-head attention
# Doing it manually, create W_q, W_k, W_v, W_o
# Make random tensors with size (emb_dim, emb_dim/2) for W_q, W_k, W_v, (emb_dim, emb_dim) for W_o
W_q_1 = torch.randn(emb_dim, emb_dim//2)
W_k_1 = torch.randn(emb_dim, emb_dim//2)
W_v_1 = torch.randn(emb_dim, emb_dim//2)
print("Wq1")
print(W_q_1.numpy())
print("Wk1")
print(W_k_1.numpy())
print("Wv1")
print(W_v_1.numpy())

W_q_2 = torch.randn(emb_dim, emb_dim//2)
W_k_2 = torch.randn(emb_dim, emb_dim//2)
W_v_2 = torch.randn(emb_dim, emb_dim//2)
print("Wq2")
print(W_q_2.numpy())
print("Wk2")
print(W_k_2.numpy())
print("Wv2")
print(W_v_2.numpy())

W_o = torch.randn(emb_dim, emb_dim)
print("Wo")
print(W_o.numpy())

# Calculate Q, K, V
Q1 = emb_pos_ln @ W_q_1
K1 = emb_pos_ln @ W_k_1
V1 = emb_pos_ln @ W_v_1
print("Q1")
print(Q1.numpy())
print("K1")
print(K1.numpy())
print("V1")
print(V1.numpy())

Q2 = emb_pos_ln @ W_q_2
K2 = emb_pos_ln @ W_k_2
V2 = emb_pos_ln @ W_v_2
print("Q2")
print(Q2.numpy())
print("K2")
print(K2.numpy())
print("V2")
print(V2.numpy())

Wq1
[[-0.74  0.7 ]
 [-1.3   1.73]
 [-0.86 -0.05]
 [-0.76 -0.09]]
Wk1
[[ 0.98 -0.18]
 [-1.44  0.51]
 [ 0.35  1.95]
 [ 1.3  -0.13]]
Wv1
[[-0.02 -0.41]
 [-0.56  0.38]
 [-0.82 -0.78]
 [ 0.23 -0.61]]
Wq2
[[ 0.2  -0.41]
 [ 0.62  1.26]
 [ 1.64  1.35]
 [-0.89  0.74]]
Wk2
[[-0.83  0.09]
 [-0.95 -2.35]
 [-0.59  1.18]
 [-2.3   2.21]]
Wv2
[[ 0.03 -0.8 ]
 [ 0.51  1.45]
 [-0.78 -0.59]
 [ 0.62 -1.37]]
Wo
[[ 0.12 -1.6  -0.37 -1.27]
 [ 0.22  1.31 -0.12 -0.3 ]
 [-0.19 -0.28  0.15  0.42]
 [-0.42 -0.23  1.29  0.22]]
Q1
[[ 0.3   0.25]
 [-0.41  0.18]
 [ 0.65 -2.87]
 [ 0.46 -1.51]
 [-0.06 -1.47]]
K1
[[ 1.77 -3.4 ]
 [-1.33 -0.03]
 [ 3.19  0.71]
 [ 2.72 -1.99]
 [-0.16  2.47]]
V1
[[ 1.56  0.23]
 [ 0.08  0.5 ]
 [ 0.41 -1.59]
 [ 1.39 -0.63]
 [-0.74 -0.69]]
Q2
[[-3.26 -2.1 ]
 [-1.21  1.95]
 [-0.73  0.35]
 [-3.22 -0.54]
 [ 1.14  2.37]]
K2
[[-1.59  0.64]
 [-1.68  0.02]
 [-1.33  6.56]
 [-2.58  4.5 ]
 [-0.41  2.81]]
V2
[[ 1.65 -1.29]
 [ 1.22  1.06]
 [-0.59 -3.58]
 [ 1.27 -2.71]
 [-0.77 -0.46]]


In [8]:
# Attention scores
scores1 = Q1 @ K1.T
scores2 = Q2 @ K2.T
print("K1^T")
print(K1.T.numpy())
print("Scores1")
print(scores1.numpy())
print("K2^T")
print(K2.T.numpy())
print("Scores2")
print(scores2.numpy())

K1^T
[[ 1.77 -1.33  3.19  2.72 -0.16]
 [-3.4  -0.03  0.71 -1.99  2.47]]
Scores1
[[-0.33 -0.41  1.14  0.31  0.58]
 [-1.35  0.54 -1.18 -1.48  0.51]
 [10.91 -0.79  0.04  7.48 -7.18]
 [ 5.96 -0.57  0.39  4.26 -3.81]
 [ 4.92  0.11 -1.23  2.78 -3.63]]
K2^T
[[-1.59 -1.68 -1.33 -2.58 -0.41]
 [ 0.64  0.02  6.56  4.5   2.81]]
Scores2
[[ 3.84  5.42 -9.44 -1.03 -4.56]
 [ 3.18  2.08 14.44 11.92  5.98]
 [ 1.38  1.23  3.25  3.44  1.27]
 [ 4.78  5.39  0.76  5.9  -0.19]
 [-0.3  -1.86 14.02  7.71  6.18]]


In [9]:
# Add causal mask
mask = torch.tril(torch.ones(seq_len, seq_len))
print("Mask")
print(mask.numpy())
scores1 = scores1.masked_fill(mask==0, float('-inf'))
scores2 = scores2.masked_fill(mask==0, float('-inf'))
print("Scores1 masked")
print(scores1.numpy())
print("Scores2 masked")
print(scores2.numpy())
# Scale and softmax
scores1 /= np.sqrt(emb_dim//2)
scores2 /= np.sqrt(emb_dim//2)
scores1 = torch.softmax(scores1, dim=-1)
scores2 = torch.softmax(scores2, dim=-1)
print("Scores1 softmax")
print(scores1.numpy())
print("Scores2 softmax")
print(scores2.numpy())

Mask
[[1. 0. 0. 0. 0.]
 [1. 1. 0. 0. 0.]
 [1. 1. 1. 0. 0.]
 [1. 1. 1. 1. 0.]
 [1. 1. 1. 1. 1.]]
Scores1 masked
[[-0.33  -inf  -inf  -inf  -inf]
 [-1.35  0.54  -inf  -inf  -inf]
 [10.91 -0.79  0.04  -inf  -inf]
 [ 5.96 -0.57  0.39  4.26  -inf]
 [ 4.92  0.11 -1.23  2.78 -3.63]]
Scores2 masked
[[ 3.84  -inf  -inf  -inf  -inf]
 [ 3.18  2.08  -inf  -inf  -inf]
 [ 1.38  1.23  3.25  -inf  -inf]
 [ 4.78  5.39  0.76  5.9   -inf]
 [-0.3  -1.86 14.02  7.71  6.18]]
Scores1 softmax
[[1.   0.   0.   0.   0.  ]
 [0.21 0.79 0.   0.   0.  ]
 [1.   0.   0.   0.   0.  ]
 [0.75 0.01 0.01 0.23 0.  ]
 [0.79 0.03 0.01 0.17 0.  ]]
Scores2 softmax
[[1.   0.   0.   0.   0.  ]
 [0.69 0.31 0.   0.   0.  ]
 [0.18 0.16 0.66 0.   0.  ]
 [0.21 0.32 0.01 0.46 0.  ]
 [0.   0.   0.98 0.01 0.  ]]


In [10]:
# Multiply by V
Z1 = scores1 @ V1
Z2 = scores2 @ V2
print("Z1")
print(Z1.numpy())
print("Z2")
print(Z2.numpy())

Z1
[[1.56 0.23]
 [0.39 0.44]
 [1.56 0.22]
 [1.5  0.01]
 [1.48 0.06]]
Z2
[[ 1.65 -1.29]
 [ 1.51 -0.55]
 [ 0.09 -2.43]
 [ 1.31 -1.22]
 [-0.57 -3.55]]


In [11]:
# Concatenate heads
Z = torch.cat((Z1, Z2), dim=-1)
print("Z")
print(Z.numpy())

Z
[[ 1.56  0.23  1.65 -1.29]
 [ 0.39  0.44  1.51 -0.55]
 [ 1.56  0.22  0.09 -2.43]
 [ 1.5   0.01  1.31 -1.22]
 [ 1.48  0.06 -0.57 -3.55]]


In [12]:
# Multiply by W_o
output = Z @ W_o
print("Output")
print(output.numpy())

Output
[[ 0.46 -2.37 -2.01 -1.63]
 [ 0.08 -0.34 -0.67 -0.11]
 [ 1.24 -1.68 -3.73 -2.54]
 [ 0.44 -2.47 -1.92 -1.61]
 [ 1.8  -1.31 -5.23 -2.91]]


In [13]:
# Add residual connection
output += emb_pos
print("Output with residual connection")
print(output.numpy())

Output with residual connection
[[ 1.07 -2.15 -2.23 -1.04]
 [ 0.41  1.79  0.62  2.24]
 [ 2.81 -1.06 -0.72  0.91]
 [ 3.41 -0.03  0.47  3.59]
 [ 5.86  3.41  0.18  2.14]]


In [14]:
# Add another layernorm
output = ln(output)
print("Output with layernorm")
print(output.numpy())

Output with layernorm
[[ 1.62 -0.8  -0.86  0.03]
 [-1.11  0.68 -0.83  1.27]
 [ 1.51 -1.01 -0.78  0.28]
 [ 0.94 -1.15 -0.84  1.05]
 [ 1.44  0.25 -1.32 -0.37]]


In [15]:
# Now simulate the feedforward network
# Wff1 (emb_dim, emb_dim*2), Wff2 (emb_dim*2, emb_dim)
# Bff1 (emb_dim*2), Bff2 (emb_dim)
W_ff1 = torch.randn(emb_dim, emb_dim*2)
W_ff2 = torch.randn(emb_dim*2, emb_dim)
B_ff1 = torch.randn(emb_dim*2)
B_ff2 = torch.randn(emb_dim)
print("Wff1")
print(W_ff1.numpy())
print("Wff2")
print(W_ff2.numpy())
print("Bff1")
print(B_ff1.numpy())
print("Bff2")
print(B_ff2.numpy())

# Calculate feedforward
wf1 = output @ W_ff1
bf1 = wf1 + B_ff1
print("Wf1")
print(wf1.numpy())
print("Bf1")
print(bf1.numpy())
rf1 = nn.ReLU()(bf1)
print("Rf1")
print(rf1.numpy())
wf2 = rf1 @ W_ff2
bf2 = wf2 + B_ff2
print("Wf2")
print(wf2.numpy())
print("Bf2")
print(bf2.numpy())


Wff1
[[-1.15  0.36  0.31  0.01 -0.86  0.69  0.1   0.69]
 [-0.15  0.06  0.34  1.06  0.21 -0.52  1.14 -1.17]
 [ 0.17  1.28 -0.45  2.57  0.03  0.78 -0.81 -0.68]
 [ 1.76 -0.32  0.68  0.18  0.89  0.91  0.62 -0.91]]
Wff2
[[ 2.17 -0.94  0.74 -0.79]
 [ 1.43 -1.12  0.81  0.86]
 [-0.01 -2.6   0.54 -1.22]
 [-0.62  1.34 -0.15  0.3 ]
 [-0.35 -0.2  -0.17 -0.2 ]
 [-0.46  0.12  0.81  0.56]
 [ 2.05 -0.03 -0.42 -0.  ]
 [-3.42 -2.2  -0.29  1.93]]
Bff1
[-1.41 -0.15 -1.25 -0.58  0.26 -1.85  0.13 -0.  ]
Bff2
[ 0.25  0.52 -1.63  0.03]
Wf1
[[-1.83 -0.58  0.63 -3.03 -1.57  0.9  -0.03  2.6 ]
 [ 3.25 -1.84  1.13 -1.21  2.2  -0.62  2.12 -2.15]
 [-1.24 -0.62  0.66 -3.01 -1.3   1.21 -0.2   2.51]
 [ 0.78 -1.15  0.99 -3.17 -0.16  1.55  0.11  1.61]
 [-2.57 -1.04  0.87 -3.18 -1.56 -0.5   1.26  1.93]]
Bf1
[[-3.24 -0.73 -0.61 -3.61 -1.31 -0.95  0.09  2.6 ]
 [ 1.85 -1.98 -0.11 -1.79  2.45 -2.47  2.25 -2.15]
 [-2.65 -0.76 -0.59 -3.59 -1.04 -0.64 -0.07  2.5 ]
 [-0.63 -1.3  -0.26 -3.75  0.1  -0.3   0.23  1.61]
 [-3.97 -1.19 

In [16]:
# Multiply by some random (emb_dim, emb_dim) matrix to simulate the next layer
# Then layernorm
W = torch.randn(emb_dim, emb_dim)
output = bf2 @ W
output = ln(output)
print("Output")
print(output.numpy())

Output
[[ 0.12  1.54 -1.16 -0.5 ]
 [-0.06 -1.45  1.37  0.14]
 [ 0.12  1.54 -1.16 -0.5 ]
 [-0.09  1.62 -1.08 -0.45]
 [-0.32  1.69 -0.92 -0.44]]


In [17]:
# Assume vocab size is 10
vocab_size = 10
# Output is (seq_len, emb_dim)
# Make a linear layer with size (emb_dim, vocab_size)
wlin = torch.randn(emb_dim, vocab_size)
print("Wlin")
print(wlin.numpy())
# Calculate logits
logits = output @ wlin
print("Logits")
print(logits.numpy())

Wlin
[[-0.69 -0.08  0.13  0.59 -1.39  0.96 -1.98  0.23  1.38 -0.09]
 [-0.32  2.06  0.49 -0.99 -1.92 -1.9  -0.25 -1.82 -0.75 -0.42]
 [ 0.2   0.29 -0.61  1.04  0.35 -0.05 -0.95 -0.49  0.22  1.01]
 [ 0.67 -0.37  0.84 -0.9  -1.14 -0.39 -0.97  1.77 -0.22  0.  ]]
Logits
[[-1.14  3.02  1.06 -2.22 -2.98 -2.56  0.97 -3.1  -1.14 -1.83]
 [ 0.87 -2.64 -1.43  2.7   3.2   2.56 -0.95  2.21  1.27  1.99]
 [-1.14  3.02  1.06 -2.23 -2.98 -2.56  0.97 -3.09 -1.14 -1.83]
 [-0.97  3.21  1.06 -2.38 -2.86 -2.93  1.24 -3.24 -1.48 -1.76]
 [-0.79  3.41  0.97 -2.43 -2.62 -3.29  1.52 -3.48 -1.82 -1.61]]


In [18]:
# Calculate softmax
probs = torch.softmax(logits, dim=-1)
print("Probs")
print(probs.numpy())

Probs
[[0.01 0.76 0.11 0.   0.   0.   0.1  0.   0.01 0.01]
 [0.03 0.   0.   0.2  0.32 0.17 0.01 0.12 0.05 0.1 ]
 [0.01 0.76 0.11 0.   0.   0.   0.1  0.   0.01 0.01]
 [0.01 0.77 0.09 0.   0.   0.   0.11 0.   0.01 0.01]
 [0.01 0.79 0.07 0.   0.   0.   0.12 0.   0.   0.01]]
