<a href="https://colab.research.google.com/github/sambhavpurohit14/Smart_Gallery/blob/encoder-functions/sentence_transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [33]:
import torch
import torch.nn as nn
from sentence_transformers import SentenceTransformer

In [34]:
class TextEncoder(nn.Module):
    def __init__(self, model_name: str, out_dim: int = 512):
        super(TextEncoder, self).__init__()

        # pre-trained sentence transformer
        self.sentence_transformer = SentenceTransformer(model_name)

        # freeze the Sentence Transformer layers, only the projection head is trainable
        for param in self.sentence_transformer.parameters():
            param.requires_grad = False

        #  projection layer
        self.projection = nn.Sequential(
            nn.Linear(self.sentence_transformer.get_sentence_embedding_dimension(), out_dim),
            nn.LayerNorm(out_dim)
        )

    def forward(self, sentences):
        # encode
        with torch.no_grad():  # frozen weightts
            sentence_embeddings = self.sentence_transformer.encode(sentences, convert_to_tensor=True)

        # project 768 to 512
        projected_embeddings = self.projection(sentence_embeddings)
        return projected_embeddings


In [35]:
if __name__ == "__main__":
    model_name = "all-mpnet-base-v2"
    out_dim = 512

    # Initialize the encoder
    text_encoder = TextEncoder(model_name=model_name, out_dim=out_dim)
    sentences = ["acm sanganitra is the best"]
    embeddings = text_encoder(sentences)

    print(embeddings.shape)

torch.Size([1, 512])


In [36]:
!pip install torchinfo



In [37]:
from torchsummary import summary

print(text_encoder)

# Count trainable parameters
def count_trainable_params(model):
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    frozen_params = total_params - trainable_params
    return total_params, trainable_params, frozen_params

total_params, trainable_params, frozen_params = count_trainable_params(text_encoder)

print("\nParameter Summary:")
print(f"Total Parameters: {total_params}")
print(f"Trainable Parameters: {trainable_params}")
print(f"Frozen Parameters: {frozen_params}")


TextEncoder(
  (sentence_transformer): SentenceTransformer(
    (0): Transformer({'max_seq_length': 384, 'do_lower_case': False}) with Transformer model: MPNetModel 
    (1): Pooling({'word_embedding_dimension': 768, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, 'pooling_mode_lasttoken': False, 'include_prompt': True})
    (2): Normalize()
  )
  (projection): Sequential(
    (0): Linear(in_features=768, out_features=512, bias=True)
    (1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  )
)

Parameter Summary:
Total Parameters: 109881216
Trainable Parameters: 394752
Frozen Parameters: 109486464
