In [1]:
import torch
from torch import nn
from torch.nn import functional as F
from tqdm import tqdm
import math
torch.manual_seed(1234)

# use GPU if available
device = 'cuda' if torch.cuda.is_available() else 'cpu' 

#### A simple/minimal implementation of the BERT model (https://arxiv.org/pdf/1810.04805v2.pdf). 

The transformer block and multihead attention layer implementations are based on the Andrej Karpathy GPT youtube tutorial. In this case, we use a transformer encoder block which uses bi-directional context, differing from the transformer decoder in GPT which is unidirectional (achieved via causal masking of attention weights).

In [None]:
import math

class MultiHeadAttention(nn.Module):
    def __init__(self, block_size, embedding_dim, total_head_size, num_heads, dropout_rate):
        super().__init__()

        assert total_head_size % num_heads == 0, "head_size needs to be integer multiple of num_heads"

        self.block_size = block_size
        self.embedding_dim = embedding_dim
        self.total_head_size = total_head_size 
        self.head_size = total_head_size // num_heads 
        self.num_heads = num_heads
        self.dropout_rate = dropout_rate

        # define parameters
        self.key = nn.Linear(embedding_dim, self.total_head_size, bias=False)
        self.query = nn.Linear(embedding_dim, self.total_head_size, bias=False)
        self.value = nn.Linear(embedding_dim, self.total_head_size, bias=False)
        self.attn_dropout = nn.Dropout(dropout_rate)

        # non-parameter tensor of lower triangular ones
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

        # we also need to apply a linear projection to make the output residual the same dimension as the input
        self.proj = nn.Linear(total_head_size, embedding_dim) 
        self.output_dropout = nn.Dropout(dropout_rate)


    # define forward pass, input shape: (B,T,C) where B=batch size, T=block_size, C=embedding_dim
    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x) # (B,T,H) where H is the total_head_size
        q = self.query(x) # (B,T,H)
        v = self.value(x) # (B,T,H)

        # reshape (B,T,H) --> (B,T,n,h), where n=num_heads and h=head_size and H=n*h
        k = k.view(B,T,self.num_heads,self.head_size) 
        q = q.view(B,T,self.num_heads,self.head_size) 
        v = v.view(B,T,self.num_heads,self.head_size) 

        # now we transpose so that the num_heads is the second dimension followed by T,h
        # this allows us to batch matrix mutliply for all heads simulataneously to compute their attention weights
        # (B,T,n,h) --> (B,n,T,h) 
        k = k.transpose(1,2) 
        q = q.transpose(1,2)
        v = v.transpose(1,2)

        # compute attention scores manually (slower)
        '''
        W = q @ k.transpose(-2,-1)  / math.sqrt(self.head_size) # (B,n,T,T)
        W = W.masked_fill(self.tril[:T, :T] == 0, float('-inf')) 
        W = F.softmax(W, dim=-1)
        # apply dropout to attention weights
        W = self.attn_dropout(W)
        out = W @ v # (B,n,T,h)
        '''

        # use pytorch built-in function for faster computation of attention scores (set the 'is_causal' parameter for applying causal masking)
        out = F.scaled_dot_product_attention(q,k,v,dropout_p=self.dropout_rate if self.training else 0,is_causal=True)

        # we can transpose the output from (B,n,T,h) --> (B,T,n,h)
        # since the last two dimensions of the transposed tensor are non-contiguous, we apply 
        # contiguous() which return a contiguous tensor
        out = out.transpose(1,2).contiguous()

        # finally we collapse the last two dimensions to get the concatenated output, (B,T,n,h) --> (B,T,n*h) 
        out = out.view(B,T,self.total_head_size)

        # now we project the concatenated output so that it has the same dimensions as the multihead attention layer input
        # (we need to add it with the input because of the residual connection, so need to be same size) 
        out = self.proj(out) # (B,T,C) 

        # apply dropout
        out = self.output_dropout(out)

        return out
    

# a simple mlp 
class FeedForward(nn.Module):
    def __init__(self, embedding_dim, dropout_rate):
        super().__init__()
        # we add extra computations by growing out the feed-forward hidden size by a factor of 4
        # we also add an extra linear layer at the end to project the residual back to same dimensions as input
        self.net = nn.Sequential(
            nn.Linear(embedding_dim, 4*embedding_dim),  
            nn.ReLU(),
            nn.Linear(4*embedding_dim, embedding_dim), 
            nn.Dropout(dropout_rate)
        )
    
    # in the forward pass, concatenate the outputs from all the attention heads
    def forward(self, x):
        return self.net(x)
    

# transformer block with residual connection and layer norm
class TransformerBlock(nn.Module):
    def __init__(self, block_size, embedding_dim, head_size, num_heads, dropout_rate):
        super().__init__()
        self.sa = MultiHeadAttention(block_size, embedding_dim, head_size, num_heads, dropout_rate) # multi-head attention layer 
        self.ff = FeedForward(embedding_dim, dropout_rate)   # feed-forward layer
        self.ln1 = nn.LayerNorm(embedding_dim) # layer norm at input of multi-head attention
        self.ln2 = nn.LayerNorm(embedding_dim) # layer norm at input of feed-forward

    # in the forward pass, concatenate the outputs from all the attention heads
    def forward(self, x):
        # residual connection between input and multi-head attention output
        x = x + self.sa(self.ln1(x))
        # residual connection between multi-head attention output and feed-forward output
        x = x + self.ff(self.ln2(x)) 
        return x
    

# language model with multiple transformer blocks
class TransformerLanguageModel(nn.Module):
    def __init__(self, vocab_size, block_size, embedding_dim, head_size, num_heads, num_blocks, dropout_rate=0.2):
        super().__init__()

        self.vocab_size = vocab_size
        self.block_size = block_size
        self.embedding_dim = embedding_dim
        self.head_size = head_size
        self.hum_heads = num_heads
        self.num_blocks = num_blocks

        '''
        Define model parameters
        '''
        # token embedding layer 
        self.token_embedding = nn.Embedding(vocab_size, embedding_dim) # shape: (vocab_size,C)
        # position embedding layer
        self.pos_embedding = nn.Embedding(block_size, embedding_dim) # shape: (T,C)

        # stack of transformer blocks
        self.blocks = nn.Sequential(*[TransformerBlock(block_size, embedding_dim, head_size, num_heads, dropout_rate) for _ in range(num_blocks)])

        # we also add a layer norm before the final output layer
        self.ln_f = nn.LayerNorm(embedding_dim)

        # output layer logits
        self.lm_head = nn.Linear(head_size, vocab_size) # shape: (h,vocab_size)


        # forward pass takes in a batch of input token sequences of shape (B,T) and corresponding targets of shape (B,T)
    def forward(self, idx, targets=None):
        B, T =idx.shape
        # get token embeddings
        token_embeds = self.token_embedding(idx) # (B,T,C)
        # add positional encoding
        pos_embeds = self.pos_embedding(torch.arange(T, device=device)) # (T,C) 
        x = token_embeds + pos_embeds # (B,T,C)
        # pass through transformer blocks
        x = self.blocks(x) # (B,T,C)
        # apply layer norm
        x = self.ln_f(x)  # (B,T,C)
        # compute output logits 
        logits = self.lm_head(x) # (B,T,vocab_size)

        loss = None
        if targets is not None:
            B,T,vocab_size = logits.shape
            # reshape the logits and targets such that batch of input sequences are flattened into a single big input sequence
            # i.e. (B,T) --> (B*T)
            logits = logits.view(B*T,vocab_size) # reshaped to (B*T,vocab_size)
            targets = targets.view(B*T) # reshaped to (B*T)
            # compute cross entropy loss (i.e. average negative log likelihood)
            loss = F.cross_entropy(logits, targets)
        return logits, loss
    

    # generates new sequences continuing from a given batch of context tokens
    @torch.no_grad()
    def generate(self, idx, max_new_tokens):
        self.eval() # swicth to inference mode
        # batch of contexts, idx has shape (B,T)
        for _ in range(max_new_tokens):
            # since we're using positional encoding, we need to crop idx if input sequence length exceeds block size (keep last block_size tokens)
            idx_crop = idx[:,-self.block_size:] 
            # get predictions
            logits, _ = self(idx_crop) # shape: (B,T,C)
            # for each context sequence (in the batch), compute the probability of the next token using the logits of the last token in the context sequence
            logits = logits[:,-1,:] # shape: (B,C)
            probs = F.softmax(logits, dim=-1) 
            # sample from the probability distribution to get next token
            idx_next = torch.multinomial(probs, num_samples=1) # shape: (B,1)
            # append to the current context
            idx = torch.cat((idx, idx_next), dim=1) # shape: (B,T+1)

        self.train() # swicth to train mode

        return idx