In [2]:
# Import Required Libraries
import os
import random
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from transformers import BertModel, BertTokenizer, get_scheduler
from torch.optim import AdamW
from datasets import load_dataset
from collections import defaultdict
from sklearn.metrics import accuracy_score, mean_squared_error
from scipy.stats import pearsonr
import numpy as np

In [3]:
#Set Environment and Seed
os.environ["WANDB_DISABLED"] = "true"
random.seed(42)
torch.manual_seed(42)

<torch._C.Generator at 0x1ab68944af0>

In [4]:
#Load Pretrained Tokenizer
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

In [5]:
# Tokenize Each Task Dataset
# Each function takes a single sample (dict) and applies the tokenizer

def tokenize_sst(sample):
    return tokenizer(sample["sentence"], padding="max_length", truncation=True, max_length=128)

def tokenize_qqp(sample):
    return tokenizer(sample["question1"], sample["question2"], padding="max_length", truncation=True, max_length=128)

def tokenize_sts(sample):
    return tokenizer(sample["sentence1"], sample["sentence2"], padding="max_length", truncation=True, max_length=128)

In [6]:
# Load and Prepare Datasets
datasets = {
    "sst": load_dataset("glue", "sst2"),
    "qqp": load_dataset("glue", "qqp"),
    "sts": load_dataset("glue", "stsb"),
}

tokenized = {
    "sst": datasets["sst"].map(tokenize_sst, batched=True),
    "qqp": datasets["qqp"].map(tokenize_qqp, batched=True),
    "sts": datasets["sts"].map(tokenize_sts, batched=True),
}

for task in tokenized:
    if task == "sst":
        tokenized[task] = tokenized[task].remove_columns(["sentence", "idx"])
    elif task == "qqp":
        tokenized[task] = tokenized[task].remove_columns(["question1", "question2", "idx"])
    elif task == "sts":
        tokenized[task] = tokenized[task].remove_columns(["sentence1", "sentence2", "idx"])
    tokenized[task] = tokenized[task].rename_column("label", "labels")
    tokenized[task].set_format("torch")

Map: 100%|██████████| 67349/67349 [00:12<00:00, 5512.85 examples/s]
Map: 100%|██████████| 872/872 [00:00<00:00, 3487.91 examples/s]
Map: 100%|██████████| 1821/1821 [00:00<00:00, 3887.97 examples/s]
Map:   0%|          | 1000/363846 [00:00<02:05, 2902.50 examples/s]Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Map:   1%|          | 3000/363846 [00:01<02:01, 2958.69 examples/s]Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Map:   4%|▍         | 14000/363846 [00:04<01:56, 3011.91 examples/s]Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So 

In [7]:
# Define Multitask BERT Model Class
class MultiTaskBERT(nn.Module):
    def __init__(self):
        super().__init__()
        self.bert = BertModel.from_pretrained("bert-base-uncased")
        self.dropout = nn.Dropout(0.1)
        self.sst_head = nn.Linear(768, 2)           # classification
        self.qqp_head = nn.Linear(768, 2)           # classification
        self.sts_head = nn.Sequential(              # regression
            nn.Linear(768, 1),
            nn.Sigmoid()
        )

    def forward(self, input_ids, attention_mask, task):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        cls_output = self.dropout(outputs.pooler_output)
        if task == "sst":
            return self.sst_head(cls_output)
        elif task == "qqp":
            return self.qqp_head(cls_output)
        elif task == "sts":
            return self.sts_head(cls_output) * 5  # scale to [0, 5]

In [8]:
#  Task-Specific Dataloaders
batch_size = 16
dataloaders = {
    task: DataLoader(tokenized[task]["train"], batch_size=batch_size, shuffle=True)
    for task in tokenized
}

task_list = list(dataloaders.keys())
task_sizes = {task: len(dataloaders[task].dataset) for task in task_list}
task_probs = [task_sizes[task] / sum(task_sizes.values()) for task in task_list]

In [9]:
# Initialize Model and Optimizer
model = MultiTaskBERT()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
optimizer = AdamW(model.parameters(), lr=2e-5)
scheduler = get_scheduler("linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=10000)

In [10]:
# Training Loop
num_epochs = 3
model.train()

task_iterators = {task: iter(loader) for task, loader in dataloaders.items()}

for epoch in range(num_epochs):
    for step in range(10000):  # define total steps or batches
        task = random.choices(task_list, weights=task_probs, k=1)[0]
        try:
            batch = next(task_iterators[task])
        except StopIteration:
            task_iterators[task] = iter(dataloaders[task])
            batch = next(task_iterators[task])

        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)

        outputs = model(input_ids=input_ids, attention_mask=attention_mask, task=task)

        if task in ["sst", "qqp"]:
            loss_fn = nn.CrossEntropyLoss()
            loss = loss_fn(outputs, labels)
        elif task == "sts":
            loss_fn = nn.MSELoss()
            loss = loss_fn(outputs.view(-1), labels.view(-1))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()

        if step % 100 == 0:
            print(f"Epoch {epoch+1} | Step {step} | Task: {task} | Loss: {loss.item():.4f}")

Epoch 1 | Step 0 | Task: qqp | Loss: 0.7150
Epoch 1 | Step 100 | Task: sst | Loss: 0.7202
Epoch 1 | Step 200 | Task: qqp | Loss: 0.4777
Epoch 1 | Step 300 | Task: qqp | Loss: 0.4509
Epoch 1 | Step 400 | Task: qqp | Loss: 0.4325
Epoch 1 | Step 500 | Task: qqp | Loss: 0.8232
Epoch 1 | Step 600 | Task: sst | Loss: 0.2647
Epoch 1 | Step 700 | Task: qqp | Loss: 0.5158
Epoch 1 | Step 800 | Task: qqp | Loss: 0.4877


KeyboardInterrupt: 

In [None]:
# Save Final Multitask Model
save_path = "./multitask_bert_model"
os.makedirs(save_path, exist_ok=True)
torch.save(model.state_dict(), f"{save_path}/model.pt")
tokenizer.save_pretrained(save_path)
print("✅ Training complete and model saved.")