In [1]:
# multiple transformer block

In [2]:
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F

In [3]:
#use GPT2 tokenizer
from transformers import GPT2Tokenizer
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
# data hyperparas
seq_len = 8 # aka context length
n_vocab = tokenizer.vocab_size #n

#model hyperparas
embed_dim = 128
nTransformerBlocks = 12

batch_size=5

In [5]:
# create one attention head
class OneAttentionHead(nn.Module):
    def __init__(self,embed_dim):
        super().__init__()

        # create q,k,v matrices
        self.key = nn.Linear(embed_dim,embed_dim, bias=False)
        self.query = nn.Linear(embed_dim,embed_dim, bias=False)
        self.value = nn.Linear(embed_dim,embed_dim, bias=False)
        self.W0 = nn.Linear(embed_dim,embed_dim, bias=False)

    def forward(self,x):
        #run the token embedd vectors through attention
        k = self.key(x)
        q = self.query(x)
        v = self.value(x)
        y = F.scaled_dot_product_attention(q,k,v,is_causal=True)
        #is_causal make sures the time causal mask is included in calculations
        y = self.W0(y) #linear transfor

        return y
    

#n Transformer block

In [6]:
class TransformerBlock(nn.Module):
    def __init__(self,embed_dim):
        super().__init__()

        #attention sublayer
        self.layerNormAttn = nn.LayerNorm(embed_dim)
        self.attn = OneAttentionHead(embed_dim)

        #feedfwd (MLP) sublayer
        self.layerNormMLP = nn.LayerNorm(embed_dim)
        self.W1 = nn.Linear(embed_dim,4*embed_dim) # 4x expansion
        self.gelu = nn.GELU()
        self.W2 = nn.Linear(4*embed_dim, embed_dim) #4x contraction

    def forward(self,x):

        ## ----attention sublayer ------##
        x = x + self.attn(self.layerNormAttn(x))
        # ------------------------------#

        # --------MLP sublayer -------#
        y = x + self.W2(self.gelu(self.W1(self.layerNormMLP(x))))
        #-----------------------------â€“#
        # y is [batch_size, seq_len, embed_dim]

        # y either goes to next transformer block
        # or to final unembedding matrix and then to token and text
        
        return y
        
        

The full model

In [13]:
# the full model class, which calls the previously defined classes
class LanguageModel(nn.Module):
    def __init__(self, nTransformerBlocks, embed_dim):
        super().__init__()

        # embedding matrices
        self.embedding = nn.Embedding(n_vocab, embed_dim)
        self.positions = nn.Embedding(seq_len, embed_dim)

        #n mutliple Transformer blocks
        # * is a unpacking operator, the list of txf blocks goes into input of Sequential()
        self.transformerBlocks = nn.Sequential(*[TransformerBlock(embed_dim) for _ in range(nTransformerBlocks)])
        # self.transformerBlocks is a Pytorch Sequential object that contain 12 txf blcoks

        # embedding to output (linear) layer
        self.finalLayerNorm = nn.LayerNorm(embed_dim) # final layernorm after all txf blocks
        self.finalLinear = nn.Linear(embed_dim, n_vocab, bias=False)

        #final ouput layer (unembedd) tied to token embedd
        self.finalLinear.weight = nn.Parameter(self.embedding.weight)

    def forward(self, tokx):

        #----------embeddings-------------##
        token_embed = self.embedding(tokx) 
        posit_embed = self.positions(torch.arange(tokx.shape[-1])) #[seq_len, embed_dim]
        x = token_embed + posit_embed #[batch, seq_len,embed_dim]
        ##--------------------------------##

        #n
        ##--transformer blocks----##
        x = self.transformerBlocks(x)
        ##-------------------------##

        #-----finally unembeddings----##
        x = self.finalLayerNorm(x)
        x = self.finalLinear(x)
        # x is [batch, seq_len, n_vocab]
        return x

    def generate(self,tokx,temperature=1.,n_new_tokens=50):
        for _ in range(n_new_tokens):
            x = self(tokx[:,-seq_len:]) # get preds, but only from past seq_len tokens
            x = x[:,-1,:] #extract final token to predict the next

            # apply softmaxt to get prob values over all tokens in vocab - with temp
            probs = F.softmax(x/temperature,dim=-1)

            #probabilistically sample from distbn
            tokx_next = torch.multinomial(probs, num_samples=1) # [batch,1]
            
            #append 
            tokx = torch.cat((tokx, tokx_next),dim=1) #[batch, (tokens+1)]
        return tokx

create a model instance and inspect

In [14]:
llm = LanguageModel(nTransformerBlocks,embed_dim)
llm

LanguageModel(
  (embedding): Embedding(50257, 128)
  (positions): Embedding(8, 128)
  (transformerBlocks): Sequential(
    (0): TransformerBlock(
      (layerNormAttn): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (attn): OneAttentionHead(
        (key): Linear(in_features=128, out_features=128, bias=False)
        (query): Linear(in_features=128, out_features=128, bias=False)
        (value): Linear(in_features=128, out_features=128, bias=False)
        (W0): Linear(in_features=128, out_features=128, bias=False)
      )
      (layerNormMLP): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (W1): Linear(in_features=128, out_features=512, bias=True)
      (gelu): GELU(approximate='none')
      (W2): Linear(in_features=512, out_features=128, bias=True)
    )
    (1): TransformerBlock(
      (layerNormAttn): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (attn): OneAttentionHead(
        (key): Linear(in_features=128, out_features=128, bias=False)
 

In [15]:
llm.transformerBlocks[2]

TransformerBlock(
  (layerNormAttn): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  (attn): OneAttentionHead(
    (key): Linear(in_features=128, out_features=128, bias=False)
    (query): Linear(in_features=128, out_features=128, bias=False)
    (value): Linear(in_features=128, out_features=128, bias=False)
    (W0): Linear(in_features=128, out_features=128, bias=False)
  )
  (layerNormMLP): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  (W1): Linear(in_features=128, out_features=512, bias=True)
  (gelu): GELU(approximate='none')
  (W2): Linear(in_features=512, out_features=128, bias=True)
)

In [17]:
llm.transformerBlocks[2].attn

OneAttentionHead(
  (key): Linear(in_features=128, out_features=128, bias=False)
  (query): Linear(in_features=128, out_features=128, bias=False)
  (value): Linear(in_features=128, out_features=128, bias=False)
  (W0): Linear(in_features=128, out_features=128, bias=False)
)

In [19]:
llm.transformerBlocks[2].attn.query.weight.detach()

tensor([[-0.0487, -0.0691,  0.0671,  ..., -0.0229, -0.0550, -0.0579],
        [ 0.0035,  0.0604, -0.0296,  ..., -0.0209, -0.0724, -0.0425],
        [-0.0564, -0.0723, -0.0847,  ..., -0.0562,  0.0730,  0.0602],
        ...,
        [ 0.0515,  0.0644,  0.0506,  ..., -0.0121, -0.0835,  0.0549],
        [-0.0323,  0.0249, -0.0067,  ...,  0.0657,  0.0785,  0.0063],
        [-0.0679, -0.0389, -0.0240,  ..., -0.0848, -0.0562, -0.0834]])

In [20]:
# create data
tokens = tokenizer.encode('I prefer oat milk in my coffee.')
X = torch.tensor(tokens[:-1]).unsqueeze(0) #unsqueeze helps to have first dim as batch
y = torch.tensor(tokens[1:]).unsqueeze(0)

print(X.shape)
print(y.shape)

torch.Size([1, 8])
torch.Size([1, 8])


In [21]:
out = llm(X)
out.shape

torch.Size([1, 8, 50257])