In [None]:
class TextEncoder(nn.Module):
    def __init__(self, out_dim):
        super().__init__()
        self.model = BertModel.from_pretrained('bert-base-uncased')
        self.projection = nn.Sequential(
            nn.Linear(768, out_dim),
            nn.LayerNorm(out_dim)
        )
        
    def forward(self,input_ids,attention_mask):
        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
        hidden_states = outputs.last_hidden_state  # Shape: (batch_size, seq_length, 768)
        
        # Create a mask to ignore padding tokens (0s in attention_mask)
        mask = attention_mask.unsqueeze(-1).expand_as(hidden_states)  # Shape: (batch_size, seq_length, 768)
        
        summed_hidden_states = hidden_states * mask  # Mask out padding tokens
        token_counts = mask.sum(dim=1, keepdim=True)  # Count the non-padding tokens
        avg_pooling = summed_hidden_states.sum(dim=1) / token_counts  # Average across the sequence

        # Apply the projection layer (linear + layernorm)
        return self.projection(avg_pooling)  # Shape: (batch_size, out_dim)