<a href="https://colab.research.google.com/github/yiwenwangANU/pytorch_review/blob/main/Extra_Shakespeare_transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [62]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from torch import nn
from torch.nn import functional as F
import os
from pathlib import Path
import requests

In [63]:
if torch.cuda.is_available():
  device = torch.device("cuda")
  print("GPU ready!")
else:
  device = "cpu"
  print("GPU not available!")

GPU ready!


##Data download and explore

In [64]:
data_dir = Path.home() / 'data'
os.makedirs(data_dir, exist_ok=True)

data_dir = os.path.join(data_dir, "input.txt")
if os.path.exists(data_dir):
  print("Data ready.")
else:
  print(f"Downloading training data to {data_dir}")
  url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
  torch.hub.download_url_to_file(url, data_dir)

Data ready.


In [65]:
with open(data_dir, 'r', encoding='utf-8') as f:
    text = f.read()
    print(f"length of dataset in characters: {len(text)}")
    print('First 300 chars:')
    print(text[:300])

length of dataset in characters: 1115394
First 300 chars:
First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us


##Tokenization

In [66]:
vocab = sorted(list(set(text)))
vocab_size = len(vocab)
print(f'Number of unique chars: {vocab_size}')
print(vocab)

Number of unique chars: 65
['\n', ' ', '!', '$', '&', "'", ',', '-', '.', '3', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']


In [67]:
charTotoken = {char:i for i,char in enumerate(vocab)}
tokenTochar = {i:char for i,char in enumerate(vocab)}

In [68]:
def encode(input):
  return [charTotoken[char] for char in input]

def decode(input):
  return ''.join([tokenTochar[token] for token in input])

In [69]:
print(encode('First Citizen:'))
print(decode(encode('First Citizen:')))
# Check other tokenizer like gpt-tokenizer

[18, 47, 56, 57, 58, 1, 15, 47, 58, 47, 64, 43, 52, 10]
First Citizen:


In [70]:
data = torch.tensor(encode(text), dtype=torch.long)
print(data.shape)
print(data[:100])

torch.Size([1115394])
tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 14, 43, 44,
        53, 56, 43,  1, 61, 43,  1, 54, 56, 53, 41, 43, 43, 42,  1, 39, 52, 63,
         1, 44, 59, 56, 58, 46, 43, 56,  6,  1, 46, 43, 39, 56,  1, 51, 43,  1,
        57, 54, 43, 39, 49,  8,  0,  0, 13, 50, 50, 10,  0, 31, 54, 43, 39, 49,
         6,  1, 57, 54, 43, 39, 49,  8,  0,  0, 18, 47, 56, 57, 58,  1, 15, 47,
        58, 47, 64, 43, 52, 10,  0, 37, 53, 59])


##Train test split and create training set

In [71]:
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]

In [72]:
# create training/validation set by randomly sample sequence at block_size in training data
def get_sample(data_set, batch_size, block_size):
  idxes = torch.randint(len(data_set)-block_size, (batch_size,))
  xb = torch.stack([data_set[idx:idx+block_size] for idx in idxes])
  yb = torch.stack([data_set[idx+1: idx+block_size+1] for idx in idxes])
  return xb, yb

In [73]:
batch_size = 4
block_size = 8

xb, yb = get_sample(train_data, batch_size, block_size)
print('inputs:')
print(xb.shape)
print(xb)
print('targets:')
print(yb.shape)
print(yb)

print('----')

for b in range(batch_size): # batch dimension
    for t in range(block_size): # time dimension
        context = xb[b, :t+1]
        target = yb[b,t]
        print(f"when input is {context.tolist()} the target: {target}")

inputs:
torch.Size([4, 8])
tensor([[ 1, 50, 47, 43, 57, 58,  8,  0],
        [ 0, 19, 24, 27, 33, 15, 17, 31],
        [41, 46,  1, 41, 53, 52, 57, 59],
        [59, 52, 42, 56, 43, 42,  1, 57]])
targets:
torch.Size([4, 8])
tensor([[50, 47, 43, 57, 58,  8,  0, 32],
        [19, 24, 27, 33, 15, 17, 31, 32],
        [46,  1, 41, 53, 52, 57, 59, 51],
        [52, 42, 56, 43, 42,  1, 57, 46]])
----
when input is [1] the target: 50
when input is [1, 50] the target: 47
when input is [1, 50, 47] the target: 43
when input is [1, 50, 47, 43] the target: 57
when input is [1, 50, 47, 43, 57] the target: 58
when input is [1, 50, 47, 43, 57, 58] the target: 8
when input is [1, 50, 47, 43, 57, 58, 8] the target: 0
when input is [1, 50, 47, 43, 57, 58, 8, 0] the target: 32
when input is [0] the target: 19
when input is [0, 19] the target: 24
when input is [0, 19, 24] the target: 27
when input is [0, 19, 24, 27] the target: 33
when input is [0, 19, 24, 27, 33] the target: 15
when input is [0, 19, 24, 

In [74]:
# 10 unique embeddings and use 3 to represent each token
embedding = nn.Embedding(num_embeddings=10, embedding_dim=3)

# Input tensor (indices of words or items)
input_indices = torch.tensor([1, 5, 7, 7])

# Get embeddings
output = embedding(input_indices)
print(output)

tensor([[ 1.1477, -0.6008, -1.5328],
        [-0.9608, -0.1654, -1.4084],
        [-1.3594, -0.5436, -1.1195],
        [-1.3594, -0.5436, -1.1195]], grad_fn=<EmbeddingBackward0>)


##Bigram Model

In [75]:
torch.manual_seed(42)

# A Bigram Model predicts the next token based on the embedding of previous token
# (i.e., it learns a probability distribution of token transitions).
class BigramLanguageModel(nn.Module):
  def __init__(self, vocab_size):
    super().__init__()
    # each token directly reads off the logits for the next token from a lookup table
    self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)
  def forward(self, x):
    return self.token_embedding_table(x)

model_bigram = BigramLanguageModel(vocab_size)

In [76]:
# use the last token in x to predict next token
# and loop to generate text at seq_len length
torch.manual_seed(42)
def generate(x, seq_len): # x-> shape(Batch size, 1)
  for _ in range(seq_len):
    logits = model_bigram(x) # shape(Batch size, Sequence length, Vocabulary size)
    logits_last = logits[:,-1,:] # shape(Batch size, Vocabulary size)
    logits_last_prob = F.softmax(logits_last, dim=-1) # shape(Batch size, Vocabulary size)
    token_next = torch.multinomial(logits_last_prob, num_samples=1) # shape(B, 1)
    x = torch.cat((x, token_next), dim=1) # (B, T+1)
  return x

In [77]:
# generate use untrained model to generate text
new_text = generate(torch.zeros(1, 1, dtype=torch.long), seq_len=100)
print(new_text)
print(decode(new_text.squeeze().numpy()))

tensor([[ 0, 25, 60, 15, 19, 27, 61, 34, 35, 42, 40, 49, 49, 17, 51, 50, 12, 35,
          7, 11,  8, 50, 49, 28, 38,  7, 23, 54, 25, 17, 59, 28, 30, 30, 49, 17,
          4, 28, 43,  8, 63, 49,  7, 46, 15, 33, 59, 60, 49, 16, 62, 42, 24,  9,
         62, 62, 17, 25, 59, 28, 40, 38, 28, 19, 37, 12, 46, 31,  7, 52, 15, 40,
         39, 55, 59, 36, 46, 15,  9, 56, 16, 21,  7, 54, 64, 29, 14, 50,  5, 26,
         62, 42, 34, 28, 40, 47, 56, 39, 55, 35, 42]])

MvCGOwVWdbkkEml?W-;.lkPZ-KpMEuPRRkE&Pe.yk-hCUuvkDxdL3xxEMuPbZPGY?hS-nCbaquXhC3rDI-pzQBl'NxdVPbiraqWd


In [78]:
optimizer = torch.optim.Adam(params=model_bigram.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()

In [79]:
epoches = 1000
for epoch in range(epoches):
  model_bigram.train()
  xb, yb = get_sample(train_data, batch_size, block_size) # get random sample in each epoch
  logits = model_bigram(xb)
  B, T, C = logits.shape # change the shape to match the input shape of loss_fn
  logits = logits.view(B*T, C)
  targets = yb.view(B*T)
  loss = loss_fn(logits, targets)
  optimizer.zero_grad()
  loss.backward()
  optimizer.step()

  if(epoch == 0):
    print(f'Logits shape: {logits.shape}')
  if(epoch % 100 == 0):
    print(f'epoch: {epoch}, train loss: {loss.item():.2f}')


Logits shape: torch.Size([32, 65])
epoch: 0, train loss: 4.92
epoch: 100, train loss: 4.71
epoch: 200, train loss: 4.48
epoch: 300, train loss: 4.51
epoch: 400, train loss: 4.54
epoch: 500, train loss: 4.62
epoch: 600, train loss: 4.43
epoch: 700, train loss: 4.26
epoch: 800, train loss: 4.13
epoch: 900, train loss: 4.05


In [80]:
# generate use trained model to generate text
print(decode(generate(torch.zeros(1, 1, dtype=torch.long), seq_len=500).squeeze().numpy()))


CE
Br?rczWpGY3fn WBJSGPgfff,oio
NLQn!vTuX.yvNa3F?kaq'hjGtCjQKocN
UpouPblMLLMptjVr,y.y xdVVzJgsdX.IW,gC$OVCODnCwcTVF&XhUKAt&xoIVF?flvwgdWKnN;3uoiaF$z
M?kI;h
DbuMG,H3LYNmrDxpWHHvAKOF-jU.ho;fBuOya-IS
ghOEb&ZQ,l;:mslpcNN
KpVEYRIIM,'hCRblAcWTo;niab&am&K3ZnkIMaqntjh-ARIaqu
ghjZTBRS$J?qBQbwyCjhRheuyJoGtsutJBy-j&,r,'k$PTuZm3tPbu mKH-IseadVJVrw'x f -d3si'g.cu,k-JMiv?R-jlqIeairoan
3VEmtWcEIccN
AFDHhesYzETkUGI;f .bYh,g:gSeIKptaxTBykJG.!gsz ffgpGLHa3xDkWlrwVm3C
Kh.H?Hvy!ogQahk!DmgiZ&X.SkeI-kooERqKpjvTWPS:Nz


##Single head self-attention

In [81]:
## lower triangle masking
## to remove the influence from tokens to previous ones
B, T, C = 4, 8, 32 # batch, context/sequence length, channel/embedding
x = torch.rand(B, T, C)

def apply_mask(wei):
  tri = torch.tril(torch.ones(T, T))
  wei = torch.masked_fill(wei, tri==0, float('-inf')) # for text generation
  wei = F.softmax(wei, dim=-1)
  return wei

wei = torch.zeros(T, T)
print(apply_mask(wei))

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])


In [82]:
## key, query and affinity matrix
B, T, C = 4, 8, 32
x = torch.rand(B, T, C)

head_size = 16 # as embedding size
# match each token from embedding to head that answer who I am
key = nn.Linear(C, head_size, bias=False)
# match each token from embedding to head that ask what I want
query = nn.Linear(C, head_size, bias=False)

k = key(x) # (B, T, head_size)
q = query(x) # (B, T, head_size)

# affinity matrix indicates the relationship between tokens
affinity = q @ k.transpose(-2, -1) # (B, T, head) @ (B, head, T) -> (B, T, T)
print(apply_mask(affinity)) # after apply mask and softmax

tensor([[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.5855, 0.4145, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.3671, 0.3057, 0.3272, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.3067, 0.2027, 0.2134, 0.2772, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.2399, 0.1743, 0.1903, 0.2152, 0.1803, 0.0000, 0.0000, 0.0000],
         [0.2356, 0.1099, 0.1793, 0.1609, 0.1562, 0.1582, 0.0000, 0.0000],
         [0.1727, 0.1160, 0.1506, 0.1440, 0.1019, 0.1473, 0.1674, 0.0000],
         [0.1329, 0.1086, 0.0991, 0.1404, 0.1164, 0.1108, 0.1354, 0.1564]],

        [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.5109, 0.4891, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.3849, 0.3327, 0.2825, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.2558, 0.2811, 0.1516, 0.3114, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.1803, 0.2208, 0.1489, 0.2042, 0.2458, 0.0000, 0.0000, 0.0000],
         [0.1801, 0.162

$$
\text{Attention}(Q, K, V) = \text{softmax} \left( \frac{Q K^T}{\sqrt{d_k}} \right) V
$$

In [83]:
k = torch.randn(B, T, head_size)
k.var()

tensor(0.9660)

In [84]:
q = torch.randn(B, T, head_size)
q.var()

tensor(0.9567)

In [85]:
affinity = q @ k.transpose(-2, -1) * head_size ** -0.5 # to keep the var
affinity.var()

tensor(0.8634)

In [86]:
# value matrix and normalization
B, T, C = 4, 8, 32
x = torch.randn(B, T, C)

head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
# match each token from embedding to head that related what can it provide
value = nn.Linear(C, head_size, bias=False)

k = key(x)
q = query(x)
v = value(x)

affinity = q @ k.transpose(-2, -1) / torch.sqrt(torch.tensor(head_size))
affinity = apply_mask(affinity) # (B, T, T)
out = affinity @ v # (B, T, head_size)
out[0]

tensor([[-0.4008, -0.2824, -0.0504, -0.1648,  0.2199, -0.6905,  0.3993,  0.0089,
          0.0030,  0.3607, -0.5116, -0.2277,  0.1935, -0.4891,  0.3365,  0.0191],
        [-0.5572, -0.3694,  0.0655, -0.1710, -0.0216, -0.5794, -0.0694,  0.0759,
          0.1772,  0.1894,  0.0330, -0.3166, -0.1639, -0.6994,  0.6523, -0.3480],
        [-0.3748, -0.3747,  0.0972, -0.3189,  0.0290, -0.6332,  0.0132,  0.3073,
          0.0451,  0.0726,  0.0844, -0.3094, -0.1222, -0.4423,  0.7516, -0.3954],
        [-0.3757, -0.1354,  0.0678, -0.3544,  0.5488, -0.5346,  0.2777,  0.4462,
          0.4850,  0.0387,  0.4747,  0.0056, -0.1109,  0.2689,  0.5192, -0.0423],
        [-0.2842, -0.2239,  0.1320, -0.2791,  0.2471, -0.6057,  0.1309,  0.4110,
          0.2765, -0.0763,  0.4293, -0.1000, -0.2380, -0.0087,  0.6724, -0.3274],
        [-0.2859, -0.1018,  0.1047, -0.1025,  0.4274, -0.4564,  0.3123,  0.2511,
          0.3937, -0.0916,  0.3692, -0.0540, -0.0763, -0.0555,  0.3209, -0.1108],
        [-0.5626, -0.2

##Model with single-head attention

In [87]:
batch_size = 32 # B
block_size = 8 # T
embedding_dims = 32 # C
vocab_size = 65
epoches = 5000
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [88]:
class Head(nn.Module):
  # single head self attention
  def __init__(self, block_size, embedding_dims, head_size):
    super().__init__()
    self.query = nn.Linear(embedding_dims, head_size)
    self.key = nn.Linear(embedding_dims, head_size)
    self.value = nn.Linear(embedding_dims, head_size)
    self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

  def forward(self, x):
    B, T, C = x.shape

    q = self.query(x)
    k = self.key(x)
    v = self.value(x)
    affinity = q @ k.transpose(-2, -1) / torch.sqrt(torch.tensor(head_size))

    affinity = torch.masked_fill(affinity, self.tril[:T, :T] == 0, float('-inf')) # for text generation
    affinity = F.softmax(affinity, dim=-1)

    out = affinity @ v # (B, T, T) @ (B, T, head_size) -> (B, T, head_size)
    return out

In [89]:
class SingleHeadModel(nn.Module):
  def __init__(self, vocab_size, embedding_dims, block_size):
    super().__init__()
    self.token_embedding = nn.Embedding(vocab_size, embedding_dims)
    self.positon_embedding = nn.Embedding(block_size, embedding_dims)
    self.self_attention_head = Head(block_size, embedding_dims, head_size=embedding_dims)
    self.decoder = nn.Linear(embedding_dims, vocab_size)

  def forward(self, x):
    B, T = x.shape
    x = self.token_embedding(x) + self.positon_embedding(torch.arange(T, device=device)) # (B, T, C) + (T, C) -> (B, T, C)
    x = self.self_attention_head(x) # (B, T, C)
    x = self.decoder(x) # (B, T, vocab_size)
    return x

model_single_head = SingleHeadModel(vocab_size=vocab_size, embedding_dims=embedding_dims, block_size=block_size).to(device)

In [90]:
# use x(shape(B, T)) to predict next token,
# after go through the model still use last token to predict the next token,
# and loop to generate text at new_seq_len length
def generate(x, new_seq_len): # x-> shape(B, T)
  model_single_head.eval()
  x = x.to(device)
  for _ in range(new_seq_len):
    with torch.inference_mode():
      context = x[:, -block_size:] # if x exceeds block_size, only feed the last block_size tokens to the model
      logits = model_single_head(context) # (B, T, vocab_size)
      logits_last = logits[:,-1,:] # (B, vocab_size)
      logits_last_prob = F.softmax(logits_last, dim=-1) # (B, vocab_size)
      token_next = torch.multinomial(logits_last_prob, num_samples=1) # (B, 1)
      x = torch.cat((x, token_next), dim=1) # (B, T+1)
  return x

In [91]:
optimizer = torch.optim.Adam(params=model_single_head.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()

In [94]:
for epoch in range(epoches):
  model_single_head.train()
  x_train, y_train = get_sample(train_data, batch_size, block_size) # get random sample in each epoch
  x_train, y_train = x_train.to(device), y_train.to(device)
  logits = model_single_head(x_train)
  B, T, C = logits.shape # change the shape to match the input shape of loss_fn
  logits = logits.view(B*T, C)
  targets = y_train.view(B*T)
  loss = loss_fn(logits, targets)
  optimizer.zero_grad()
  loss.backward()
  optimizer.step()

  model_single_head.eval()
  val_loss = 0.0
  with torch.inference_mode():
    test_loss = 0.0
    x_val, y_val = get_sample(val_data, batch_size, block_size) # get random sample in each epoch
    x_val, y_val = x_train.to(device), y_train.to(device)
    val_logits = model_single_head(x_val)
    B, T, C = val_logits.shape # change the shape to match the input shape of loss_fn
    val_logits = val_logits.view(B*T, C)
    val_targets = y_val.view(B*T)
    val_loss = loss_fn(val_logits, val_targets)
  if(epoch == 0):
    print(f'Logits shape: {logits.shape}')
  if(epoch % 1000 == 0):
    print(f'epoch: {epoch}, train loss: {loss.item():.4f}, val loss: {val_loss.item():.4f}')

Logits shape: torch.Size([256, 65])
epoch: 0, train loss: 2.3565, val loss: 2.3519
epoch: 1000, train loss: 2.5141, val loss: 2.5082
epoch: 2000, train loss: 2.6513, val loss: 2.6451
epoch: 3000, train loss: 2.2744, val loss: 2.2686
epoch: 4000, train loss: 2.3394, val loss: 2.3345


In [95]:
# generate use trained model to generate text
print(decode(generate(torch.zeros(1, 1, dtype=torch.long), new_seq_len=500).squeeze().cpu().numpy()))


IFO: as; nte ty hn my thur mod thand he, preckn,
OH:
D
TUESY:
Sow lapiwisers we hakergoeptend muprt, ody inde eve ditlat mang yofo thy
SThind.

Wheer url hat ght me machan fl Coust dieere--sbwurore cense wince me, woess thee?

Firomest;
Al me hming I hee lars ole!-
I ind hnerve she, f'd ban ndowr:
Lit ot t's lore ry, bd ph-e
peleme thildosos I hearung.

COLOREV ORER:
Tom, ban te ssefr I wos blodu'es cemnode bs as GmR:
Kef DI basy Vo on my.

IRWINIIONUnstiloru!

INIUCHake I eald barsit ut wntier.


##Multi-head self-attention model

In [118]:
batch_size = 32 # B
block_size = 8 # T
embedding_dims = 32 # C
vocab_size = 65
num_heads = 4
epoches = 5000
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [119]:
class Head(nn.Module): # in -> (B, T, C) out -> (B, T, head_size)
  # single head self attention
  def __init__(self, block_size, embedding_dims, head_size):
    super().__init__()
    self.query = nn.Linear(embedding_dims, head_size)
    self.key = nn.Linear(embedding_dims, head_size)
    self.value = nn.Linear(embedding_dims, head_size)
    self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

  def forward(self, x):
    B, T, C = x.shape

    q = self.query(x)
    k = self.key(x)
    v = self.value(x)
    affinity = q @ k.transpose(-2, -1) / torch.sqrt(torch.tensor(head_size))

    affinity = torch.masked_fill(affinity, self.tril[:T, :T] == 0, float('-inf')) # for text generation
    affinity = F.softmax(affinity, dim=-1)

    out = affinity @ v # (B, T, T) @ (B, T, head_size) -> (B, T, head_size)
    return out

In [120]:
class MultiHead(nn.Module): # in -> (B, T, C) out -> (B, T, num_heads * head_size)
  def __init__(self, num_heads, block_size, embedding_dims, head_size):
    super().__init__()
    self.heads = nn.ModuleList([Head(block_size, embedding_dims, head_size) for _ in range(num_heads)])
  def forward(self, x):
    return torch.cat([head(x) for head in self.heads], dim=-1)

In [121]:
class FeedForward(nn.Module): # in (B, T, C) out -> (B, T, C)
  def __init__(self, embedding_dims):
    super().__init__()
    self.layer_stack = nn.Sequential(
        nn.Linear(embedding_dims, embedding_dims), # add feedforward network
        nn.ReLU(),
        nn.Linear(embedding_dims, embedding_dims)
        )
  def forward(self, x):
    return self.layer_stack(x)

In [122]:
class AttentionBlock(nn.Module): # in (B, T, C) out -> (B, T, C)
  def __init__(self, num_heads, block_size, embedding_dims):
    super().__init__()
    head_size = embedding_dims // num_heads
    self.self_attention = MultiHead(num_heads, block_size, embedding_dims, head_size)
    self.feedforward = FeedForward(embedding_dims)
  def forward(self, x): # x shape (B, T, C)
    x = self.self_attention(x)
    x = self.feedforward(x)
    return x

In [123]:
class MultiHeadModel(nn.Module):
  def __init__(self, num_heads, vocab_size, embedding_dims, block_size):
    super().__init__()
    self.token_embedding = nn.Embedding(vocab_size, embedding_dims)
    self.positon_embedding = nn.Embedding(block_size, embedding_dims)
    self.attention_blocks = nn.Sequential(
        AttentionBlock(num_heads, block_size, embedding_dims),
        AttentionBlock(num_heads, block_size, embedding_dims),
        AttentionBlock(num_heads, block_size, embedding_dims),
    )
    self.linear = nn.Linear(embedding_dims, vocab_size)

  def forward(self, x):
    B, T = x.shape
    x = self.token_embedding(x) + self.positon_embedding(torch.arange(T, device=device)) # (B, T, C) + (T, C) -> (B, T, C)
    x = self.attention_blocks(x) # (B, T, C)
    x = self.linear(x) # (B, T, vocab_size)
    return x

model_multi_head = MultiHeadModel(num_heads=num_heads,
                                  vocab_size=vocab_size,
                                  embedding_dims=embedding_dims,
                                  block_size=block_size).to(device)

In [124]:
optimizer = torch.optim.Adam(params=model_multi_head.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()

In [125]:
for epoch in range(epoches):
  model_multi_head.train()
  x_train, y_train = get_sample(train_data, batch_size, block_size) # get random sample in each epoch
  x_train, y_train = x_train.to(device), y_train.to(device)
  logits = model_multi_head(x_train)
  B, T, C = logits.shape # change the shape to match the input shape of loss_fn
  logits = logits.view(B*T, C)
  targets = y_train.view(B*T)
  loss = loss_fn(logits, targets)
  optimizer.zero_grad()
  loss.backward()
  optimizer.step()

  model_multi_head.eval()
  val_loss = 0.0
  with torch.inference_mode():
    test_loss = 0.0
    x_val, y_val = get_sample(val_data, batch_size, block_size) # get random sample in each epoch
    x_val, y_val = x_train.to(device), y_train.to(device)
    val_logits = model_multi_head(x_val)
    B, T, C = val_logits.shape # change the shape to match the input shape of loss_fn
    val_logits = val_logits.view(B*T, C)
    val_targets = y_val.view(B*T)
    val_loss = loss_fn(val_logits, val_targets)
  if(epoch == 0):
    print(f'Logits shape: {logits.shape}')
  if(epoch % 1000 == 0):
    print(f'epoch: {epoch}, train loss: {loss.item():.4f}, val loss: {val_loss.item():.4f}')

Logits shape: torch.Size([256, 65])
epoch: 0, train loss: 4.1529, val loss: 4.1456
epoch: 1000, train loss: 2.7593, val loss: 2.7370
epoch: 2000, train loss: 2.5666, val loss: 2.5541
epoch: 3000, train loss: 2.4018, val loss: 2.3912
epoch: 4000, train loss: 2.3907, val loss: 2.3790


In [126]:
def generate(x, new_seq_len): # x-> shape(B, T)
  model_multi_head.eval()
  x = x.to(device)
  for _ in range(new_seq_len):
    with torch.inference_mode():
      context = x[:, -block_size:] # if x exceeds block_size, only feed the last block_size tokens to the model
      logits = model_multi_head(context) # (B, T, vocab_size)
      logits_last = logits[:,-1,:] # (B, vocab_size)
      logits_last_prob = F.softmax(logits_last, dim=-1) # (B, vocab_size)
      token_next = torch.multinomial(logits_last_prob, num_samples=1) # (B, 1)
      x = torch.cat((x, token_next), dim=1) # (B, T+1)
  return x

In [127]:
# generate use trained model to generate text
print(decode(generate(torch.zeros(1, 1, dtype=torch.long), new_seq_len=500).squeeze().cpu().numpy()))


Wour ghok earnf therer enl: by dot it
And shubancllelice.

KFOORVYD:
To hod had mnerpveme hearver thy fipu blaonge sbomm liten nois butaxnse rear'n; thate noser buslece thilgood liek leroparebiredsy Vrowting erbe none; py ckise wirty: ther ler hatuasst gakes theardjrerpangd you'sm,'k treabuchoe, meid: ther, wou! the boveray, as.

I thother; cherst woursarmnsesemong mee'd tnesk you moud bearourtmordfr the trrer thur jomem siir frededfecoder? shaciles confamidely domernn, thel, Tow, minealtondaed,


Make model deeper and applying tricks that to make sure network in optimizable
1.  add residual connection
2.  layer norm
3.  dropout layer

In [185]:
batch_size = 64 # B
block_size = 256 # T
embedding_dims = 384 # C
vocab_size = 65
num_heads = 6
n_layers = 6
dropout = 0.2

epoches = 5000
learning_rate = 3e-4
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [186]:
class Head(nn.Module): # in -> (B, T, C) out -> (B, T, head_size)
  # single head self attention
  def __init__(self, block_size, embedding_dims, head_size):
    super().__init__()
    self.query = nn.Linear(embedding_dims, head_size)
    self.key = nn.Linear(embedding_dims, head_size)
    self.value = nn.Linear(embedding_dims, head_size)
    self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
    self.dropout = nn.Dropout(dropout)

  def forward(self, x):
    B, T, C = x.shape

    q = self.query(x)
    k = self.key(x)
    v = self.value(x)
    affinity = q @ k.transpose(-2, -1) / torch.sqrt(torch.tensor(head_size))

    affinity = torch.masked_fill(affinity, self.tril[:T, :T] == 0, float('-inf')) # for text generation
    affinity = F.softmax(affinity, dim=-1)
    affinity = self.dropout(affinity) # dropout the affinity matrix
    out = affinity @ v # (B, T, T) @ (B, T, head_size) -> (B, T, head_size)
    return out

In [187]:
class MultiHead(nn.Module): # in -> (B, T, C) out -> (B, T, num_heads * head_size)
  def __init__(self, num_heads, block_size, embedding_dims, head_size):
    super().__init__()
    self.heads = nn.ModuleList([Head(block_size, embedding_dims, head_size) for _ in range(num_heads)])
    self.proj = nn.Linear(embedding_dims, embedding_dims) # residual connection
    self.dropout = nn.Dropout(dropout)
  def forward(self, x):
    x = torch.cat([head(x) for head in self.heads], dim=-1)
    x = self.proj(x)
    x = self.dropout(x)
    return x

In [188]:
class FeedForward(nn.Module): # in (B, T, C) out -> (B, T, C)
  def __init__(self, embedding_dims):
    super().__init__()
    self.layer_stack = nn.Sequential(
        nn.Linear(embedding_dims, 4 * embedding_dims), # add feedforward network
        nn.ReLU(),
        nn.Linear(4 * embedding_dims, embedding_dims), # projection layer that go back to residual pathway
        nn.Dropout(dropout) # add dropout after residual path way
        )
  def forward(self, x):
    return self.layer_stack(x)

In [189]:
class AttentionBlock(nn.Module): # in (B, T, C) out -> (B, T, C)
  def __init__(self, num_heads, block_size, embedding_dims):
    super().__init__()
    head_size = embedding_dims // num_heads
    self.self_attention = MultiHead(num_heads, block_size, embedding_dims, head_size)
    self.feedforward = FeedForward(embedding_dims)
    self.layernorm1 = nn.LayerNorm(embedding_dims)
    self.layernorm2 = nn.LayerNorm(embedding_dims)
  def forward(self, x): # x shape (B, T, C)
    x = x + self.self_attention(self.layernorm1(x)) # add residual connection and layer norm
    x = x + self.feedforward(self.layernorm2(x)) # add residual connection
    return x

In [190]:
class MultiHeadModel(nn.Module):
  def __init__(self, num_heads, vocab_size, embedding_dims, block_size):
    super().__init__()
    self.token_embedding = nn.Embedding(vocab_size, embedding_dims)
    self.positon_embedding = nn.Embedding(block_size, embedding_dims)
    self.attention_blocks = nn.Sequential(
      *[AttentionBlock(num_heads, block_size, embedding_dims) for _ in range(n_layers)]
      )
    self.layernorm = nn.LayerNorm(embedding_dims) # add layer norm
    self.linear = nn.Linear(embedding_dims, vocab_size)

  def forward(self, x):
    B, T = x.shape
    x = self.token_embedding(x) + self.positon_embedding(torch.arange(T, device=device)) # (B, T, C) + (T, C) -> (B, T, C)
    x = self.attention_blocks(x)
    x = self.layernorm(x)
    x = self.linear(x) # (B, T, vocab_size)
    return x

model_multi_head = MultiHeadModel(num_heads=num_heads,
                                  vocab_size=vocab_size,
                                  embedding_dims=embedding_dims,
                                  block_size=block_size).to(device)

In [191]:
optimizer = torch.optim.Adam(params=model_multi_head.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()

In [192]:
for epoch in range(epoches):
  model_multi_head.train()
  x_train, y_train = get_sample(train_data, batch_size, block_size) # get random sample in each epoch
  x_train, y_train = x_train.to(device), y_train.to(device)
  logits = model_multi_head(x_train)
  B, T, C = logits.shape # change the shape to match the input shape of loss_fn
  logits = logits.view(B*T, C)
  targets = y_train.view(B*T)
  loss = loss_fn(logits, targets)
  optimizer.zero_grad()
  loss.backward()
  optimizer.step()

  model_multi_head.eval()
  val_loss = 0.0
  with torch.inference_mode():
    test_loss = 0.0
    x_val, y_val = get_sample(val_data, batch_size, block_size) # get random sample in each epoch
    x_val, y_val = x_train.to(device), y_train.to(device)
    val_logits = model_multi_head(x_val)
    B, T, C = val_logits.shape # change the shape to match the input shape of loss_fn
    val_logits = val_logits.view(B*T, C)
    val_targets = y_val.view(B*T)
    val_loss = loss_fn(val_logits, val_targets)
  if(epoch == 0):
    print(f'Logits shape: {logits.shape}')
  if(epoch % 1000 == 0):
    print(f'epoch: {epoch}, train loss: {loss.item():.4f}, val loss: {val_loss.item():.4f}')

Logits shape: torch.Size([16384, 65])
epoch: 0, train loss: 4.3724, val loss: 3.9283
epoch: 1000, train loss: 1.4644, val loss: 1.3727
epoch: 2000, train loss: 1.2678, val loss: 1.1781
epoch: 3000, train loss: 1.1298, val loss: 1.0222
epoch: 4000, train loss: 1.0539, val loss: 0.9323


In [193]:
def generate(x, new_seq_len): # x-> shape(B, T)
  model_multi_head.eval()
  x = x.to(device)
  for _ in range(new_seq_len):
    with torch.inference_mode():
      context = x[:, -block_size:] # if x exceeds block_size, only feed the last block_size tokens to the model
      logits = model_multi_head(context) # (B, T, vocab_size)
      logits_last = logits[:,-1,:] # (B, vocab_size)
      logits_last_prob = F.softmax(logits_last, dim=-1) # (B, vocab_size)
      token_next = torch.multinomial(logits_last_prob, num_samples=1) # (B, 1)
      x = torch.cat((x, token_next), dim=1) # (B, T+1)
  return x

In [194]:
# generate use trained model to generate text
print(decode(generate(torch.zeros(1, 1, dtype=torch.long), new_seq_len=2000).squeeze().cpu().numpy()))


Or seem'st a bodk indeed great proper bench
And that's before his limbs.

COMINIUS:
Do you think oft mine: thus but is an end
The chamber of birth countenance guilty of
The spirits and language of him what we heard the brother
Laid's prenzie.

MARCIUS:
There is a present death to our mortal lady,
Our soldier vice to instruct me of our
business to bite in the spirit of custom, the state
And strength o' the melancholy brains
All reasons of near, the nice o' the time;
The earth of the royalties and kindly be here,
Ten thyself and all the sovereigns were still shed
He was the prevail of his livery.

LEONTES:
Nothing but he won that; only of my master
Appear it, he is sent in my sex tongue!

PAULINA:
Is true? Is this truth?

LEONTES:
Ay, as the imagine of him? My cause.

LEONTES:
No, I promised well for that.

LEONTES:
My reasons for thee:
For a friar good Lord ever made her;
Behold, hold, am I, that, unless
Dead, where that he profits like him beaten an
The state and do choke of children 