In [1]:
# Import libraries

import torch
import torch.nn as nn
from transformers import AutoModel, AutoTokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# Define the model and post processing to get fixed size embedding
class SentenceTransformerModel(nn.Module):
    """
    A custom PyTorch model for generating sentence embeddings using a transformer model.
    The model uses a pre-trained transformer (default: all-MiniLM-L6-v2) from the Hugging Face library
    and applies mean pooling to obtain a fixed-length sentence embedding.
    """
    
    def __init__(self, model_name='sentence-transformers/all-MiniLM-L6-v2'):
        """
        Initializes the SentenceTransformerModel by loading the specified transformer model and tokenizer.
        
        Args:
            model_name (str): The model name or path for the pre-trained transformer model.
                               Default is 'sentence-transformers/all-MiniLM-L6-v2'.
        """
        super(SentenceTransformerModel, self).__init__()
        
        # Load the pre-trained transformer model and tokenizer
        self.transformer = AutoModel.from_pretrained(model_name)
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        
    def mean_pooling(self, model_output, attention_mask):
        """
        Applies mean pooling on the token embeddings, considering the attention mask.
        The attention mask ensures that padding tokens do not affect the averaging process.
        
        Args:
            model_output (tuple): The output of the transformer model containing hidden states.
            attention_mask (tensor): The attention mask that specifies which tokens are real and which are padding.
        
        Returns:
            tensor: The sentence embedding after applying mean pooling.
        """
        token_embeddings = model_output[0]  # First element is the token embeddings (hidden states)
        
        # Expand attention mask to match the size of token embeddings
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        
        # Compute the sum of token embeddings, weighted by the attention mask
        sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
        
        # Avoid division by zero by clamping the sum of the mask to a minimum value
        sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
        
        # Return the average of the embeddings (mean pooling)
        return sum_embeddings / sum_mask
    
    def forward(self, sentences):
        """
        Forward pass of the model: encodes input sentences into fixed-length sentence embeddings.
        
        Args:
            sentences (list of str): A list of sentences to be encoded into embeddings.
        
        Returns:
            tensor: A tensor containing the sentence embeddings.
        """
        # Tokenize the input sentences, padding and truncating as necessary, and return PyTorch tensors
        encoded_input = self.tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')
        
        # Pass the tokenized inputs through the transformer model to get token-level embeddings
        model_output = self.transformer(**encoded_input)
        
        # Perform mean pooling to aggregate the token embeddings into fixed-length sentence embeddings
        sentence_embeddings = self.mean_pooling(model_output, encoded_input['attention_mask'])
        
        # Return the sentence embeddings
        return sentence_embeddings


In [4]:
# Inference run on test examples
sentences = ['Hello World!', 'My name is Suraj3##$', 'Today is a wonderful day, but not great weather though!']
# Make object of the model
model = SentenceTransformerModel()
# Get the outputs
embed = model.forward(sentences)
print(f'Final answer: {embed}')

Final answer: tensor([[-0.1096,  0.1359, -0.0030,  ..., -0.1648,  0.2019,  0.2012],
        [-0.4113, -0.3770,  0.0335,  ..., -0.4212, -0.3258, -0.1551],
        [-0.0059,  0.4075,  0.7919,  ..., -0.1142, -0.4333,  0.3434]],
       grad_fn=<DivBackward0>)
