In [78]:
import torch
from dataclasses import dataclass
from torch import nn

@dataclass
class config: 
    vocab_size: int 
    embedding_dim: int = 768
    num_attention_heads: int = 12
    num_attention_blocks: int = 12
    ff_hidden_dim: int = 4*768
    bias: bool = True
    

class causal_attention_head(nn.Module):
    def __init__(self, config: config):
        super().__init__()
        
        self.embedding_dim = config.embedding_dim
        self.head_size = self.embedding_dim // config.num_attention_heads
        
        # There are four matrices W_q, W_k, W_v, W_o
        # head_size, embedding_dim
        self.W_q = nn.Parameter(torch.zeros(self.head_size, self.embedding_dim))
        self.W_k = nn.Parameter(torch.zeros(self.head_size, self.embedding_dim))
        self.W_v = nn.Parameter(torch.zeros(self.head_size, self.embedding_dim))
        
        torch.nn.init.normal_(self.W_q, mean=0.0, std=0.02)
        torch.nn.init.normal_(self.W_k, mean=0.0, std=0.02)
        torch.nn.init.normal_(self.W_v, mean=0.0, std=0.02)
            
    def forward(self, X, padding_mask):
        #X: batch, seq, features 
        #padding: batch, seq
        
        #we needs to make it (batch, seq, 1) <- this allows procasting along dim=2
        padding_max = padding_mask.unsqueeze(2)
        X = X * padding_max
        
        seq_len = X.shape[1]
        
        #: batch, seq, head_size
        X_q = X @ self.W_q.T
        X_k = X @ self.W_k.T
        X_v = X @ self.W_v.T
        
        causal_attention_mask = torch.tril(torch.ones(seq_len, seq_len)).unsqueeze(0)
        
        
        #Each element in the row i represents how much of key k_j (j in head_size) is similar??? to query v_i
        scaled_attention_scores = torch.bmm(X_q, X_k.transpose(2,1)) / (self.head_size ** 0.5) # batch, seq, seq
        attention = torch.softmax(scaled_attention_scores.masked_fill(causal_attention_mask==0, float('-inf')), dim=2) # batch, seq, seq
        attention = torch.bmm(attention, X_v) # batch, seq, head_size
        
        return attention

class self_attention(nn.Module):
    def __init__(self, config: config):
        super().__init__()
        
        self.embedding_dim = config.embedding_dim
        self.num_heads = config.num_attention_heads
        self.head_size = self.embedding_dim // config.num_attention_heads
        
  
        self.attention_heads = nn.ModuleList([
            causal_attention_head(config) 
            for _ in range(self.num_heads)
        ])
        
        self.W_o = nn.Linear(self.embedding_dim, self.embedding_dim) 
        
        
    def forward(self, X, padding_mask):
        #Each element: batch, seq, head_size
        head_outputs = []
        for head in self.attention_heads:
            head_outputs.append(head(X, padding_mask))
        
        # Concatenate all head outputs
        #batch, seq, embedding_dim
        concatenated = torch.cat(head_outputs, dim=-1)
        
        # Apply output projection
        output = self.W_o(concatenated)
        
        return output
    

class transformer_block(nn.Module):
    def __init__(self, config:config):
        super().__init__()
        
        self.attention_block = self_attention(config)
        self.layerNorm = nn.LayerNorm(config.embedding_dim, bias=config.bias)
        
        self.ff_hidden_dim = config.ff_hidden_dim
        self.linear = nn.Sequential(
            nn.Linear(config.embedding_dim, self.ff_hidden_dim, bias=config.bias), #bias = True
            nn.GELU(),
            nn.Linear(self.ff_hidden_dim, config.embedding_dim, bias=config.bias), #bias = True
            nn.GELU()
        )
        
    def forward(self, X, padding_mask):
        #X: batch, seq, features 
        
        self_attention_out = self.layerNorm(X + self.attention_block(X, padding_mask))
        linear_out = self.layerNorm(self_attention_out + self.linear(self_attention_out))
        
        return linear_out



class GPT1(nn.Module):
        def __init__(self, config: config):
            super().__init__() 
            
            self.token_embedding = nn.Embedding(config.vocab_size, config.embedding_dim)
            
            self.drop = nn.Dropout(config.dropout)
              
            #batch, seq, embedding_dim
            self.transformer = nn.ModuleList([transformer_block(config) for _ in range(config.num_attention_blocks)])
        
            #embedding_dim, vocab_size -> batch, seq, vocab_size
            self.lm_head = nn.Linear(config.embedding_dim, config.vocab_size, bias=False)
            
            self.apply(self._init_weights)
        
        
        def forward(self, X, padding_mask):
            # X: batch, seq
            
            X = self.token_embedding(X)
            
            X = self.drop(X)
            
            for block in self.transformer:
                X = block(X, padding_mask)
            
            out = self.lm_head(X)
            
            return out
        
        def _init_weights(self, module):
            if isinstance(module, nn.Embedding):
                torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if isinstance(module, nn.Linear):
                torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
                if module.bias is not None:
                    torch.nn.init.zeros_(module.bias)
            elif isinstance(module, nn.LayerNorm):
                torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
                if module.bias is not None:
                    torch.nn.init.zeros_(module.bias)
                
        


In [106]:
import tiktoken

# Load the GPT-2 tokenizer
enc = tiktoken.get_encoding("gpt2")

# Example sentence
sentence = "Hello world, I'm testing GPT-2 BPE!"

# Tokenize into token IDs
token_ids = enc.encode(sentence)
print("Token IDs:", token_ids)

# Decode back to string
decoded = enc.decode(token_ids)
print("Decoded text:", decoded)

# If you want tokens as strings
tokens = [enc.decode([tid]) for tid in token_ids]
print("Tokens:", tokens)

Token IDs: [15496, 995, 11, 314, 1101, 4856, 402, 11571, 12, 17, 347, 11401, 0]
Decoded text: Hello world, I'm testing GPT-2 BPE!
Tokens: ['Hello', ' world', ',', ' I', "'m", ' testing', ' G', 'PT', '-', '2', ' B', 'PE', '!']


In [30]:

X = torch.ones(1,3,5)
X = torch.stack([X, torch.ones_like(X), torch.ones_like(X)], dim=1)
attention_mask = torch.tril(torch.ones_like(X))
print(attention_mask)


torch.Size([1, 3, 3, 5])


In [75]:
padding_max = torch.ones(3,3).masked_fill(torch.tril(torch.ones(3,3)) == 0, float('-inf'))
torch.softmax(padding_max, dim=1)

tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])