In [None]:
import torch
from torch import nn

In [None]:
class TransformerClassifier(nn.Module):
    """A transformer-based classifier for sequence data.
    
    This model uses a transformer encoder architecture followed by a classification layer
    to perform sequence classification tasks.
    
    Args:
        input_dim (int): Dimension of input features
        num_classes (int): Number of output classes
        d_model (int, optional): Dimension of transformer model. Defaults to 512.
        nhead (int, optional): Number of attention heads. Defaults to 8.
        num_encoder_layers (int, optional): Number of transformer encoder layers. Defaults to 3.
        dim_feedforward (int, optional): Dimension of feedforward network. Defaults to 2048.
        dropout (float, optional): Dropout rate. Defaults to 0.1.
    """
    
    def __init__(
            self, 
            input_dim: int, 
            num_classes: int, 
            d_model: int = 512, 
            nhead: int = 8, 
            num_encoder_layers: int = 3, 
            dim_feedforward: int = 2048, 
            dropout: float = 0.1
    ) -> None:
        super().__init__()
        
        # Input projection layer
        self.input_projection = nn.Linear(input_dim, d_model)
        
        # Transformer encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            batch_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(
            encoder_layer,
            num_layers=num_encoder_layers
        )
        
        # Output classifier
        self.classifier = nn.Linear(d_model, num_classes)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass of the model.
        
        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, seq_length, input_dim)
            
        Returns:
            torch.Tensor: Output tensor of shape (batch_size, num_classes)
        """
        # Project input to d_model dimensions
        x = self.input_projection(x)
        
        # Apply transformer encoder
        x = self.transformer_encoder(x)
        
        # Classification layer
        output = self.classifier(x)
        return output