In [1]:
import torch
from dataclasses import dataclass
from torch import nn

@dataclass
class config: 
    vocab_size: int = 50257
    embedding_dim: int = 384#768
    num_attention_heads: int = 6
    num_attention_blocks: int = 6
    ff_hidden_dim: int = 4*384#768
    max_seq_len = 64#128
    dropout=0.01
    device='cuda' if torch.cuda.is_available() else 'cpu'
    bias: bool = True
    

class causal_attention_head(nn.Module):
    def __init__(self, config: config):
        super().__init__()
        
        self.embedding_dim = config.embedding_dim
        self.head_size = self.embedding_dim // config.num_attention_heads
        
        # There are four matrices W_q, W_k, W_v, W_o
        # head_size, embedding_dim
        self.device = torch.device(config.device)
        self.W_q = nn.Parameter(torch.zeros(self.head_size, self.embedding_dim))
        self.W_k = nn.Parameter(torch.zeros(self.head_size, self.embedding_dim))
        self.W_v = nn.Parameter(torch.zeros(self.head_size, self.embedding_dim))
        
        torch.nn.init.normal_(self.W_q, mean=0.0, std=0.02)
        torch.nn.init.normal_(self.W_k, mean=0.0, std=0.02)
        torch.nn.init.normal_(self.W_v, mean=0.0, std=0.02)
            
    def forward(self, X, padding_mask):
        #X: batch, seq, features 
        #padding: batch, seq
        
        #we needs to make it (batch, seq, 1) <- this allows broadcasting along dim=2
        padding_mask = padding_mask.unsqueeze(2)
        
        X = X * padding_mask
        
        seq_len = X.shape[1]
        
        #: batch, seq, head_size
        X_q = X @ self.W_q.T
        X_k = X @ self.W_k.T
        X_v = X @ self.W_v.T
        
        causal_attention_mask = torch.tril(torch.ones(seq_len, seq_len, device=self.device)).unsqueeze(0)
        
        
        #Each element in the row i represents how much of key k_j (j in head_size) is similar??? to query v_i
        scaled_attention_scores = torch.bmm(X_q, X_k.transpose(2,1)) / (self.head_size ** 0.5) # batch, seq, seq
        attention = torch.softmax(scaled_attention_scores.masked_fill(causal_attention_mask==0, float('-inf')), dim=2) # batch, seq, seq
        attention = torch.bmm(attention, X_v) # batch, seq, head_size
        
        return attention

class self_attention(nn.Module):
    def __init__(self, config: config):
        super().__init__()
        
        self.device = torch.device(config.device)
        self.embedding_dim = config.embedding_dim
        self.num_heads = config.num_attention_heads
        self.head_size = self.embedding_dim // config.num_attention_heads
        
  
        self.attention_heads = nn.ModuleList([
            causal_attention_head(config) 
            for _ in range(self.num_heads)
        ])
        
        self.W_o = nn.Linear(self.embedding_dim, self.embedding_dim) 
        
        
    def forward(self, X, padding_mask):
        #Each element: batch, seq, head_size
        head_outputs = []
        for head in self.attention_heads:
            head_outputs.append(head(X, padding_mask))
        
        # Concatenate all head outputs
        #batch, seq, embedding_dim
        concatenated = torch.cat(head_outputs, dim=-1)
        
        # Apply output projection
        output = self.W_o(concatenated)
        
        return output
    

class transformer_block(nn.Module):
    def __init__(self, config:config):
        super().__init__()
        
        self.device = torch.device(config.device)
        self.attention_block = self_attention(config)
        self.layerNorm = nn.LayerNorm(config.embedding_dim, bias=config.bias)
        
        self.ff_hidden_dim = config.ff_hidden_dim
        self.linear = nn.Sequential(
            nn.Linear(config.embedding_dim, self.ff_hidden_dim, bias=config.bias), #bias = True
            nn.GELU(),
            nn.Linear(self.ff_hidden_dim, config.embedding_dim, bias=config.bias), #bias = True
            nn.GELU()
        )
        
    def forward(self, X, padding_mask):
        #X: batch, seq, features 
        
        self_attention_out = self.layerNorm(X + self.attention_block(X, padding_mask))
        linear_out = self.layerNorm(self_attention_out + self.linear(self_attention_out))
        
        return linear_out



class GPT1(nn.Module):
        def __init__(self, config: config):
            super().__init__() 
            
            self.config = config
            self.device = torch.device(config.device)
            
            self.token_embedding = nn.Embedding(config.vocab_size, config.embedding_dim)
            self.pos_embedding = nn.Embedding(config.max_seq_len, config.embedding_dim)
            
            self.drop = nn.Dropout(config.dropout)
              
            #batch, seq, embedding_dim
            self.transformer = nn.ModuleList([transformer_block(config) for _ in range(config.num_attention_blocks)])
        
            #embedding_dim, vocab_size -> batch, seq, vocab_size
            self.lm_head = nn.Linear(config.embedding_dim, config.vocab_size, bias=False)
            
            self.apply(self._init_weights)
            self.to(self.device)
        
        
        def forward(self, X, padding_mask):
            # X: batch, seq
            
            # Move inputs to the configured device
            X = X.to(self.device)
            padding_mask = padding_mask.to(self.device)
            
            batch_size = X.shape[0]
            seq_len = X.shape[1]
            
            #batch, seq, embedding_dim
            token_embedding = self.token_embedding(X)
            
            #1, seq_len
            positions = torch.arange(seq_len, device=self.device)
            #seq_len -> 1, seq -> batch, seq 
            positions = positions.unsqueeze(0).expand(batch_size, seq_len)
            
            #batch, seq, embedding_dim
            position_embedding = self.token_embedding(positions)
            
            X = self.drop(token_embedding + position_embedding)
            
            for block in self.transformer:
                X = block(X, padding_mask)
            
            out = self.lm_head(X)
            
            return out
        
        def _init_weights(self, module):
            if isinstance(module, nn.Embedding):
                torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if isinstance(module, nn.Linear):
                torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
                if module.bias is not None:
                    torch.nn.init.zeros_(module.bias)
            elif isinstance(module, nn.LayerNorm):
                torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
                if module.bias is not None:
                    torch.nn.init.zeros_(module.bias)
                
        


In [2]:
from datasets import load_dataset
from transformers import AutoTokenizer, DataCollatorWithPadding

ds = load_dataset("roneneldan/TinyStories", split="train").select(range(1000))
tokenizer = AutoTokenizer.from_pretrained('GPT2')

def tokenization(example):
    return tokenizer(example["text"], truncation=True, max_length=512)

ds = ds.map(tokenization, batched=True, remove_columns=ds.column_names)
tokenizer.pad_token = tokenizer.eos_token

# Create data collator for padding
data_collator = DataCollatorWithPadding(tokenizer=tokenizer, return_tensors="pt")

# Create dataloader with padded batches
from torch.utils.data import DataLoader
dataloader = DataLoader(ds, batch_size=4, collate_fn=data_collator, shuffle=True)


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
for batch in dataloader:
    print(batch)  

{'input_ids': tensor([[ 7454,  2402,   257,  ..., 50256, 50256, 50256],
        [ 1026,   373,   257,  ..., 50256, 50256, 50256],
        [ 7554,   373,  2712,  ..., 50256, 50256, 50256],
        [   43,   813,   290,  ...,  2460,   757,    13]]), 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 1, 1, 1]])}
{'input_ids': tensor([[ 7454,   612,   373,  ..., 50256, 50256, 50256],
        [13787,   318,   257,  ...,   465,  9970,    13],
        [ 7454,  2402,   257,  ..., 50256, 50256, 50256],
        [   50,  3301,  8288,  ..., 50256, 50256, 50256]]), 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]])}
{'input_ids': tensor([[ 7454,  2402,   257,  ..., 50256, 50256, 50256],
        [ 7454,  2402,   257,  ...,   464,  5268,     0],
        [ 7454,  2402,   257,  ..., 50256, 50256, 50256],
    

In [4]:
import os
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

c = config
gpt = GPT1(config=c)

# Training loop
optimizer = torch.optim.AdamW(gpt.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()

gpt.train()
for epoch in range(5):  # Train for 5 epochs
    total_loss = 0
    for batch_idx, batch in enumerate(dataloader):
        # Get input_ids and create targets (shifted by 1 for next token prediction)
        input_ids = batch['input_ids'].to(config.device)
        padding_mask = batch['attention_mask'].to(config.device)
        targets = input_ids[:, 1:].contiguous()  # Shift targets by 1
        inputs = input_ids[:, :-1].contiguous()  # Remove last token from inputs
        padding_mask = padding_mask[:, :-1].contiguous() 
        
        # Forward pass
        optimizer.zero_grad()
        outputs = gpt(inputs, padding_mask)
        
        # Calculate loss (flatten outputs and targets for CrossEntropyLoss)
        loss = criterion(outputs.view(-1, outputs.size(-1)), targets.view(-1))
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        
        if batch_idx % 10 == 0:
            print(f'Epoch {epoch+1}, Batch {batch_idx}, Loss: {loss.item():.4f}')
            
        del input_ids
        del padding_mask
    
    avg_loss = total_loss / len(dataloader)
    print(f'Epoch {epoch+1} completed. Average loss: {avg_loss:.4f}')


Epoch 1, Batch 0, Loss: 10.8257
Epoch 1, Batch 10, Loss: 10.7973
Epoch 1, Batch 20, Loss: 10.7693
Epoch 1, Batch 30, Loss: 10.7368
Epoch 1, Batch 40, Loss: 10.6913
Epoch 1, Batch 50, Loss: 10.7254
Epoch 1, Batch 60, Loss: 10.6413
Epoch 1, Batch 70, Loss: 10.5845
Epoch 1, Batch 80, Loss: 10.5665
Epoch 1, Batch 90, Loss: 10.4863
Epoch 1, Batch 100, Loss: 10.4896
Epoch 1, Batch 110, Loss: 10.4523
Epoch 1, Batch 120, Loss: 10.3026
Epoch 1, Batch 130, Loss: 10.3068
Epoch 1, Batch 140, Loss: 10.2305
Epoch 1, Batch 150, Loss: 10.1026
Epoch 1, Batch 160, Loss: 10.0623
Epoch 1, Batch 170, Loss: 9.9200
Epoch 1, Batch 180, Loss: 9.8725
Epoch 1, Batch 190, Loss: 9.7316
Epoch 1, Batch 200, Loss: 9.6160
Epoch 1, Batch 210, Loss: 9.5261
Epoch 1, Batch 220, Loss: 9.4674
Epoch 1, Batch 230, Loss: 9.2690
Epoch 1, Batch 240, Loss: 8.9535
Epoch 1 completed. Average loss: 10.1454
Epoch 2, Batch 0, Loss: 8.9866
Epoch 2, Batch 10, Loss: 8.7392
Epoch 2, Batch 20, Loss: 8.7336
Epoch 2, Batch 30, Loss: 8.3581
E

In [5]:
import tiktoken

# Load the GPT-2 tokenizer
enc = tiktoken.get_encoding("gpt2")

# Get the vocabulary size
vocab_size = enc.n_vocab
print("GPT-2 vocabulary size:", vocab_size)

# Example sentence
sentence = "Hello world, I'm testing GPT-2 BPE!"

# Tokenize into token IDs
token_ids = enc.encode(sentence)
print("Token IDs:", token_ids)

# Decode back to string
decoded = enc.decode(token_ids)
print("Decoded text:", decoded)

# If you want tokens as strings
tokens = [enc.decode([tid]) for tid in token_ids]
print("Tokens:", tokens)

GPT-2 vocabulary size: 50257
Token IDs: [15496, 995, 11, 314, 1101, 4856, 402, 11571, 12, 17, 347, 11401, 0]
Decoded text: Hello world, I'm testing GPT-2 BPE!
Tokens: ['Hello', ' world', ',', ' I', "'m", ' testing', ' G', 'PT', '-', '2', ' B', 'PE', '!']


In [6]:

X = torch.ones(1,3,5)
X = torch.stack([X, torch.ones_like(X), torch.ones_like(X)], dim=1)
attention_mask = torch.tril(torch.ones_like(X))
print(attention_mask)


tensor([[[[1., 0., 0., 0., 0.],
          [1., 1., 0., 0., 0.],
          [1., 1., 1., 0., 0.]],

         [[1., 0., 0., 0., 0.],
          [1., 1., 0., 0., 0.],
          [1., 1., 1., 0., 0.]],

         [[1., 0., 0., 0., 0.],
          [1., 1., 0., 0., 0.],
          [1., 1., 1., 0., 0.]]]])


In [7]:
padding_max = torch.ones(3,3).masked_fill(torch.tril(torch.ones(3,3)) == 0, float('-inf'))
torch.softmax(padding_max, dim=1)

tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])

In [88]:
# More efficient implementation of self-attention
#The idea is that this results in fewer matrix multiplications and fewer read-writes

x = torch.rand((5,2,5))

W_k_v_q = torch.rand((3*3, 5))


print(x.shape)
W_k_v_q.unsqueeze(0).shape

out = x @ W_k_v_q.unsqueeze(0).transpose(2,1)
print(out)

X_q = out[:, :, :3]
X_k = out[:, :, 3:6]    
X_v = out[:, :, 6:] 

print(X_q.shape)
print(X_k.shape)
print(X_v.shape)

torch.Size([5, 2, 5])
tensor([[[0.7020, 0.3055, 1.3670, 0.7016, 0.2195, 0.8787, 0.9115, 0.9427,
          0.8876],
         [1.3896, 0.7785, 1.7930, 0.9667, 0.6785, 0.9026, 1.1906, 0.8887,
          1.1139]],

        [[1.3520, 0.9128, 1.3772, 0.9284, 0.7301, 0.8384, 1.1292, 0.6220,
          1.1598],
         [0.9522, 0.5236, 1.1161, 1.1226, 0.2809, 1.3056, 1.2456, 1.2239,
          1.4152]],

        [[0.9635, 0.4195, 1.4356, 0.8938, 0.3116, 1.0273, 1.0549, 1.0924,
          1.0963],
         [1.1926, 0.7369, 1.7594, 1.0814, 0.5485, 1.1620, 1.3647, 1.0864,
          1.3058]],

        [[1.3329, 0.5089, 2.0440, 1.0928, 0.4653, 1.1526, 1.2786, 1.3798,
          1.2063],
         [1.6583, 1.2565, 1.4734, 1.4685, 0.9297, 1.2580, 1.6590, 0.9659,
          1.6021]],

        [[1.5449, 0.8312, 2.0136, 1.3256, 0.6441, 1.3790, 1.5633, 1.3671,
          1.5699],
         [1.3842, 0.6683, 1.7953, 1.0887, 0.6273, 0.9163, 1.2133, 1.1131,
          1.0046]]])
torch.Size([5, 2, 3])
torch.Size([5, 2