# Task 2: Multi-Task Learning Expansion
Expand the sentence transformer to handle a multi-task learning setting.
1. Task A: Sentence Classification – Classify sentences into predefined classes (you can make these up).
2. Task B: [Choose another relevant NLP task such as Named Entity Recognition,
Sentiment Analysis, etc.] (you can make the labels up)
Describe the changes made to the architecture to support multi-task learning.

This notebook demonstrates how to extend a sentence-transformer model (BERT) to perform **two** tasks simultaneously:

1. **Task A: Sentence Classification**  
2. **Task B: Named Entity Recognition (NER)**

We share a single encoder and add two task-specific heads. We then train with a combined loss.


# 1. Imports Libraries

We load PyTorch, HuggingFace Transformers, and the `datasets` library.

In [None]:

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from transformers import BertModel, BertTokenizer, AdamW
from datasets import load_dataset
import numpy as np


# 2. Data Preparation

- **Task A**: Sentiment classification on IMDB (binary: positive vs. negative).  
- **Task B**: Synthetic NER labels (4 classes: O, PER, LOC, ORG) for demonstration only.

We’ll:
1. Load the IMDB dataset.
2. Tokenize text to fixed length.
3. Create randomm NER labels aligned to token count.


In [None]:
# 2.1 Load the IMDB dataset (train & test splits)
dataset = load_dataset("imdb")

# 2.2 Initialize the BERT tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

def create_ner_labels(text: str):
    """
    Generate synthetic token-level labels.
    In a real setting, you’d use human-annotated NER tags.
    Returns a list of random ints (0-3) matching the tokenized length.
    """
    tokens = tokenizer.tokenize(text)
    return np.random.randint(0, 4, size=(len(tokens),)).tolist()

def preprocess(batch):
    """
    Tokenize the batch of texts, create Task A & B labels.
    Returns PyTorch tensors and raw label lists.
    """
    # Convert IMDB 0/1 labels to our binary classification labels
    task_a_labels = [1 if lbl == 1 else 0 for lbl in batch["label"]]
    
    # Generate synthetic NER labels per example
    task_b_labels = [create_ner_labels(txt) for txt in batch["text"]]

    # Tokenize to max_length=128
    enc = tokenizer(
        batch["text"],
        padding="max_length",
        truncation=True,
        max_length=128,
        return_tensors="pt"
    )

    return {
        "input_ids": enc.input_ids,
        "attention_mask": enc.attention_mask,
        "task_a_labels": task_a_labels,
        "task_b_labels": task_b_labels
    }

# Apply preprocessing to the train split
dataset = dataset.map(preprocess, batched=True, remove_columns=["text","label"])


# 3. Model Architecture

We define a single `BertModel` encoder plus:

- **Classification head** (pooled output → 2 classes)  
- **NER head** (token outputs → 4 classes)  
- Shared dropout for regularization.


In [None]:
class MultiTaskTransformer(nn.Module):
    def __init__(self):
        super().__init__()
        # 3.1 Shared BERT encoder
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        
        # 3.2 Task A: Sentence classification head
        self.classifier = nn.Linear(self.bert.config.hidden_size, 2)
        
        # 3.3 Task B: Token-level NER head
        self.ner = nn.Linear(self.bert.config.hidden_size, 4)
        
        # 3.4 Dropout to reduce overfitting
        self.dropout = nn.Dropout(p=0.1)

    def forward(self, input_ids, attention_mask):
        # 3.5 Pass inputs through BERT
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        
        # 3.6 Pooled [CLS] output for classification
        pooled = outputs.pooler_output
        
        # 3.7 Full sequence output for token tagging
        sequence = outputs.last_hidden_state
        
        # 3.8 Compute logits for each task
        logits_a = self.classifier(self.dropout(pooled))          # shape: (batch, 2)
        logits_b = self.ner(self.dropout(sequence))               # shape: (batch, seq_len, 4)
        
        return logits_a, logits_b


# 4. Training Setup

- **Loss functions**:  
  - Task A: `CrossEntropyLoss` over sentence logits  
  - Task B: `CrossEntropyLoss` over flattened token logits  
- **Optimizer**: AdamW with a small learning rate.


In [None]:

# 4.1 Instantiate the model
model = MultiTaskTransformer()

# 4.2 Define losses
loss_fn_a = nn.CrossEntropyLoss()
loss_fn_b = nn.CrossEntropyLoss()

# 4.3 Optimizer
optimizer = AdamW(model.parameters(), lr=2e-5)


# 5. Collate Function & DataLoader

We need to pad NER label lists up to `max_length` so they align with token sequences.


In [None]:

def collate_fn(batch):
    """
    Batch is a list of dicts:
      - input_ids: tensor [seq_len]
      - attention_mask: tensor [seq_len]
      - task_a_labels: int
      - task_b_labels: list[int]
    We stack input tensors and pad NER labels to length 128.
    """
    input_ids    = torch.stack([item["input_ids"] for item in batch])
    attention    = torch.stack([item["attention_mask"] for item in batch])
    task_a       = torch.tensor([item["task_a_labels"] for item in batch])
    
    # Pad each task_b_labels list to length 128
    task_b_padded = [
        torch.nn.functional.pad(torch.tensor(lbls), (0, 128 - len(lbls)), value=-100)
        for lbls in (item["task_b_labels"] for item in batch)
    ]
    task_b = torch.stack(task_b_padded)
    
    return {
        "input_ids": input_ids,
        "attention_mask": attention,
        "task_a_labels": task_a,
        "task_b_labels": task_b
    }

# Create DataLoader for training
train_loader = DataLoader(dataset["train"], batch_size=16, shuffle=True, collate_fn=collate_fn)


# 6. Training Loop

For each epoch and batch:
1. Forward pass through shared encoder + both heads  
2. Compute Task A and Task B losses  
3. Sum losses (here equally weighted)  
4. Backpropagate & optimizer step  
5. Track average loss


In [None]:
# from torch.utils.data import DataLoader
# import torch
# import torch.nn.functional as F

# def collate_fn(batch):
#     """
#     Batch preparation for multi-task training.
    
#     Args:
#         batch: list of dicts, each containing:
#             - "input_ids":      list[int] of token IDs
#             - "attention_mask": list[int] mask (1=real token, 0=padding)
#             - "task_a_labels":  int label for sentence classification
#             - "task_b_labels":  list[int] of NER labels (one per token)
    
#     Returns:
#         A dict of batched Tensors:
#             input_ids:      LongTensor of shape (batch_size, seq_len)
#             attention_mask: LongTensor of shape (batch_size, seq_len)
#             task_a_labels:  LongTensor of shape (batch_size,)
#             task_b_labels:  LongTensor of shape (batch_size, seq_len)
#     """
#     # 1. Convert input_ids and attention_mask lists into tensors and stack
#     #    Resulting shape: (batch_size, seq_len)
#     input_ids = torch.stack([
#         torch.tensor(item["input_ids"], dtype=torch.long)
#         for item in batch
#     ])
#     attention_mask = torch.stack([
#         torch.tensor(item["attention_mask"], dtype=torch.long)
#         for item in batch
#     ])
    
#     # 2. Convert sentence-level labels (Task A) into a 1D tensor
#     #    Shape: (batch_size,)
#     task_a_labels = torch.tensor(
#         [item["task_a_labels"] for item in batch],
#         dtype=torch.long
#     )
    
#     # 3. Prepare token-level labels (Task B):
#     #    - Each example’s label list may be shorter than seq_len.
#     #    - Pad to length=128 with value -100, so these positions are ignored by loss_fn.
#     #    - Stack into shape: (batch_size, seq_len)
#     task_b_labels = torch.stack([
#         F.pad(
#             torch.tensor(item["task_b_labels"], dtype=torch.long),
#             pad=(0, 128 - len(item["task_b_labels"])),  # (left, right) padding
#             value=-100                                  # ignore index for CrossEntropyLoss
#         )
#         for item in batch
#     ])
    
#     # 4. Return the collated batch dictionary
#     return {
#         "input_ids": input_ids,
#         "attention_mask": attention_mask,
#         "task_a_labels": task_a_labels,
#         "task_b_labels": task_b_labels
#     }

# # Recreate the DataLoader to use our custom collate function.
# # Shuffle for training, batch_size=16.
# train_loader = DataLoader(
#     dataset["train"],
#     batch_size=16,
#     shuffle=True,
#     collate_fn=collate_fn
# )


# 7. Inference Example

A helper function to run both tasks on a single input string.


In [None]:

def predict(text: str):
    # Tokenize single sentence
    enc = tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=128)
    
    model.eval()
    with torch.no_grad():
        logits_a, logits_b = model(enc.input_ids, enc.attention_mask)
    
    # Task A: classification label
    pred_a = torch.argmax(logits_a, dim=1).item()
    sentiment = "Positive" if pred_a == 1 else "Negative"
    
    # Task B: NER tags per token
    tags = torch.argmax(logits_b, dim=2).squeeze().tolist()
    tokens = tokenizer.convert_ids_to_tokens(enc.input_ids[0])
    entities = list(zip(tokens, tags))
    
    return {"sentiment": sentiment, "entities": entities}

# Test
sample = "Christopher Nolan directed Inception in London"
print(predict(sample))


## 8. Summary of Architectural Changes

- **Shared Encoder**: One BERT processes all inputs.  
- **Task Heads**: Separate linear layers for sentence-level and token-level outputs.  
- **Dropout**: Regularizes both heads.  
- **Combined Loss**: Sum of classification and NER cross-entropy.  
- **Padding Strategy**: NER labels padded to match token length, with `-100` to ignore in the loss.
