# Task 4: Training Loop Implementation (BONUS)
If not already done, code the training loop for the Multi-Task Learning Expansion in Task 2.
Explain any assumptions or decisions made paying special attention to how training within a
MTL framework operates. Please note you need not actually train the model.

Things to focus on:
- Handling of hypothetical data
- Forward pass
- Metrics

In this notebook we implement a multi-task training loop for our `MultiTaskTransformer` from Task 2.  
We illustrate:

- **Handling of hypothetical data** via `train_loader`  
- **Forward pass** through the shared encoder and both task heads  
- **Metrics** for each task  

> **Note:** This code is illustrative—you don’t need to run it end-to-end. Ensure you have defined:
> ```python
> model        # MultiTaskTransformer instance
> train_loader # yields dicts with input_ids, attention_mask, task_a_labels, task_b_labels
> loss_fn_a    # CrossEntropyLoss() for Task A
> loss_fn_b    # CrossEntropyLoss() for Task B
> optimizer    # e.g. AdamW(model.parameters(), lr=2e-5)
> ```


In [6]:

import torch

# Number of epochs (for illustration)
num_epochs = 3

for epoch in range(1, num_epochs + 1):
    # Set model to training mode (enables dropout, etc.)
    model.train()
    
    # Accumulators for averaged loss and metrics
    total_loss = 0.0
    correct_a, total_a = 0, 0
    correct_b, total_b = 0, 0
    
    # Iterate over each batch from the DataLoader
    for batch in train_loader:
        # Reset gradients from previous step
        optimizer.zero_grad()
        
        # ----- Forward pass -----
        # Shared encoder produces:
        #   logits_a: shape (B, num_classes_A)
        #   logits_b: shape (B, seq_len, num_classes_B)
        logits_a, logits_b = model(
            input_ids=batch["input_ids"],         # Tensor [B, seq_len]
            attention_mask=batch["attention_mask"]# Tensor [B, seq_len]
        )
        
        # ----- Task A: Sentence Classification -----
        # Compute cross-entropy loss between logits and true labels
        loss_a = loss_fn_a(logits_a, batch["task_a_labels"])
        # Derive predicted class indices
        preds_a = torch.argmax(logits_a, dim=1)
        # Update accuracy counters
        correct_a += (preds_a == batch["task_a_labels"]).sum().item()
        total_a   += batch["task_a_labels"].size(0)
        
        # ----- Task B: Named Entity Recognition -----
        # logits_b is [B, seq_len, C], flatten for loss
        B, L, C = logits_b.shape
        loss_b = loss_fn_b(
            logits_b.view(-1, C),                # shape (B*L, C)
            batch["task_b_labels"].view(-1)       # shape (B*L,)
        )
        # Compute token-level predictions
        preds_b = torch.argmax(logits_b, dim=2)  # shape (B, seq_len)
        # Create mask for non-padding tokens (label != -100)
        mask = batch["task_b_labels"] != -100    # shape (B, seq_len)
        # Update token-level accuracy
        correct_b += ((preds_b == batch["task_b_labels"]) & mask).sum().item()
        total_b   += mask.sum().item()
        
        # ----- Combine & Backpropagate -----
        # Equal-weight sum of both task losses
        loss = loss_a + loss_b
        total_loss += loss.item()
        
        # Backpropagate gradients and update parameters
        loss.backward()
        optimizer.step()
    
    # ----- End-of-Epoch Metrics -----
    avg_loss = total_loss / len(train_loader)
    acc_a    = correct_a / total_a if total_a else 0.0
    acc_b    = correct_b / total_b if total_b else 0.0
    
    print(
        f"Epoch {epoch}/{num_epochs}  "
        f"| Loss: {avg_loss:.4f}  "
        f"| Task A Acc: {acc_a:.4f}  "
        f"| Task B Token Acc: {acc_b:.4f}"
    )


TypeError: list indices must be integers or slices, not str

### Assumptions & Decisions

1. **Data Handling**  
   - We assume `train_loader` uses our custom `collate_fn`, so each `batch` is a **dict** of tensors:
     - `input_ids`, `attention_mask`: shape `(B, seq_len)`
     - `task_a_labels`: shape `(B,)`
     - `task_b_labels`: shape `(B, seq_len)` with `-100` for padding positions.

2. **Forward Pass**  
   - A single call to `model(...)` yields two sets of logits:
     - **Sentence logits** for Task A
     - **Token logits** for Task B

3. **Loss Computation**  
   - **Task A:** Standard `CrossEntropyLoss` on `(B, num_classes_A)` vs. `(B,)`.  
   - **Task B:** Flatten `(B, seq_len, num_classes_B)` and `(B, seq_len)` to apply `CrossEntropyLoss`, ignoring `-100` tokens.

4. **Metrics**  
   - **Task A Accuracy:** Number of correct sentence predictions divided by total sentences.  
   - **Task B Token Accuracy:** Number of correct token predictions (excluding padding) divided by total valid tokens.

5. **Loss Aggregation**  
   - We sum `loss_a + loss_b` equally.  
   - For imbalanced tasks, introduce a weight α:  
     ```python
     loss = α * loss_a + (1 - α) * loss_b
     ```

6. **MTL Dynamics**  
   - Gradients from both losses flow back through the shared encoder, encouraging representations beneficial to both tasks.
   - This joint optimization can improve generalization via inductive transfer between related tasks.



In [4]:
# This means `batch` was coming in as a **list** of examples instead of a **dict** of tensors.  
# To fix this, we must pass our custom `collate_fn` when creating the DataLoader:

from torch.utils.data import DataLoader

# Recreate train_loader with collate_fn so each batch is a dict of tensors
train_loader = DataLoader(
    dataset["train"],
    batch_size=16,
    shuffle=True,
    collate_fn=collate_fn   # <-- ensures batch["input_ids"], etc. are tensors
)


NameError: name 'dataset' is not defined

### Assumptions & Decisions

- **Hypothetical Data**: We rely on our `train_loader` to supply properly tokenized inputs and padded labels.  
- **Loss Aggregation**: We sum Task A and Task B losses equally; this can be adjusted via weighting if tasks differ in scale.  
- **Metrics**:  
  - **Task A**: Sentence‐level accuracy.  
  - **Task B**: Token‐level accuracy, excluding padding tokens (label = `-100`).  
- **Modularity**: This loop can be extended to include learning‐rate schedulers, gradient clipping, or dynamic loss weighting without altering the core structure.  
- **No Actual Training**: This code is illustrative—define real loss functions and optimizer to run on your data.
