<a href="https://colab.research.google.com/github/syedmahmoodiagents/transformers/blob/main/GPT_using_Transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
class GPTBlock(nn.Module):
    def __init__(self, embed_dim=768, num_heads=12, ffn_hidden=3072, dropout=0.1):
        super().__init__()

        # Causal Self-Attention
        self.self_attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)

        # Feed Forward
        self.linear1 = nn.Linear(embed_dim, ffn_hidden)
        self.linear2 = nn.Linear(ffn_hidden, embed_dim)

        # LayerNorm + Dropout
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, attn_mask=None):
        # Self-attention with causal mask
        attn_out, _ = self.self_attn(x, x, x, attn_mask=attn_mask)

        # Residual + Norm
        x = self.norm1(x + self.dropout(attn_out))

        # Feed-forward
        ffn_out = self.linear2(F.gelu(self.linear1(x)))

        # Residual + Norm
        x = self.norm2(x + self.dropout(ffn_out))
        return x


In [None]:
class MiniGPT(nn.Module):
    def __init__(self, vocab_size, max_len=512, embed_dim=768, num_heads=12, num_layers=12, ffn_hidden=3072, dropout=0.1):
        super().__init__()
        # Token + Positional embeddings
        # vocab_size means all the unique words
        # max_len means size of the biggest sentence
        self.token_emb = nn.Embedding(vocab_size, embed_dim)
        self.pos_emb = nn.Embedding(max_len, embed_dim)

        # Transformer decoder layers
        self.layers = nn.ModuleList([
            GPTBlock(embed_dim, num_heads, ffn_hidden, dropout)
            for _ in range(num_layers)
        ])

        self.norm = nn.LayerNorm(embed_dim)
        self.lm_head = nn.Linear(embed_dim, vocab_size, bias=False)  # LM prediction head

    def forward(self, input_ids):
        B, L = input_ids.shape

        # Token + Pos embeddings
        tok_emb = self.token_emb(input_ids)
        pos = torch.arange(L, device=input_ids.device).unsqueeze(0)
        pos_emb = self.pos_emb(pos)
        x = tok_emb + pos_emb

        # Create causal mask (prevent attending to future tokens)
        # Shape: (L, L), with -inf above diagonal
        attn_mask = torch.full((L, L), float('-inf'), device=input_ids.device)
        attn_mask = torch.triu(attn_mask, diagonal=1)

        # Pass through GPT blocks
        for layer in self.layers:
            x = layer(x, attn_mask)

        x = self.norm(x)
        logits = self.lm_head(x)  # (B, L, vocab_size)
        return logits


In [None]:
# Example usage
vocab_size = 30522
model = MiniGPT(vocab_size=vocab_size, num_layers=2)  # small GPT-1

input_ids = torch.randint(0, vocab_size, (3, 10))  # batch of 3 sentences, len=10
logits = model(input_ids)

print(logits.shape)  # (3, 10, 30522) → next-token prediction


torch.Size([3, 10, 30522])
