Task 1: Sentence Transformer Implementation


Implement a sentence transformer model using any deep learning framework of your choice.This model should be able to encode input sentences into fixed-length embeddings. Test yourimplementation with a few sample sentences and showcase the obtained embeddings.Describe any choices you had to make regarding the model architecture outside of the transformer backbone.

In [9]:
import torch
import torch.nn as nn
from transformers import BertModel, BertTokenizerFast


Before feeding the sentence inputs to the sentence transformer, we first have to tokenize the sentences and encode them, Here my choice was to use pretrained BERT as my backbone.

get_tokenizer function simply loads the right HuggingFace tokenizer for the chosen BERT variant. 

encode_sentences then takes raw strings, applies that tokenizer to produce token IDs and an attention mask, pads every sequence in the batch up to its longest length, truncates anything beyond max_length, returns PyTorch tensors, and immediately moves them onto your target device. 

In [10]:
# 1) Tokenization & Encoding
def get_tokenizer(model_name="bert-base-uncased"):
    return BertTokenizerFast.from_pretrained(model_name)

def encode_sentences(tokenizer, sentences, max_length=128, device="cpu"):
    encoded = tokenizer(
        sentences,
        padding="longest",     # pad to longest in batch
        truncation=True,
        max_length=max_length,
        return_tensors="pt",
    )
    return encoded["input_ids"].to(device), encoded["attention_mask"].to(device)


In [11]:
# 2) Sentence Transformer Model
class BertSentenceTransformer(nn.Module):
    def __init__(self,
                 model_name="bert-base-uncased",
                 pooling: str = "mean"   # options: "cls", "mean", "max"
                 ):
        super().__init__()
        self.backbone = BertModel.from_pretrained(model_name)
        self.pooling = pooling.lower()
        hidden_size = self.backbone.config.hidden_size

    def forward(self, input_ids, attention_mask):
        # 1) get all token embeddings
        outputs = self.backbone(
            input_ids=input_ids,
            attention_mask=attention_mask,
            return_dict=True
        )
        token_embeddings = outputs.last_hidden_state  # (B, T, H)
        cls_embedding    = outputs.pooler_output      # (B, H)

        # 2) apply pooling
        if self.pooling == "cls":
            sent_emb = cls_embedding

        elif self.pooling == "mean":
            # mask out padding tokens
            mask = attention_mask.unsqueeze(-1).float()  # (B, T, 1)
            sum_embeddings = (token_embeddings * mask).sum(dim=1)
            lengths = mask.sum(dim=1).clamp(min=1e-9)
            sent_emb = sum_embeddings / lengths

        elif self.pooling == "max":
            mask = attention_mask.unsqueeze(-1).bool()
            # set padding tokens to very large negative so they don’t affect max
            token_embeddings[~mask] = -1e9
            sent_emb = token_embeddings.max(dim=1).values

        else:
            raise ValueError(f"Unknown pooling: {self.pooling}")


        return sent_emb  # (B, H)

In [12]:
# 3) Example Usage
if __name__ == "__main__":
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # a) prepare data
    sentences = [
        "Sushi is the best food",
        "lakers is winning this year",
        "what is the capital of US"
    ]
    tokenizer = get_tokenizer("bert-base-uncased")
    input_ids, attention_mask = encode_sentences(tokenizer, sentences, device=device)

    model = BertSentenceTransformer(
        model_name="bert-base-uncased",
        pooling="mean"
    ).to(device)
    model.eval()

    with torch.no_grad():
        embeddings = model(input_ids, attention_mask)
        for sentence, embed in zip(sentences, embeddings):
            print('Sentence:\n',sentence)
            print("Embeddings shape:", embed.shape)



Sentence:
 Sushi is the best food
Embeddings shape: torch.Size([768])
Sentence:
 lakers is winning this year
Embeddings shape: torch.Size([768])
Sentence:
 what is the capital of US
Embeddings shape: torch.Size([768])
