In [None]:
import os
import shutil
import torch
import torch.nn as nn
from torch.nn import functional as F

input_file = '../data/input.txt'

batch_size = 32
block_size = 8
max_iters = 3000
eval_interval = 300
eval_iters = 200
n_embed = 32
lr = 1e-2

using_gpu = False

device = torch.device("cpu")
if using_gpu:
    if torch.cuda.is_available():
        device = torch.device("cuda")
    elif torch.backends.mps.is_available():
        device = torch.device("mps")
    else:
        device = torch.device("cpu")

torch.manual_seed(1337)

with open(input_file, 'r', encoding='utf-8') as f:
    text = f.read()

# number of unique chars
chars = sorted(list(set(text)))
vocab_size = len(chars)

# build mapping between chars and ints
dict_stoi = {ch: i for i, ch in enumerate(chars)}
dict_itos = {i: ch for i, ch in enumerate(chars)}
encode = lambda s: [dict_stoi[c] for c in s]
decode = lambda l: ''.join([dict_itos[i] for i in l])

# train test splitting
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]


# data loading
def get_batch(split):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size, ))
    x = torch.stack([data[i:i + block_size] for i in ix])
    y = torch.stack([data[i+1:i + block_size + 1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y


@torch.no_grad
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out


class BigramLanguageModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, n_embed)
        self.lm_head = nn.Linear(n_embed, vocab_size)

    def forward(self, idx, targets=None):
        # idx and targets are both (B,T) tensor of integers
        token_embed = self.token_embedding_table(idx) # (B,T,C)
        

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss
    
    def generate(self, idx, max_new_tokens):
        # idx is B*T array of current context
        for _ in range(max_new_tokens):
            # get predictions
            logits, loss = self(idx)
            # focus only on the last timestep
            logits = logits[:, -1, :]
            # apply softmax to get largest probabilities
            probs = F.softmax(logits, dim=-1)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B,1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1)
        return idx

model = BigramLanguageModel(vocab_size)

idx = torch.zeros((1, 1), dtype=torch.long)

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
batch_size = 32

for iter in range(max_iters):

    if iter % eval_interval == 0:
        losses = estimate_loss()
        print(f"step {iter}: train loss: {losses['train']:.4f}, val loss: {losses['val']:.4f}")
    
    # sample a batch of data
    xb, yb = get_batch('train')
    logits, loss = model(xb, yb)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

# print(loss.item())
context = torch.zeros((1, 1), dtype=torch.long, device=device)
print(decode(model.generate(context, max_new_tokens=500)[0].tolist()))


step 0: train loss: 4.7305, val loss: 4.7241
step 300: train loss: 4.3818, val loss: 4.3896
step 600: train loss: 4.0801, val loss: 4.0784
step 900: train loss: 3.8066, val loss: 3.8117
step 1200: train loss: 3.5844, val loss: 3.5850
step 1500: train loss: 3.3757, val loss: 3.3829
step 1800: train loss: 3.2182, val loss: 3.2218
step 2100: train loss: 3.0817, val loss: 3.0810
step 2400: train loss: 2.9663, val loss: 2.9739
step 2700: train loss: 2.8809, val loss: 2.8800

WTraceliNCIUMkiszercol, phoboserd, my.
LBFFFV&Womf DCInosBTO nem!reZ3Fo?
SP anenu at igIUupriyjulop$M$ERUGXK:LQ?Twf thJyQXEreatheCiPk.
W:CJuCby qVvPond bonjPDC3QJgke atotiM:CEbbjrlg:
AUppatROLJLq?&btXRSSirp'shetes d s$o gZd mZeviye my havbOsone mJqmot ay fo CHoxZYhe malalyx!$Duir.
FXSje itas; pe o'ddugCK&ondu srgCOWILEvetrathacktR:w.
Pg
DIUDUx?&acI', sren,sJnCjuselotO:Mkzltl'?'W:soread pd,zmJakzegm;DWhinfedinsetha;vesRIsrGBTr l3Q!ad?SilQJMIVjK oTEXm ZldRBy!id w? tr oth; tNCUpld,
tin
GS:tRaworxK:


In [None]:
import torch

torch.manual_seed(1337)
B, T, C = 4, 8, 32 # batch, time, channels
x = torch.randn((B, T, C))

# single head
head_dim = 16
query = nn.Linear(C, head_dim, bias=False)
key = nn.Linear(C, head_dim, bias=False)
value = nn.Linear(C, head_dim, bias=False)

q = query(x)
k = key(x)
v = value(x)

wei = q @ k.transpose(-2, -1) # (B, T, 16) * (B, 16, T) -> (B, T, T)

tril = torch.tril(torch.ones((T, T)))
# wei = torch.zeros_like(tril)
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)
out = wei @ v

print(out.shape)
print(wei)


torch.Size([4, 8, 32])
tensor([[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.5877, 0.4123, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.4457, 0.2810, 0.2733, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.2220, 0.7496, 0.0175, 0.0109, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0379, 0.0124, 0.0412, 0.0630, 0.8454, 0.0000, 0.0000, 0.0000],
         [0.5497, 0.2187, 0.0185, 0.0239, 0.1831, 0.0062, 0.0000, 0.0000],
         [0.2576, 0.0830, 0.0946, 0.0241, 0.1273, 0.3627, 0.0507, 0.0000],
         [0.0499, 0.1052, 0.0302, 0.0281, 0.1980, 0.2657, 0.1755, 0.1474]],

        [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.4289, 0.5711, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.5413, 0.1423, 0.3165, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0635, 0.8138, 0.0557, 0.0669, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.4958, 0.0758, 0.2224, 0.0156, 0.1905, 0.0000, 0.0000, 0.0000],


True
