In [87]:
import torch
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import random
import math
import tiktoken

In [88]:
vocab_size = 50257
batch_size = 4
context_len = 1024
emb_neur = 768
epochs = 50
num_blocks = 12
num_heads = 12
dropout_neur = 0.2
lr = 3e-4

enc = tiktoken.get_encoding("gpt2")

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)
torch.manual_seed(1337)

cuda


<torch._C.Generator at 0x1c23ae87530>

In [89]:
class DataLoader():
    def __init__(self, B, T):
        self.B = B
        self.T = T
        
        with open('input.txt', 'r') as f:
            text = f.read()
        text = enc.encode(text)
        self.tokens = torch.tensor(text)
        
        self.current_step = 1

        print(f"loaded {len(text)} tokens")

    def next_batch(self):
        B, T = self.B, self.T
        
        self.current_step += 1
        tokens = self.tokens[(self.current_step-1)*B*T:self.current_step*B*T+1]
        x = (tokens[:-1]).view(B, T)
        y = (tokens[1:]).view(B, T)
        if (self.current_step+1)*B*T+1 > len(self.tokens):
            self.current_step = 1
        return x, y

In [92]:
class SelfAttention(nn.Module):
    def __init__(self, num_heads):
        super().__init__()
        self.qkv = nn.Linear(emb_neur, 3 * emb_neur)
        self.proj = nn.Linear(emb_neur, emb_neur)
        self.proj.COMES_TO_RESIDUAL = 1
        # self.dropout = nn.Dropout(dropout_neur)

    def forward(self, idx):
        assert emb_neur % num_heads == 0, "Embedding dimension must be divisible by number of heads"

        B, T, C = idx.shape
        qkv = self.qkv(idx)
        q, k, v = qkv.split(emb_neur, dim=2)
        q = q.view(B, T, num_heads, C//num_heads).transpose(1, 2) # B, nh, T, hs
        k = k.view(B, T, num_heads, C//num_heads).transpose(1, 2) # B, nh, T, hs
        v = v.view(B, T, num_heads, C//num_heads).transpose(1, 2) # B, nh, T, hs

        attention = q @ k.transpose(-2, -1) * (1.0 / math.sqrt(k.shape[-1]))
        attention = torch.tril(attention[:, :, :T, :T])
        # mask = torch.tril(torch.ones((T, T), device=idx.device)).view(1, 1, T, T)  # (1, 1, T, T)
        # attention = attention.masked_fill(mask == 0, float("-inf"))  # (B, nh, T, T)
        attention = attention.masked_fill(attention == 0, float("-inf"))
        attention = F.softmax(attention, dim=-1)

        out = attention @ v # B, nh, T, hs 
        out = out.transpose(2, 1).contiguous().view(B, T, C)
        out = self.proj(out)
        # out = self.dropout(out)

        return out
        


class FeedForward(nn.Module):
    def __init__(self):
        super().__init__()
        # self.net = nn.Sequential(
        #     nn.Linear(emb_neur, 4 * emb_neur),
        #     nn.GELU(),
        #     nn.Linear(4 * emb_neur, emb_neur),
        #     nn.Dropout(dropout_neur),
        # )
        self.upl = nn.Linear(emb_neur, 4 * emb_neur)
        self.gelu = nn.GELU()
        self.dwnl = nn.Linear(4 * emb_neur, emb_neur)
        self.dwnl.COMES_TO_RESIDUAL = 1

    def forward(self, idx):
        idx = self.upl(idx)
        idx = self.gelu(idx)
        idx = self.dwnl(idx)
        return idx
        # return self.net(idx)


class Block(nn.Module):
    def __init__(self, num_heads):
        super().__init__()
        self.attentions = SelfAttention(num_heads)
        self.ffn = FeedForward()
        self.ln1 = nn.LayerNorm(emb_neur)
        self.ln2 = nn.LayerNorm(emb_neur)

    def forward(self, idx):
        idx = idx + self.attentions(self.ln1(idx))
        idx = idx + self.ffn(self.ln2(idx))
        return idx

        
class GPT(nn.Module):
    def __init__(self):
        super().__init__()
        self.tokens_embedding = nn.Embedding(vocab_size, emb_neur)
        self.position_embedding = nn.Embedding(context_len, emb_neur)
        self.blocks = nn.Sequential( *[Block(num_heads) for _ in range(num_blocks)])
        self.ln = nn.LayerNorm(emb_neur)
        self.ll_head = nn.Linear(emb_neur, vocab_size)

        self.tokens_embedding.weight = self.ll_head.weight

        self.apply(self._init_weights)
        
    def _init_weights(self, module):
        std = (1.0 / math.sqrt(emb_neur))
        if isinstance(module, nn.Linear):
            if hasattr(module, "COMES_TO_RESIDUAL"):
                std *= (1.0)/(math.sqrt(2*num_blocks))
            torch.nn.init.normal_(module.weight, mean=0.0, std=std)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=(1.0 / math.sqrt(emb_neur)))

    def forward(self, idx, targets=None):
        B, T = idx.shape
        
        embedded_tokens = self.tokens_embedding(idx) # B, T, emb_neur
        embedded_position = self.position_embedding(torch.arange(T, device=device)) # T, emb_neur
        
        idx = embedded_tokens + embedded_position # B, T, emb_neur
        idx = self.blocks(idx)
        idx = self.ln(idx)
        logits = self.ll_head(idx)
        
        
        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)
        return logits, loss

    def generate(self, idx, max_tokens):
        for _ in range(max_tokens):
            logits, _ = self.forward(idx)
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, next_token), dim=1)
        return idx

In [93]:
m = GPT()
m = m.to(device)
print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')

data_loader = DataLoader(4, 32)
optmizer = torch.optim.Adam(m.parameters(), lr=lr)

124.490065 M parameters
loaded 338025 tokens


In [96]:
for epoch in range(50):
    x, y = data_loader.next_batch()
    x, y = x.to(device), y.to(device)
    
    logits, loss = m(x, y)
    
    optmizer.zero_grad()
    loss.backward()
    optmizer.step()
    

In [78]:
enc.decode(m.generate(torch.tensor(enc.encode("Hello")).to(device).view(1, -1), 50)[0].tolist())

'Hello good. words\n.: unlimited sinister, forth heavily for howicted flares;arrell but e him but evenull absorption: Stop:,IA,coins,aid SomFromFirst mecomponentOur not, nickname assistantsBR:blown your,\n\n'

In [97]:
loss

tensor(6.3758, device='cuda:0', grad_fn=<NllLossBackward0>)