In [2]:
from tokenizers import Tokenizer

file_name = "bpe_tokenizer.json"
tokenizer = Tokenizer.from_file(file_name)
vocab_size = tokenizer.get_vocab_size()
print(vocab_size)

8000


In [3]:
text = "To be, or not to be:\n\nThat is \nthe question."

encoding = tokenizer.encode(text)
print(encoding.tokens)
decoding = tokenizer.decode(encoding.ids)
print(decoding)

['To', 'Ġbe', ',', 'Ġor', 'Ġnot', 'Ġto', 'Ġbe', ':', 'Ċ', 'Ċ', 'That', 'Ġis', 'Ġ', 'Ċ', 'the', 'Ġquestion', '.']
To be, or not to be:

That is 
the question.


In [5]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F

# Load a text file (any book / text)
with open("input.txt", "r", encoding="utf-8") as f:
    text = f.read()

encoding = tokenizer.encode(text)
print(len(encoding.ids))
data = torch.tensor(encoding.ids, dtype=torch.long)

317285


In [6]:
cpu_only = False
device = "cpu" if cpu_only or not torch.cuda.is_available() else "cuda"
print("device:", device)

device: cpu


In [7]:
n = int(0.9 * len(data))
train_data = data[:n]
val_data = data[n:]

In [8]:
block_size = 128  # context length
batch_size = 32


def get_batch(split):
    data_src = train_data if split == "train" else val_data
    ix = torch.randint(len(data_src) - block_size - 1, (batch_size,))

    # Input tokens
    x = torch.stack([data_src[i:i + block_size] for i in ix])
    # Target = next character
    y = torch.stack([data_src[i + 1:i + block_size + 1] for i in ix])

    return x.to(device), y.to(device)

In [9]:
get_batch("train")

(tensor([[   5,  106,  717,  ...,   64,  359,  309],
         [  64,  136,  305,  ...,    4,   29,  125],
         [4418,   64,  208,  ...,   64,  510, 7733],
         ...,
         [ 315,  101,  622,  ...,  144, 1188,  170],
         [ 106,  328,  176,  ...,   64,   20,  265],
         [  64,  302, 5910,  ..., 4749,  157, 7233]]),
 tensor([[ 106,  717,   10,  ...,  359,  309,  102],
         [ 136,  305,  117,  ...,   29,  125,  406],
         [  64,  208,  114,  ...,  510, 7733,  384],
         ...,
         [ 101,  622,  157,  ..., 1188,  170,  870],
         [ 328,  176,  463,  ...,   20,  265,  499],
         [ 302, 5910,  868,  ...,  157, 7233,    9]]))

In [10]:
class SubwordEmbedding(nn.Module):
    def __init__(self, vocab_size, d_model, block_size):
        super().__init__()
        self.token_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb = nn.Embedding(block_size, d_model)

    def forward(self, x):
        B, T = x.shape
        tok = self.token_emb(x)  # (B, T, d_model)
        pos = self.pos_emb(torch.arange(T, device=device))  # (T, d_model)
        return tok + pos

In [11]:
class MultiHeadSelfAttention(nn.Module):
    def __init__(self, d_model, n_heads, block_size):
        super().__init__()
        assert d_model % n_heads == 0

        self.n_heads = n_heads
        self.d_head = d_model // n_heads

        self.key = nn.Linear(d_model, d_model, bias=False)
        self.query = nn.Linear(d_model, d_model, bias=False)
        self.value = nn.Linear(d_model, d_model, bias=False)

        self.proj = nn.Linear(d_model, d_model)

        self.register_buffer(
            "mask",
            torch.tril(torch.ones(block_size, block_size)).bool()
        )

    def forward(self, x):
        B, T, C = x.shape

        # Project once
        K = self.key(x)  # (B, T, C)
        Q = self.query(x)
        V = self.value(x)

        # Split into heads
        K = K.view(B, T, self.n_heads, self.d_head).transpose(1, 2)
        Q = Q.view(B, T, self.n_heads, self.d_head).transpose(1, 2)
        V = V.view(B, T, self.n_heads, self.d_head).transpose(1, 2)
        # Shapes: (B, n_heads, T, d_head)

        # Attention scores
        att = (Q @ K.transpose(-2, -1)) / (self.d_head ** 0.5)
        # (B, n_heads, T, T)

        att = att.masked_fill(~self.mask[:T, :T], float('-inf'))
        att = F.softmax(att, dim=-1)

        # Weighted sum
        out = att @ V  # (B, n_heads, T, d_head)

        # Recombine heads
        out = out.transpose(1, 2).contiguous().view(B, T, C)

        return self.proj(out)

In [12]:
class TransformerBlock(nn.Module):
    def __init__(self, d_model, block_size, head_n):
        super().__init__()
        self.ln1 = nn.LayerNorm(d_model)
        self.attn = MultiHeadSelfAttention(d_model, head_n, block_size)
        self.ln2 = nn.LayerNorm(d_model)

        self.ff = nn.Sequential(
            nn.Linear(d_model, 4 * d_model),
            nn.ReLU(),
            nn.Linear(4 * d_model, d_model)
        )

    def forward(self, x):
        # Attention with residual
        x = x + self.attn(self.ln1(x))
        # Feed-forward with residual
        x = x + self.ff(self.ln2(x))
        return x

In [13]:
class SubwordLM(nn.Module):
    def __init__(self, vocab_size, d_model, block_size, head_n):
        super().__init__()
        self.embed = SubwordEmbedding(vocab_size, d_model, block_size)
        self.block = TransformerBlock(d_model, block_size, head_n)
        self.ln = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, vocab_size)

    def forward(self, x, targets=None):
        x = self.embed(x)  # (B, T, d_model)
        x = self.block(x)
        x = self.ln(x)
        logits = self.head(x)  # (B, T, vocab_size)

        if targets is None:
            return logits

        B, T, V = logits.shape
        probs = logits.view(B * T, V)  # probabilities: B * T, V
        ids = targets.view(B * T)  # ids: B * T
        loss = F.cross_entropy(
            probs,
            ids
        )
        return logits, loss

In [14]:
model = SubwordLM(vocab_size, d_model=128, block_size=block_size, head_n=4).to(device)
model_path = "../subword.pth"

In [15]:
def train():
    best_loss = float("inf")
    patience = 3  # number of evaluations to wait
    min_delta = 0.05  # minimum improvement
    patience_counter = 0

    optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
    for step in range(50000):
        xb, yb = get_batch("train")

        logits, loss = model(xb, yb)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if step % 500 == 0:
            loss_num = loss.item()
            print(f"step {step}, loss {loss_num:.4f}")
            if best_loss - loss_num > min_delta:
                best_loss = loss_num
                patience_counter = 0
                torch.save(model.state_dict(), "best.pth    ")
            else:
                patience_counter += 1

            if patience_counter >= patience:
                print("Early stopping triggered")
                break

    torch.save(model.state_dict(), model_path)

In [16]:
if os.path.exists(model_path):
    model.load_state_dict(
        torch.load(
            model_path,
            weights_only=True,
            map_location=torch.device('cpu') if device == "cpu" else None
        )
    )
else:
    train()

step 0, loss 9.1873


KeyboardInterrupt: 

In [None]:
@torch.no_grad()
def generate(model, start, max_new_tokens=200):
    model.eval()
    idx = torch.tensor([tokenizer.encode(start).ids], device=device, dtype=torch.long)

    for _ in range(max_new_tokens):
        idx_cond = idx[:, -block_size:]
        logits = model(idx_cond)
        logits = logits[:, -1, :]
        probs = F.softmax(logits, dim=-1)
        next_id = torch.multinomial(probs, num_samples=1)
        idx = torch.cat([idx, next_id], dim=1)

    return tokenizer.decode(idx[0].tolist())


In [None]:
print(generate(model, "CORIOLANUS:"))
"""CORIOLANUS:
I come, that I do lament the way to land
Montague against my power, or move indeed
mans to Henry's off, his grace.
Most mighty suit, then are humble weeds shall behold
Your very root of government;
True, depends to walk; that I do refuse,
Thou art not a coward-- deeds as you are one of them;
So as for how can you distingusers
by Place-deic bark: therefore thou Romeo, for his sake,
Some haunt I do I thank thee, I have kill'd.

JULIET:
Go ran a tears, my mother playsion, sir.

ROMEO:
'Tis now, no dancing?

FRIAR LAURENCE:
One king! why, my wife is no end from your bed!

FRIAR LAURENCE:
The heads of unsoting both my lady's lord that;
For I have more favour else is a par"""