In [2]:
# Setup and Imports
import os
import sys
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torch.optim import AdamW
from transformers import BertModel, get_linear_schedule_with_warmup
import numpy as np
import pandas as pd
from sklearn.metrics import f1_score, precision_score, recall_score, roc_auc_score, accuracy_score, classification_report
from tqdm import tqdm
import json
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

# Add src to path
sys.path.append(str(Path.cwd().parent / 'src'))
from multitask_bert import MultiTaskBERT

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
print(f"Using device: {device}")

# Verify GPU usage
if device.type == 'cuda':
    print(f"✓ GPU: {torch.cuda.get_device_name(0)}")
    print(f"✓ GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
else:
    print("⚠️ WARNING: Not using GPU! Training will be very slow.")
    print("In Colab: Runtime → Change runtime type → Hardware accelerator → GPU")

# Enable mixed precision training
from torch.cuda.amp import autocast, GradScaler
use_amp = device.type == 'cuda'
scaler = GradScaler() if use_amp else None
if use_amp:
    print("✓ Mixed precision training enabled")

# Create directories
os.makedirs('checkpoints', exist_ok=True)
os.makedirs('results', exist_ok=True)

Using device: cuda
✓ GPU: Tesla T4
✓ GPU Memory: 15.83 GB
✓ Mixed precision training enabled


In [18]:
from google.colab import drive
drive.mount('/content/drive')
BASE_PATH = 'drive/MyDrive/Deep Learning project/'

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


# Part 1.3 — Experiments

This notebook contains all experimental configurations for toxicity detection:
1. **Baseline**: Fine-tune BERT only on toxicity (single-task)
2. **Multi-Task**: Train jointly on both tasks (shared encoder)
3. **Sequential**: Pre-train on emotion → fine-tune for toxicity
4. **Ablation studies** with different loss weights (1:1, 2:1, 1:2)

### Metrics Tracked
- **F1 Score**: Primary metric for toxicity classification
- **Precision**: Precision for toxic class
- **Recall**: Recall for toxic class
- **ROC-AUC**: Area under ROC curve
- **Accuracy**: Overall classification accuracy


In [None]:
class ToxicityDataset(Dataset):
    """Dataset for toxicity classification."""

    def __init__(self, data):
        self.tokens = data["tokens"]
        self.labels = data["labels"]

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return {
            "input_ids": self.tokens["input_ids"][idx],
            "attention_mask": self.tokens["attention_mask"][idx],
            "labels": self.labels[idx]
        }

class EmotionDataset(Dataset):
    """Dataset for emotion classification."""

    def __init__(self, data):
        self.tokens = data["tokens"]
        self.labels = data["labels"]

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return {
            "input_ids": self.tokens["input_ids"][idx],
            "attention_mask": self.tokens["attention_mask"][idx],
            "labels": self.labels[idx]
        }

def toxicity_collate_fn(batch):
    """Collate function for toxicity batches."""
    return {
        "input_ids": torch.stack([item["input_ids"] for item in batch]),
        "attention_mask": torch.stack([item["attention_mask"] for item in batch]),
        "labels": torch.stack([item["labels"] for item in batch])
    }

def emotion_collate_fn(batch):
    """Collate function for emotion batches."""
    return {
        "input_ids": torch.stack([item["input_ids"] for item in batch]),
        "attention_mask": torch.stack([item["attention_mask"] for item in batch]),
        "labels": torch.stack([item["labels"] for item in batch])
    }

def load_data(data_root="../data/processed/tokenized"):
    """Load separate datasets for toxicity and emotion classification."""
    data_root = Path(data_root)

    tox_dir = data_root / "toxicity"
    emo_dir = data_root / "emotion"

    if not tox_dir.exists() or not emo_dir.exists():
        raise FileNotFoundError(f"Data directories not found: {tox_dir}, {emo_dir}")

    print("Loading toxicity data...")
    toxicity_train_data = torch.load(tox_dir / "train.pt")
    toxicity_val_data = torch.load(tox_dir / "val.pt")
    toxicity_test_data = torch.load(tox_dir / "test.pt")

    print("Loading emotion data...")
    emotion_train_data = torch.load(emo_dir / "train.pt")
    emotion_val_data = torch.load(emo_dir / "val.pt")
    emotion_test_data = torch.load(emo_dir / "test.pt")

    tox_train = ToxicityDataset(toxicity_train_data)
    tox_val = ToxicityDataset(toxicity_val_data)
    tox_test = ToxicityDataset(toxicity_test_data)
    emo_train = EmotionDataset(emotion_train_data)
    emo_val = EmotionDataset(emotion_val_data)
    emo_test = EmotionDataset(emotion_test_data)

    print(f"Toxicity - Train: {len(tox_train)}, Val: {len(tox_val)}, Test: {len(tox_test)}")
    print(f"Emotion - Train: {len(emo_train)}, Val: {len(emo_val)}, Test: {len(emo_test)}")

    return tox_train, tox_val, tox_test, emo_train, emo_val, emo_test


In [10]:
def evaluate_toxicity(model, dataloader, device):
    """Evaluate toxicity model and return metrics."""
    model.eval()
    all_preds = []
    all_labels = []
    total_loss = 0
    criterion = nn.CrossEntropyLoss()

    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating"):
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].long().to(device)

            # Get predictions (only toxicity head)
            if isinstance(model, MultiTaskBERT):
                logits, _ = model(input_ids, attention_mask)
            else:
                # Single-task baseline model
                logits = model(input_ids, attention_mask)

            loss = criterion(logits, labels)
            total_loss += loss.item()

            preds = torch.argmax(logits, dim=1).cpu().numpy()
            all_preds.extend(preds)
            all_labels.extend(labels.cpu().numpy())

    # Calculate metrics
    f1 = f1_score(all_labels, all_preds, average='weighted')
    precision = precision_score(all_labels, all_preds, average='weighted', zero_division=0)
    recall = recall_score(all_labels, all_preds, average='weighted', zero_division=0)
    accuracy = accuracy_score(all_labels, all_preds)

    # ROC-AUC (multi-class)
    try:
        # Convert to one-hot for multi-class ROC-AUC
        num_classes = len(np.unique(all_labels))
        if num_classes > 2:
            # Use macro average for multi-class
            roc_auc = roc_auc_score(all_labels, all_preds, average='macro', multi_class='ovr', labels=range(num_classes))
        else:
            roc_auc = roc_auc_score(all_labels, all_preds)
    except:
        roc_auc = 0.0

    avg_loss = total_loss / len(dataloader)

    return {
        'f1': f1,
        'precision': precision,
        'recall': recall,
        'roc_auc': roc_auc,
        'accuracy': accuracy,
        'loss': avg_loss
    }


In [None]:
# Configuration
DATA_BASE_PATH = '../data/processed/tokenized'
TOX_DATA_PATH = f'{DATA_BASE_PATH}/toxicity'
EMO_DATA_PATH = f'{DATA_BASE_PATH}/emotion'

EPOCHS = 3
BATCH_SIZE = 32
LEARNING_RATE = 2e-5
GRADIENT_CLIP = 1.0

print("Configuration:")
print(f"  Epochs: {EPOCHS}")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  Learning rate: {LEARNING_RATE}")

# Load data
tox_train, tox_val, tox_test, emo_train, emo_val, emo_test = load_data(DATA_BASE_PATH)


Configuration:
  Epochs: 3
  Batch size: 32
  Learning rate: 2e-05
Loading toxicity data...
Loading emotion data...
Toxicity - Train: 126580, Val: 15823, Test: 15823
Emotion - Train: 39064, Val: 4883, Test: 4883


### 1. Baseline Model: Single-Task Toxicity Fine-tuning


In [12]:
# Baseline: Single-task BERT for toxicity only
class BaselineBERT(nn.Module):
    """Single-task BERT model for toxicity classification only."""

    def __init__(self, bert_model_name='bert-base-uncased', num_labels=6):
        super().__init__()
        self.bert = BertModel.from_pretrained(bert_model_name)
        self.dropout = nn.Dropout(self.bert.config.hidden_dropout_prob)
        self.classifier = nn.Linear(self.bert.config.hidden_size, num_labels)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        cls_output = outputs.last_hidden_state[:, 0, :]
        cls_output = self.dropout(cls_output)
        logits = self.classifier(cls_output)
        return logits

def train_baseline(model, train_loader, val_loader, epochs=3, lr=2e-5, device='cpu', save_path='checkpoints/baseline_best.pt', use_amp=False, scaler=None):
    """Train baseline single-task model."""
    optimizer = AdamW(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()

    steps_per_epoch = len(train_loader)
    total_steps = steps_per_epoch * epochs
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=int(0.1 * total_steps), num_training_steps=total_steps)

    model.to(device)
    best_val_f1 = 0.0

    for epoch in range(epochs):
        model.train()
        total_loss = 0

        print(f"\nEpoch {epoch+1}/{epochs}")
        for batch in tqdm(train_loader, desc="Training"):
            optimizer.zero_grad()

            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].long().to(device)

            # Mixed precision training
            if use_amp and scaler is not None:
                with autocast():
                    logits = model(input_ids, attention_mask)
                    loss = criterion(logits, labels)

                scaler.scale(loss).backward()
                scaler.unscale_(optimizer)
                nn.utils.clip_grad_norm_(model.parameters(), GRADIENT_CLIP)
                scaler.step(optimizer)
                scaler.update()
            else:
                logits = model(input_ids, attention_mask)
                loss = criterion(logits, labels)
                loss.backward()
                nn.utils.clip_grad_norm_(model.parameters(), GRADIENT_CLIP)
                optimizer.step()

            scheduler.step()
            total_loss += loss.item()

        avg_loss = total_loss / len(train_loader)
        print(f"Train Loss: {avg_loss:.4f}")

        # Validation
        val_metrics = evaluate_toxicity(model, val_loader, device)
        print(f"Val Metrics - F1: {val_metrics['f1']:.4f}, Precision: {val_metrics['precision']:.4f}, "
              f"Recall: {val_metrics['recall']:.4f}, ROC-AUC: {val_metrics['roc_auc']:.4f}")

        # Save best model
        if val_metrics['f1'] > best_val_f1:
            best_val_f1 = val_metrics['f1']
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_f1': val_metrics['f1'],
                'val_metrics': val_metrics
            }, save_path)
            print(f"Saved best model (F1: {val_metrics['f1']:.4f})")

    return model, val_metrics

In [13]:
# Train baseline
print("="*80)
print("Training Baseline Model (Single-Task Toxicity)")
print("="*80)

baseline_model = BaselineBERT()
tox_train_loader = DataLoader(tox_train, batch_size=BATCH_SIZE, shuffle=True, collate_fn=toxicity_collate_fn, num_workers=2, pin_memory=True)
tox_val_loader = DataLoader(tox_val, batch_size=BATCH_SIZE, collate_fn=toxicity_collate_fn, num_workers=2, pin_memory=True)
tox_test_loader = DataLoader(tox_test, batch_size=BATCH_SIZE, collate_fn=toxicity_collate_fn, num_workers=2, pin_memory=True)

baseline_model, baseline_val_metrics = train_baseline(
    baseline_model, tox_train_loader, tox_val_loader,
    epochs=EPOCHS, lr=LEARNING_RATE, device=device,
    save_path='checkpoints/baseline_best.pt', use_amp=use_amp, scaler=scaler
)

# Test evaluation
print("\nEvaluating on test set...")
baseline_test_metrics = evaluate_toxicity(baseline_model, tox_test_loader, device)
print(f"Test Metrics - F1: {baseline_test_metrics['f1']:.4f}, Precision: {baseline_test_metrics['precision']:.4f}, "
      f"Recall: {baseline_test_metrics['recall']:.4f}, ROC-AUC: {baseline_test_metrics['roc_auc']:.4f}")

# Save final model
torch.save(baseline_model.state_dict(), 'checkpoints/baseline_final.pt')
print("Saved final baseline model")

Training Baseline Model (Single-Task Toxicity)


config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]


Epoch 1/3


Training: 100%|██████████| 3956/3956 [13:06<00:00,  5.03it/s]


Train Loss: 0.0740


Evaluating: 100%|██████████| 495/495 [01:45<00:00,  4.69it/s]


Val Metrics - F1: 0.9937, Precision: 0.9915, Recall: 0.9958, ROC-AUC: 0.0000
Saved best model (F1: 0.9937)

Epoch 2/3


Training: 100%|██████████| 3956/3956 [13:08<00:00,  5.02it/s]


Train Loss: 0.0316


Evaluating: 100%|██████████| 495/495 [01:45<00:00,  4.69it/s]


Val Metrics - F1: 0.9937, Precision: 0.9915, Recall: 0.9958, ROC-AUC: 0.0000

Epoch 3/3


Training: 100%|██████████| 3956/3956 [13:04<00:00,  5.04it/s]


Train Loss: 0.0213


Evaluating: 100%|██████████| 495/495 [01:45<00:00,  4.68it/s]


Val Metrics - F1: 0.9935, Precision: 0.9922, Recall: 0.9953, ROC-AUC: 0.0000

Evaluating on test set...


Evaluating: 100%|██████████| 495/495 [01:45<00:00,  4.71it/s]


Test Metrics - F1: 0.9919, Precision: 0.9903, Recall: 0.9941, ROC-AUC: 0.0000
Saved final baseline model


### 2. Multi-Task Model: Joint Training


In [21]:
def train_multitask(model, tox_train_loader, tox_val_loader, emo_train_loader, emo_val_loader,
                    epochs=3, lr=2e-5, device='cpu', save_path='checkpoints/multitask_best.pt'):
    """Train multi-task model with alternating batches."""
    optimizer = AdamW(model.parameters(), lr=lr)

    steps_per_epoch = max(len(tox_train_loader), len(emo_train_loader))
    total_steps = steps_per_epoch * epochs
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=int(0.1 * total_steps), num_training_steps=total_steps)

    criterion_tox = nn.CrossEntropyLoss()
    criterion_emo = nn.BCEWithLogitsLoss()

    model.to(device)
    best_val_f1 = 0.0

    for epoch in range(epochs):
        model.train()
        total_tox_loss = 0
        total_emo_loss = 0

        print(f"\nEpoch {epoch+1}/{epochs}")
        tox_iter = iter(tox_train_loader)
        emo_iter = iter(emo_train_loader)

        progress_bar = tqdm(range(steps_per_epoch), desc="Training")
        for step in progress_bar:
            # Train on toxicity
            try:
                tox_batch = next(tox_iter)
                optimizer.zero_grad()

                input_ids = tox_batch["input_ids"].to(device)
                attention_mask = tox_batch["attention_mask"].to(device)
                tox_labels = tox_batch["labels"].long().to(device)

                tox_logits, _ = model(input_ids, attention_mask)
                loss = model.lambda_tox * criterion_tox(tox_logits, tox_labels)

                loss.backward()
                nn.utils.clip_grad_norm_(model.parameters(), GRADIENT_CLIP)
                optimizer.step()
                scheduler.step()

                total_tox_loss += loss.item()
            except StopIteration:
                tox_iter = iter(tox_train_loader)

            # Train on emotion
            try:
                emo_batch = next(emo_iter)
                optimizer.zero_grad()

                input_ids = emo_batch["input_ids"].to(device)
                attention_mask = emo_batch["attention_mask"].to(device)
                emo_labels = emo_batch["labels"].float().to(device)

                _, emo_logits = model(input_ids, attention_mask)
                loss = model.lambda_emo * criterion_emo(emo_logits, emo_labels)

                loss.backward()
                nn.utils.clip_grad_norm_(model.parameters(), GRADIENT_CLIP)
                optimizer.step()
                scheduler.step()

                total_emo_loss += loss.item()
            except StopIteration:
                emo_iter = iter(emo_train_loader)

            progress_bar.set_postfix({
                'tox_loss': f'{total_tox_loss/(step+1):.4f}',
                'emo_loss': f'{total_emo_loss/(step+1):.4f}'
            })

        avg_tox_loss = total_tox_loss / steps_per_epoch
        avg_emo_loss = total_emo_loss / steps_per_epoch
        print(f"Train Loss - Tox: {avg_tox_loss:.4f}, Emo: {avg_emo_loss:.4f}")

        # Validation (only toxicity metrics)
        val_metrics = evaluate_toxicity(model, tox_val_loader, device)
        print(f"Val Metrics - F1: {val_metrics['f1']:.4f}, Precision: {val_metrics['precision']:.4f}, "
              f"Recall: {val_metrics['recall']:.4f}, ROC-AUC: {val_metrics['roc_auc']:.4f}")

        # Save best model
        if val_metrics['f1'] > best_val_f1:
            best_val_f1 = val_metrics['f1']
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_f1': val_metrics['f1'],
                'val_metrics': val_metrics,
                'lambda_tox': model.lambda_tox,
                'lambda_emo': model.lambda_emo
            }, save_path)
            print(f"Saved best model (F1: {val_metrics['f1']:.4f})")

    return model, val_metrics

In [None]:
# Train multi-task with different loss weights
print("="*80)
print("Training Multi-Task Models with Different Loss Weights")
print("="*80)

emo_train_loader = DataLoader(emo_train, batch_size=BATCH_SIZE, shuffle=True, collate_fn=emotion_collate_fn, num_workers=0)
emo_val_loader = DataLoader(emo_val, batch_size=BATCH_SIZE, collate_fn=emotion_collate_fn, num_workers=0)

# Loss weight configurations
loss_configs = [
    (1.0, 1.0, '1_1'),
    (2.0, 1.0, '2_1'),
    (1.0, 2.0, '1_2')
]

multitask_results = []

for lambda_tox, lambda_emo, suffix in loss_configs:
    print(f"\n{'='*80}")
    print(f"Multi-Task Training: λ_tox={lambda_tox}, λ_emo={lambda_emo}")
    print(f"{'='*80}")

    model = MultiTaskBERT(lambda_tox=lambda_tox, lambda_emo=lambda_emo)
    save_path = f'checkpoints/multitask_w{lambda_tox}_{lambda_emo}_best.pt'

    model, val_metrics = train_multitask(
        model, tox_train_loader, tox_val_loader, emo_train_loader, emo_val_loader,
        epochs=EPOCHS, lr=LEARNING_RATE, device=device, save_path=save_path
    )

    # Test evaluation
    test_metrics = evaluate_toxicity(model, tox_test_loader, device)
    print(f"\nTest Metrics - F1: {test_metrics['f1']:.4f}, Precision: {test_metrics['precision']:.4f}, "
          f"Recall: {test_metrics['recall']:.4f}, ROC-AUC: {test_metrics['roc_auc']:.4f}")

    # Save final model
    torch.save(model.state_dict(), f'checkpoints/multitask_w{lambda_tox}_{lambda_emo}_final.pt')

    multitask_results.append({
        'model': 'multitask',
        'config': f'multitask_w{lambda_tox}_{lambda_emo}',
        'lambda_tox': lambda_tox,
        'lambda_emo': lambda_emo,
        'val_f1': val_metrics['f1'],
        'val_precision': val_metrics['precision'],
        'val_recall': val_metrics['recall'],
        'val_roc_auc': val_metrics['roc_auc'],
        'val_accuracy': val_metrics['accuracy'],
        'test_f1': test_metrics['f1'],
        'test_precision': test_metrics['precision'],
        'test_recall': test_metrics['recall'],
        'test_roc_auc': test_metrics['roc_auc'],
        'test_accuracy': test_metrics['accuracy']
    })

Training Multi-Task Models with Different Loss Weights

Multi-Task Training: λ_tox=1.0, λ_emo=1.0

Epoch 1/3


Training:   6%|▌         | 247/3956 [05:14<1:18:10,  1.26s/it, tox_loss=0.4848, emo_loss=0.3900]

### 3. Sequential Model: Pre-train on Emotion → Fine-tune for Toxicity


In [None]:
def train_sequential(emo_train_loader, emo_val_loader, tox_train_loader, tox_val_loader,
                     epochs_pretrain=2, epochs_finetune=3, lr=2e-5, device='cpu',
                     save_path='checkpoints/sequential_best.pt'):
    """Sequential training: pre-train on emotion, then fine-tune for toxicity."""

    # Step 1: Pre-train on emotion
    print("="*80)
    print("Step 1: Pre-training on Emotion")
    print("="*80)

    model = MultiTaskBERT(lambda_tox=0.0, lambda_emo=1.0)  # Only emotion loss
    optimizer = AdamW(model.parameters(), lr=lr)
    criterion_emo = nn.BCEWithLogitsLoss()

    steps_per_epoch = len(emo_train_loader)
    total_steps = steps_per_epoch * epochs_pretrain
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=int(0.1 * total_steps), num_training_steps=total_steps)

    model.to(device)

    for epoch in range(epochs_pretrain):
        model.train()
        total_loss = 0

        print(f"\nPre-train Epoch {epoch+1}/{epochs_pretrain}")
        for batch in tqdm(emo_train_loader, desc="Training"):
            optimizer.zero_grad()

            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            emo_labels = batch["labels"].float().to(device)

            _, emo_logits = model(input_ids, attention_mask)
            loss = criterion_emo(emo_logits, emo_labels)

            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), GRADIENT_CLIP)
            optimizer.step()
            scheduler.step()

            total_loss += loss.item()

        avg_loss = total_loss / len(emo_train_loader)
        print(f"Pre-train Loss: {avg_loss:.4f}")

    print("\nPre-training complete!")

    # Step 2: Fine-tune for toxicity
    print("\n" + "="*80)
    print("Step 2: Fine-tuning for Toxicity")
    print("="*80)

    # Now switch to toxicity-only training
    model.lambda_tox = 1.0
    model.lambda_emo = 0.0  # Disable emotion loss

    optimizer = AdamW(model.parameters(), lr=lr * 0.1)  # Lower learning rate for fine-tuning
    criterion_tox = nn.CrossEntropyLoss()

    steps_per_epoch = len(tox_train_loader)
    total_steps = steps_per_epoch * epochs_finetune
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=int(0.1 * total_steps), num_training_steps=total_steps)

    best_val_f1 = 0.0

    for epoch in range(epochs_finetune):
        model.train()
        total_loss = 0

        print(f"\nFine-tune Epoch {epoch+1}/{epochs_finetune}")
        for batch in tqdm(tox_train_loader, desc="Training"):
            optimizer.zero_grad()

            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            tox_labels = batch["labels"].long().to(device)

            tox_logits, _ = model(input_ids, attention_mask)
            loss = criterion_tox(tox_logits, tox_labels)

            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), GRADIENT_CLIP)
            optimizer.step()
            scheduler.step()

            total_loss += loss.item()

        avg_loss = total_loss / len(tox_train_loader)
        print(f"Fine-tune Loss: {avg_loss:.4f}")

        # Validation
        val_metrics = evaluate_toxicity(model, tox_val_loader, device)
        print(f"Val Metrics - F1: {val_metrics['f1']:.4f}, Precision: {val_metrics['precision']:.4f}, "
              f"Recall: {val_metrics['recall']:.4f}, ROC-AUC: {val_metrics['roc_auc']:.4f}")

        # Save best model
        if val_metrics['f1'] > best_val_f1:
            best_val_f1 = val_metrics['f1']
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_f1': val_metrics['f1'],
                'val_metrics': val_metrics
            }, save_path)
            print(f"Saved best model (F1: {val_metrics['f1']:.4f})")

    return model, val_metrics

In [None]:
# Train sequential model
sequential_model, sequential_val_metrics = train_sequential(
    emo_train_loader, emo_val_loader, tox_train_loader, tox_val_loader,
    epochs_pretrain=2, epochs_finetune=EPOCHS, lr=LEARNING_RATE, device=device,
    save_path='checkpoints/sequential_best.pt'
)

# Test evaluation
print("\nEvaluating on test set...")
sequential_test_metrics = evaluate_toxicity(sequential_model, tox_test_loader, device)
print(f"Test Metrics - F1: {sequential_test_metrics['f1']:.4f}, Precision: {sequential_test_metrics['precision']:.4f}, "
      f"Recall: {sequential_test_metrics['recall']:.4f}, ROC-AUC: {sequential_test_metrics['roc_auc']:.4f}")

# Save final model
torch.save(sequential_model.state_dict(), 'checkpoints/sequential_final.pt')
print("Saved final sequential model")

In [None]:
# Compile all results
all_results = []

# Baseline results
all_results.append({
    'model': 'baseline',
    'config': 'single_task',
    'lambda_tox': 1.0,
    'lambda_emo': 0.0,
    'val_f1': baseline_val_metrics['f1'],
    'val_precision': baseline_val_metrics['precision'],
    'val_recall': baseline_val_metrics['recall'],
    'val_roc_auc': baseline_val_metrics['roc_auc'],
    'val_accuracy': baseline_val_metrics['accuracy'],
    'test_f1': baseline_test_metrics['f1'],
    'test_precision': baseline_test_metrics['precision'],
    'test_recall': baseline_test_metrics['recall'],
    'test_roc_auc': baseline_test_metrics['roc_auc'],
    'test_accuracy': baseline_test_metrics['accuracy']
})

# Multi-task results
for result in multitask_results:
    all_results.append(result)

# Sequential results
all_results.append({
    'model': 'sequential',
    'config': 'pretrain_emotion_finetune_toxicity',
    'lambda_tox': 1.0,
    'lambda_emo': 1.0,
    'val_f1': sequential_val_metrics['f1'],
    'val_precision': sequential_val_metrics['precision'],
    'val_recall': sequential_val_metrics['recall'],
    'val_roc_auc': sequential_val_metrics['roc_auc'],
    'val_accuracy': sequential_val_metrics['accuracy'],
    'test_f1': sequential_test_metrics['f1'],
    'test_precision': sequential_test_metrics['precision'],
    'test_recall': sequential_test_metrics['recall'],
    'test_roc_auc': sequential_test_metrics['roc_auc'],
    'test_accuracy': sequential_test_metrics['accuracy']
})

# Create DataFrame and save
results_df = pd.DataFrame(all_results)
results_path = 'results/metrics.csv'
results_df.to_csv(results_path, index=False)

print("\n" + "="*80)
print("All Results Summary")
print("="*80)
print(results_df.to_string(index=False))
print(f"\nResults saved to {results_path}")


### Checkpoint Files:
- `checkpoints/baseline_best.pt` / `baseline_final.pt`: Single-task toxicity model
- `checkpoints/multitask_w1.0_1.0_best.pt` / `multitask_w1.0_1.0_final.pt`: Multi-task (1:1 weights)
- `checkpoints/multitask_w2.0_1.0_best.pt` / `multitask_w2.0_1.0_final.pt`: Multi-task (2:1 weights)
- `checkpoints/multitask_w1.0_2.0_best.pt` / `multitask_w1.0_2.0_final.pt`: Multi-task (1:2 weights)
- `checkpoints/sequential_best.pt` / `sequential_final.pt`: Sequential training model

### Metrics:
All evaluation metrics (F1, Precision, Recall, ROC-AUC, Accuracy) are saved in `results/metrics.csv`
