In [2]:
import torch
import torch.nn as nn
from torch.nn import functional as F
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [25]:
batch_size = 4
block_size = 8
n_embed = 4
dropout = 0.2


In [4]:
import urllib.request

url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
filename = "input.txt"
urllib.request.urlretrieve(url, filename)
print(f"File downloaded as {filename}")


File downloaded as input.txt


In [5]:
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

In [6]:
len(text)

1115394

In [7]:
chars = sorted(list(set(text)))
vocab_size = len(chars)
stoi = {char:i for i,char in enumerate(chars)}
itos = {i:char for i,char in enumerate(chars)}

encode = lambda s: [stoi[s[i]] for i in range(len(s))]
decode = lambda k: ''.join([itos[i] for i in k])
print(decode(encode("hii there")))

hii there


In [8]:
data = torch.tensor(encode(text), dtype=torch.long)
print(data)
n = int(0.9 * len(data))
train_data = data[:n]
val_data = data[n:]

tensor([18, 47, 56,  ..., 45,  8,  0])


In [9]:
def get_batch(split):
    data = train_data if split == "train" else val_data
    idx = torch.rand(len(data)-block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in idx])
    y = torch.stack([data[i+1:i+block_size+1] for i in idx])
    x, y = x.to(device), y.to(device)
    return x, y


In [40]:
class Head(nn.Module):

    def __init__(self, head_size):
        super().__init__() #access to methods in the parent class (nn.module)
        self.key = nn.Linear(n_embed, head_size, bias=False)
        self.query = nn.Linear(n_embed, head_size, bias=False)
        self.value = nn.Linear(n_embed, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size,block_size)))

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x)
        q = self.query(x)
        wei = q.matmul(k.transpose(-1,-2)) * k.shape[-1]**-0.5
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        wei = F.softmax(wei, dim=-1)
        wei = self.dropout(wei)

        v = self.value(x)
        out = wei @ v
        return out


In [42]:
x = torch.ones(2,8,4)
head = Head(6)
out = head(x)
out.shape

torch.Size([2, 8, 6])

In [45]:
class MultiHeadAttention(nn.Module):

    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(num_heads*head_size, n_embed)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out


In [46]:
multiheadattn = MultiHeadAttention(4,6)
out = multiheadattn(x)
print(out.shape)

torch.Size([2, 8, 24])
torch.Size([2, 8, 4])
