# 1. Multi-task transformer Implementation

To support multi-task learning, we modify the original Sentence Transformer architecture (which as done in step 1) by adding task-specific heads. Each head is responsible for handling a particular task. The main components include:

- **Transformer Backbone**: Utilizes a pre-trained transformer model (e.g., DistilBERT) to generate sentence embeddings.
  - **Task A Head**: A linear layer for sentence classification.
  - **Task B Head**: A linear layer for sentiment analysis.


In [1]:
import torch
import torch.nn as nn
from sentence_transformers import SentenceTransformer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# multi_task_sentence_transformer.py

# multi_task_sentence_transformer.py

import torch
import torch.nn as nn
from sentence_transformers import SentenceTransformer

class MultiTaskSentenceTransformer(nn.Module):
    def __init__(self, transformer_model='distilbert-base-uncased', num_classes_task_a=3, num_classes_task_b=3, device=None):
        """
        Initializes the multi-task sentence transformer model.

        Args:
            transformer_model (str): Name of the pre-trained transformer model.
            num_classes_task_a (int): Number of classes for Task A (Sentence Classification).
            num_classes_task_b (int): Number of classes for Task B (Sentiment Analysis).
            device (torch.device, optional): The device to run the model on. Defaults to None, which auto-selects.
        """
        super(MultiTaskSentenceTransformer, self).__init__()
        
        # Determine device
        if device is None:
            if torch.backends.mps.is_available():
                self.device = torch.device("mps")
            elif torch.cuda.is_available():
                self.device = torch.device("cuda")
            else:
                self.device = torch.device("cpu")
        else:
            self.device = device
        
        print(f"Initializing model on device: {self.device}")
        
        # Initialize the Sentence Transformer model on the specified device
        self.transformer = SentenceTransformer(transformer_model, device=self.device)
        
        # Task A: Sentence Classification Head
        self.classification_head = nn.Linear(self.transformer.get_sentence_embedding_dimension(), num_classes_task_a).to(self.device)
        
        # Task B: Sentiment Analysis Head
        self.sentiment_head = nn.Linear(self.transformer.get_sentence_embedding_dimension(), num_classes_task_b).to(self.device)
        
        # Optionally freeze transformer parameters if not training
        for param in self.transformer.parameters():
            param.requires_grad = False

    def forward(self, sentences):
        """
        Forward pass to obtain predictions for both tasks.

        Args:
            sentences (list of str): Input sentences.

        Returns:
            dict: Predictions for Task A and Task B.
        """
        # Generate sentence embeddings
        embeddings = self.transformer.encode(sentences, convert_to_tensor=True, device=self.device)
        
        # Task A Predictions
        task_a_logits = self.classification_head(embeddings)
        task_a_predictions = torch.argmax(task_a_logits, dim=1)
        
        # Task B Predictions
        task_b_logits = self.sentiment_head(embeddings)
        task_b_predictions = torch.argmax(task_b_logits, dim=1)
        
        return {
            'Task_A_Predictions': task_a_predictions,
            'Task_B_Predictions': task_b_predictions
        }

if __name__ == "__main__":
    # Initialize the multi-task model
    model = MultiTaskSentenceTransformer()
    
    # Sample sentences for testing
    sample_sentences = [
        "The new smartphone has an excellent camera.",
        "I am disappointed with the service.",
        "The weather today is sunny and bright."
    ]
    
    # Perform forward pass
    predictions = model(sample_sentences)
    
    # Define label mappings for demonstration
    task_a_labels = {0: "Technology", 1: "Service", 2: "Weather"}
    task_b_labels = {0: "Negative", 1: "Positive", 2: "Neutral"}
    
    # Display predictions
    for idx, sentence in enumerate(sample_sentences):
        task_a_pred = predictions['Task_A_Predictions'][idx].item()
        task_b_pred = predictions['Task_B_Predictions'][idx].item()
        print(f"Sentence {idx+1}: \"{sentence}\"")
        print(f"  Task A Prediction (Sentence Classification): {task_a_labels.get(task_a_pred, 'Unknown')}")
        print(f"  Task B Prediction (Sentiment Analysis): {task_b_labels.get(task_b_pred, 'Unknown')}\n")


        



No sentence-transformers model found with name distilbert-base-uncased. Creating a new one with mean pooling.


Initializing model on device: mps
Sentence 1: "The new smartphone has an excellent camera."
  Task A Prediction (Sentence Classification): Technology
  Task B Prediction (Sentiment Analysis): Negative

Sentence 2: "I am disappointed with the service."
  Task A Prediction (Sentence Classification): Technology
  Task B Prediction (Sentiment Analysis): Negative

Sentence 3: "The weather today is sunny and bright."
  Task A Prediction (Sentence Classification): Technology
  Task B Prediction (Sentiment Analysis): Negative

