In [12]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from utils import MultiTaskBert
from transformers import BertModel, BertTokenizerFast
from datasets import load_dataset
from itertools import chain

In [13]:
# 1. Single Input, Two Labels per Example
# Using ATIS dataset so that each batch naturally contains both labels_intent and labels_ner for the same input_ids. 
# This allows the model to compute both losses in one forward pass, rather than using separate datasets or alternating batches.

In [14]:
atis = load_dataset("tuetschek/atis")
print(atis['train'].features)
atis.keys()

{'id': Value(dtype='int64', id=None), 'intent': Value(dtype='string', id=None), 'text': Value(dtype='string', id=None), 'slots': Value(dtype='string', id=None)}


dict_keys(['train', 'test'])

In [15]:
# Data preprocessing:
# Used chatgpt to assist me since I am not familiar with the datasetl
# Subword tokens stay aligned with their original word‐level slot labels.
# Padding and special tokens are masked out in the slot loss via -100.
# Each example carries both sentence‐level (intent) and token‐level (slot) labels, ready for your multi‐task training loop.

In [16]:
# 1) Load ATIS
atis = load_dataset("tuetschek/atis")

# 2) Build intent2id from ALL splits
splits = atis.keys() 
all_intents = sorted(set(chain.from_iterable(atis[split]["intent"] for split in splits)))
intent2id  = {label: idx for idx, label in enumerate(all_intents)}

# 3) Build slot2id from ALL splits
all_slots = set()
for split in splits:
    for slot_seq in atis[split]["slots"]:
        all_slots.update(slot_seq.split())
slot_labels = sorted(all_slots)
slot2id     = {label: idx for idx, label in enumerate(slot_labels)}

tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")

# 4) Preprocessing / alignment function
def preprocess(batch):
    # split into word‑tokens
    words      = batch["text"].split()
    slot_seq   = batch["slots"].split()
    # tokenize (preserving word→subword mapping)
    enc = tokenizer(words,
                    is_split_into_words=True,
                    padding="max_length",
                    truncation=True,
                    max_length=32,
                    return_tensors="pt")
    # align slot labels to subwords
    word_ids = enc.word_ids(batch_index=0)  # list of length seq_len
    ner_labels = []
    for widx in word_ids:
        if widx is None:
            ner_labels.append(-100)                # will be ignored by loss
        else:
            ner_labels.append(slot2id[slot_seq[widx]])
    enc["labels_ner"]    = torch.tensor([ner_labels])
    enc["labels_intent"] = torch.tensor([intent2id[batch["intent"]]])
    return enc

# 5) Apply to the dataset
atis_tok = atis.map(preprocess, batched=False, remove_columns=atis["train"].column_names)
atis_tok.set_format(type="torch", columns=["input_ids","attention_mask","labels_intent","labels_ner"])

# 6) DataLoader
train_loader = DataLoader(atis_tok["train"], batch_size=16, shuffle=True)
test_loader = DataLoader(atis_tok["test"], batch_size=16)


In [17]:
#We assume both tasks are equally important(loss = loss_sent + loss_ner)
# Even though BERT is large, the two small heads can overfit quickly on limited data.
# Therefore, I added a dropout(0.1)layer on both heads before both the intent classifier and slot classifier. 
# This injects noise at the head level, encouraging the shared encoder to produce features that are robust across random perturbations.

In [18]:
def train(model, optimizer, loss_fn_ner, loss_fn_sent, n_epoch, device):
    for epoch in range(1, n_epoch+1):
        model.train()
        epoch_loss = 0.0
        train_loss = 0.0
        sent_correct, sent_total = 0, 0
        token_correct, token_total = 0, 0

        for batch in train_loader:
            #import pdb;pdb.set_trace()
            # 1) Move data to device
            input_ids      = batch["input_ids"].squeeze(1).to(device)
            attention_mask = batch["attention_mask"].squeeze(1).to(device)
            # 2) Labels
            intent_labels = batch["labels_intent"].to(device)
            if intent_labels.dim()==2 and intent_labels.size(1)==1:
                intent_labels = intent_labels.squeeze(1)
            elif intent_labels.dim()==2 and intent_labels.size(1)>1:
                intent_labels = intent_labels.argmax(dim=1)
            intent_labels = intent_labels.long()

            ner_labels = batch["labels_ner"].to(device)
            if ner_labels.dim()==3 and ner_labels.size(2)==1:
                ner_labels = ner_labels.squeeze(2)
            ner_labels = ner_labels.long()

            #import pdb;pdb.set_trace()
            # 3) Forward pass
            outputs       = model(input_ids, attention_mask)
            sent_logits   = outputs["sent_logits"]   # (B, num_intent_labels)
            token_logits  = outputs["token_logits"]  # (B, T, num_slot_labels)

            # 4) Compute losses
            loss_sent = loss_fn_sent(sent_logits, intent_labels)

            B, T, C = token_logits.size()
            loss_ner  = loss_fn_ner(
                token_logits.view(-1, C),             # (B*T, C)
                ner_labels.view(-1)                   # (B*T,)
            )

            # 5) Aggregate & backprop
            loss = loss_sent + loss_ner
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()
            train_loss += loss.item()

            # 5) Training Metrics
            # Intent
            sent_preds = sent_logits.argmax(dim=1)
            sent_correct += (sent_preds == intent_labels).sum().item()
            sent_total   += intent_labels.size(0)

            # Slots
            token_preds = token_logits.argmax(dim=2)  # [B, T]
            mask = ner_labels != -100
            ner_labels = ner_labels.squeeze(1) 
            #import pdb;pdb.set_trace()
            # ——— Slots accuracy ———
            token_preds = token_logits.argmax(dim=2)  # [B, T]
            mask        = ner_labels != -100         # [B, T]
            token_correct += (token_preds[mask] == ner_labels[mask]).sum().item()
            token_total   += mask.sum().item()

        avg_train_loss = train_loss / len(train_loader)
        train_sent_acc = sent_correct / sent_total
        train_token_acc= token_correct / token_total
        print(f"\nEpoch {epoch}/{n_epoch}")
        print(f" Train → Loss: {avg_train_loss:.4f} | "
            f"Intent Acc: {train_sent_acc:.4f} | Slot Acc: {train_token_acc:.4f}")
        
        # ——— Evaluation ———
        model.eval()
        test_loss = 0.0
        sent_correct = sent_total = 0
        token_correct = token_total = 0

        with torch.no_grad():
            for batch in test_loader:
                #import pdb;pdb.set_trace()
                # 1) Move data to device
                input_ids      = batch["input_ids"].squeeze(1).to(device)
                attention_mask = batch["attention_mask"].squeeze(1).to(device)
                # 2) Labels
                intent_labels = batch["labels_intent"].to(device)
                if intent_labels.dim()==2 and intent_labels.size(1)==1:
                    intent_labels = intent_labels.squeeze(1)
                elif intent_labels.dim()==2 and intent_labels.size(1)>1:
                    intent_labels = intent_labels.argmax(dim=1)
                intent_labels = intent_labels.long()

                ner_labels = batch["labels_ner"].to(device)
                if ner_labels.dim()==3 and ner_labels.size(2)==1:
                    ner_labels = ner_labels.squeeze(2)
                ner_labels = ner_labels.long()

                #import pdb;pdb.set_trace()
                # 3) Forward pass
                outputs       = model(input_ids, attention_mask)
                sent_logits   = outputs["sent_logits"]   # (B, num_intent_labels)
                token_logits  = outputs["token_logits"]  # (B, T, num_slot_labels)

                # 4) Compute losses
                loss_sent = loss_fn_sent(sent_logits, intent_labels)

                B, T, C = token_logits.size()
                loss_ner  = loss_fn_ner(
                    token_logits.view(-1, C),             # (B*T, C)
                    ner_labels.view(-1)                   # (B*T,)
                )

                # 5) Aggregate & backprop
                loss = loss_sent + loss_ner
                test_loss += loss.item()

                # 5) Training Metrics
                # Intent
                sent_preds = sent_logits.argmax(dim=1)
                sent_correct += (sent_preds == intent_labels).sum().item()
                sent_total   += intent_labels.size(0)

                # Slots
                token_preds = token_logits.argmax(dim=2)  # [B, T]
                mask = ner_labels != -100
                ner_labels = ner_labels.squeeze(1) 
                #import pdb;pdb.set_trace()
                # ——— Slots accuracy ———
                token_preds = token_logits.argmax(dim=2)  # [B, T]
                mask        = ner_labels != -100         # [B, T]
                token_correct += (token_preds[mask] == ner_labels[mask]).sum().item()
                token_total   += mask.sum().item()

            avg_test_loss = test_loss / len(train_loader)
            test_sent_acc = sent_correct / sent_total
            test_token_acc= token_correct / token_total
            print(f"\nEpoch {epoch}/{n_epoch}")
            print(f" Test → Loss: {avg_test_loss:.4f} | "
                f"Intent Acc: {test_sent_acc:.4f} | Slot Acc: {test_token_acc:.4f}")

In [19]:
device       = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('Using device: ', device)
model        = MultiTaskBert(
                   model_name="bert-base-uncased",
                   num_sent_labels=len(intent2id),
                   num_token_labels=len(slot2id),
                   pooling="cls"
               ).to(device)

# Optimizer & losses
optimizer    = AdamW(model.parameters(), lr=2e-5)
loss_fn_sent = nn.CrossEntropyLoss()                   # sentence‐level (intent)
loss_fn_ner  = nn.CrossEntropyLoss(ignore_index=-100)  # token‐level (slots)

n_epoch = 3

Using device:  cuda


In [20]:
train(model, optimizer, loss_fn_ner, loss_fn_sent, n_epoch, device)


Epoch 1/3
 Train → Loss: 1.6269 | Intent Acc: 0.8297 | Slot Acc: 0.8558

Epoch 1/3
 Test → Loss: 0.1701 | Intent Acc: 0.8701 | Slot Acc: 0.9223

Epoch 2/3
 Train → Loss: 0.4674 | Intent Acc: 0.9538 | Slot Acc: 0.9564

Epoch 2/3
 Test → Loss: 0.1006 | Intent Acc: 0.8981 | Slot Acc: 0.9523

Epoch 3/3
 Train → Loss: 0.2509 | Intent Acc: 0.9777 | Slot Acc: 0.9761

Epoch 3/3
 Test → Loss: 0.0702 | Intent Acc: 0.9709 | Slot Acc: 0.9679
