In [1]:
import torch
import torch.nn as nn

# ============================
# Simple Transformer Encoder Block
# ============================

class SimpleEncoderBlock(nn.Module):
    def __init__(self, d_model=128, num_heads=8, ff_hidden=512):
        super(SimpleEncoderBlock, self).__init__()

        # Multi-head self-attention
        self.self_attn = nn.MultiheadAttention(embed_dim=d_model, num_heads=num_heads, batch_first=True)

        # Feed Forward Network
        self.ffn = nn.Sequential(
            nn.Linear(d_model, ff_hidden),
            nn.ReLU(),
            nn.Linear(ff_hidden, d_model)
        )

        # Layer Norm
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

    def forward(self, x):
        # ---- Multi-head Self Attention + Residual ----
        attn_output, _ = self.self_attn(x, x, x)
        x = self.norm1(x + attn_output)

        # ---- Feed Forward + Residual ----
        ff_output = self.ffn(x)
        x = self.norm2(x + ff_output)

        return x


# ============================
# Test the Encoder Block
# ============================

# PART (a): Initialize d_model = 128, h = 8
encoder = SimpleEncoderBlock(d_model=128, num_heads=8)

# Input: batch_size=32, seq_length=10, embedding_size=128
batch = torch.randn(32, 10, 128)

# Pass through encoder block
output = encoder(batch)

print("Input shape :", batch.shape)
print("Output shape:", output.shape)


Input shape : torch.Size([32, 10, 128])
Output shape: torch.Size([32, 10, 128])
