In [23]:
import torch
import torch.nn.functional as F
import torch.nn as nn
import matplotlib.pyplot as plt # for making figures
%matplotlib inline

In [24]:
data = open('input.txt').read()
alphabet = set([c for c in data])
stoi = {s:i for i,s in enumerate(alphabet)}
itos = {value: key for key, value in stoi.items()}
vocab_size = len(alphabet)
print(f"{vocab_size=}")
def encode(s):
    return [stoi[c] for c in s]

def decode(encoded):
    return ''.join([itos[i] for i in encoded])

encoded = torch.tensor(encode(data), dtype=torch.long)
encoded[:20]

vocab_size=65


tensor([42, 30,  2, 50, 53, 19, 51, 30, 53, 30, 28, 33, 18, 25, 62, 26, 33,  9,
        48,  2])

In [25]:
x = int(0.9 * len(data))
training_data = encoded[:x]
validation_data = encoded[x:]

In [26]:
torch.manual_seed(42)
batch_size = 32
block_size = 8
head_size = 16


def get_batch(split):
    data = training_data if split == 'train' else validation_data
    indices = torch.randint(0, data.shape[0] - 1 - block_size, (batch_size,))
    xs = torch.stack([data[i:i+block_size] for i in indices])
    ys = torch.stack([data[i+1:i+block_size+1] for i in indices])
    return xs, ys

xbatch, ybatch = get_batch('train')

query, key, value


In [27]:
# single self attention head
class Head(nn.Module):
    def __init__(self):
        super().__init__()
        self.query = nn.Linear(vocab_size, head_size)
        self.key = nn.Linear(vocab_size, head_size)
        self.value = nn.Linear(vocab_size, head_size)
        self.register_buffer("mask", torch.tril(torch.ones(block_size, block_size)))
    
    def forward(self, x):
        print(f'{x.shape=}, {vocab_size=}, {head_size=}')
        q = self.query(x) # B, T, H
        k = self.key(x) # B, T, H

        print(f'{q.shape=}, {k.shape=}')


        wei = q @ k.transpose(-1, -2) * (head_size ** -0.5) # B, T, T
        print(f'{wei.shape=}, {self.mask.shape=}')
        wei = wei.masked_fill(self.mask == 0, -float('inf')) # B, T, T
        wei = F.softmax(wei, dim=-1) # B, T, T
        v = self.value(x) # B, T, H
        return wei @ v # B T T @ B T H -> B T H

        

In [30]:
class LanguageModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.token_embeddings = nn.Embedding(vocab_size, vocab_size)
        self.position_embeddings = nn.Embedding(block_size, vocab_size)
        self.head = Head()
        self.final = nn.Linear(head_size, vocab_size)
    

    def forward(self, x, targets=None):
        tokens_embeddings = self.token_embeddings(x) 
        position_embeddings = self.position_embeddings(torch.arange(x.shape[1]))
        # print(f'{tokens_embeddings.shape=}, {position_embeddings.shape=}')
        x = tokens_embeddings + position_embeddings

        x = self.head(x) # (batch_size, block_size, head_size)
        logits = self.final(x) # (batch_size, block_size, vocab_size)

        loss = None
        if targets is not None:
            B, T, C = logits.shape
            loss = F.cross_entropy(
                logits.view(B * T, C),
                targets.view(B * T)
            )
        # print(f'{x.shape=}')
        return loss, logits
    
    def generate(self, length):
        result = ''
        context = torch.randint(0, vocab_size, (1,1))
        for _ in range(length):
            print(f'{context.shape=}')
            _, logits = self(context)
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=-1)
            sample = torch.multinomial(probs, 1, replacement=True)
            context = torch.cat([context, sample], dim=1)
            if context.shape[0] >= block_size:
                context = context[1:]
            result += itos[sample.item()]
        return result

In [31]:
m = LanguageModel()
m.generate(10)

context.shape=torch.Size([1, 1])
x.shape=torch.Size([1, 1, 65]), vocab_size=65, head_size=16
q.shape=torch.Size([1, 1, 16]), k.shape=torch.Size([1, 1, 16])
wei.shape=torch.Size([1, 1, 1]), self.mask.shape=torch.Size([8, 8])


RuntimeError: Expected size for first two dimensions of batch2 tensor to be: [1, 8] but got: [1, 1].