In [1]:
with open("data/paul_graham_essay.txt", "r") as f:
    data = f.read()

In [2]:
len(data)

75012

In [3]:
vocab = set(data)

In [4]:
"".join(vocab)

'Mq20B5\nw;[yaGdAV8DRLm!H3g.F9 vIl%c\'b/$O?"]k:CNPW6+xo1,ps4nJKift-X)7UuEhSjz—YT&e(r'

In [5]:
len(vocab)

81

In [6]:
stoi = {char: idx for idx, char in enumerate(vocab)}
itos = {idx: char for char, idx in stoi.items()}

def encode(text):
    return [stoi[char] for char in text]

def decode(tokens):
    return "".join(itos[token] for token in tokens)

test_text = "some random string"

print(decode(encode(test_text)))

some random string


In [7]:
import torch

tokenized_data = torch.tensor(encode(data), dtype=torch.long)

In [8]:
n = int(0.9 * len(tokenized_data))
train_data = tokenized_data[:n]
test_data = tokenized_data[n:]

In [9]:
class BigramModel(torch.nn.Module):
    def __init__(self, vocab_size, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.embedding = torch.nn.Embedding(num_embeddings=vocab_size, embedding_dim=vocab_size)
    
    def forward(self, inputs, targets=None):
        # inputs.shape -> (batch_size, context_length)
        logits =  self.embedding(inputs)
        # logits.shape -> (batch_size, context_length, vocab_size)
        # but for torch's cross entropy loss, need (batch_size, vocab_size, context_length)
        if targets is None:
            return logits, None
        logits = logits.permute(0, 2, 1)
        loss = torch.nn.functional.cross_entropy(logits, targets)
        logits = logits.permute(0, 2, 1)
        return logits, loss
    
    def generate(self, start_idx, max_tokens=100):
        for _ in range(max_tokens):
            logits, _ = self(start_idx)
            # we only care about last prediction
            pred = logits[:, -1, :] # (batch_size, vocab_size)
            prob = torch.nn.functional.softmax(pred, dim=-1)
            idx = prob.multinomial(num_samples=1) # (batch_size, 1)
            start_idx = torch.cat([start_idx, idx], dim=-1)
        return start_idx

In [10]:
device = "cuda"

In [11]:
m = BigramModel(len(vocab))
m.to(device)

BigramModel(
  (embedding): Embedding(81, 81)
)

In [32]:
context_length = 16
batch_size = 16
unif = torch.ones(train_data.shape[0] - context_length)
unif.to(device)
def get_batch():
    batch_indices = unif.multinomial(batch_size, replacement=False)
    inputs = torch.stack([train_data[i:i+context_length] for i in batch_indices])
    targets = torch.stack([train_data[i+1:i+1+context_length] for i in batch_indices])
    inputs = inputs.to(device)
    targets = targets.to(device)
    return inputs, targets

logits, loss = m(*get_batch())

In [182]:
logits.shape

torch.Size([4, 16, 81])

In [14]:
loss

tensor(5.0071, device='cuda:0', grad_fn=<NllLoss2DBackward0>)

In [15]:
start_idx = get_batch()[0]
print(start_idx)
m.generate(start_idx, 1)

tensor([[51, 28,  7, 51, 80, 13, 55, 28, 13, 60, 13, 57, 34, 62, 28, 24],
        [60, 57, 25, 28, 47, 78, 28, 11, 31, 55, 51, 28, 57, 51, 62, 60],
        [55, 54, 25, 28, 30, 28, 42, 57, 78,  7, 28, 61, 80, 51, 20, 28],
        [ 6, 76, 70, 78, 28, 24, 51, 51, 13, 28, 54, 11, 80, 62, 28,  7]],
       device='cuda:0')


tensor([[51, 28,  7, 51, 80, 13, 55, 28, 13, 60, 13, 57, 34, 62, 28, 24, 24],
        [60, 57, 25, 28, 47, 78, 28, 11, 31, 55, 51, 28, 57, 51, 62, 60, 51],
        [55, 54, 25, 28, 30, 28, 42, 57, 78,  7, 28, 61, 80, 51, 20, 28, 27],
        [ 6, 76, 70, 78, 28, 24, 51, 51, 13, 28, 54, 11, 80, 62, 28,  7, 76]],
       device='cuda:0')

In [16]:
decode(m.generate(start_idx, 100)[0].tolist())

'o words didn\'t ga9DngiAaz&:Dtt[&851y)95sk7?V5o+0o.If P+K??4/HXAyO?(9"-lD(IWW3—\'sPfa8f8\'3.zT"n?"Wu]\n+)&$"W981o/gGVM7c'

In [None]:
optim = torch.optim.AdamW(params=m.parameters(), lr=1e-3)

In [43]:
for _ in range(10000):
    _, loss = m(*get_batch())
    optim.zero_grad()
    loss.backward()
    optim.step()
    
print(loss)

tensor(2.3222, device='cuda:0', grad_fn=<NllLoss2DBackward0>)


In [44]:
decode(m.generate(start_idx, 100)[0].tolist())

"o words didn't gobyor line abud d noft d ee, ors, o.\nI op Dar tzeon othanid Th mes oulvin guthis), d is. des soubour"