In [2]:
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import requests



In [3]:
DATA_PATH = "sales_textbook.txt"
DATA_SOURCE_URL = "https://huggingface.co/datasets/goendalf666/sales-textbook_for_convincing_and_selling/raw/main/sales_textbook.txt"

if not os.path.exists(DATA_PATH):
    r = requests.get(DATA_SOURCE_URL)
    with open(DATA_PATH, "w") as f:
        f.write(r.text)

with open(DATA_PATH, "r") as f:
    text = f.read()
    

In [4]:
import tiktoken
encoding = tiktoken.get_encoding("cl100k_base")
tokenized_text = encoding.encode(text)

max_token_value = max(tokenized_text) + 1

tokenized_text = torch.tensor(tokenized_text)


In [5]:
# hyperparameters
context_length = 16
d_model = 64
n_heads = 4
batch_size = 4

In [6]:
# split the data into training and validation sets
train_size = int(0.9 * len(tokenized_text))
val_size = len(tokenized_text) - train_size

print(f"Training set size: {train_size}")
print(f"Validate set size: {val_size}")

train_data = tokenized_text[:train_size]
val_data = tokenized_text[train_size:]

data = train_data

idxs = torch.randint(low=0, high=len(data) - context_length, size=(n_heads,))

x_batch = torch.stack([data[idx:idx + context_length] for idx in idxs])
y_batch = torch.stack([data[idx + 1:idx + context_length + 1] for idx in idxs])

print(x_batch.shape, y_batch.shape)

Training set size: 70127
Validate set size: 7792
torch.Size([4, 16]) torch.Size([4, 16])


In [7]:
# prepare for token embedding
token_embedding_lookup_table = nn.Embedding(max_token_value, d_model)
x = token_embedding_lookup_table(x_batch.data)
y = token_embedding_lookup_table(y_batch.data)


In [8]:
# prepare for position encoding
positional_encoding_lookup_table = torch.zeros(context_length, d_model)
position = torch.arange(0, context_length, dtype=torch.float).unsqueeze(1)

div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-torch.log(torch.tensor(10000.0)) / d_model))

positional_encoding_lookup_table[:, 0::2] = torch.sin(position * div_term)
positional_encoding_lookup_table[:, 1::2] = torch.cos(position * div_term)

positional_encoding_lookup_table = positional_encoding_lookup_table.unsqueeze(0).expand(batch_size, -1, -1)

input_embedding_x = x + positional_encoding_lookup_table
input_embedding_y = y + positional_encoding_lookup_table

x_plot = input_embedding_x[0].detach().numpy()

In [9]:
# prepare for Q, K, V
query = key = value = input_embedding_x

Wq = nn.Linear(d_model, d_model)
Wk = nn.Linear(d_model, d_model)
Wv = nn.Linear(d_model, d_model)

Q = Wq(query)
Q = Q.reshape(batch_size, -1, n_heads, d_model // n_heads).permute(0, 2, 1, 3)

K = Wk(key)
K = K.reshape(batch_size, -1, n_heads, d_model // n_heads).permute(0, 2, 1, 3)

V = Wv(value)
V = V.reshape(batch_size, -1, n_heads, d_model // n_heads).permute(0, 2, 1, 3)

output = Q @ K.transpose(-2, -1) / math.sqrt(d_model // n_heads)

# apply mask
mask = torch.triu(torch.ones(context_length, context_length), diagonal=-1).bool()
output = output.masked_fill(mask, float('-inf'))

# apply softmax
attention_score = F.softmax(output, dim=-1)

# apply attention
A = attention_score @ V

# apply concatenate
A = A.permute(0, 2, 1, 3).reshape(batch_size, context_length, d_model)
Wo = nn.Linear(d_model, d_model)

output = Wo(A)
output.shape

torch.Size([4, 16, 64])

In [10]:
# apply residual connection
output = output + x
print(output.shape)
print(x.shape)

torch.Size([4, 16, 64])
torch.Size([4, 16, 64])


In [11]:
# apply layer norm
layer_norm1 = nn.LayerNorm(d_model)
layer_norm_output = layer_norm1(output)


In [12]:
# apply feed forward network
output = nn.Linear(d_model, d_model * 4)(layer_norm_output)
output = nn.ReLU()(output)
output = nn.Linear(d_model * 4, d_model)(output)

output = output + layer_norm_output


In [13]:
# apply layer norm again
layer_norm2 = nn.LayerNorm(d_model)
output = layer_norm2(output)

In [14]:
# apply final linear layer
output = nn.Linear(d_model, max_token_value)(output)
output.shape

torch.Size([4, 16, 100070])

In [15]:
logits = F.softmax(output, dim=-1)
predicted_index = torch.argmax(logits[0,0]).item()

encoding.decode([0])

'!'