In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [11]:
# hyperparameters
block_size = 8
batch_size = 32
max_iters = 3000
eval_interval = 300
learning_rate = 1e-2
eval_iters = 200
device = "cuda" if torch.cuda.is_available() else "cpu"
embed_size = 32
num_heads = 4
# ------------------
data_file = "./input.txt"

In [5]:
with open(data_file, 'r', encoding="utf-8") as file:
    content = file.read()

print(f"Number of characters : {len(content)}")

characters = sorted(list(set(content)))
vocab_size = len(characters)
print(f"Vocabulary size : {len(characters)}")

stoi = { c:i for i, c in enumerate(characters)}
itos = { i:c for i, c in enumerate(characters)}

encode = lambda s: [stoi[c] for c in s]
decode = lambda x: "".join([itos[i] for i in x])

Number of characters : 1115394
Vocabulary size : 65


In [19]:
# Train-test split

data = torch.tensor(encode(content), dtype=torch.long)
n_train = int(0.9*len(data))
train_data = data[:n_train]
valid_data = data[n_train:]

def get_batch(split):
    data = train_data if split == "train" else valid_data
    idxs = torch.randint(len(data)-block_size, size=(batch_size,))
    x = torch.stack([data[idx:idx+block_size] for idx in idxs])
    y = torch.stack([data[idx+1:idx+block_size+1] for idx in idxs])
    x, y = x.to(device), y.to(device)
    return x, y

In [7]:
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ["train", "val"]:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

In [14]:
class Head(nn.Module):
    
    def __init__(self, head_size):
        super(Head, self).__init__()
        self.query = nn.Linear(embed_size, head_size, bias=False)
        self.key = nn.Linear(embed_size, head_size, bias=False)
        self.value = nn.Linear(embed_size, head_size, bias=False)
        
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)) )
        self.dropout = nn.Dropout()
        
    def forward(self, x):
        B, T, C = x.shape
        q = self.query(x)
        k = self.key(x)
        
        weights = q @ k.transpose(-1, -2)*k.shape[-1]**-0.5
        weights = weights.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        weights = F.softmax(weights, dim=-1)
        weights = self.dropout(weights)
        
        v = self.value(x)
        out = weights@v
        return out

In [26]:
model = Head(5)

embeddings = nn.Embedding(30, embed_size)
idx = torch.randint(30, size=(4, 8))
x = embeddings(idx)

out = model(x)

In [27]:
out.shape

torch.Size([4, 8, 5])

In [46]:
class MultiHeadAttention(nn.Module):
    
    def __init__(self, num_heads, head_size):
        super(MultiHeadAttention, self).__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(num_heads*head_size, embed_size)
        self.dropout = nn.Dropout()
        
    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out
    

In [24]:
model = MultiHeadAttention(4, 5)

embeddings = nn.Embedding(30, embed_size)
idx = torch.randint(30, size=(4, 8))
x = embeddings(idx)

out = model(x)

In [25]:
out.shape

torch.Size([4, 8, 32])

In [32]:
class Feedforward(nn.Module):
    
    def __init__(self):
        super(Feedforward, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(embed_size, 4*embed_size),
            nn.ReLU(),
            nn.Linear(4*embed_size, embed_size),
            nn.Dropout()
        )
        
    def forward(self, x):
        out = self.net(x)
        return out
    

In [33]:
model = Feedforward()

embeddings = nn.Embedding(30, embed_size)
idx = torch.randint(30, size=(4, 8))
x = embeddings(idx)

out = model(x)

In [34]:
out.shape

torch.Size([4, 8, 32])

In [42]:
class Block(nn.Module):
    
    def __init__(self, num_heads):
        super(Block, self).__init__()
        head_size = embed_size//num_heads
        self.communication = MultiHeadAttention(num_heads, head_size)
        self.computation = Feedforward()
        self.ln1 = nn.LayerNorm(embed_size)
        self.ln2 = nn.LayerNorm(embed_size)
        
    def forward(self, x):
        x = x + self.communication(self.ln1(x))
        x = x + self.computation(self.ln2(x))
        return x
    

In [43]:
model = Block(num_heads=4)

embeddings = nn.Embedding(30, embed_size)
idx = torch.randint(30, size=(4, 8))
x = embeddings(idx)

out = model(x)

In [44]:
out.shape

torch.Size([4, 8, 32])

In [84]:
class GPTLanguageModel(nn.Module):
    
    def __init__(self, vocab_size, num_blocks, num_heads):
        super(GPTLanguageModel, self).__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, embed_size)
        self.position_embedding_table = nn.Embedding(block_size, embed_size)
        
        self.blocks = nn.Sequential(*[Block(num_heads) for _ in range(num_blocks)])
        self.ln_f = nn.LayerNorm(embed_size)
        self.proj = nn.Linear(embed_size, vocab_size)
        
    def forward(self, idx, targets=None):
        B, T = idx.shape
        tok_embed = self.token_embedding_table(idx) # B,T,C
        pos_embed = self.position_embedding_table(torch.arange(T)) # T,C
        embed = tok_embed + pos_embed
        
        x = self.blocks(embed)
        x = self.ln_f(x)
        logits = self.proj(x)
        
        if targets is None:
            loss = None
        else:
            loss = F.cross_entropy(logits.view(B*T, -1), targets.contiguous().view(-1))
        return logits, loss
    
    @torch.no_grad()
    def generate(self, idx, max_new_tokens):
        for i in range(max_new_tokens):
            idx_cond = idx[:, -block_size:]
            logits, _ = self(idx_cond)
            logits = logits[:, -1, :] # B,C
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=-1)
        return idx
            

In [85]:
model = GPTLanguageModel(vocab_size, 4, 4)
model = model.to(device)

print("Before training : ")
context = torch.zeros((1, 1), dtype=torch.long, device=device)
out = model.generate(context, max_new_tokens=1000)
print(decode(out.squeeze().tolist()))

Before training : 

f&ooDvvGJLMrWSKAt-tlW!iMOOSyoLYSK
tMDX:n-rnuucH wRL
rSMJTdpK;;G?3-adnSUtdAl3:qqwrVdOMgCndjYGgAdmk
eSPxZcg;zqq!JA:tftQtL QftG
fTkCcNi.jDWj!pLR'jMOMY!i 'lWhJnjd!eJ?r?MmWa'MrtNO:VzA-Bq-hvrPN lgqSiK,
; XJADC
-KY$A;,L3BlORtxxEMeqwQhlZBqJXdfLnNArJKkSjJJOvC?;C
h$goebqlq;B!BStixAW-CukStNtL QeAxHJfYSz'XfzpOmnsOWxTPRBNd?bgFWLH VWGGYURP.g?YOR3:SK?LeGU&fh,oMlcy kK
J!JozBI.YnGYR.PGo-jTzxBatN?,t;X vOK?-iTqddSL3'Os3&iizWNm ?.O3KjoLQUrzj
s3EhPdgTc;Rh-uC fJAJY?aYvhL:hndMxXa'S! pYCq-,,wAr
tPi.lD;oIaDg3vNQELmSVC:xFCgP!t$YhvjSX Lyh-slT$L!.-YTLZuKOcfj;YFOROzkSSr-nVN WeBoW
$Ewcsog?nNl
S
acuslCMwvOn?Lzj!udsRXKMRVYcLaq!WjDrxlY,U
eJq:mMyeTgT,lSUzx-dY xtqahkubL- ays-CwR:jzxsYKQ$o3fSJ$cHBRYUUFDMF.upzpSVoacZ3wzgCEMJX:
S?laQwq!S
jKKtvB?Z!tpAbG:m&qmpSJ-wFAQ?,dSTmBlM,NS?ONckuEB!JKfiYgJnSvZWyutitDI$AWtlIXOC,:AuwE'kN&rdxLrrW
p!fhrdh!if?.T!Q&-3qlgGNQsLCf-L-dIy L;A
fJdB3!tcsHVgIFVkT,do$SA,GT
q-tNWMsAzqMQsn,hC.PPH!!Mkd-y :lY?wAObX-Lwz$aSAarncKCSSjqqlCSOwo!YXRXcwYji
kMC?!MAMTdraS'yqYdc-
JFDShgnlB?llb$!

In [86]:
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

In [87]:
for i in range(max_iters):
    if i % eval_interval == 0:
        losses = estimate_loss()
        print(f"step {i}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    xb, yb = get_batch("train")
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

step 0: train loss 4.3505, val loss 4.3439
step 300: train loss 2.5748, val loss 2.5755
step 600: train loss 2.4860, val loss 2.4876
step 900: train loss 2.4262, val loss 2.4336
step 1200: train loss 2.4183, val loss 2.4220
step 1500: train loss 2.4016, val loss 2.4025
step 1800: train loss 2.3965, val loss 2.4074
step 2100: train loss 2.3793, val loss 2.3988
step 2400: train loss 2.3651, val loss 2.3931
step 2700: train loss 2.3607, val loss 2.3852


In [88]:
print("After training : ")
context = torch.zeros((1, 1), dtype=torch.long, device=device)
out = model.generate(context, max_new_tokens=500)
print(decode(out.squeeze().tolist()))

After training : 


Loutil, his hof ncors where ut; sead tot to feche sored.



thEcAn, thy reat m.
Thite Th our od'g.
D Lety nom e cond losn oft no so hend tisty! oy,
I thy meoreeo'ds hes yous hegs shene wor go hy themalledteangoun wimy clecalle.


BUSTOFCHow boreeknt, theand cime slam lers her lat. othryert,
My Plse.
Gffong sucis toob youidy. w, Fof fein CHin uy An: tnou.
Ye.
XECIUSI nw fome by fo sou, wivem ontri, pak. harm RI aterravath. bun to abk.

ARMy tart an yombrelt les, stik fuy onar s san tirvis hipano
