# Task 4: Training Loop Implementation (BONUS)
If not already done, code the training loop for the Multi-Task Learning Expansion in Task 2.
Explain any assumptions or decisions made paying special attention to how training within a
MTL framework operates. Please note you need not actually train the model.

Things to focus on:
- Handling of hypothetical data
- Forward pass
- Metrics

In [1]:
# import torch
# import torch.nn as nn
# import torch.nn.functional as F
# from torch.utils.data import DataLoader
# from transformers import BertModel, BertTokenizer, AdamW
# from datasets import load_dataset
# import numpy as np

# # Device: CPU-only
# device = torch.device("cpu")

# # Load tokenizer and IMDB dataset
# tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
# dataset = load_dataset("imdb")

# # Generate synthetic NER labels (4 classes: O, PER, LOC, ORG)
# def create_ner_labels(text):
#     tokens = tokenizer.tokenize(text)
#     return np.random.randint(0, 4, size=len(tokens)).tolist()

# # Tokenization + Label generation
# def preprocess(batch):
#     task_a = [1 if lbl == 1 else 0 for lbl in batch["label"]]  # Binary classification
#     task_b = [create_ner_labels(text) for text in batch["text"]]  # Random NER tags
#     enc = tokenizer(batch["text"], padding="max_length", truncation=True, max_length=128, return_tensors="pt")
#     return {
#         "input_ids": enc["input_ids"],
#         "attention_mask": enc["attention_mask"],
#         "task_a_labels": task_a,
#         "task_b_labels": task_b
#     }

# # Apply preprocessing to dataset
# dataset = dataset.map(preprocess, batched=True, remove_columns=["text", "label"])

# # Multi-Task Model (shared BERT + two task heads)
# class MultiTaskTransformer(nn.Module):
#     def __init__(self):
#         super().__init__()
#         self.bert = BertModel.from_pretrained("bert-base-uncased")
#         self.dropout = nn.Dropout(0.1)
#         self.classifier = nn.Linear(self.bert.config.hidden_size, 2)  # Task A: Sentiment
#         self.ner = nn.Linear(self.bert.config.hidden_size, 4)         # Task B: NER

#     def forward(self, input_ids, attention_mask):
#         outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
#         pooled = self.dropout(outputs.pooler_output)  # [CLS] for sentence classification
#         sequence = self.dropout(outputs.last_hidden_state)  # full sequence for token-level task
#         return self.classifier(pooled), self.ner(sequence)

# # Instantiate model
# model = MultiTaskTransformer().to(device)

# # Losses and optimizer
# loss_fn_a = nn.CrossEntropyLoss()
# loss_fn_b = nn.CrossEntropyLoss(ignore_index=-100)
# optimizer = AdamW(model.parameters(), lr=2e-5)

# # Custom collate function to pad token-level labels and stack tensors
# def collate_fn(batch):
#     input_ids = torch.stack([item["input_ids"] for item in batch]).to(device)
#     attention_mask = torch.stack([item["attention_mask"] for item in batch]).to(device)
#     task_a = torch.tensor([item["task_a_labels"] for item in batch], dtype=torch.long).to(device)
#     task_b = torch.stack([
#         F.pad(torch.tensor(item["task_b_labels"], dtype=torch.long), (0, 128 - len(item["task_b_labels"])), value=-100)
#         for item in batch
#     ]).to(device)
#     return {
#         "input_ids": input_ids,
#         "attention_mask": attention_mask,
#         "task_a_labels": task_a,
#         "task_b_labels": task_b
#     }

# # Dataloader
# train_loader = DataLoader(dataset["train"], batch_size=4, shuffle=True, collate_fn=collate_fn)

# # -------------------------------
# # TRAINING LOOP (CPU version)
# # -------------------------------
# model.train()  # Set model to training mode

# # Initialize metrics for Task A (sentence classification)
# total_correct_a = 0  # Count of correct predictions
# total_seen_a = 0     # Total number of samples seen

# # Iterate through the training DataLoader
# for step, batch in enumerate(train_loader):
#     optimizer.zero_grad()  # Reset gradients from previous step

#     # Forward pass through the model for both tasks
#     logits_a, logits_b = model(batch["input_ids"], batch["attention_mask"])
#     # logits_a: shape (batch_size, 2)     → for sentence classification
#     # logits_b: shape (batch_size, seq_len, 4) → for token-level NER

#     # Compute loss for Task A (sentence classification)
#     loss_a = loss_fn_a(logits_a, batch["task_a_labels"])

#     # Compute loss for Task B (NER), flatten for token-level CE loss
#     # logits_b.view(-1, 4): shape (batch_size * seq_len, 4)
#     # labels.view(-1): shape (batch_size * seq_len,)
#     # -100 label values will be ignored (padding)
#     loss_b = loss_fn_b(logits_b.view(-1, 4), batch["task_b_labels"].view(-1))

#     # Total loss = combined multitask objective
#     total_loss = loss_a + loss_b

#     # Backward pass: compute gradients
#     total_loss.backward()

#     # Update parameters using gradients
#     optimizer.step()

#     # -------------------------
#     # Task A: Compute accuracy
#     # -------------------------
#     preds = torch.argmax(logits_a, dim=1)  # Pick class with highest logit
#     correct = (preds == batch["task_a_labels"]).sum().item()  # Count correct preds
#     total_correct_a += correct
#     total_seen_a += len(batch["task_a_labels"])

#     # Periodic logging of training progress
#     if step % 20 == 0:
#         print(f"Step {step} | LossA: {loss_a.item():.4f} | LossB: {loss_b.item():.4f} | Total: {total_loss.item():.4f}")

#     # Stop after 50 steps (demo/truncated training)
#     if step == 50:
#         break

# # Compute and report final Task A accuracy
# print(f"\nTask A (Sentiment) Accuracy: {total_correct_a / total_seen_a:.4f}")





### Assumptions & Explanation for MTL Framework

| Aspect                         | Description                                                                                                                                               |
| ------------------------------ | --------------------------------------------------------------------------------------------------------------------------------------------------------- |
| **Handling Hypothetical Data** | Task B uses **synthetic NER labels** for simplicity. This simulates a real token classification task using random labels instead of annotated ones.       |
| **Forward Pass**               | The BERT encoder is shared. Task A (classification) uses the `[CLS]` pooled output. Task B (NER) uses the full hidden sequence for token-wise prediction. |
| **Loss Handling**              | Two `CrossEntropyLoss` functions are used and **summed**. Task B uses `ignore_index=-100` to mask padded tokens.                                          |
| **Metrics**                    | Only Task A has accuracy tracked. Task B could be evaluated with token-level precision/recall/F1 (not implemented here due to synthetic labels).          |
| **Design Choice**              | Both tasks are trained simultaneously, allowing **shared representation learning** and efficient multitask learning.                                      |
