# Dataset

We download a ~1MB file containing the entirety of Shakespeare's work. This is the dataset we will train our language model on.

In [1]:
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

--2024-07-14 15:55:20--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.110.133, 185.199.111.133, 185.199.108.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.110.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt’


2024-07-14 15:55:21 (2.92 MB/s) - ‘input.txt’ saved [1115394/1115394]



In [31]:
import torch
from tqdm import trange
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import os
from IPython.core.display_functions import clear_output
import matplotlib.pyplot as plt
import math
device = "cpu"
if torch.backends.mps.is_available():
    device = "mps:0"
elif torch.cuda.is_available():
    device = "cuda:0"


In [3]:
with open('input.txt') as f:
    text = f.read()

print("----Sample Shakespeare----")
print(text[:100])

----Sample Shakespeare----
First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You


## Tokenization

Tokenization converts raw sub-sequences of text (substrings) to sequences of integers. For example, `"ll." -> 208`. We will be developing a character level language model, so we will be converting each individual word into an integer. For example, `"Hello" -> 48`.

In [4]:
def split_to_words(text):
    words = []
    word = ""
    for c in text:
        if c.isalnum():
            word += c
        else:
            words.append(word)
            words.append(c)
            word = ""
    words.append(word)
    return words

words = list(set(split_to_words(text)))
vocab_size = len(words)
print("Number of distinct words in text: {}".format(vocab_size))

Number of distinct words in text: 13334


In [5]:
# Use index to map words to integer
stoi = {word:i for i, word in enumerate(words)}
itos = {i:word for i, word in enumerate(words)}
def words_to_tokens(words):
    return [stoi[w] for w in words]

def tokens_to_words(int_list):
    return [itos[i] for i in int_list]

sample_words = split_to_words(text)[:10]
print("Original text: {}".format("".join(sample_words)))
print("Encoded text: {}".format(words_to_tokens(sample_words)))
print("Decoded text: {}".format(tokens_to_words(words_to_tokens(sample_words))))

Original text: First Citizen:
Before we 
Encoded text: [11026, 12486, 12555, 10887, 0, 8957, 8995, 12486, 11929, 12486]
Decoded text: ['First', ' ', 'Citizen', ':', '', '\n', 'Before', ' ', 'we', ' ']


In [6]:
tokenized_text = words_to_tokens(split_to_words(text))
print("Encoded text sample: {}".format(tokenized_text[:10]))
print(tokens_to_words(tokenized_text[:10]))
tokenized_text = torch.tensor(tokenized_text)

Encoded text sample: [11026, 12486, 12555, 10887, 0, 8957, 8995, 12486, 11929, 12486]
['First', ' ', 'Citizen', ':', '', '\n', 'Before', ' ', 'we', ' ']


## Data Split

In [7]:
context_size = 32
split_factor = 0.9
split_index = int(split_factor * len(tokenized_text))
train = tokenized_text[:split_index].to(device)
test = tokenized_text[split_index:].to(device)

## Data Loader

In [8]:
class TextDataset(Dataset):
    def __init__(self, text, context_size):
        self.text = text
        self.context_size = context_size
        assert self.context_size < len(text), "context_size must be less than len(text)"

    def __len__(self):
        return len(self.text) - self.context_size

    def __getitem__(self, idx):
        return self.text[idx:idx + self.context_size],  self.text[idx + 1:idx + self.context_size + 1]

train_set = TextDataset(train, context_size)
test_set = TextDataset(test, context_size)

train_loader = DataLoader(train_set, batch_size=64, shuffle=True)
test_loader = DataLoader(test_set, batch_size=64, shuffle=False)

# Embeddings

We will use PCA to create the token embeddings

In [12]:
# Create co-occurrence matrix
# The co-occurrence matrix X is a VxV (V is our vocab size) symmetric matrix where X_ij is how many times the ith word appears within W words away from the jth word.
W = 10
X = torch.stack([torch.zeros(len(words)) for _ in range(len(words))])
for i in trange(len(tokenized_text)):
    words_to_right = tokenized_text[i+1:i+W+1]
    words_to_left = tokenized_text[i-W:i]
    X[tokenized_text[i], words_to_right] += 1.0
    X[tokenized_text[i], words_to_left] += 1.0
X = X.to(device)

100%|██████████| 528579/528579 [00:07<00:00, 67490.79it/s]


In [13]:
# Torch has a bug on mps devices so this won't work on MacBooks
embedding_dim = 256
X -= X.mean(dim=1, keepdim=True)
X /= X.std(dim=1, keepdim=True)
cov = (X @ X.T)/(X.shape[0] - 1)
L, Q = torch.linalg.eigh(cov)
principle_eigv = Q[:, -embedding_dim:].T
embeddings = X @ principle_eigv.T # (vocab_size, embedding_dim)

NotImplementedError: The operator 'aten::_linalg_eigh.eigenvalues' is not currently implemented for the MPS device. If you want this op to be added in priority during the prototype phase of this feature, please comment on https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.

# Model

In [14]:
class Attn(nn.Module):
    def __init__(self, num_heads, dqk, dv):
        super(Attn, self).__init__()
        self.num_heads = num_heads
        self.dqk = dqk
        self.dv = dv
        self.Wq = nn.Parameter(torch.randn(num_heads, embedding_dim, dqk))
        nn.init.kaiming_uniform_(self.Wq, a=math.sqrt(5))
        self.Wk = nn.Parameter(torch.randn(num_heads, embedding_dim, dqk))
        nn.init.kaiming_uniform_(self.Wk, a=math.sqrt(5))
        self.Wv = nn.Parameter(torch.randn(num_heads, embedding_dim, dv))
        nn.init.kaiming_uniform_(self.Wv, a=math.sqrt(5))
        self.Wo = nn.Parameter(torch.randn(num_heads * dv, embedding_dim))
        nn.init.kaiming_uniform_(self.Wo, a=math.sqrt(5))
    def forward(self, x, use_mask=False):

        if len(x.shape) == 2:
            x = x.unsqueeze(0)

        B, N, D = x.shape
        x = x.unsqueeze(1)
        q = x @ self.Wq.unsqueeze(0)
        k = x @ self.Wk.unsqueeze(0)
        v = x @ self.Wv.unsqueeze(0)
        qk =  q @ k.transpose(-2, -1) * (self.dqk ** -0.5)

        if use_mask:
            mask = torch.tril_indices(qk.shape[-2], qk.shape[-1], -1)
            qk[:, :, mask[0], mask[1]] = float('-inf')

        softmax_qk = F.softmax(qk, dim=-1)
        qkv = softmax_qk @ v
        concat_qkv = qkv.permute(0, 2, 1, 3).reshape(B, N, self.num_heads * self.dv)
        out = concat_qkv @ self.Wo.unsqueeze(0)

        return out


In [15]:
class DecoderLayer(nn.Module):
    def __init__(self, num_heads, dqk=embedding_dim, dv=embedding_dim):
        super(DecoderLayer, self).__init__()
        self.masked_attn = Attn(num_heads, dqk, dv)
        self.norm1 = nn.LayerNorm(embedding_dim)
        self.norm2 = nn.LayerNorm(embedding_dim)
        self.linear1 = nn.Linear(embedding_dim, 3 * embedding_dim)
        self.linear2 = nn.Linear(3 * embedding_dim, embedding_dim)
    def forward(self, x):
        x = self.masked_attn(self.norm1(x)) + x
        x = self.linear2(F.relu(self.linear1(self.norm2(x)))) + x
        return x

In [16]:
class LLM(nn.Module):
  def __init__(self, num_blocks, num_heads_per_block, key_query_dim=embedding_dim, value_dim=embedding_dim):
    super(LLM, self).__init__()
    self.num_blocks = num_blocks
    self.attn = Attn(num_heads_per_block, key_query_dim, value_dim)
    self.position_embedding = nn.Embedding(context_size, embedding_dim)
    self.token_embedding = embeddings
    self.decoder_layers = nn.ModuleList([DecoderLayer(num_heads_per_block, key_query_dim, value_dim) for _ in range(num_blocks)])
    self.norm = nn.LayerNorm(embedding_dim)
    self.out = nn.Linear(embedding_dim, vocab_size)

  def forward(self, tokens):
    token_emb = self.token_embedding[tokens]
    pos_emb = self.position_embedding(torch.arange(tokens.shape[1], device=device))
    x = token_emb + pos_emb
    for layer in self.decoder_layers:
        x = layer(x)

    return self.out(self.norm(x))

  def generate(self, input_tokens, max_generate_tokens=500):
    for _ in range(max_generate_tokens):
      logits = self(input_tokens[: , -context_size:])
      logits = logits[:, -1, :]
      probs = F.softmax(logits, dim=-1)
      next_token = torch.multinomial(probs, num_samples=1)
      input_tokens = torch.cat([input_tokens, next_token], dim=1)
    return input_tokens

In [32]:
num_blocks = 6
num_heads_per_block = 8
if os.path.exists("./model.pt"):
    model = torch.load("./model.pt", map_location=device)
else:
    model = LLM(num_blocks, num_heads_per_block).to(device)
lr = 1e-4
opt = optim.AdamW(model.parameters(), lr=lr)
num_epochs = 100
model.eval()

LLM(
  (attn): Attn()
  (position_embedding): Embedding(32, 256)
  (decoder_layers): ModuleList(
    (0-5): 6 x DecoderLayer(
      (masked_attn): Attn()
      (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (linear1): Linear(in_features=256, out_features=768, bias=True)
      (linear2): Linear(in_features=768, out_features=256, bias=True)
    )
  )
  (norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
  (out): Linear(in_features=256, out_features=13334, bias=True)
)

In [33]:
train_loss_evolution = []
for epoch in trange(num_epochs):
    train_loss = 0
    for i, (x, y) in enumerate(train_loader):
        logits = model(x)
        batch_size, _, _ = logits.shape
        loss = F.cross_entropy(logits.view(batch_size * context_size, -1), y.view(batch_size * context_size, -1).squeeze())
        opt.zero_grad()
        loss.backward()
        opt.step()
        train_loss += loss.item()
    train_loss_evolution.append(train_loss/len(train_loader))
    clear_output()
    print(f"Epoch {epoch}, Loss {train_loss/len(train_loader)}")
    plt.plot(train_loss_evolution)

  0%|          | 0/100 [14:43<?, ?it/s]


KeyboardInterrupt: 

In [34]:
test_loss = 0
with torch.no_grad():
    for i, (x, y) in enumerate(test_loader):
        logits = model(x)
        batch_size, _, _ = logits.shape
        loss = F.cross_entropy(logits.view(batch_size * context_size, -1), y.view(batch_size * context_size, -1).squeeze())
        test_loss += loss.item()

print("Test loss: ", test_loss / len(test_loader))

Test loss:  0.12407381848835772


In [35]:
initial = test[132:164].unsqueeze(0)
print("".join(tokens_to_words(model.generate(initial, max_generate_tokens=1000).squeeze().tolist())))


BAPTISTA:
After my death the one half of my lands,
And in possession twenty and with this-northneitherquarry persons.
boss, life good Duke as thy Adieu this gentleman?
The King is a high
Accursed in the mark of his time-traitor were.'
How, arms thou hast well, in him ere kinsman.

POLIXENES:
Till, but preachment, I would have been took;
For 'Let I think'd this to come spell.

FRIAR Murderer:
My, sir, if in the Unto, sir;
Fair foolish, come to the put world of your own?

ANGELO:
Be without me of his remembrance bold face?
And acquit, go helding is so she, which you heard;
velvet to plead point of his prince;
For favourites, in flatter expiring well-wall, scatter, which I have,
Who murder Polixenes tongues brought attend do from;
The R northern, they! man well, come.

POLIXENES:
O is have, some son seem, there thou shalt.

ROMEO:
Quick, sir, Richmond up?

Servant:
Those, therefore.

GLOUCESTER:
mirth my lord?

BUCKINGHAM:
Ay, when, sir;
Let on hope on't brains than yours.

GREGORY:
Is b