## Transformers Architecture

### Imports

In [80]:
import math
import torch
import torch.nn as nn
from torchinfo import summary
import torch.nn.functional as F

### Positional and Embedding

In [81]:
class PositionalEncoding(nn.Module):
    def __init__(self, context_len, embed_dim):
        super().__init__()
        self.pe = torch.zeros(context_len, embed_dim)

        position = torch.arange(0, context_len, dtype=torch.float).unsqueeze(1)
        # print(position.shape)

        div_term = torch.exp(torch.arange(0, embed_dim, 2).float() * -math.log(10000.0)/embed_dim)

        self.pe[:, 0::2] = torch.sin(position * div_term)

        self.pe[:, 1::2] = torch.cos(position * div_term)

        self.pe = self.pe.unsqueeze(0)

        # self.pe_matrix = pe
    def forward(self, x):
        x = x + self.pe[:, :x.size(1), :]

        return x

class TokenEmbedding(nn.Module):
    def __init__(self, embed_dim, vocab_size):
        super().__init__()
        self.embed_dim = embed_dim
        self.vocab_size = vocab_size

        self.embedding = nn.Embedding(self.vocab_size, self.embed_dim)
    
    def forward(self, x):
        return self.embedding(x)

### MultiHead Attention

In [82]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, context_len, n_heads):
        super().__init__()

        self.embed_dim = embed_dim
        self.context_len = context_len
        self.n_heads = n_heads
        self.head_dim = embed_dim // n_heads

        self.K = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
        self.Q = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
        self.V = nn.Linear(self.embed_dim, self.embed_dim, bias=False)

        self.mask = torch.tril(torch.ones(self.context_len, self.context_len))

        self.out = torch.nn.Linear(self.embed_dim, self.embed_dim, bias=False)
    
    def forward(self, x):
        B, T, C = x.shape

        K = self.K(x)
        Q = self.Q(x)
        V = self.V(x)

        print(K.shape)
        K = K.view(B, T, self.n_heads, self.head_dim).transpose(1,2)
        print(K.shape)

        Q = Q.view(B, T, self.n_heads, self.head_dim).transpose(1,2)
        V = V.view(B, T, self.n_heads, self.head_dim).transpose(1,2)

        attn_scores = Q @ K.transpose(-2,-1)
        
        scaled_scores = attn_scores / math.sqrt(self.head_dim)

        masked_scores = scaled_scores.masked_fill(self.mask[:T, :T]==0, float('-inf'))

        attention_weights = F.softmax(masked_scores, dim=-1)

        context_vectors = attention_weights @ V
        print(context_vectors.shape)

        output = context_vectors.transpose(1, 2).contiguous().view(B, T, C)

        return self.out(output)


### Decoder class

In [83]:
class TransformerDecoder(nn.Module):
    def __init__(self, embed_dim, context_len, n_heads):
        super().__init__()
        self.embed_dim = embed_dim
        self.context_len = context_len
        self.n_heads = n_heads

        self.attn = MultiHeadAttention(self.embed_dim, self.context_len, self.n_heads)

        self.layer_norm1 = nn.LayerNorm(self.embed_dim)
        self.layer_norm2 = nn.LayerNorm(self.embed_dim)
        
        self.feedforward = nn.Sequential(
            nn.Linear(self.embed_dim, 4 * self.embed_dim),
            nn.GELU(),
            nn.Linear(4 * self.embed_dim, self.embed_dim)
        )

        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        attn_output = self.attn(x)

        x = self.layer_norm1(x + self.dropout(attn_output))

        ff_out = self.feedforward(x)
        
        x = self.layer_norm2(x + self.dropout(ff_out))

        return x

### Model

In [84]:
class DecoderOnlyModel(nn.Module):
    def __init__(self, embed_dim, vocab_size, context_len, n_heads, n_layers):
        super().__init__()
        self.embed_dim = embed_dim
        self.vocab_size = vocab_size
        self.context_len = context_len
        self.n_heads = n_heads
        self.n_layers = n_layers

        self.embeddings = TokenEmbedding(self.embed_dim, self.vocab_size)
        self.positional_encoding = PositionalEncoding(self.context_len, self.embed_dim)

        self.decoder_blocks = nn.Sequential(*[TransformerDecoder(self.embed_dim, self.context_len, self.n_heads) for _ in range(self.n_layers)])

        self.linear = nn.Linear(self.embed_dim, self.vocab_size)
    
    def forward(self, x):
        embed = self.embeddings(x)
        pos = self.positional_encoding(embed)
        
        self.d_block = self.decoder_blocks(pos)
        
        out = self.linear(self.d_block)

        return out


### Testing the model

In [85]:
embed_dim = 8
vocab_size = 4
context_len = 8  # <<< THE FIX: Set context_len to match the data length
n_heads = 2
n_layers = 1

# Create input data with sequence length = 8
input_indices = torch.randint(0, vocab_size, (1, context_len))
print(f"Input shape: {input_indices.shape}\n")

# Create a model built to handle sequences up to length 8
model = DecoderOnlyModel(embed_dim, vocab_size, context_len, n_heads, n_layers)
print(f"Model parameters:\n{model}\n") # Using list() to print them

# This will now run without error
logits = model(input_indices)
print(f"Output logits shape: {logits.shape}")
print(f"Output logits:\n{logits}")

Input shape: torch.Size([1, 8])

Model parameters:
DecoderOnlyModel(
  (embeddings): TokenEmbedding(
    (embedding): Embedding(4, 8)
  )
  (positional_encoding): PositionalEncoding()
  (decoder_blocks): Sequential(
    (0): TransformerDecoder(
      (attn): MultiHeadAttention(
        (K): Linear(in_features=8, out_features=8, bias=False)
        (Q): Linear(in_features=8, out_features=8, bias=False)
        (V): Linear(in_features=8, out_features=8, bias=False)
        (out): Linear(in_features=8, out_features=8, bias=False)
      )
      (layer_norm1): LayerNorm((8,), eps=1e-05, elementwise_affine=True)
      (layer_norm2): LayerNorm((8,), eps=1e-05, elementwise_affine=True)
      (feedforward): Sequential(
        (0): Linear(in_features=8, out_features=32, bias=True)
        (1): GELU(approximate='none')
        (2): Linear(in_features=32, out_features=8, bias=True)
      )
      (dropout): Dropout(p=0.5, inplace=False)
    )
  )
  (linear): Linear(in_features=8, out_features=4, b

In [106]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from torchinfo import summary # Import torchinfo

# --- CORRECTED CUSTOM MODULES ---

class TokenEmbedding(nn.Module):
    def __init__(self, embed_dim, vocab_size):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
    def forward(self, x):
        return self.embedding(x)

class PositionalEncoding(nn.Module):
    def __init__(self, context_len, embed_dim):
        super().__init__()
        # Ensure the order is (context_len, embed_dim)
        pe = torch.zeros(context_len, embed_dim)
        position = torch.arange(0, context_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, embed_dim, 2).float() * -math.log(10000.0) / embed_dim)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:, :x.size(1), :]
        return x

class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, context_len, n_heads):
        super().__init__()
        self.embed_dim = embed_dim
        self.context_len = context_len
        self.n_heads = n_heads
        self.head_dim = embed_dim // n_heads
        self.K = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
        self.Q = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
        self.V = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
        self.register_buffer('mask', torch.tril(torch.ones(context_len, context_len)))
        self.out = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
    
    def forward(self, x):
        B, T, C = x.shape
        K = self.K(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
        Q = self.Q(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
        V = self.V(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
        attn_scores = (Q @ K.transpose(-2, -1)) * (self.head_dim**-0.5)
        masked_scores = attn_scores.masked_fill(self.mask[:T, :T] == 0, float('-inf'))
        attention_weights = F.softmax(masked_scores, dim=-1)
        output = (attention_weights @ V).transpose(1, 2).contiguous().view(B, T, C)
        return self.out(output)

class TransformerDecoder(nn.Module):
    def __init__(self, embed_dim, context_len, n_heads, dropout=0.1):
        super().__init__()
        self.attention = MultiHeadAttention(embed_dim, context_len, n_heads)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.feed_forward = nn.Sequential(
            nn.Linear(embed_dim, 4 * embed_dim),
            nn.GELU(),
            nn.Linear(4 * embed_dim, embed_dim)
        )
        self.norm2 = nn.LayerNorm(embed_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        attn_output = self.attention(x)
        x = self.norm1(x + self.dropout(attn_output))
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout(ff_output))
        return x

class DecoderOnlyModel(nn.Module):
    def __init__(self, embed_dim, vocab_size, context_len, n_heads, n_layers):
        super().__init__()
        self.token_embedding = TokenEmbedding(embed_dim, vocab_size)
        self.positional_encoding = PositionalEncoding(context_len, embed_dim)
        self.decoder_blocks = nn.Sequential(*[TransformerDecoder(embed_dim, context_len, n_heads) for _ in range(n_layers)])
        self.lm_head = nn.Linear(embed_dim, vocab_size)
    
    def forward(self, x):
        x = self.token_embedding(x)
        x = self.positional_encoding(x)
        x = self.decoder_blocks(x)
        logits = self.lm_head(x)
        return logits


# --- TEST SCRIPT ---

embed_dim = 512
vocab_size = 100_000
context_len = 1024
n_heads = 16
n_layers = 8

model = DecoderOnlyModel(embed_dim, vocab_size, context_len, n_heads, n_layers)

batch_size = 6
input_shape = (batch_size, context_len)

# This will now run without error and give you the summary.
summary(model, input_size=input_shape, dtypes=[torch.long], depth=3)

Layer (type:depth-idx)                   Output Shape              Param #
DecoderOnlyModel                         [6, 1024, 100000]         --
├─TokenEmbedding: 1-1                    [6, 1024, 512]            --
│    └─Embedding: 2-1                    [6, 1024, 512]            51,200,000
├─PositionalEncoding: 1-2                [6, 1024, 512]            --
├─Sequential: 1-3                        [6, 1024, 512]            --
│    └─TransformerDecoder: 2-2           [6, 1024, 512]            --
│    │    └─MultiHeadAttention: 3-1      [6, 1024, 512]            1,048,576
│    │    └─Dropout: 3-2                 [6, 1024, 512]            --
│    │    └─LayerNorm: 3-3               [6, 1024, 512]            1,024
│    │    └─Sequential: 3-4              [6, 1024, 512]            2,099,712
│    │    └─Dropout: 3-5                 [6, 1024, 512]            --
│    │    └─LayerNorm: 3-6               [6, 1024, 512]            1,024
│    └─TransformerDecoder: 2-3           [6, 1024, 512]  