In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [3]:
context_size = 512

In [None]:
### 1-) MAIN CLASS DEFINITIONS ###
class Transformer(nn.Module):
    def __init__(
        self, vocab_size, d_model, n_layers, n_heads, d_ff, max_len, dropout=0.1
    ):
        super(Transformer, self).__init__()
        self.encoder = TransformerEncoder(
            vocab_size, d_model, n_layers, n_heads, d_ff, max_len, dropout
        )

    def forward(self, x, mask=None):
        return self.encoder(x, mask)

    def generate(self, input, max_length=500):
        for _ in range(max_length):
            input = input[:, -context_size:]  # restrict input to the context size
            logits = self.forward(input)
            # get the next token prediction (same as the earlier generate function)


### 2-) TRANSFORMER ENCODER CLASS DEFINITIONS ###
class TransformerEncoder(nn.Module):
    def __init__(
        self, vocab_size, d_model, n_layers, n_heads, d_ff, max_len, dropout=0.1
    ):
        super(TransformerEncoder, self).__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = PositionalEncoding(d_model, max_len)
        self.layers = nn.ModuleList(
            [TransformerBlock(d_model, n_heads, d_ff, dropout) for _ in range(n_layers)]
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        # Embedding + Positional Encoding
        x = self.embedding(x)
        x = self.pos_encoding(x)
        x = self.dropout(x)

        # Pass through each Transformer block
        for layer in self.layers:
            x = layer(x, mask)

        return x

### 3-) POSITIONAL ENCODING CLASS DEFINITIONS ###
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2) * -(torch.log(torch.tensor(10000.0)) / d_model)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer("pe", pe)

    def forward(self, x):
        x = x + self.pe[: x.size(0), :]
        return x
    

### 4-) TRANSFORMER BLOCKS CLASS DEFINITIONS ###
class TransformerBlock(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
        super(TransformerBlock, self).__init__()
        self.attention = MultiheadAttention(d_model, n_heads, dropout)
        self.ffn = FeedForward(d_model, d_ff, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        # Apply multi-head attention with skip connection and layer norm
        attn_out = self.attention(x, x, x, mask)
        x = self.norm1(x + self.dropout(attn_out))

        # Apply feed-forward network with skip connection and layer norm
        ffn_out = self.ffn(x)
        return self.norm2(x + self.dropout(ffn_out))
    

### 5-) MULTIHEAD ATTENTION CLASS DEFINITIONS ###
class MultiheadAttention(nn.Module):
    def __init__(self, d_model, n_heads, dropout=0.1):
        super(MultiheadAttention, self).__init__()
        assert d_model % n_heads == 0
        self.d_model = d_model
        self.n_heads = n_heads
        self.head_dim = d_model // n_heads

        # Define the linear layers for Q, K, V transformations
        self.q_proj = nn.Linear(d_model, d_model)
        self.k_proj = nn.Linear(d_model, d_model)
        self.v_proj = nn.Linear(d_model, d_model)

        # Final linear transformation for output
        self.out_proj = nn.Linear(d_model, d_model)

        self.dropout = nn.Dropout(dropout)

    def forward(self, query, key, value, mask=None):
        batch_size = query.size(0)

        # Linear projections
        Q = (
            self.q_proj(query)
            .view(batch_size, -1, self.n_heads, self.head_dim)
            .transpose(1, 2)
        )
        K = (
            self.k_proj(key)
            .view(batch_size, -1, self.n_heads, self.head_dim)
            .transpose(1, 2)
        )
        V = (
            self.v_proj(value)
            .view(batch_size, -1, self.n_heads, self.head_dim)
            .transpose(1, 2)
        )

        # Scaled dot-product attention
        scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim**0.5)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)

        attn_weights = F.softmax(scores, dim=-1)
        attn_output = torch.matmul(attn_weights, V)

        # Concatenate heads and apply final linear transformation
        attn_output = (
            attn_output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
        )
        return self.out_proj(attn_output)

### 6-) FEED FORWARD CLASS DEFINITIONS ###
class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        super(FeedForward, self).__init__()
        # Two fully connected layers with dropout in between
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)
        self.activation = nn.ReLU()  # or nn.GELU()

    def forward(self, x):
        x = self.fc1(x)
        x = self.activation(x)
        x = self.dropout(x)
        return self.fc2(x)