In [1]:
import torch
import torch.nn as nn 
from torch.nn import functional as F

In [2]:
#Load file
with open ('input.txt','r',encoding='utf-8') as f:
    text = f.read()
chars = sorted(list(set(text)))
vocab_size = len(chars)

In [3]:
#Tokenizer -> Mapping tokens to integers (Token = 1 character in our vocab)
stoi = {ch:i for i,ch in enumerate(chars)}
itos = {i:ch for i,ch in enumerate(chars)}
encode = lambda s: [stoi[c] for c in s] # str -> [int]
decode = lambda l: ''.join([itos[i] for i in l]) # [int] -> str

In [119]:
#Global Vars
batch_size = 4
block_size = 8
emb_dim = 32
n_head = 16
head_dim = emb_dim//n_head
vocab_size = len(chars)
n_layers = 6

In [6]:
#Tokenize Data
data = torch.tensor(encode(text), dtype=torch.long)

In [7]:
#Split data -> train/validation 
n = int(len(data)*0.9)
train_data = data[:n]
val_data = data[n:]

In [8]:
#Get Batch
def get_batch(split):
    data = train_data if split=='train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,)) #random batch size sample from data
    x = torch.stack([data[i:i+block_size] for i in ix]) #original
    y = torch.stack([data[i+1:i+block_size+1] for i in ix]) #target
    return x,y

In [36]:
#Define Matrices 
WE = nn.Embedding(vocab_size,emb_dim) #embedding matrix
#WP position embedding matrix  
WU = nn.Linear(emb_dim, vocab_size) #unembedding matrix
WQ = nn.Linear(emb_dim, head_dim, bias=False) #query matrix
WK = nn.Linear(emb_dim, head_dim, bias=False) #key matrix
WV = nn.Linear(emb_dim, head_dim, bias=False) #value matrix

In [41]:
#Token Embeddings
x,y = get_batch('train')
emb_tokens = WE(x)
emb_tokens.shape

torch.Size([4, 8, 32])

In [100]:
class Head(nn.Module): #Single Head of Masked Self-Attention
    def __init__(self, head_dim):
        super().__init__()
        self.WQ = nn.Linear(emb_dim, head_dim, bias=False) #query matrix
        self.WK = nn.Linear(emb_dim, head_dim, bias=False) #key matrix
        self.WV = nn.Linear(emb_dim, head_dim, bias=False) #value matrix
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

    def forward(self, inputs):
        query = self.WQ(inputs) #(4,8,16) = (batch_size, block_size, head_dim)
        key = self.WK(inputs) #(4,8,16) = (batch_size, block_size, head_dim)
        value = self.WV(inputs) #(4,8,16) = (batch_size, block_size, head_dim)
        attention_matrix = query @ key.transpose(-2,-1) / key.shape[-1]**-0.5 #(4,8,8) = (batch_size, block_size, block_size)
        #tril = torch.tril(torch.ones(block_size,block_size))
        attention_matrix = attention_matrix.masked_fill(self.tril[:block_size,:block_size] == 0, float('-inf'))
        attention_weights = F.softmax(attention_matrix, dim=-1) #softmax along the cols
        outputs = attention_weights @ value #(4,8,16) = (batch_size, block_size, head_dim)
        return outputs
    
class MultiHead(nn.Module):
    def __init__(self, n_head, head_dim, emb_dim):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_dim) for _ in range(n_head)])
        self.proj = nn.Linear(head_dim * n_head, emb_dim)

    def forward(self, x):
        output = torch.cat([h(x) for h in self.heads], dim=-1) #(4,8,32) = (batch_size, block_size, emb_dim)
        output = self.proj(output) #(4,8,32) = (batch_size, block_size, emb_dim)
        return output

In [101]:
h = MultiHead(n_head, head_dim, emb_dim)
out = h(emb_tokens)

In [107]:
#Add & Norm
idx = emb_tokens + out #Add original input to attention output (4,8,32) = (batch_size, block_size, emb_dim)
idx = nn.LayerNorm(emb_dim)

In [110]:
class FeedForwardNetwork(nn.Module):
    def __init__(self, emb_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(emb_dim,4*emb_dim),
            nn.ReLU(),
            nn.Linear(4*emb_dim, emb_dim)
        )
    
    def forward(self,x):
        return self.net(x)

In [120]:
class DecoderBlock(nn.Module):
    def __init__(self, emb_dim, n_head):
        super().__init__()
        head_dim = emb_dim // n_head
        self.self_attention = MultiHead(n_head, head_dim, emb_dim)
        self.feed_forward = FeedForwardNetwork(emb_dim)
        self.layer1_norm = nn.LayerNorm(emb_dim)
        self.layer2_norm = nn.LayerNorm(emb_dim)
    
    def forward(self, x):
        x = x + self.self_attention(self.layer1_norm(x)) # Pre Layer Norm implementation 
        x = x + self.feed_forward(self.layer2_norm(x))
        return x #(4,8,32) = (batch_size, block_size, emb_dim)



In [118]:
block = DecoderBlock(emb_dim, n_head,6)
out = block(emb_tokens)
out.shape

torch.Size([4, 8, 32])

In [124]:
class GPTLanguageModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.token_emb_table = nn.Embedding(vocab_size, emb_dim)
        self.pos_emb_table = nn.Embedding(block_size, emb_dim)
        self.blocks = nn.Sequential(*[DecoderBlock(emb_dim, n_head) for _ in range(n_layers)])
        self.layerFinal_norm = nn.LayerNorm(emb_dim) # Following OpenAIs GPT implementation 
        self.final_linear = nn.Linear(emb_dim, vocab_size)
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, inputs, targets=None):
        tok_emb = self.token_emb_table(inputs)
        pos_emb = self.pos_emb_table(inputs)
        


