# Decoder only Transformer / GPT 

**Components** 
1. Masked MHA (Multi head Attention) : Attention Mechanism 
2. Feed Forward Netwwork : Two nn.Linear with activation
3. Layer Norm : handles vanishing gradients / 0 mean 1 variance / Normalizes activation
4. Residual connections : Adds inputs to outputs (x + f(x))

In [2]:
import torch
import torch.nn as nn 
import torch.nn.functional as F 
import numpy as np

## Multi head attention

In [10]:
class MHA(nn.Module):
    """Multi head attention (masked)""" 
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0

        self.d_model = d_model 
        self.num_heads = num_heads 
        self.d_k = d_model // num_heads 

        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)
        self.W_o = nn.Linear(d_model, d_model, bias=False)

    def forward(self, x, mask=None):

        batch, seq, d_model = x.shape

        Q = self.W_q(x)
        K = self.W_k(x)
        V = self.W_v(x)

        Q = Q.view(batch, seq, self.num_heads, self.d_k).transpose(1,2) # seq for each head
        K = K.view(batch, seq, self.num_heads, self.d_k).transpose(1,2) 
        V = V.view(batch, seq, self.num_heads, self.d_k).transpose(1,2) 

        scores = torch.matmul(Q , K.transpose(-2, -1)) / np.sqrt(self.d_k)

        if mask is not None: 
            # mask == float("-inf") creates a boolean tensor of shape mask, where mask has value -inf, it is set to True
            # then wherever condition is set to true, -inf it
            # masked fill needs boolean value. 
            scores = scores.masked_fill(mask == float("-inf"), float("-inf"))
        
        attn_weights = F.softmax(scores, dim=-1)
        # print(f"attn weights sum : {attn_weights.sum(dim=-1)}")
        attn_outputs = torch.matmul(attn_weights, V)

        attn_output = attn_outputs.transpose(1, 2).contiguous() # merging heads
        attn_output = attn_output.view(batch, seq, d_model)

        out = self.W_o(attn_output)
        return out


In [11]:
d_model = 64
heads = 8
seq = 10
batch = 4


causal_mask = torch.triu(torch.ones(seq, seq), diagonal=1)
# print(causal_mask)
causal_mask = causal_mask.masked_fill(causal_mask == 1, float("-inf"))
# print(causal_mask)

mha = MHA(d_model, heads)
x = torch.randn(batch, seq, d_model)
out = mha(x)

print(f"x.shape : {x.shape}")
print(f"out.shape : {out.shape}")


x.shape : torch.Size([4, 10, 64])
out.shape : torch.Size([4, 10, 64])


## FeedForward Network

In [17]:
class FFN(nn.Module):
    """Feed Forward Network"""
    # enhances model's capability to learn by expanding and then compressing dims
    # kind of enriches attention output, makes them more expressive, richer representation of tokens
    def __init__(self, d_model, d_ff):
        super().__init__()
        
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.relu = nn.ReLU()
    
    def forward(self, x):
        return self.linear2(self.relu(self.linear1(x)))


In [18]:
batch = 2
seq = 10
d_ff = 256 
d_model = 64

ff = FFN(d_model, d_ff)
x = torch.randn(batch, seq, d_model)
out = ff(x)

print(f"out.shape : {out.shape}")

out.shape : torch.Size([2, 10, 64])


## Decoder Block

In [19]:
class DecoderBlock(nn.Module):
    "decoder block, combines mha, ffn, layer norm, residual connections"

    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()

        self.attention = MHA(d_model, num_heads)
        self.feedforward = FFN(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, mask=None):
        # attn with residual and norm
        attn_out = self.attention(x, mask)
        x = self.norm1(x + self.dropout(attn_out))

        # ffn with residual and norm 
        ff_o = self.feedforward(x)
        x = self.norm2(x + self.dropout(ff_o))

        return x

In [20]:
d_model = 64 
d_ff = 256
seq - 10 
batch = 2 
heads = 8 

causal_mask = torch.triu(torch.ones(seq, seq), diagonal=1) # True at 1s
causal_mask = causal_mask.masked_fill(causal_mask == 1, float("-inf")) # -inf at True

decoder_block = DecoderBlock(d_model, heads, d_ff)
x = torch.randn(batch, seq, d_model)
out = decoder_block(x, causal_mask)

print(f"x shape : {x.shape}")
print(f"out shape : {out.shape}")

x shape : torch.Size([2, 10, 64])
out shape : torch.Size([2, 10, 64])


# Decoder Only Transformer - implemented
Stacking multiple decoder blocks 

In [23]:
class DecoderTransformer(nn.Module):
    """Decoder only Transformer / GPT """
    def __init__(self, vocab_size, d_model, num_heads, num_layers, d_ff, seq, dropout=0.1):
        super().__init__()

        self.token_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb = nn.Embedding(seq, d_model) # seq is max_seq_len
        # stacking blocks
        self.blocks = nn.ModuleList([DecoderBlock(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])

        self.ln = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, vocab_size, bias=False) # output projection [batch, seq, vocab_size]

        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        """x: Token [batch, seq]"""

        batch, seq = x.shape

        mask = torch.triu(torch.ones(seq, seq), diagonal=1)
        mask = mask.masked_fill(mask == 1, float("-inf"))

        token_emb = self.token_emb(x) # batch, seq, d_model / maps token id to vector of size d_model
        positions = torch.arange(seq).unsqueeze(0) # adds batch dim to positions
        pos_emb = self.pos_emb(positions) # 1, seq, d_model / maps positoin to vector of size d_model

        x = self.dropout(token_emb + pos_emb) # brodcast pos_emb over batch [batch, seq, d_model] / dropout for regularization

        for block in self.blocks:
            x = block(x, mask)
        
        x = self.ln(x)
        logits = self.head(x) # gives raw unnormalized scores which are converted to probabs called logits 

        return logits 

In [26]:
vocab = 500
d_model = 64 
d_ff = 256
heads = 8 
num_layers = 4 
seq_len = 128 # max_seq_len 

model = DecoderTransformer(vocab, d_model, heads, num_layers, d_ff, seq_len, dropout=0.1)

batch = 2 
seq = 20 # indices
x = torch.randint(0, vocab, (batch, seq))
print(f"x: {x.shape}")
# print(x)

logits = model(x)
print(f"logits shape : {logits.shape}")
print(f"total params: {sum(p.numel() for p in model.parameters())}")

x: torch.Size([2, 20])
logits shape : torch.Size([2, 20, 500])
total params: 271232


implemented decoder only transformer / GPT style