In [None]:
import torch
import torch.nn as nn

class VanillaTransformer(nn.Module):
    def __init__(self, token_dim=128, num_heads=4, num_layers=2, dropout=0.1):
        super(VanillaTransformer, self).__init__()

        self.token_dim = token_dim
        self.pos_encoding = nn.Parameter(torch.randn(1, 9, token_dim))  # learned positional encoding

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=token_dim,
            nhead=num_heads,
            dropout=dropout,
            dim_feedforward=4 * token_dim,
            batch_first=True  # important: input should be [B, N, D]
        )

        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

    def forward(self, x):
        """
        x: tensor of shape [B, 9, D] = token sequence from dRoFE
        returns: same shape [B, 9, D] = encoded token sequence
        """
        x = x + self.pos_encoding  # add learned positional encoding
        out = self.transformer(x)  # apply N transformer layers
        return out




In [None]:
B, N, D = 2, 9, 128
dummy_tokens = torch.randn(B, N, D)

model = VanillaTransformer(token_dim=128, num_heads=4, num_layers=2)
output = model(dummy_tokens)

print("Transformer output shape:", output.shape)
## NB: in the true version(not dummy) we should use Q rotated and K rotated of the dRoFe token


Transformer output shape: torch.Size([2, 9, 128])


What it does:
Processes the enriched tokens from dRoFE to learn contextual relationships between the tokens (e.g., how different frequency bands interact in the context of the demographic data).

Uses self-attention and feedforward layers to produce context-aware embeddings.

Input: Enriched tokens from dRoFE ([B, 9, D]).

Output: Contextually enriched tokens of shape [B, 9, D].
