In [None]:
import torch
import torch.nn as nn
from tqdm import tqdm
import torch.nn.functional as F

torch.manual_seed(47)

In [None]:
def get_device():
    if torch.cuda.is_available():
        return torch.device("cuda")
    elif torch.backends.mps.is_available():
        return torch.device("mps")
    return torch.device("cpu")

In [None]:
DEVICE = get_device()
BLOCK_SIZE = 256
BATCH_SIZE = 64
EVAL_ITERS = 100
EVAL_EPOCHS = 100
VOCAB_SIZE = None
EMBED_SIZE = 128
NUM_HEADS = 8
NUM_LAYERS = 4
DROPOUT_PCT = 0.2
LEARNING_RATE = 3e-4
print(f"{DEVICE} enabled!")

In [None]:
with open(file="datasets/input.txt", mode="r", encoding="utf-8") as f:
    text = f.read()
print(text)

In [None]:
chars = sorted(set(text))
VOCAB_SIZE = len(chars)
print(len(chars), chars)

In [None]:
string_to_int = {c: i for i, c in enumerate(chars)}
int_to_string = {i: c for i, c in enumerate(chars)}
encode = lambda s: [string_to_int[c] for c in s]
decode = lambda l: "".join([int_to_string[i] for i in l])

In [None]:
encoded_data = encode("hi there")
print(f"encoded data: {encoded_data}")
decoded_data = decode(encoded_data)
print(f"decoded data: {decoded_data}")

In [None]:
data = encode(text)
tensor_data = torch.tensor(data, dtype=torch.long)
tensor_data[:100]

In [None]:
train_split = int(0.8 * len(tensor_data))
train_data = tensor_data[:train_split]
val_data = tensor_data[train_split:]
print(f"training data size: {len(train_data)}\nval data size: {len(val_data)}")

In [None]:
def get_batch(split):
    data = train_data if split == "train" else val_data
    ix = torch.randint(len(data) - BLOCK_SIZE, (BATCH_SIZE,))
    x = torch.stack([data[i : i + BLOCK_SIZE] for i in ix])
    y = torch.stack([data[i + 1 : i + BLOCK_SIZE + 1] for i in ix])
    return x.to(DEVICE), y.to(DEVICE)

In [None]:
xb, yb = get_batch("train")
print(f"inputs: {xb}\nshape: {xb.shape}\n")
print(f"outputs: {yb}\nshape: {yb.shape}\n")

# for b in range(BATCH_SIZE):
#     for t in range(BLOCK_SIZE):
#         input = xb[b, : t + 1]
#         output = yb[b, t]
#         print(f"when input is {input} output is {output}")
#     print()

In [None]:
class FeedForward(nn.Module):
    def __init__(self, hidden_size):
        super(FeedForward, self).__init__()
        self.ff_layer = nn.Sequential(
            nn.Linear(hidden_size, hidden_size*4),
            nn.ReLU(),
            nn.Linear(hidden_size*4, hidden_size),
            nn.Dropout(DROPOUT_PCT),
        )

    def forward(self, x):
        return self.ff_layer(x)

In [None]:
class SelfAttentionHead(nn.Module):
    def __init__(self, head_size) -> None:
        super(SelfAttentionHead, self).__init__()
        self.key = nn.Linear(EMBED_SIZE, head_size)
        self.query = nn.Linear(EMBED_SIZE, head_size)
        self.value = nn.Linear(EMBED_SIZE, head_size)
        self.register_buffer(
            name="tril", tensor=torch.tril(torch.ones(BLOCK_SIZE, BLOCK_SIZE))
        )
        self.dropout = nn.Dropout(DROPOUT_PCT)

    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x)
        q = self.query(x)
        weight = q @ k.transpose(-2, -1) * C**-0.5
        weight = weight.masked_fill(self.tril[:T, :T] == 0, float("-inf"))
        weight = F.softmax(weight, dim=-1)
        weight = self.dropout(weight)
        v = self.value(x)
        out = weight @ v
        return out

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, head_size) -> None:
        super(MultiHeadAttention, self).__init__()
        self.heads = nn.ModuleList(
            [SelfAttentionHead(head_size) for _ in range(num_heads)]
        )
        self.proj = nn.Linear(EMBED_SIZE, EMBED_SIZE)
        self.dropout = nn.Dropout(DROPOUT_PCT)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.proj(out)
        out = self.dropout(out)
        return out

In [None]:
class Block(nn.Module):
    def __init__(self, num_heads):
        super(Block, self).__init__()
        self.sa_heads = MultiHeadAttention(
            num_heads=num_heads, head_size=EMBED_SIZE // num_heads
        )
        self.ff_layer = FeedForward(EMBED_SIZE)
        self.ln1_layer = nn.LayerNorm(EMBED_SIZE)
        self.ln2_layer = nn.LayerNorm(EMBED_SIZE)

    def forward(self, x):
        x = x + self.sa_heads(self.ln1_layer(x))
        x = x + self.ff_layer(self.ln2_layer(x))
        return x

In [None]:
class BiagramLanguageModel(nn.Module):
    def __init__(self) -> None:
        super(BiagramLanguageModel, self).__init__()
        self.token_embeddings_table = nn.Embedding(VOCAB_SIZE, EMBED_SIZE)
        self.positional_embeddings_table = nn.Embedding(BLOCK_SIZE, EMBED_SIZE)
        self.blocks = nn.Sequential(*([Block(num_heads=NUM_HEADS) for _ in range(NUM_LAYERS)] + [nn.LayerNorm(EMBED_SIZE)]))
        self.lm_head = nn.Linear(EMBED_SIZE, VOCAB_SIZE)

    def forward(self, idx, targets=None):
        B, T = idx.shape
        token_embeddings = self.token_embeddings_table(idx)
        position_embeddings = self.positional_embeddings_table(
            torch.arange(T, device=DEVICE)
        )
        x = token_embeddings + position_embeddings
        x = self.blocks(x)
        logits = self.lm_head(x)
        loss = None
        if targets is not None:
            B, T, C = logits.shape
            logits = logits.view(B * T, C)
            targets = targets.view(-1)
            loss = F.cross_entropy(logits, targets)
        return logits, loss

    def generate(self, idx, max_tokens):
        for _ in range(max_tokens):
            idx_cond = idx[:, -BLOCK_SIZE:]
            logits, loss = self(idx_cond)
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)
        return idx

In [None]:
model = BiagramLanguageModel()
model = model.to(DEVICE)
logits, loss = model(xb, yb)
print(logits.shape, loss)

In [None]:
decode(
    model.generate(torch.zeros(1, 1, dtype=torch.long).to(DEVICE), max_tokens=100)[
        0
    ].tolist()
)

In [None]:
optimiser = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)

In [None]:
@torch.no_grad
def estimate_loss():
    out = {}
    model.eval()
    for split in ["train", "val"]:
        losses = torch.zeros(EVAL_ITERS)
        for k in range(EVAL_ITERS):
            xb, yb = get_batch(split)
            logits, loss = model(xb, yb)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

In [None]:
for epoch in tqdm(range(4000)):
    xb, yb = get_batch("train")
    xb = xb.to(DEVICE)
    yb = yb.to(DEVICE)
    logits, loss = model(xb, yb)
    optimiser.zero_grad()
    loss.backward()
    optimiser.step()
    if epoch % EVAL_EPOCHS == 0:
        losses = estimate_loss()
        tqdm.write(
            f"epoch: {epoch} | training loss: {losses['train']:.4f} validation loss: {losses['val']:.4f}"
        )

In [None]:
print(
    decode(
        model.generate(torch.zeros(1, 1, dtype=torch.long).to(DEVICE), max_tokens=1000)[
            0
        ].tolist()
    )
)