# Transformer Demo

# Prepare the Enviroment

In [97]:
pip install numpy requests torch tiktoken matplotlib pandas

You should consider upgrading via the '/Users/roger/Dev/Transformer-from-scratch/trans_env/bin/python -m pip install --upgrade pip' command.[0m[33m
[0mNote: you may need to restart the kernel to use updated packages.


In [98]:
import os
import requests
import pandas as pd
import matplotlib.pyplot as plt
import math
import tiktoken
import torch
import torch.nn as nn

# Setup Hyperparameters

In [99]:
batch_size = 4
context_length = 16
d_model = 64
num_layers = 8
num_heads = 4
learning_rate = 1e-3
dropout = 0.1
max_iters = 5000
eval_interval = 50
eval_iters = 20
device = ('mps' if torch.backends.mps.is_available() 
else ('cuda' if torch.cuda.is_available() else 'cpu'))
TORCH_SEED = 1337
torch.manual_seed(TORCH_SEED)

print(device)

mps


# Prepare the Dataset

In [100]:
if not os.path.exists('sales_textbook.txt'):
    url = 'https://huggingface.co/datasets/goendalf666/sales-textbook_for_convincing_and_selling/raw/main/sales_textbook.txt'
    with open('sales_textbook.txt', 'w') as f:
        f.write(requests.get(url).text)

with open('sales_textbook.txt', 'r') as f:
    text = f.read()

print(text[:100])
print(len(text))

Chapter 1: Building Rapport and Capturing Attention
Subpoint: Understanding the Importance of Buildi
460319


# Step1: Tokenization

In [None]:
encoding = tiktoken.get_encoding("cl100k_base")
tokenized_text = encoding.encode(text)
vocab_size = len(set(tokenized_text))
max_token_value = max(tokenized_text)

print(text[:25])
print(tokenized_text[:20])
print(f"Tokenized text size: {len(tokenized_text)}")
print(f"Vocabulary size: {vocab_size}")
print(f"The maximum token value in the tokenized text is: {max_token_value}")

Chapter 1: Building Rappo
[26072, 220, 16, 25, 17283, 23097, 403, 323, 17013, 1711, 63120, 198, 3214, 2837, 25, 46551, 279, 94100, 315, 17283]
Tokenized text size: 77919
Vocabulary size: 3771
The maximum token value in the tokenized text is: 100069
min token is: 1


# Step2: Word Embedding

In [102]:
# 切分训练和验证集
split_idx = int(len(tokenized_text) * 0.8)
train_data = tokenized_text[:split_idx] 
val_data = tokenized_text[split_idx:]

data = torch.tensor(train_data, dtype=torch.long, device=device)
idxs = torch.randint(low=0, high=len(data) - context_length, size=(batch_size,)) 

x_batch = torch.stack([data[i:i + context_length] for i in idxs]) 
y_batch = torch.stack([data[i + 1:i + context_length + 1] for i in idxs])
print(x_batch.shape)
print(y_batch.shape)
print(x_batch[0])
print(y_batch[0])

token_embedding_lookup_table = nn.Embedding(max_token_value, d_model).to(device)
x = token_embedding_lookup_table(x_batch)
y = token_embedding_lookup_table(y_batch)
print("Token Embedding Lookup Table:", token_embedding_lookup_table.weight.shape)
print(x.shape)
print(y.shape)
print(x[0][1])

torch.Size([4, 16])
torch.Size([4, 16])


tensor([  627,  1383, 88861,   279,  1989,   315, 25607, 16940, 65931,   323,
        32097,    11,   584, 26458, 13520,   449], device='mps:0')
tensor([ 1383, 88861,   279,  1989,   315, 25607, 16940, 65931,   323, 32097,
           11,   584, 26458, 13520,   449,   264], device='mps:0')
Token Embedding Lookup Table: torch.Size([100069, 64])
torch.Size([4, 16, 64])
torch.Size([4, 16, 64])
tensor([-7.2645e-01,  2.7959e-01,  6.1379e-01, -6.8169e-01,  6.1964e-01,
        -8.3871e-01,  1.6872e+00,  9.2419e-01, -1.7660e+00,  1.0969e+00,
         8.2345e-01,  4.0419e-01, -1.0015e+00, -4.0253e-01, -3.6015e-01,
        -1.8088e-01,  1.1226e+00, -1.2955e+00, -1.4447e+00,  1.3121e+00,
        -6.9775e-01,  1.1749e+00,  5.5158e-01, -1.8591e-01,  1.0946e-01,
         7.3421e-01,  1.2775e+00,  1.2458e+00,  2.8220e-01,  1.5083e+00,
        -2.1885e-03, -1.7179e+00,  3.7070e-02,  1.2844e-01,  2.3702e-01,
         3.5079e-01, -5.3175e-01,  3.4943e-01, -7.5952e-01,  3.6131e-01,
         7.6173e-01,  3

# Step3: Positional Encoding

In [103]:
position_encoding_lookup_table = torch.zeros(context_length, d_model).to(device) # (context_length, d_model)
position = torch.arange(0, context_length, dtype=torch.float).unsqueeze(1) # (context_length, 1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) # (d_model/2,)  
position_encoding_lookup_table[:, 0::2] = torch.sin(position * div_term)
position_encoding_lookup_table[:, 1::2] = torch.cos(position * div_term)
position_encoding_lookup_table = position_encoding_lookup_table.unsqueeze(0).expand(batch_size, -1, -1) # (batch_size, context_length, d_model) 
print("Position Encoding Lookup Table: ", position_encoding_lookup_table.shape)

input_embedding_x = x + position_encoding_lookup_table # (batch_size, context_length, d_model)
input_embedding_y = y + position_encoding_lookup_table # (batch_size, context_length, d_model)
X = input_embedding_x # (batch_size, context_length, d_model)
x_plot = input_embedding_x[0].detach().cpu().numpy()
print("Final Input Embedding of x:\n", pd.DataFrame(x_plot))

Position Encoding Lookup Table:  torch.Size([4, 16, 64])
Final Input Embedding of x:
           0         1         2         3         4         5         6   \
0   0.346711  0.416130  0.194183  2.676231  0.104132  0.317640 -0.716679   
1   0.115021  0.819889  1.295355  0.050067  1.152809  0.007295  2.096463   
2   0.025032 -0.896510  2.007128  0.645062  1.861388  1.183611  0.626594   
3   1.540862 -2.196910  1.106263 -0.879057  1.525128 -0.364159  1.305571   
4  -1.026346 -2.520314  0.092276 -1.922128  0.772673 -1.248682  0.811819   
5  -0.333637 -0.845062 -3.047244 -0.539240  0.370664 -0.112350  0.087253   
6  -0.671321  1.424937 -2.025386 -1.031886 -0.385049 -0.656757  0.465220   
7   1.460090  1.632251 -2.407473 -0.108412 -1.577922 -1.534449 -0.767310   
8   0.493083 -2.413755 -1.205481  0.105365 -0.475036  0.941739  0.546560   
9   0.381502 -0.606034  0.249510  1.915294 -0.984113  2.072770 -2.097719   
10 -0.586758 -1.718416  1.647630 -0.431370 -0.642387 -0.391723 -0.296238   
11

# Step4: Transformer Block

### 4.1 Multi-head Attention Overview

<img src="./img/multihead_attention.png" width="600">

### 4.2 Prepare Q,K,V

In [114]:
query = key = value = X # [4, 16, 64] [batch_size, context_length, d_model]

Wq = nn.Linear(d_model, d_model).to(device)
Wk = nn.Linear(d_model, d_model).to(device)
Wv = nn.Linear(d_model, d_model).to(device)

Q = Wq(query)
Q = Q.view(batch_size, -1, num_heads, d_model // num_heads) # [4, 16, 4, 16] [batch_size, context_length, num_heads, d_model // num_heads]
K = Wq(key)
K = K.view(batch_size, -1, num_heads, d_model // num_heads) # [4, 16, 4, 16] [batch_size, context_length, num_heads, d_model // num_heads]
V = Wv(value)
V = V.view(batch_size, -1, num_heads, d_model // num_heads) # [4, 16, 4, 16] [batch_size, context_length, num_heads, d_model // num_heads]

# 交换维度位置，将 num_heads 移到第二维，为后续的注意力计算做准备，使得注意力计算可以并行进行
Q = Q.transpose(1, 2) # [4, 4, 16, 16] [batch_size, num_heads, context_length, d_model // num_heads]
K = K.transpose(1, 2) # [4, 4, 16, 16] [batch_size, num_heads, context_length, d_model // num_heads]
V = V.transpose(1, 2) # [4, 4, 16, 16] [batch_size, num_heads, context_length, d_model // num_heads]

print(Q.shape)
print(pd.DataFrame(Q[0][0].detach().cpu().numpy()))

torch.Size([4, 4, 16, 16])
          0         1         2         3         4         5         6   \
0   1.098015  0.205219  1.033230 -0.421223  0.429190 -0.765578 -0.993587   
1   0.486800  0.758198  0.342134  1.147735  0.352587  0.913575 -0.240828   
2   1.361852  0.288645  0.558850  0.197726 -0.759869 -1.004792 -0.119369   
3  -0.319381  1.608541 -0.265403  0.286685 -0.825504 -0.311573 -0.466977   
4   0.656908  1.683313 -0.752694 -0.501442 -0.103208 -0.331558 -0.179397   
5   0.722305  0.611283 -0.358031 -0.466505 -0.169242  0.115154 -0.692522   
6   0.652447  0.972758 -0.640037 -0.757617 -0.713646  0.129324  0.375101   
7   0.147175  0.174586  0.415578 -0.213032 -0.499469  1.262464  0.319959   
8   0.771177  0.638311 -0.267933  0.029066  0.822012  0.515905  1.583130   
9   0.558526  0.459015  0.656071 -0.618360 -0.238946 -0.036866  0.360074   
10 -0.719556  1.098709 -0.940380  0.338702 -0.291737 -0.410984 -0.367827   
11  1.422824 -0.127827  0.626552 -0.511144  0.215399  0.76142

### 4.3 Calculate QK^T Attention

In [115]:
attention_score = torch.matmul(Q, K.transpose(-2, -1)) # [4, 4, 16, 16] [batch_size, num_heads, context_length, context_length]
print(attention_score.shape)
print(pd.DataFrame(attention_score[0][0].detach().cpu().numpy()))

torch.Size([4, 4, 16, 16])
          0         1         2          3         4         5         6   \
0   8.619459  3.142893  4.520507   3.242733  1.839295  2.879557  0.559735   
1   3.142893  6.451763  2.082761   4.955133  2.036033  2.428303  0.106789   
2   4.520507  2.082761  6.354085   3.475078  2.538650  1.368454  0.815163   
3   3.242733  4.955133  3.475078  10.501399  5.351847  3.757942  2.637617   
4   1.839295  2.036033  2.538650   5.351847  6.621957  3.289231  2.876918   
5   2.879557  2.428303  1.368454   3.757942  3.289231  4.508989  2.594900   
6   0.559735  0.106789  0.815163   2.637617  2.876918  2.594900  4.669538   
7   0.757476  2.951372  0.902415   2.625209 -0.025387  1.103071  1.169249   
8   2.191056  2.811849  0.734108   0.660408  1.074552 -0.038380  1.637440   
9   2.258942  0.744831  1.753906   1.744321  1.680879  1.185197  1.106795   
10 -0.839946  0.034575 -0.145063   3.563659  1.829958  0.376614  1.623793   
11  3.621936  2.011119  2.420086   0.204779  0.87

### 4.4 Scale

In [116]:
attention_score = attention_score / math.sqrt(d_model // num_heads)
print(pd.DataFrame(attention_score[0][0].detach().cpu().numpy()))

          0         1         2         3         4         5         6   \
0   2.154865  0.785723  1.130127  0.810683  0.459824  0.719889  0.139934   
1   0.785723  1.612941  0.520690  1.238783  0.509008  0.607076  0.026697   
2   1.130127  0.520690  1.588521  0.868769  0.634663  0.342114  0.203791   
3   0.810683  1.238783  0.868769  2.625350  1.337962  0.939486  0.659404   
4   0.459824  0.509008  0.634663  1.337962  1.655489  0.822308  0.719229   
5   0.719889  0.607076  0.342114  0.939486  0.822308  1.127247  0.648725   
6   0.139934  0.026697  0.203791  0.659404  0.719229  0.648725  1.167385   
7   0.189369  0.737843  0.225604  0.656302 -0.006347  0.275768  0.292312   
8   0.547764  0.702962  0.183527  0.165102  0.268638 -0.009595  0.409360   
9   0.564735  0.186208  0.438477  0.436080  0.420220  0.296299  0.276699   
10 -0.209987  0.008644 -0.036266  0.890915  0.457490  0.094153  0.405948   
11  0.905484  0.502780  0.605021  0.051195  0.217614  0.497847  0.444722   
12 -0.371107

### 4.5 Mask

In [117]:
attention_score = attention_score.masked_fill(
        torch.triu(torch.ones(attention_score.shape[-2:]).to(device), diagonal=1).bool(), 
        float('-inf')) # 将矩阵的上三角设置为-inf, 形状为[4, 4, 16, 16] [batch_size, num_heads, context_length, context_length]
print(pd.DataFrame(attention_score[0][0].detach().cpu().numpy()))

          0         1         2         3         4         5         6   \
0   2.154865      -inf      -inf      -inf      -inf      -inf      -inf   
1   0.785723  1.612941      -inf      -inf      -inf      -inf      -inf   
2   1.130127  0.520690  1.588521      -inf      -inf      -inf      -inf   
3   0.810683  1.238783  0.868769  2.625350      -inf      -inf      -inf   
4   0.459824  0.509008  0.634663  1.337962  1.655489      -inf      -inf   
5   0.719889  0.607076  0.342114  0.939486  0.822308  1.127247      -inf   
6   0.139934  0.026697  0.203791  0.659404  0.719229  0.648725  1.167385   
7   0.189369  0.737843  0.225604  0.656302 -0.006347  0.275768  0.292312   
8   0.547764  0.702962  0.183527  0.165102  0.268638 -0.009595  0.409360   
9   0.564735  0.186208  0.438477  0.436080  0.420220  0.296299  0.276699   
10 -0.209987  0.008644 -0.036266  0.890915  0.457490  0.094153  0.405948   
11  0.905484  0.502780  0.605021  0.051195  0.217614  0.497847  0.444722   
12 -0.371107

### 4.6 Softmax

In [118]:
attention_score = torch.softmax(attention_score, dim=-1) # [4, 4, 16, 16] [batch_size, num_heads, context_length, context_length]
print(pd.DataFrame(attention_score[0][0].detach().cpu().numpy()))


          0         1         2         3         4         5         6   \
0   1.000000  0.000000  0.000000  0.000000  0.000000  0.000000  0.000000   
1   0.304234  0.695766  0.000000  0.000000  0.000000  0.000000  0.000000   
2   0.319981  0.173960  0.506060  0.000000  0.000000  0.000000  0.000000   
3   0.102741  0.157640  0.108886  0.630732  0.000000  0.000000  0.000000   
4   0.111686  0.117317  0.133025  0.268764  0.369208  0.000000  0.000000   
5   0.155401  0.138822  0.106509  0.193563  0.172161  0.233543  0.000000   
6   0.091895  0.082056  0.097954  0.154488  0.164012  0.152847  0.256748   
7   0.082528  0.142824  0.085574  0.131641  0.067859  0.089976  0.091477   
8   0.089952  0.105055  0.062492  0.061351  0.068044  0.051518  0.078326   
9   0.119371  0.081754  0.105212  0.104960  0.103309  0.091268  0.089497   
10  0.053185  0.066182  0.063276  0.159921  0.103675  0.072090  0.098466   
11  0.121795  0.081421  0.090186  0.051834  0.061220  0.081021  0.076829   
12  0.043020

### 4.7 Calculate V Attention

In [119]:
A = torch.matmul(attention_score, V) # [4, 4, 16, 16] [batch_size, num_heads, context_length, d_model // num_heads]
print(attention_score.shape)
print(A.shape)
print(pd.DataFrame(A[0][0].detach().cpu().numpy()))

torch.Size([4, 4, 16, 16])
torch.Size([4, 4, 16, 16])
          0         1         2         3         4         5         6   \
0   1.404491  0.273069 -1.009697  1.125089  0.448860 -0.699946 -0.194428   
1   1.039182  0.530752 -0.623729  1.028149  0.048530 -0.715421 -0.089486   
2   0.333371  0.114961 -0.981869  1.175668 -0.042304 -0.108881  0.517010   
3   0.421122  0.940021 -0.920434  1.004752 -0.253904 -1.109372  0.136975   
4   0.133389  0.445351 -0.638007  0.601900 -0.233886 -0.635158  0.202332   
5   0.164900  0.477558 -0.511829  0.587863 -0.161267 -0.498483  0.268524   
6  -0.026703  0.591076 -0.166389  0.355832  0.050325 -0.386237  0.341763   
7  -0.109981  0.730334 -0.580128  0.493912 -0.014560 -0.267685  0.515086   
8   0.145014  0.400024 -0.521913  0.314520  0.208060 -0.132144 -0.363337   
9   0.135645  0.550333 -0.413918  0.371470  0.161968 -0.244217  0.123829   
10  0.149946  0.645187 -0.412956  0.225751 -0.083613 -0.395455 -0.138985   
11 -0.029733  0.386741 -0.423626  

### 4.8 Concatenate and Output

In [120]:
A = A.transpose(1, 2) # [4, 4, 16, 16] -> [4, 16, 4, 16] [batch_size, context_length, num_heads, d_model // num_heads]
A = A.reshape(batch_size, -1, d_model) # [4, 16, 4, 16] -> [4, 16, 64] [batch_size, context_length, d_model]
print(A.shape)

Wo = nn.Linear(d_model, d_model).to(device) # [64, 64] [d_model, d_model]
output = Wo(A) # [4, 16, 64] [batch_size, context_length, d_model]
print(output.shape)

torch.Size([4, 16, 64])
torch.Size([4, 16, 64])


# Step5: Residual Connection and Layer Normalization

In [121]:
# Add residual connection
output = output + X

# Add Layer Normalization
layer_norm = nn.LayerNorm(d_model).to(device)
output = layer_norm(output) # [4, 16, 64] [batch_size, context_length, d_model]
print(output.shape)

torch.Size([4, 16, 64])


# Step6: Feed Forward Network

In [136]:
# 保存输入，用于后续残差连接
layer_norm_output = output
# 升维
output = nn.Linear(d_model, d_model * 4).to(device)(output) # [4, 16, 256] [batch_size, context_length, d_model * 4]
# 激活函数
output = nn.ReLU().to(device)(output)
# 降维
output = nn.Linear(d_model * 4, d_model).to(device)(output) # [4, 16, 64] [batch_size, context_length, d_model]
# dropout防止过拟合
output = torch.dropout(output, p=dropout, train=True) # [4, 16, 64] [batch_size, context_length, d_model]
print(output.shape)

# 残差连接
output = output + layer_norm_output
# 层归一化
layer_norm = nn.LayerNorm(d_model).to(device)
output = layer_norm(output)
print(output.shape)

torch.Size([4, 16, 64])
torch.Size([4, 16, 64])


# Step7: Repeat step 4 to 6

# Step8: Output Probabilities

In [147]:
# 通过线性层，映射到词表，词表从0开始编码，长度为 max_token_value + 1
logist = nn.Linear(d_model, max_token_value + 1).to(device)(output) # [4, 16, 100070] [batch_size, context_length, max_token_value + 1]
print(logist.shape)
# 计算每一个词的概率分布
probabilities = torch.softmax(logist, dim=-1) 
print(pd.DataFrame(probabilities[0].detach().cpu().numpy()))

predicted_token = torch.argmax(probabilities[0], dim=-1)
print("predict token: ", predicted_token)
predicted_text = encoding.decode(predicted_token.cpu().tolist())
print("predict text: ", predicted_text)

torch.Size([4, 16, 100070])
      0         1         2         3         4         5         6       \
0   0.000003  0.000006  0.000010  0.000015  0.000018  0.000010  0.000007   
1   0.000005  0.000007  0.000010  0.000008  0.000010  0.000005  0.000006   
2   0.000013  0.000019  0.000005  0.000009  0.000007  0.000014  0.000005   
3   0.000011  0.000017  0.000012  0.000004  0.000005  0.000012  0.000008   
4   0.000013  0.000015  0.000018  0.000004  0.000010  0.000008  0.000011   
5   0.000008  0.000015  0.000012  0.000004  0.000010  0.000012  0.000004   
6   0.000012  0.000014  0.000010  0.000005  0.000003  0.000008  0.000003   
7   0.000020  0.000009  0.000016  0.000005  0.000007  0.000004  0.000008   
8   0.000020  0.000015  0.000009  0.000005  0.000006  0.000003  0.000009   
9   0.000008  0.000008  0.000010  0.000016  0.000009  0.000002  0.000004   
10  0.000006  0.000011  0.000010  0.000003  0.000007  0.000007  0.000008   
11  0.000006  0.000026  0.000008  0.000015  0.000016  0.0000