In [1]:
import math
import torch


class SimpleSelfAttentionModel(torch.nn.Module):
    def __init__(self, embed_size, dropout):
        super(SimpleSelfAttentionModel, self).__init__()

        self.num_heads = max([i for i in range(2, 9) if embed_size % i == 0])

        self.attention = torch.nnMultiheadAttention(
            embed_dim=embed_size, num_heads=self.num_heads, dropout=dropout, batch_first=True
        )
        self.norm = torch.nn.LayerNorm(embed_size)
        self.dropout = torch.nn.Dropout(dropout)

        self.register_buffer(
            "positional_encodings",
            self.generate_positional_encodings(max_len=5000, embed_size=embed_size),
        )

    def generate_positional_encodings(self, max_len, embed_size):
        pe = torch.zeros(max_len, embed_size)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, embed_size, 2).float() * (-math.log(10000.0) / embed_size)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        return pe

    def generate_square_subsequent_mask(self, sz):
        mask = torch.triu(torch.ones(sz, sz) * float("-inf"), diagonal=1)
        return mask

    def forward(self, x):
        # Add positional encodings to the input embeddings
        pos = (
            self.positional_encodings[: x.size(1), :]
            .unsqueeze(0)
            .repeat(x.size(0), 1, 1)
            .to(x.device)
        )
        x = x + pos

        causal_mask = self.generate_square_subsequent_mask(x.size(1)).to(x.device)
        attn_output, _ = self.attention(x, x, x, attn_mask=causal_mask)
        out = self.norm(x + self.dropout(attn_output))

        return out


# Example usage
embed_size = 256
dropout = 0.1

model = SimpleSelfAttentionModel(embed_size, dropout)

x = torch.rand((32, 10, embed_size))  # Example input with batch_first=True

output = model(x)
print(output.shape)  # Expected output shape: (batch_size, seq_len, embed_size)

torch.Size([32, 10, 256])
