In [1]:
# this is GPT2 model demo, with multi head attention, multiple transformer
# and all other stuff learned so far

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# use GPU
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

In [3]:
# hyper para for GPT2-124M
n_vocab = 50257 # GPT2 vocab size
embed_dim = 768 #embedding dim
seq_len = 1024 #max seq len
n_heads = 12 # attention heads
n_blocks = 12 # tranformer blocks
#each transformer block has 12 atention heads
batch_size = 8

class for multihead attention

In [4]:
class MultiHeadAttention(nn.Module):
    def __init__(self):
        
        super().__init__()
    
        #head dimensionality is embed_dim split across the heads
        self.num_heads = n_heads
        self.head_dim = embed_dim // n_heads
    
        # the three Q,K,V weight matrices are init as one, and are split inside attention eqn
        self.QKV = nn.Linear(embed_dim, 3*embed_dim, bias=True)
    
        #final linear projection merges the heads outputs
        self.W0 = nn.Linear(embed_dim, embed_dim, bias=True)
    
    def forward(self, x):
        # extract the dimension size of the inputs(token embedds)
        B, T, E = x.shape # [batch, tokens (or seq_len), embed_dim]
        

        #push data through Q,K and V in one concatenated matrix
        qkv = self.QKV(x) #[batch, seq_len, 3*embed]
        q,k,v = torch.split(qkv, E, dim=2) # each matrix is [B,T,E]

        # reshape to [B,T,nHeads, head_dim]
        # and then transpose to [B, nHeads, T, head_dim]
        q = q.view(B, T, self.num_heads, self.head_dim).transpose(1,2) #[B, num_heads, T, head_dim]
        k = k.view(B, T, self.num_heads, self.head_dim).transpose(1,2)
        v = v.view(B, T, self.num_heads, self.head_dim).transpose(1,2)

        # Pytorchs SDPA func handles multi head shapes
        out = F.scaled_dot_product_attention(q,k,v,is_causal=True)

        # recombine heads : (B,nHeads,T,head_dim) -> [B,T,E]
        out = out.transpose(1,2).reshape(B,T,E)
    

        #finallt apply linear mixing matrix
        out = self.W0(out)

        return out

Transfomer Block

In [5]:
class TransformerBlock(nn.Module):
    def __init__(self):
        super().__init__()

        #attention subblock
        self.layernorm_1 = nn.LayerNorm(embed_dim,eps=1e-5)
        self.attn = MultiHeadAttention()

        #feedfwd (MLP) sublayer
        self.layernorm_2 = nn.LayerNorm(embed_dim,eps=1e-5)
        self.mlp_1 = nn.Linear(embed_dim,4*embed_dim,bias=True) # 4x expansion
        self.gelu = nn.GELU()
        self.mlp_2 = nn.Linear(4*embed_dim, embed_dim, bias=True) #4x contraction

    def forward(self,x):

        ## ----attention sublayer ------##
        x_att = self.layernorm_1(x) # pre attn normalisn
        x_att = x + self.attn(x_att) # run through attention, then add pre attn activations

        #MLP
        x_ff = self.layernorm_2(x_att) # pre MLP normlsn
        x_ff = x_att + self.mlp_2( self.gelu( self.mlp_1(x_ff)))
        
        return x_ff
        
        

class for full model

In [6]:
# the full model class, which calls the previously defined classes
class LanguageModel(nn.Module):
    def __init__(self):
        super().__init__()

        # token + posn embedds
        self.wte = nn.Embedding(n_vocab, embed_dim) # token embedds
        self.wpe = nn.Embedding(seq_len, embed_dim) # posn embedds

        #n mutliple Transformer blocks
        # * is a unpacking operator, the list of txf blocks goes into input of Sequential()
        self.transformerBlocks = nn.Sequential(*[TransformerBlock() for _ in range(n_blocks)])

        # embedding to output (linear) layer
        self.layernorm_final = nn.LayerNorm(embed_dim,eps=1e-5) # final layernorm after all txf blocks
        #unembed matirx
        self.final_head = nn.Linear(embed_dim, n_vocab, bias=False)
        #final ouput layer (unembedd) tied to token embedd
        self.final_head.weight = nn.Parameter(self.wte.weight)

    def forward(self, idx):

        #----------embeddings-------------##
        token_emb = self.wte(idx)  # [B,T,E]   T is seq_len and E is embed_dim
        posit_emb = self.wpe(torch.arange(idx.shape[-1],device=device)) #[seq_len, embed_dim]
        x = token_emb + posit_emb #[B,T,E]
        ##--------------------------------##

        #n
        ##--pass through each transformer blocks----##
        x = self.transformerBlocks(x)
        ##-------------------------##

        #-----finally unembeddings----##
        x = self.layernorm_final(x)
        logits = self.final_head(x) # [B,T, n_vocab]
        # logits is [batch, seq_len, n_vocab]
        return logits

    def generate(self,idx,temperature=1.,max_new_tokens=50):
        for _ in range(max_new_tokens):
            # fwd passb
            logits = self(idx[:,-seq_len:]) # [B,T,n_vocab]   get preds, but only from past seq_len tokens 
            logits = logits[:,-1,:] #[B,n_vocab]   extract last tokens logitsto predict the next

            # apply softmax with temp to get prob values over all tokens in vocab - with temp
            probs = F.softmax(logits/temperature,dim=-1) #[B,n_vocab]

            #probabilistically sample next token from distbn
            idx_next = torch.multinomial(probs, num_samples=1) # [batch,1]
            
            #append 
            idx = torch.cat((idx, idx_next),dim=1) #[batch, (tokens+1)]
        return tokx

create an instance and test it out

In [7]:
model = LanguageModel().to(device)
model

LanguageModel(
  (wte): Embedding(50257, 768)
  (wpe): Embedding(1024, 768)
  (transformerBlocks): Sequential(
    (0): TransformerBlock(
      (layernorm_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (attn): MultiHeadAttention(
        (QKV): Linear(in_features=768, out_features=2304, bias=True)
        (W0): Linear(in_features=768, out_features=768, bias=True)
      )
      (layernorm_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (mlp_1): Linear(in_features=768, out_features=3072, bias=True)
      (gelu): GELU(approximate='none')
      (mlp_2): Linear(in_features=3072, out_features=768, bias=True)
    )
    (1): TransformerBlock(
      (layernorm_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (attn): MultiHeadAttention(
        (QKV): Linear(in_features=768, out_features=2304, bias=True)
        (W0): Linear(in_features=768, out_features=768, bias=True)
      )
      (layernorm_2): LayerNorm((768,), eps=1e-05, elementwise_affine=Tr

In [8]:
# running some fake data through
data = torch.randint(0,n_vocab,size=(batch_size,seq_len)).to(device)
out = model(data)
print(f'input size: {data.shape}')
print(f'Output size: {out.shape}')

input size: torch.Size([8, 1024])
Output size: torch.Size([8, 1024, 50257])


How many parameters do we have??

In [12]:
from torchinfo import summary

# summary of model and params
summary(model, input_data=data, col_names =['input_size','output_size', 'num_params'])

Layer (type:depth-idx)                   Input Shape               Output Shape              Param #
LanguageModel                            [8, 1024]                 [8, 1024, 50257]          --
├─Embedding: 1-1                         [8, 1024]                 [8, 1024, 768]            38,597,376
├─Embedding: 1-2                         [1024]                    [1024, 768]               786,432
├─Sequential: 1-3                        [8, 1024, 768]            [8, 1024, 768]            --
│    └─TransformerBlock: 2-1             [8, 1024, 768]            [8, 1024, 768]            --
│    │    └─LayerNorm: 3-1               [8, 1024, 768]            [8, 1024, 768]            1,536
│    │    └─MultiHeadAttention: 3-2      [8, 1024, 768]            [8, 1024, 768]            2,362,368
│    │    └─LayerNorm: 3-3               [8, 1024, 768]            [8, 1024, 768]            1,536
│    │    └─Linear: 3-4                  [8, 1024, 768]            [8, 1024, 3072]           2,362,368
│ 

In [15]:
# we actually dont have 163M params as shown above
# This is bcoz summary() doesnot know that unembedd matrix is tied to embedding matrix

In [14]:
print(f'Total trainable params = {163037184 - 38597376}')
#124M

Total trainable params = 124439808


In [13]:
class tempA:
    def __init__(self):
         self.a = batch_size
    def disp(self):
        print("Val is: ",self.a)

In [15]:
a=tempA()
a.disp()

Val is:  8


In [12]:
a.disp()

TypeError: tempA.disp() missing 1 required positional argument: 'self'