In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Set device
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Hyperparameters
vocab_size = 1000  # Suppose vocab size is 1000
embed_dim = 64     # Embedding dimension
num_heads = 4      # Number of attention heads
num_layers = 2     # Number of Transformer blocks
block_size = 128   # Max sequence length
ffn_hidden = 256   # Hidden layer size in feed-forward network

# 1. Self Attention Head
class SelfAttention(nn.Module):
    def __init__(self, embed_dim, head_size):
        super().__init__()
        self.key = nn.Linear(embed_dim, head_size, bias=False)
        self.query = nn.Linear(embed_dim, head_size, bias=False)
        self.value = nn.Linear(embed_dim, head_size, bias=False)
        self.dropout = nn.Dropout(0.1)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x)   # (B,T,head_size)
        q = self.query(x) # (B,T,head_size)
        wei = q @ k.transpose(-2, -1) * C**-0.5  # (B,T,T)
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        wei = F.softmax(wei, dim=-1)
        wei = self.dropout(wei)

        v = self.value(x)
        out = wei @ v  # (B,T,head_size)
        return out

# 2. Multi-Head Attention
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        head_size = embed_dim // num_heads
        self.heads = nn.ModuleList([SelfAttention(embed_dim, head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(embed_dim, embed_dim)
        self.dropout = nn.Dropout(0.1)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out

# 3. Feed Forward Network
class FeedForward(nn.Module):
    def __init__(self, embed_dim, ffn_hidden):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(embed_dim, ffn_hidden),
            nn.ReLU(),
            nn.Linear(ffn_hidden, embed_dim),
            nn.Dropout(0.1),
        )

    def forward(self, x):
        return self.net(x)

# 4. Transformer Block
class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, ffn_hidden):
        super().__init__()
        self.sa = MultiHeadAttention(embed_dim, num_heads)
        self.ffwd = FeedForward(embed_dim, ffn_hidden)
        self.ln1 = nn.LayerNorm(embed_dim)
        self.ln2 = nn.LayerNorm(embed_dim)

    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x

# 5. GPT Model
class SimpleGPT(nn.Module):
    def __init__(self):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, embed_dim)
        self.position_embedding_table = nn.Embedding(block_size, embed_dim)
        self.blocks = nn.Sequential(*[TransformerBlock(embed_dim, num_heads, ffn_hidden) for _ in range(num_layers)])
        self.ln_f = nn.LayerNorm(embed_dim)  # final layer norm
        self.head = nn.Linear(embed_dim, vocab_size)  # output head

    def forward(self, idx, targets=None):
        B, T = idx.shape

        tok_emb = self.token_embedding_table(idx)      # (B,T,C)
        pos_emb = self.position_embedding_table(torch.arange(T, device=idx.device))  # (T,C)
        x = tok_emb + pos_emb                          # (B,T,C)

        x = self.blocks(x)                             # (B,T,C)
        x = self.ln_f(x)                               # (B,T,C)
        logits = self.head(x)                          # (B,T,vocab_size)

        if targets is None:
            loss = None
        else:
            logits = logits.view(-1, logits.size(-1))  # (B*T,vocab_size)
            targets = targets.view(-1)                 # (B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

# Create model
model = SimpleGPT().to(device)
print(f"Simple GPT Model has {sum(p.numel() for p in model.parameters())/1e6:.2f}M parameters.")

# Dummy Input Example
idx = torch.randint(0, vocab_size, (2, 20)).to(device)  # (batch_size, sequence_length)
targets = torch.randint(0, vocab_size, (2, 20)).to(device)
logits, loss = model(idx, targets)
print("Logits shape:", logits.shape)  # (2, 20, vocab_size)
print("Loss:", loss.item())
