In [1]:
import torch
import torch.nn as nn
from transformers import BertModel, BertTokenizer

class SentenceTransformer(nn.Module):
    def __init__(self, model_name='bert-base-uncased'):
        super(SentenceTransformer, self).__init__()
        # Load pre-trained BERT model and tokenizer
        self.bert = BertModel.from_pretrained(model_name)
        self.tokenizer = BertTokenizer.from_pretrained(model_name)

    def forward(self, input_ids, attention_mask):
        # Get token embeddings from BERT
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        last_hidden_state = outputs.last_hidden_state  # Shape: (batch_size, seq_len, hidden_size)
        
        # Mean pooling: compute the mean across the sequence length dimension
        # Use attention_mask to exclude padding tokens from the mean
        mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float()
        sum_embeddings = torch.sum(last_hidden_state * mask_expanded, dim=1)
        sum_mask = torch.clamp(mask_expanded.sum(dim=1), min=1e-9)  # Avoid division by zero
        sentence_embeddings = sum_embeddings / sum_mask  # Shape: (batch_size, hidden_size)
        
        return sentence_embeddings

    def encode(self, sentences):
        # Tokenize input sentences
        encoding = self.tokenizer(sentences, padding=True, truncation=True, 
                                 max_length=128, return_tensors='pt')
        input_ids = encoding['input_ids']
        attention_mask = encoding['attention_mask']
        
        # Forward pass to get embeddings
        with torch.no_grad():
            embeddings = self.forward(input_ids, attention_mask)
        return embeddings

# Test the implementation
model = SentenceTransformer()
sample_sentences = [
    "Hello World!",
    "I'm AGI! How can I help you?"
]
embeddings = model.encode(sample_sentences)
print("Embeddings:", embeddings)
print("Sentence Embeddings Shape:", embeddings.shape)  # Expected: (2, 768)
for sentence, emb in zip(sample_sentences, embeddings):
    print(f"Sentence: '{sentence}' | Embedding (first 5 values): {emb[:5]}")

# Explanation:

# Transformer Backbone: I chose BERT (bert-base-uncased) because it’s a widely-used, pre-trained transformer model 
# that captures rich contextual information, making it suitable as a foundation for sentence embeddings. 
# Using a pre-trained model saves time compared to training a transformer from scratch and leverages BERT’s general language understanding.

#Pooling Strategy: To convert token-level embeddings from BERT into a single sentence embedding, 
# I used mean pooling over the token embeddings which aggregates information from all tokens, 
# weighted by the attention mask to ignore padding, and is a common choice in sentence transformers
# (e.g., Sentence-BERT) because it often outperforms using the [CLS] token alone for sentence-level tasks.

# I didn’t add a projection layer after pooling to keep the architecture simple and preserve the 
# 768-dimensional embeddings from BERT, which are already rich and usable for downstream tasks.

Embeddings: tensor([[-0.1373, -0.1593,  0.0821,  ..., -0.0644, -0.0986, -0.0170],
        [ 0.0211, -0.0120, -0.1718,  ..., -0.1398, -0.1174,  0.2579]])
Sentence Embeddings Shape: torch.Size([2, 768])
Sentence: 'Hello World!' | Embedding (first 5 values): tensor([-0.1373, -0.1593,  0.0821, -0.3459, -0.2501])
Sentence: 'I'm AGI! How can I help you?' | Embedding (first 5 values): tensor([ 0.0211, -0.0120, -0.1718, -0.3703,  0.1255])


In [2]:
class MultiTaskModel(nn.Module):
    def __init__(self, num_sentence_classes=2, num_ner_classes=5, model_name='bert-base-uncased'):
        super(MultiTaskModel, self).__init__()
        # Shared BERT backbone
        self.bert = BertModel.from_pretrained(model_name)
        self.tokenizer = BertTokenizer.from_pretrained(model_name)
        hidden_size = self.bert.config.hidden_size  # 768 for bert-base
        
        # Sentence classification head
        self.sentence_classifier = nn.Linear(hidden_size, num_sentence_classes)
        
        # NER classification head
        self.ner_classifier = nn.Linear(hidden_size, num_ner_classes)

    def forward(self, input_ids, attention_mask):
        # Get outputs from BERT
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        last_hidden_state = outputs.last_hidden_state  # Shape: (batch_size, seq_len, hidden_size)
        
        # Sentence classification: mean pooling
        mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float()
        sum_embeddings = torch.sum(last_hidden_state * mask_expanded, dim=1)
        sum_mask = torch.clamp(mask_expanded.sum(dim=1), min=1e-9)
        sentence_embeddings = sum_embeddings / sum_mask
        sentence_logits = self.sentence_classifier(sentence_embeddings)  # Shape: (batch_size, num_sentence_classes)
        
        # NER classification: per-token predictions
        ner_logits = self.ner_classifier(last_hidden_state)  # Shape: (batch_size, seq_len, num_ner_classes)
        
        return sentence_logits, ner_logits

# Test the model
model = MultiTaskModel(num_sentence_classes=2, num_ner_classes=5)
sample_sentences = ["Let's get this party started!", "Humans will always win against AGI"]
encoding = model.tokenizer(sample_sentences, padding=True, truncation=True, 
                          max_length=128, return_tensors='pt')
input_ids = encoding['input_ids']
attention_mask = encoding['attention_mask']
sentence_logits, ner_logits = model(input_ids, attention_mask)
print("Sentence Logits Shape:", sentence_logits.shape)
print("NER Logits Shape:", ner_logits.shape) 

# Shared Backbone: The BERT model remains shared to learn representations beneficial for both tasks
# Sentence Classification Head: Takes the pooled sentence embedding and outputs logits for 2 classes (e.g. +/-). I assumed a binary classification task for simplicity.
# NER Head: Operates on the full last_hidden_state to produce per-token logits for 5 NER classes 

Sentence Logits Shape: torch.Size([2, 2])
NER Logits Shape: torch.Size([2, 10, 5])


In [3]:
import torch.nn.functional as F
from torch.optim import AdamW

# Initialize model and optimizer
model = MultiTaskModel(num_sentence_classes=2, num_ner_classes=5)
optimizer = AdamW(model.parameters(), lr=2e-5)

# Hypothetical training loop
def train_epoch(model, dataloader):
    model.train()
    for batch in dataloader:
        input_ids = batch['input_ids']  # Shape: (batch_size, seq_len)
        attention_mask = batch['attention_mask']  # Shape: (batch_size, seq_len)
        sentence_labels = batch['sentence_labels']  # Shape: (batch_size)
        ner_labels = batch['ner_labels']  # Shape: (batch_size, seq_len), -100 for padding
        
        # Forward pass
        sentence_logits, ner_logits = model(input_ids, attention_mask)
        
        # Compute losses
        sentence_loss = F.cross_entropy(sentence_logits, sentence_labels)
        ner_loss = F.cross_entropy(ner_logits.view(-1, num_ner_classes), 
                                  ner_labels.view(-1), ignore_index=-100)
        total_loss = sentence_loss + ner_loss
        
        # Backward pass
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()
        
        # Compute metrics
        sentence_preds = torch.argmax(sentence_logits, dim=1)
        sentence_acc = (sentence_preds == sentence_labels).float().mean()
        ner_preds = torch.argmax(ner_logits, dim=-1)
        ner_mask = (ner_labels != -100)
        ner_acc = (ner_preds[ner_mask] == ner_labels[ner_mask]).float().mean()
        
        print(f"Sentence Loss: {sentence_loss.item():.4f}, NER Loss: {ner_loss.item():.4f}, "
              f"Sentence Acc: {sentence_acc.item():.4f}, NER Acc: {ner_acc.item():.4f}")

# Note: In practice, run `for epoch in range(num_epochs): train_epoch(model, dataloader)`