# Sentence Transformer Implementation

In [None]:
import torch
import torch.nn as nn
from transformers import BertModel, BertTokenizer

  torch.utils._pytree._register_pytree_node(
  torch.utils._pytree._register_pytree_node(


In [None]:
# Defining a SentenceTransformer Class which inherits torch.nn.Module and uses BERT model and takes sentences as input to give fixed-size shared embeddings

class SentenceTransformer(nn.Module):
    def __init__(self, model_name='bert-base-uncased'):
        super().__init__()
        self.tokenizer = BertTokenizer.from_pretrained(model_name) # Importing BertTokenizer to tokenize sentences
        self.bert = BertModel.from_pretrained(model_name) # Importing BertModel to extract contextualized embeddings from sentences


    def forward(self, sentences):
        tokens = self.tokenizer(sentences, padding = True, truncation = True, return_tensors = 'pt') # Tokenizing all the sentences
        outputs = self.bert(input_ids=tokens['input_ids'], token_type_ids=tokens['token_type_ids'], attention_mask=tokens['attention_mask']) # Using BertModel to extract embeddings
        embeddings = outputs.last_hidden_state.mean(dim = 1) # Performing pooling i.e. averaging the last hidden state embeddings over the sequence length to get fixed size embeddings
        return embeddings

In [None]:
# Creating a list of sentences with variable length to check if we get the embeddings of fixed-size

sentences = ['I love Pizza',
            'I went to the zoo and saw a tiger',
            'After all, you are my wonderwall']

In [None]:
# Calling out the model to perform preprocessing of the sentences
sentence_transformer = SentenceTransformer()



In [None]:
# giving the list of sentences to the model as an input
embeds = sentence_transformer(sentences)

In [None]:
# Preview of the embeddings
embeds

tensor([[ 0.2334,  0.1736,  0.1612,  ...,  0.2010,  0.1730, -0.0193],
        [ 0.0689, -0.3667, -0.3938,  ..., -0.4283,  0.0947,  0.0498],
        [ 0.3152,  0.0161,  0.2936,  ..., -0.0361,  0.0472, -0.1615]],
       grad_fn=<MeanBackward1>)

In [None]:
# CHecking the Embeddings shape, and the shape of individual embedding s of each sentence to make sure they are of same length.

print('The shape of the embeddings are', embeds.shape)
print('\nChecking the length of embeddings to make sure it is the same for all elements\n')
print('Length of Sentence one', embeds[0].shape)
print('Length of Sentence two', embeds[1].shape)
print('Length of Sentence three', embeds[2].shape)


The shape of the embeddings are torch.Size([3, 768])

Checking the length of embeddings to make sure it is the same for all elements

Length of Sentence one torch.Size([768])
Length of Sentence two torch.Size([768])
Length of Sentence three torch.Size([768])
