In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Use CUDA if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
from pathlib import Path

text = Path('../data/tiny-shakespeare.txt').read_text()

In [3]:
print(text[0:1000])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor citizens, the patricians good.
What authority surfeits on would relieve us: if they
would yield us but the superfluity, while it were
wholesome, we might guess they relieved us humanely;
but they think we are too dear: the leanness that
afflicts us, the object of our misery, is as an
inventory to particularise their abundance; our
sufferance is a gain to them Let us revenge this with
our pikes, ere we become rakes: for the gods know I
speak this in hunger for bread, not in thirst for revenge.



In [4]:

class CharTokenizer:
  def __init__(self, vocabulary):
    self.token_id_for_char = {char: token_id for token_id, char in enumerate(vocabulary)}
    self.char_for_token_id = {token_id: char for token_id, char in enumerate(vocabulary)}

  @staticmethod
  def train_from_text(text):
    vocabulary = set(text)
    return CharTokenizer(sorted(list(vocabulary)))

  def encode(self, text):
    token_ids = []
    for char in text:
      token_ids.append(self.token_id_for_char[char])
    return torch.tensor(token_ids, dtype=torch.long)

  def decode(self, token_ids):
    chars = []
    for token_id in token_ids.tolist():
      chars.append(self.char_for_token_id[token_id])
    return ''.join(chars)


  def vocabulary_size(self):
    return len(self.token_id_for_char)

In [5]:
tokenizer = CharTokenizer.train_from_text(text)

In [6]:
print(tokenizer.encode("Hello world"))
print(tokenizer.decode(tokenizer.encode("Hello world")))

tensor([20, 43, 50, 50, 53,  1, 61, 53, 56, 50, 42])
Hello world


In [7]:
print(f"Vocabulary size: {tokenizer.vocabulary_size()}")

Vocabulary size: 65


In [8]:
from torch.utils.data import Dataset

class TokenIdsDataset(Dataset):
  def __init__(self, data, block_size):
    self.data = data
    self.block_size = block_size

  def __len__(self):
    return len(self.data) - self.block_size

  def __getitem__(self, pos):
    assert pos < len(self.data) - self.block_size

    x = self.data[pos:pos + self.block_size]
    y = self.data[pos + 1:pos + 1 + self.block_size]
    return x, y

In [9]:
config = {
  "vocabulary_size": tokenizer.vocabulary_size(),
  "context_size": 256,
  "embedding_dim": 768,
  "heads_num": 12,
  "layers_num": 10,
  "dropout_rate": 0.1,
  "use_bias": False,
}

config["head_size"] = config["embedding_dim"] // config["heads_num"]

In [10]:
class AttentionHead(nn.Module):
  def __init__(self, config):
    super().__init__()
    self.Q_weights = nn.Linear(config["embedding_dim"], config["head_size"], config["use_bias"])
    self.K_weights = nn.Linear(config["embedding_dim"], config["head_size"], config["use_bias"])
    self.V_weights = nn.Linear(config["embedding_dim"], config["head_size"], config["use_bias"])

    self.dropout = nn.Dropout(config["dropout_rate"])

    casual_attention_mask = torch.tril(torch.ones(config["context_size"], config["context_size"]))
    self.register_buffer('casual_attention_mask', casual_attention_mask)

  def forward(self, input):
    batch_size, tokens_num, embedding_dim = input.shape
    Q = self.Q_weights(input)
    K = self.K_weights(input)
    V = self.V_weights(input)

    attention_scores = Q @ K.transpose(1, 2)
    attention_scores = attention_scores.masked_fill(
        self.casual_attention_mask[:tokens_num,:tokens_num] == 0,
        -torch.inf
    )
    attention_scores = attention_scores / ( K.shape[-1] ** 0.5 )
    attention_scores = torch.softmax(attention_scores, dim=-1)
    attention_scores = self.dropout(attention_scores)

    return attention_scores @ V

In [11]:
input = torch.rand(8, config["context_size"], config["embedding_dim"])

In [12]:
ah = AttentionHead(config)

In [13]:
output = ah(input)

In [14]:
output.shape

torch.Size([8, 256, 64])

In [15]:
class MultiHeadAttention(nn.Module):
  def __init__(self, config):
    super().__init__()

    heads_list = [AttentionHead(config) for _ in range(config["heads_num"])]
    self.heads = nn.ModuleList(heads_list)

    self.linear = nn.Linear(config["embedding_dim"], config["embedding_dim"])
    self.dropout = nn.Dropout(config["dropout_rate"])

  def forward(self, input):
    # print(f"Input shape: {input.shape}")
    heads_outputs = [head(input) for head in self.heads]

    scores_change = torch.cat(heads_outputs, dim=-1)
    # print(f"heads shape: {scores_change.shape}")

    scores_change = self.linear(scores_change)
    return self.dropout(scores_change)

In [16]:
mha = MultiHeadAttention(config)

In [17]:
input = torch.rand(8, config["context_size"], config["embedding_dim"])

In [18]:
output = mha(input)

In [19]:
output.shape

torch.Size([8, 256, 768])

In [20]:
class FeedForward(nn.Module):

  def __init__(self, config):
    super().__init__()

    self.linear_layers = nn.Sequential(
        nn.Linear(config["embedding_dim"], config["embedding_dim"] * 4),
        nn.GELU(),
        nn.Linear(config["embedding_dim"] * 4, config["embedding_dim"]),
        nn.Dropout(config["dropout_rate"])
    )

  def forward(self, input):
    return self.linear_layers(input)

In [21]:
ff = FeedForward(config)

In [22]:
input = torch.rand(8, config["context_size"], config["embedding_dim"])

In [23]:
ouptut = ff(input)

In [24]:
output.shape

torch.Size([8, 256, 768])

In [25]:
class Block(nn.Module):

  def __init__(self, config):
    super().__init__()

    self.multi_head = MultiHeadAttention(config)
    self.layer_norm_1 = nn.LayerNorm(config["embedding_dim"])

    self.feed_forward = FeedForward(config)
    self.layer_norm_2 = nn.LayerNorm(config["embedding_dim"])

  def forward(self, input):
    residual = input
    x = self.multi_head(self.layer_norm_1(input))
    x = x + residual

    residual = x
    x = self.feed_forward(self.layer_norm_2(x))
    return x + residual

In [26]:
b = Block(config)

In [27]:
ouptut = b(input)

In [28]:
output.shape

torch.Size([8, 256, 768])

In [29]:
class DemoGPT(nn.Module):
  def __init__(self, config):
    super().__init__()

    self.token_embedding_layer = nn.Embedding(config["vocabulary_size"], config["embedding_dim"])
    self.positional_embedding_layer = nn.Embedding(config["context_size"], config["embedding_dim"])

    blocks = [Block(config) for _ in range(config["layers_num"])]
    self.layers = nn.Sequential(*blocks)

    self.layer_norm = nn.LayerNorm(config["embedding_dim"])
    self.unembedding = nn.Linear(config["embedding_dim"], config["vocabulary_size"], bias=False)

  def forward(self, token_ids):
    # print("Forward")
    batch_size, tokens_num = token_ids.shape

    x = self.token_embedding_layer(token_ids)
    sequence = torch.arange(tokens_num, device=device)
    x = x + self.positional_embedding_layer(sequence)

    x = self.layers(x)
    x = self.layer_norm(x)
    x = self.unembedding(x)

    return x

In [30]:
model = DemoGPT(config).to(device)

In [31]:
output = model(tokenizer.encode("Hi").unsqueeze(dim=0).to(device))

In [32]:
output.shape

torch.Size([1, 2, 65])

In [33]:
def generate(model, prompt_ids, max_tokens):
    output_ids = prompt_ids
    for _ in range(max_tokens):
      if output_ids.shape[1] >= config["context_size"]:
        break
      with torch.no_grad():
        logits = model(output_ids)

      logits = logits[:, -1, :]
      probs = F.softmax(logits, dim=-1)
      # Sample a random token given the softmax distribution
      next_token_id = torch.multinomial(probs, num_samples=1)
      # Add new token to the output, and repeat the process
      output_ids = torch.cat([output_ids, next_token_id], dim=-1)
    return output_ids

In [34]:
def generate_with_prompt(model, tokenizer, prompt, max_tokens=100):
  model.eval()

  prompt = tokenizer.encode(prompt).unsqueeze(dim=0).to(device)

  return tokenizer.decode(generate(model, prompt, max_tokens=max_tokens)[0])

In [35]:
generate_with_prompt(model, tokenizer, "First Citizen:\n")

'First Citizen:\nTA\n!Q.-OQBkz?QuY\nrZjsE bSQj&zPdmVprb3Kr!h?$L-Nk pD3fXz  WbFVsMYFNnXIUgXGXCjoqZLA.VWV;RFTZ.wAvNnD-!VR'

In [36]:
batch_size = 64

train_iterations = 5000
evaluation_interval = 100
learning_rate=4e-4

In [37]:
train_data = tokenizer.encode(text).to(device)

In [38]:
train_dataset = TokenIdsDataset(train_data, config["context_size"])

In [39]:
from torch.utils.data import Dataset, DataLoader, RandomSampler

train_sampler = RandomSampler(train_dataset, num_samples=batch_size * train_iterations, replacement=True)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler)

In [40]:
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

In [None]:
for step_num, sample in enumerate(train_dataloader):

  model.train()
  input, targets = sample
  logits = model(input)

  logits_view = logits.view(batch_size * config["context_size"], config["vocabulary_size"])
  targets_view = targets.view(batch_size * config["context_size"])
  
  loss = F.cross_entropy(logits_view, targets_view)
  # Backward propagation
  loss.backward()
  # Update model parameters
  optimizer.step()
  # Set to None to reduce memory usage
  optimizer.zero_grad(set_to_none=True)

  print(f"Step {step_num}. Loss {loss.item():.3f}")

  if step_num % evaluation_interval == 0:
    print("Demo GPT:\n" + generate_with_prompt(model, tokenizer, "\n"))

Step 0. Loss 4.389
Demo GPT:

o o   o   a  t o  t t o
  r X    o  w    s   e t o n       r  t      p    e s s   e t    t t   o    
Step 1. Loss 4.356
Step 2. Loss 4.362
Step 3. Loss 4.475
Step 4. Loss 4.126
Step 5. Loss 3.460
Step 6. Loss 3.013
Step 7. Loss 2.933
Step 8. Loss 2.900
Step 9. Loss 2.864
Step 10. Loss 2.822
Step 11. Loss 2.804
Step 12. Loss 2.735
Step 13. Loss 2.710
Step 14. Loss 2.712
Step 15. Loss 2.694
Step 16. Loss 2.677
Step 17. Loss 2.608
Step 18. Loss 2.624
Step 19. Loss 2.623
Step 20. Loss 2.594
Step 21. Loss 2.606
Step 22. Loss 2.600
Step 23. Loss 2.570
Step 24. Loss 2.573
Step 25. Loss 2.560
Step 26. Loss 2.561
Step 27. Loss 2.535
Step 28. Loss 2.553
Step 29. Loss 2.552
Step 30. Loss 2.545
Step 31. Loss 2.535
Step 32. Loss 2.538
Step 33. Loss 2.513
Step 34. Loss 2.526
Step 35. Loss 2.524
Step 36. Loss 2.505
Step 37. Loss 2.506
Step 38. Loss 2.503
Step 39. Loss 2.495
Step 40. Loss 2.515
Step 41. Loss 2.499
Step 42. Loss 2.517
Step 43. Loss 2.510
Step 44. Loss 2.49