In [2]:
import torch
import torch.nn as nn
import math

# Let's define our configuration class first - it's like the blueprint for our BERT model
class BertConfig:
    def __init__(self,
                 vocab_size=30522,         # How many words we know
                 hidden_size=768,          # Size of our hidden layers
                 num_hidden_layers=12,     # How many transformer layers to stack
                 num_attention_heads=12,   # Number of attention heads
                 intermediate_size=3072,   # Size of the feed-forward network
                 hidden_dropout_prob=0.1,  # Dropout probability
                 attention_probs_dropout_prob=0.1,
                 max_position_embeddings=512,
                 type_vocab_size=2):       # Usually 2 for sentence A/B
        # Just storing all our config values
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads
        self.intermediate_size = intermediate_size
        self.hidden_dropout_prob = hidden_dropout_prob
        self.attention_probs_dropout_prob = attention_probs_dropout_prob
        self.max_position_embeddings = max_position_embeddings
        self.type_vocab_size = type_vocab_size

# This handles turning our input tokens into proper embeddings
class BertEmbeddings(nn.Module):
    def __init__(self, config):
        super().__init__()
        # Three types of embeddings we'll combine
        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
        self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)

        # Normalization and dropout for stability
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=1e-12)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, input_ids, token_type_ids=None):
        # If we don't get token types, assume they're all zeros
        if token_type_ids is None:
            token_type_ids = torch.zeros_like(input_ids)

        # Generate position IDs for each token in sequence
        seq_length = input_ids.size(1)
        position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
        position_ids = position_ids.unsqueeze(0).expand_as(input_ids)

        # Get all our embeddings and add them together
        words = self.word_embeddings(input_ids)
        positions = self.position_embeddings(position_ids)
        token_types = self.token_type_embeddings(token_type_ids)

        embeddings = words + positions + token_types

        # Normalize and apply dropout
        embeddings = self.LayerNorm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings

# The heart of the attention mechanism
class BertSelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        # Make sure hidden size works with number of heads
        if config.hidden_size % config.num_attention_heads != 0:
            raise ValueError("Hey, the hidden size needs to be divisible by number of attention heads!")

        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = config.hidden_size // config.num_attention_heads
        self.all_head_size = config.hidden_size

        # Linear layers for Q, K, V
        self.query = nn.Linear(config.hidden_size, self.all_head_size)
        self.key = nn.Linear(config.hidden_size, self.all_head_size)
        self.value = nn.Linear(config.hidden_size, self.all_head_size)
        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)

    def transpose_for_scores(self, x):
        # Reshape for multi-head attention
        new_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(*new_shape)
        return x.permute(0, 2, 1, 3)

    def forward(self, hidden_states, attention_mask=None):
        # Calculate Q, K, V and reshape them
        query_layer = self.transpose_for_scores(self.query(hidden_states))
        key_layer = self.transpose_for_scores(self.key(hidden_states))
        value_layer = self.transpose_for_scores(self.value(hidden_states))

        # Attention scores - the magic happens here!
        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
        attention_scores = attention_scores / math.sqrt(self.attention_head_size)

        # Apply mask if we have one
        if attention_mask is not None:
            attention_scores = attention_scores + attention_mask

        # Convert scores to probabilities
        attention_probs = nn.Softmax(dim=-1)(attention_scores)
        attention_probs = self.dropout(attention_probs)

        # Get our final output
        context_layer = torch.matmul(attention_probs, value_layer)
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        context_layer = context_layer.view(*context_layer.size()[:-2], self.all_head_size)

        return context_layer

# Putting it all together for a complete attention block
class BertAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.self = BertSelfAttention(config)
        self.output = nn.Linear(config.hidden_size, config.hidden_size)
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=1e-12)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, hidden_states, attention_mask=None):
        # Run attention and process the output
        self_output = self.self(hidden_states, attention_mask)
        proj_output = self.output(self_output)
        proj_output = self.dropout(proj_output)
        # Add residual connection and normalize
        return self.LayerNorm(hidden_states + proj_output)

# The feed-forward part of the transformer
class BertIntermediate(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
        self.intermediate_act_fn = nn.GELU()  # GELU seems to work better than ReLU here

    def forward(self, hidden_states):
        return self.intermediate_act_fn(self.dense(hidden_states))

# Final output layer with residual connection
class BertOutput(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=1e-12)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, hidden_states, input_tensor):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states)
        return self.LayerNorm(hidden_states + input_tensor)

# A complete transformer layer
class BertLayer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.attention = BertAttention(config)
        self.intermediate = BertIntermediate(config)
        self.output = BertOutput(config)

    def forward(self, hidden_states, attention_mask=None):
        # Attention first, then feed-forward
        attn_output = self.attention(hidden_states, attention_mask)
        intermediate = self.intermediate(attn_output)
        return self.output(intermediate, attn_output)

# Stack up all our layers
class BertEncoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])

    def forward(self, hidden_states, attention_mask=None):
        # Run through each layer
        for layer in self.layer:
            hidden_states = layer(hidden_states, attention_mask)
        return hidden_states

# The full BERT model!
class BertModel(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.embeddings = BertEmbeddings(config)
        self.encoder = BertEncoder(config)

    def forward(self, input_ids, token_type_ids=None, attention_mask=None):
        # Create attention mask if none provided
        if attention_mask is None:
            attention_mask = (input_ids != 0).unsqueeze(1).unsqueeze(2).float()
            attention_mask = (1.0 - attention_mask) * -10000.0

        # Get embeddings and run through encoder
        embeddings = self.embeddings(input_ids, token_type_ids)
        return self.encoder(embeddings, attention_mask)

# Let's test it out
if __name__ == "__main__":
    # Create a small version for testing
    config = BertConfig(
        vocab_size=30522,
        hidden_size=128,      # Smaller than default for testing
        num_hidden_layers=2,  # Just 2 layers to keep it simple
        num_attention_heads=2,
        intermediate_size=512
    )

    model = BertModel(config)

    # Some dummy input data
    input_ids = torch.tensor([
        [101, 7592, 2023, 102, 0, 0],    # "Hello this" + padding
        [101, 2023, 2003, 1037, 3978, 102] # "This is a test"
    ])
    token_type_ids = torch.tensor([
        [0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 1, 1]
    ])

    output = model(input_ids, token_type_ids)
    print(f"Got output shape: {output.shape}")  # Should be [2, 6, 128]

Got output shape: torch.Size([2, 6, 128])
