Colab Notebook

In [3]:
!pip install tiktoken

Collecting tiktoken
  Downloading tiktoken-0.9.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.7 kB)
Downloading tiktoken-0.9.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.2 MB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/1.2 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.2/1.2 MB[0m [31m6.9 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.2/1.2 MB[0m [31m17.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: tiktoken
Successfully installed tiktoken-0.9.0


Training

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
import tiktoken
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import os
from datetime import datetime

In [None]:

@dataclass
class Config:
    vocab_size: int = 50257
    max_seq_len: int = 2048
    dim: int = 768
    num_layers: int = 8
    num_heads: int = 12
    dropout: float = 0.1

class MultiHeadAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.n_head = config.num_heads
        self.n_embd = config.dim

        # Linear projections for Q, K, V
        self.c_attn = nn.Linear(config.dim, 3 * config.dim) # [n_embd, 3 * n_embd] # for q,k, v we are defined together and split later eg: instead of defining q,k,v separately 100*100,100*100,100*100 , we are defining as 100*300 and then split later.
        self.c_proj = nn.Linear(config.dim, config.dim) # [n_embd, n_embd]

        self.attn_dropout = nn.Dropout(config.dropout)
        self.resid_dropout = nn.Dropout(config.dropout)

    def forward(self, x):
        B, T, C = x.size() # [B, T, n_embd] Batch - B, Token_len - T, embedding dim - C

        # Linear projection and split into Q, K, V
        q, k, v = self.c_attn(x).split(self.n_embd, dim=2) # [B, T, n_embd] each

        # Reshape for multi-head attention
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # [B, n_head, T, n_embd/n_head]
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # [B, n_head, T, n_embd/n_head]
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # [B, n_head, T, n_embd/n_head]

        # Attention scores
        att = (q @ k.transpose(-2, -1)) * (1.0 / (k.size(-1) ** 0.5)) # [B, n_head, T, T]
        att = F.softmax(att, dim=-1) # [B, n_head, T, T]
        att = self.attn_dropout(att) # [B, n_head, T, T]

        # Weighted sum of values
        y = att @ v # [B, n_head, T, n_embd/n_head]

        # Reshape and project
        y = y.transpose(1, 2).contiguous().view(B, T, C) # [B, T, n_embd]
        y = self.c_proj(y) # [B, T, n_embd]
        y = self.resid_dropout(y) # [B, T, n_embd]

        return y

class FeedForward(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.c_fc = nn.Linear(config.dim, 4 * config.dim) # [n_embd, 4 * n_embd]
        self.c_proj = nn.Linear(4 * config.dim, config.dim) # [4 * n_embd, n_embd]
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x):
        x = self.c_fc(x) # [B, T, 4 * n_embd]
        x = F.gelu(x) # [B, T, 4 * n_embd]
        x = self.c_proj(x) # [B, T, n_embd]
        x = self.dropout(x) # [B, T, n_embd]
        return x

class TransformerBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln_1 = nn.LayerNorm(config.dim) # [n_embd]
        self.attn = MultiHeadAttention(config)
        self.ln_2 = nn.LayerNorm(config.dim) # [n_embd]
        self.mlp = FeedForward(config)

    def forward(self, x):
        x = x + self.attn(self.ln_1(x)) # [B, T, n_embd] #calls forward func of multihead attention
        x = x + self.mlp(self.ln_2(x)) # [B, T, n_embd]
        return x

class DecoderOnlyTransformer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.wte = nn.Embedding(config.vocab_size, config.dim) # [vocab_size, n_embd]
        self.wpe = nn.Embedding(config.max_seq_len, config.dim) # [max_seq_len, n_embd]
        self.drop = nn.Dropout(config.dropout)
        self.blocks = nn.ModuleList([TransformerBlock(config) for _ in range(config.num_layers)])
        self.ln_f = nn.LayerNorm(config.dim) # [n_embd]
        self.lm_head = nn.Linear(config.dim, config.vocab_size, bias=False) # [n_embd, vocab_size]

        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, (nn.Linear, nn.Embedding)):
            module.weight.data.normal_(mean=0.0, std=0.02)
            if isinstance(module, nn.Linear) and module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)

    def forward(self, idx):
        B, T = idx.size() # [B, T]  o/p-->torch.Size([4, 128])

        # Positional embeddings
        pos = torch.arange(0, T, dtype=torch.long, device=idx.device).unsqueeze(0) # [1, T]

        # Token and position embeddings
        tok_emb = self.wte(idx) # [B, T, n_embd]
        pos_emb = self.wpe(pos) # [1, T, n_embd]

        # Combine embeddings and apply dropout
        x = self.drop(tok_emb + pos_emb) # [B, T, n_embd]

        # Transformer blocks
        for block in self.blocks:
            x = block(x) # [B, T, n_embd] #calls forward method inside TransformerBlock

        # Final layer norm and linear projection
        x = self.ln_f(x) # [B, T, n_embd]
        logits = self.lm_head(x) # [B, T, vocab_size]

        return logits

class TextDataset(Dataset):
    def __init__(self, text, seq_len):
        self.text = text
        self.seq_len = seq_len
        self.num_samples = len(text) // seq_len  # Non-overlapping sequences

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        start = idx * self.seq_len
        return (
            self.text[start:start+self.seq_len],    # Input sequence
            self.text[start+1:start+self.seq_len+1] # Target sequence
        )

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def estimate_memory(model):
    param_size = sum(p.numel() * p.element_size() for p in model.parameters())
    buffer_size = sum(b.numel() * b.element_size() for b in model.buffers())
    total_size = (param_size + buffer_size) / (1024 ** 2)  # Convert to MB
    return total_size

def train(model, dataset, config, epochs, lr=1e-4, log_dir="logs"):

    os.makedirs(log_dir, exist_ok=True)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
    optimizer = optim.AdamW(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()

    print(f"Model Size: {count_parameters(model) / 1e6:.2f}M parameters")
    print(f"Estimated Model Memory Usage: {estimate_memory(model):.2f} MB")

    model.train()
    step = 0  # Initialize step counter

    for epoch in range(epochs):
        total_loss = 0
        for batch_idx, batch in enumerate(dataloader):
            step += 1  # Increment step counter

            #print(f"Step {step} (Epoch {epoch+1}, Batch {batch_idx+1})")

            inputs, targets = batch
            inputs, targets = inputs.to(device), targets.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)

            # print(f"Outputs shape: {outputs.shape}")  # Should be (batch_size, seq_len, vocab_size)
            # print(f"Targets shape: {targets.shape}")  # Should be (batch_size, seq_len)

            loss = criterion(outputs.view(-1, config.vocab_size), targets.view(-1))
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

            print(f"Epoch {epoch+1}/{epochs}, Batch {batch_idx+1}, Step {step}: Loss = {loss.item():.6f}")

        avg_loss = total_loss / len(dataloader)
        print(f" Epoch {epoch+1}/{epochs}, Avg Loss: {avg_loss:.6f}")

        if avg_loss < 0.099999:
            print("Early stopping as loss is below target")
            break

        # Save model checkpoint
        torch.save(model.state_dict(), f"{log_dir}/model_epoch_{epoch+1}.pt")
        print("Checkpoint saved..")

        # Log to file
        with open(f"{log_dir}/training_log.txt", "a") as log_file:
            log_file.write(f"{datetime.now()} - Step {step} - Epoch {epoch+1}/{epochs}, Avg Loss: {avg_loss:.6f}\n")


if __name__ == '__main__':
    config = Config()
    model = DecoderOnlyTransformer(config)

    with open('input.txt','r') as fp:
        text = fp.read()

    enc = tiktoken.get_encoding("gpt2")
    idx = torch.tensor(enc.encode(text), dtype=torch.long)
    #print(idx.shape)
    #print(idx)

    dataset = TextDataset(idx, seq_len=512)
    train(model, dataset, config, epochs=100)

    torch.save(model.state_dict(), "final_model.pt")
    print("Model training complete.")


Model Size: 135.47M parameters
Estimated Model Memory Usage: 516.79 MB
Epoch 1/100, Batch 1, Step 1: Loss = 10.984265
Epoch 1/100, Batch 2, Step 2: Loss = 9.836699
Epoch 1/100, Batch 3, Step 3: Loss = 9.416938
Epoch 1/100, Batch 4, Step 4: Loss = 9.196118
Epoch 1/100, Batch 5, Step 5: Loss = 8.933338
Epoch 1/100, Batch 6, Step 6: Loss = 9.046812
Epoch 1/100, Batch 7, Step 7: Loss = 8.860850
Epoch 1/100, Batch 8, Step 8: Loss = 8.885395
Epoch 1/100, Batch 9, Step 9: Loss = 8.777838
Epoch 1/100, Batch 10, Step 10: Loss = 8.579698
Epoch 1/100, Batch 11, Step 11: Loss = 8.542672
Epoch 1/100, Batch 12, Step 12: Loss = 8.391079
Epoch 1/100, Batch 13, Step 13: Loss = 8.531050
Epoch 1/100, Batch 14, Step 14: Loss = 8.376135
Epoch 1/100, Batch 15, Step 15: Loss = 8.302006
Epoch 1/100, Batch 16, Step 16: Loss = 8.169391
Epoch 1/100, Batch 17, Step 17: Loss = 8.161842
Epoch 1/100, Batch 18, Step 18: Loss = 8.148483
Epoch 1/100, Batch 19, Step 19: Loss = 7.996818
Epoch 1/100, Batch 20, Step 20: Lo

Inference

In [None]:
# Load the model
model = DecoderOnlyTransformer(config)
model.load_state_dict(torch.load("final_model.pt"))
model.eval()  # Set the model to evaluation mode

  model.load_state_dict(torch.load("/content/logs/final_model.pt"))


DecoderOnlyTransformer(
  (wte): Embedding(50257, 768)
  (wpe): Embedding(2048, 768)
  (drop): Dropout(p=0.1, inplace=False)
  (blocks): ModuleList(
    (0-7): 8 x TransformerBlock(
      (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (attn): MultiHeadAttention(
        (c_attn): Linear(in_features=768, out_features=2304, bias=True)
        (c_proj): Linear(in_features=768, out_features=768, bias=True)
        (attn_dropout): Dropout(p=0.1, inplace=False)
        (resid_dropout): Dropout(p=0.1, inplace=False)
      )
      (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (mlp): FeedForward(
        (c_fc): Linear(in_features=768, out_features=3072, bias=True)
        (c_proj): Linear(in_features=3072, out_features=768, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)

In [26]:
# Example input text
input_text = "More learned than the ears--waving thy head, Which often, thus, correcting thy stout heart,"

# Encode the input text
enc = tiktoken.get_encoding("gpt2")
input_ids = torch.tensor(enc.encode(input_text), dtype=torch.long).unsqueeze(0)  # Add batch dimension

In [27]:
# Forward pass
with torch.no_grad():  # Disable gradient calculation for inference
    logits = model(input_ids)

# Get the predicted token indices
predicted_indices = torch.argmax(logits, dim=-1)
print(predicted_indices)

# Decode the predicted indices to text
predicted_text = enc.decode(predicted_indices.view(-1).tolist())  # Flattens the tensor
print("Predicted text:", predicted_text)

tensor([[ 4499,   621,   262, 11368,   438,    86,  2703, 11906,  1182,    11,
          9022,  1690,    11,  4145,    11,  3613, 11906, 39171,  2612,    11,
          4145]])
Predicted text:  learned than the ears--waving thy head, Which often, thus, save thy stout heart, thus
