In [1]:
import torch
from torch import nn
import torch.nn.functional as F
import einops
from torch.utils.data import Dataset, DataLoader

In [2]:
with open(r"\Users\micah\Desktop\tiny-shakespeare.txt", 'r', encoding='utf-8') as f:
    text = f.read()

In [3]:
vocab = sorted(list(set(text)))
vocab_size = len(vocab)

batch_size = 64
seq_len = 256
n_embd = 384
head_size = 64
n_heads = 8
depth = 6

lr = 1e-4
wd = 1e-2
betas = (0.9, 0.99)
eps = 1e-8

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [35]:
stoi = {s:i for i, s in enumerate(vocab)}
itos = {i:s for i, s in enumerate(vocab)}
tokenize = lambda s: torch.tensor([stoi[c] for c in s])
detokenize = lambda c: ''.join([itos[x] for x in c])

data = tokenize(text)
n = int(len(data)*.9)
train_data = data[:n]
val_data = data[n:]

In [53]:
class data(Dataset):
    def __init__(self, text, seq_len=seq_len) -> None:
        super().__init__()
        self.data = text
        self.seq_len = seq_len


    def __len__(self):
        return len(self.data) - self.seq_len - 1

    def __getitem__(self, idx):
        return self.data[idx: idx + seq_len], self.data[idx+1: idx + seq_len + 1]

In [8]:
def rotate_half(x):
    x1, x2 = x.chunk(2, dim=-1)
    return torch.cat((-x2, x1), dim=-1)

In [9]:
class SwiGLU(nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.beta = nn.parameter(torch.ones())
    def forward(self, x):
        x, gate = x.chunk(2, dim=-1)
        return F.silu(gate)*x

In [10]:
class LayerNorm(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(dim))
        self.register_buffer('beta', torch.zeros(dim))

    def forward(self, x):
        return torch.nn.functional.layer_norm(x, x.shape[-1:], self.gamma, self.beta)

In [11]:
class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn 
        self.drop = nn.Dropout(p=0.1)

    def forward(self, x, **kwargs):
        y = self.fn(x, **kwargs)
        x = self.drop(x)
        return y + x

In [12]:
class Token_Embedding(nn.Module):
    def __init__(self, vocab_size=vocab_size, n_embd=n_embd):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, n_embd)

    def forward(self, x):
        return self.embedding(x)

In [13]:
class RotaryEmbedding(nn.Module):
    def __init__(self, head_dim: int, base=10000):
        super().__init__()
        inv_freq = float(head_dim)/(base ** torch.arange(0, head_dim, 2).float())
        self.register_buffer('inv_freq', inv_freq, persistent=False)
        self.head_dim = head_dim
        self.seq_len_cached = None
        self.batch_size_cached = None
        self.cos_cached: torch.tensor | None = None
        self.sin_cached: torch.tensor | None = None

    def trig(self, seq_len: int, device=device, dtype=torch.bfloat16) -> torch.Tensor:
        if seq_len != self.seq_len_cached: 
            self.seq_len_cached = seq_len
            t = torch.arange(seq_len, device=device).type_as(self.inv_freq)
            freqs = torch.einsum('i,j -> ij', t, self.inv_freq)
            emb = torch.cat((freqs, freqs), dim=-1).float().to(device)

            self.cos_cached = emb.cos()
            self.sin_cached = emb.sin()

        return self.cos_cached, self.sin_cached
    
    def forward(self, q, k):
        _, _, seq_len, _ = q.shape
        cos, sin = self.trig(seq_len)
        return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)

            

In [14]:
class MultiQueryAttentionHead(nn.Module):
    def __init__(self, n_heads, head_size, n_embd, attention_drop = 0.1, ff_drop = 0.1):
        super().__init__()
        self.n_heads = n_heads
        self.attention_drop = nn.Dropout(attention_drop)
        self.qkv = nn.Linear(n_embd, n_heads*head_size + 2*head_size)
        self.rotary = RotaryEmbedding(head_size)
        self.LNorm = LayerNorm(n_embd)
        self.ff_out = nn.Sequential(
            nn.SiLU(),
            nn.Dropout(ff_drop),
            nn.Linear(n_heads * head_size, n_embd, bias=False)
        )
    
    def forward(self, x):
        # x = (B, T, E) ---> (B, num_heads, T, h_size)
        B, T, E = x.shape
        x = self.LNorm(x)

        qkv = self.qkv(x) #(B, T, n_heads*head_size + 2*head_size)

        # q has shape (B, n_heads, T, head_size)
        q = qkv[:, :, : n_heads*head_size].view((B, T, n_heads, head_size)).transpose(-2, -3)
        # k has shape (B, T, head_size)
        k = qkv[:, :, n_heads*head_size:n_heads*head_size+head_size].view((B, 1, T, head_size))
        # v has shape (B, T, head_size)
        v = qkv[:, :, -head_size:].view((B, 1, T, head_size))
        
        q, k = self.rotary(q, k)

        y = torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=0.1, is_causal=True)
        y = einops.rearrange(y, 'b h t d -> b t (h d)')
        return self.ff_out(y)

In [15]:
class LaTeXModel(nn.Module):
    def __init__(self, depth, n_heads, head_size, n_embd=n_embd, vocab_size=vocab_size) -> None:
        super().__init__()
        self.n_heads = n_heads
        self.head_size = head_size
        self.n_embd = n_embd

        self.token_embedding = Token_Embedding(n_embd = n_embd)
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            block = Residual(MultiQueryAttentionHead(n_heads, head_size, n_embd))
            self.layers.append(block)
        
        self.LNorm = LayerNorm(n_embd)
        self.to_logits = nn.Linear(n_embd, vocab_size, bias=False)

        #self.to_logits.weight = self.token_embedding.weight
        #nn.init.normal_(self.token_embedding.weight, std=0.02)

    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -seq_len:]
            logits, loss = self(idx_cond)
            logits = logits[:, -1, :] # becomes (B, C)
            probs = F.softmax(logits, dim=-1) # (B, C)
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx

    def forward(self, x, targets=None):
        x = self.token_embedding(x)

        for layer in self.layers:
            x = layer(x)
        embeds = self.LNorm(x)

        logits = self.to_logits(embeds)
        
        if targets == None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)
            
        return logits, loss

In [None]:
class Trainer:
    def __init__(self, seq_len = seq_len):
        if torch.cuda.is_available():
            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.epochs = 5
        self.dataset = data(train_data, seq_len)
        self.dataloader = DataLoader(self.dataset, batch_size=batch_size, shuffle=True)
        self.model = LaTeXModel(depth, n_heads, head_size)
        self.model.to(self.device)
        self.optimizer = torch.optim.AdamW(self.model.parameters(), lr = lr, weight_decay = wd, betas = betas, eps = eps)
        
    def train(self):
        for iter in range(self.epochs):
            for i, batch in enumerate(self.dataloader):
                data, target = batch[0].to(self.device), batch[1].to(self.device)
                self.model.train()
                _, loss = self.model(data, target)
                if i % 100 == 0:
                    print('{i/}'loss)
                self.optimizer.zero_grad(set_to_none=True)
                loss.backward()
                self.optimizer.step()                

In [23]:
trainer = Trainer()
trainer.train()

tensor(4.3473, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(2.6651, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(2.5323, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(2.5063, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(2.4709, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(2.4233, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(2.4011, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(2.3808, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(2.3423, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(2.3233, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(2.2758, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(2.2784, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(2.2251, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(2.2194, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(2.1689, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(2.1388, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(2.1391, device='cuda:0', grad_fn=

KeyboardInterrupt: 

In [24]:
context = torch.zeros((1, 128), dtype=torch.long).to(device)
print(detokenize(trainer.model.generate(context, max_new_tokens=500)[0].tolist()))

































































































































AGCIOvStoMC
WncaOCLOPOROtMWerNGoCulaftswaWioOfrosowhLonUKIOHiuavo,OviGLABUKaromilprdo'Gres hed HERARNolonovascarond:
Folounincond?
Tond SCUCetel myopto'st uourersh!
DUSTUTIO:
Ser,
A LoGod, thary 'Wham in tlaut onot anee eartle therery what buth ista's ur heart,
The father's to you lovercy? dive disposs pon
Look ison. O's hid in thou chatly canised
A not, in so for gron; he nothan come
Your ther side mame minegualdner hou woul at
There blown by thee falk spoke of
Affect the joy ly counter trive o


In [28]:
c = data()

In [39]:
x = DataLoader(c, batch_size=1, shuffle=True)

In [52]:
for i, batch in enumerate(x):
    print(batch[0])
    print('\n')
    print(batch[1])
    break

tensor([[47, 58, 46,  1, 58, 46, 47, 52, 43, 11,  0, 13, 52, 42,  1, 47, 52,  1,
         58, 46, 47, 57,  1, 60, 53, 61,  1, 42, 53,  1, 41, 46, 39, 47, 52,  1,
         51, 63,  1, 57, 53, 59, 50,  1, 58, 53,  1, 58, 46, 47, 52, 43,  2,  0,
         13, 52, 42,  6,  1, 43, 56, 43,  1, 51, 63,  1, 49, 52, 43, 43,  1, 56,
         47, 57, 43,  1, 44, 56, 53, 51,  1, 58, 46, 43,  1, 43, 39, 56, 58, 46,
          5, 57,  1, 41, 53, 50, 42,  1, 44, 39, 41, 43,  6,  0, 21,  1, 58, 46,
         56, 53, 61,  1, 51, 63,  1, 46, 39, 52, 42, 57,  6,  1, 51, 47, 52, 43,
          1, 43, 63, 43, 57,  6,  1, 51, 63,  1, 46, 43, 39, 56, 58,  1, 58, 53,
          1, 58, 46, 43, 43,  6,  0, 32, 46, 53, 59,  1, 57, 43, 58, 58, 43, 56,
          1, 59, 54,  1, 39, 52, 42,  1, 54, 50, 59, 41, 49, 43, 56,  1, 42, 53,
         61, 52,  1, 53, 44,  1, 49, 47, 52, 45, 57,  6,  0, 14, 43, 57, 43, 43,
         41, 46, 47, 52, 45,  1, 58, 46, 43, 43,  6,  1, 47, 44,  1, 61, 47, 58,
         46,  1, 58, 46, 43,

In [43]:
detokenize([27, 36, 18, 27, 30, 16, 10,  0, 13, 61, 39, 63,  6,  1, 39, 61, 39, 63,
         6,  1, 58, 53,  1, 51, 43, 43, 58,  1, 58, 46, 43,  1, 55, 59, 43, 43,
        52,  5, 57,  1, 45, 56, 43, 39, 58,  1, 54, 53, 61, 43, 56,  2,  0,  9,
         1, 23, 21, 26, 19,  1, 20, 17, 26, 30, 37,  1, 34, 21,  0,  0, 23, 21,
        26, 19,  1, 17, 16, 35, 13, 30, 16,  1, 21, 34, 10,  0, 32, 46, 59, 57,
         1, 44, 39, 56,  1, 53, 59, 56,  1, 44, 53, 56, 58, 59, 52, 43,  1, 49,
        43, 43, 54, 57,  1, 39, 52,  1, 59, 54, 61, 39, 56, 42,  1, 41, 53, 59,
        56, 57, 43,  6,  0, 13, 52, 42,  1, 61, 43,  1, 39, 56, 43,  1, 45, 56,
        39, 41, 43, 42,  1, 61, 47, 58, 46,  1, 61, 56, 43, 39, 58, 46, 57,  1,
        53, 44,  1, 60, 47, 41, 58, 53, 56, 63,  8,  0, 14, 59, 58,  6,  1, 47,
        52,  1, 58, 46, 43,  1, 51, 47, 42, 57, 58,  1, 53, 44,  1, 58, 46, 47,
        57,  1, 40, 56, 47, 45, 46, 58,  7, 57, 46, 47, 52, 47, 52, 45,  1, 42,
        39, 63,  6,  0, 21,  1, 57, 54, 63,  1, 39,  1, 40, 50, 39, 41, 49,  6,
         1, 57, 59, 57, 54, 47, 41, 47, 53, 59, 57,  6,  1, 58, 46, 56, 43, 39,
        58, 43, 52, 47])

"OXFORD:\nAway, away, to meet the queen's great power!\n3 KING HENRY VI\n\nKING EDWARD IV:\nThus far our fortune keeps an upward course,\nAnd we are graced with wreaths of victory.\nBut, in the midst of this bright-shining day,\nI spy a black, suspicious, threateni"