# GUS-Net with GPT-2: Social Bias Token Classification

This notebook trains a **GPT-2 based** multi-label token classifier for social bias detection,
replicating the GUS-Net methodology on the same dataset but swapping the BERT backbone for GPT-2.

**Key differences from the BERT version:**
- GPT-2 uses **BPE** tokenization (vs BERT's WordPiece)
- GPT-2 is **autoregressive** (left-to-right causal attention) — each token only sees left context
- No `[CLS]`/`[SEP]` tokens — GPT-2 uses `<|endoftext|>`
- LLRD targets `transformer.h.{i}` layers (vs `bert.encoder.layer.{i}`)

**Training pipeline** (identical to BERT version):
- Multi-label sigmoid with 6 channels (B/I × GEN/UNFAIR/STEREO)
- Focal Loss with per-channel α weighting
- Layer-wise Learning Rate Decay (LLRD)
- Cosine annealing + warmup
- SWA (Stochastic Weight Averaging)
- Two-pass threshold optimization

---

## 1. Setup

In [66]:
import os
import gc
import glob
import json
import ast

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from datasets import load_dataset, DatasetDict, Dataset
from transformers import (
    GPT2TokenizerFast,
    GPT2ForTokenClassification,
    AutoConfig,
    TrainingArguments,
    Trainer,
    DataCollatorForTokenClassification,
    EarlyStoppingCallback,
    get_cosine_schedule_with_warmup,
)
from sklearn.metrics import f1_score, precision_score, recall_score
from scipy.optimize import minimize_scalar

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available:  {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

PyTorch version: 2.8.0+cu126
CUDA available:  True
GPU: NVIDIA GeForce RTX 2060


## 2. Dataset

In [67]:
print("Loading dataset...")
dataset = load_dataset("ethical-spectacle/gus-dataset-v1", split="train")

print(f"Loaded {len(dataset)} examples")
print(f"Columns: {dataset.column_names}")
print("\nSample entry:")
print(json.dumps(dataset[0], indent=2, ensure_ascii=False)[:400] + "...")

Loading dataset...
Loaded 3739 examples
Columns: ['text_str', 'ner_tags', 'rationale', 'individual_ner_annotations_json', 'id']

Sample entry:
{
  "text_str": "Should we assume that every retired individual spends hours playing bingo each week?",
  "ner_tags": "[['O'], ['O'], ['O'], ['O'], ['B-GEN'], ['I-GEN', 'B-STEREO'], ['I-GEN', 'I-STEREO'], ['I-STEREO'], ['I-STEREO'], ['I-STEREO'], ['I-STEREO'], ['I-STEREO'], ['I-STEREO']]",
  "rationale": "['Reasoning: Let\\'s think step by step in order to produce the annotations. We will analyze ...


## 3. Tokenization with GPT-2 BPE

In [68]:
# ── GPT-2 tokenizer setup ──
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", add_prefix_space=True)

# GPT-2 has no pad token by default — use eos_token
tokenizer.pad_token = tokenizer.eos_token
# Pad on the RIGHT (like BERT) so label indices align correctly
tokenizer.padding_side = "right"

print(f"Vocab size:    {tokenizer.vocab_size}")
print(f"Pad token:     {tokenizer.pad_token!r} (id={tokenizer.pad_token_id})")
print(f"Padding side:  {tokenizer.padding_side}")

Vocab size:    50257
Pad token:     '<|endoftext|>' (id=50256)
Padding side:  right


In [69]:
# ── Multi-label channel definition ──
# 6 binary channels (no explicit O — O = all zeros)
channels = ["B-GEN", "I-GEN", "B-UNFAIR", "I-UNFAIR", "B-STEREO", "I-STEREO"]
channel2idx = {c: i for i, c in enumerate(channels)}
idx2channel = {i: c for i, c in enumerate(channels)}
num_channels = len(channels)

print(f"Channels ({num_channels}): {channels}")

Channels (6): ['B-GEN', 'I-GEN', 'B-UNFAIR', 'I-UNFAIR', 'B-STEREO', 'I-STEREO']


In [70]:
def parse_annotations(example):
    """Extract per-word multi-hot tags from the ner_tags string."""
    # The HF dataset provides 'ner_tags' as a string representation of a list of lists.
    # e.g. "[['O'], ['B-GEN'], ...]"
    return ast.literal_eval(example["ner_tags"])

def prepare_example(example):
    """Tokenize with GPT-2 BPE and build multi-hot label matrix.

    Steps:
      1. Split text into words
      2. Tokenize with is_split_into_words=True (preserves word→subword map)
      3. Use word_ids() to align BPE subwords to original words
      4. Build [seq_len, num_channels] multi-hot labels
      5. Mask special/padding positions with -100
    """
    text = example["text_str"]
    word_tags = parse_annotations(example)
    words = text.split()

    tokenized = tokenizer(
        words,
        is_split_into_words=True,
        truncation=True,
        max_length=128,
        padding="max_length",
    )

    word_ids = tokenized.word_ids()
    seq_len = len(word_ids)
    labels_multi = np.zeros((seq_len, num_channels), dtype=np.float32)

    prev_word_id = None
    for idx, word_id in enumerate(word_ids):
        if word_id is None:
            prev_word_id = None
            continue
        if word_id >= len(word_tags):
            prev_word_id = word_id
            continue
        tags = word_tags[word_id]
        for tag in tags:
            if tag == "O":
                continue
            # Continuation subword: B- → I-
            if word_id == prev_word_id:
                if tag.startswith("B-"):
                    i_tag = "I-" + tag[2:]
                    if i_tag in channel2idx:
                        labels_multi[idx, channel2idx[i_tag]] = 1.0
                elif tag in channel2idx:
                    labels_multi[idx, channel2idx[tag]] = 1.0
            else:
                if tag in channel2idx:
                    labels_multi[idx, channel2idx[tag]] = 1.0
        prev_word_id = word_id

    # Mask padding / special positions with -100
    final_labels = []
    for idx, word_id in enumerate(word_ids):
        if word_id is None:
            final_labels.append([-100.0] * num_channels)
        else:
            final_labels.append(labels_multi[idx].tolist())

    tokenized["labels"] = final_labels
    return tokenized

In [71]:
print("Tokenizing dataset with GPT-2 BPE...")
tokenized_dataset = DatasetDict({
    "train": dataset.map(
        prepare_example,
        batched=False,
        remove_columns=dataset.column_names,
    )
})
print("Tokenization complete!")

# ── Sanity check ──
total_positive = 0
total_valid = 0
for ex in tokenized_dataset["train"]:
    labels = np.array(ex["labels"])
    valid = labels[labels[:, 0] != -100.0]
    total_valid += len(valid)
    total_positive += (valid > 0).any(axis=1).sum()

print(f"\nSanity check:")
print(f"  Total valid tokens:  {total_valid}")
print(f"  Tokens with bias:    {total_positive}")
print(f"  Positive rate:       {total_positive / max(total_valid, 1):.2%}")
assert total_positive > 0, "FATAL: No positive labels found."
print("  OK")

Tokenizing dataset with GPT-2 BPE...
Tokenization complete!

Sanity check:
  Total valid tokens:  67503
  Tokens with bias:    21450
  Positive rate:       31.78%
  OK


## 4. Train / Dev / Test Split

In [72]:
# 70% train, 15% dev, 15% test
train_devtest = tokenized_dataset["train"].train_test_split(test_size=0.30, seed=42)
train_split = train_devtest["train"]
dev_test = train_devtest["test"].train_test_split(test_size=0.5, seed=42)
dev_split = dev_test["train"]
test_split = dev_test["test"]

print(f"Train: {len(train_split)}")
print(f"Dev:   {len(dev_split)}")
print(f"Test:  {len(test_split)}")

Train: 2617
Dev:   561
Test:  561


## 5. GPT-2 Model Setup

In [73]:
config = AutoConfig.from_pretrained("gpt2")
config.num_labels = num_channels
config.problem_type = "multi_label_classification"
config.pad_token_id = tokenizer.pad_token_id
config.classifier_dropout = 0.3
config.resid_pdrop = 0.15
config.embd_pdrop = 0.15
config.attn_pdrop = 0.15

model = GPT2ForTokenClassification.from_pretrained("gpt2", config=config)
# Resize embeddings to account for pad token
model.resize_token_embeddings(len(tokenizer))

print(f"Model:       {config.model_type}")
print(f"Parameters:  {model.num_parameters():,}")
print(f"Layers:      {config.n_layer}")
print(f"Heads:       {config.n_head}")
print(f"Hidden dim:  {config.n_embd}")
print(f"Classifier dropout: {config.classifier_dropout}")
print(f"Residual dropout:   {config.resid_pdrop}")

Some weights of GPT2ForTokenClassification were not initialized from the model checkpoint at gpt2 and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Model:       gpt2
Parameters:  124,444,422
Layers:      12
Heads:       12
Hidden dim:  768
Classifier dropout: 0.3
Residual dropout:   0.15


### Optional: Disable Causal Mask (Bidirectional Mode)

GPT-2 uses a **causal (triangular) attention mask** by default — each token only sees
tokens to its left.  For NER this is a disadvantage because right context helps
disambiguate tokens.

Uncommenting the cell below replaces the causal mask with a full mask, making GPT-2
effectively bidirectional.  This breaks the pre-training assumption but can improve NER
performance.  Keep it **commented** for a fair causal-vs-bidirectional comparison.

In [74]:
# ── Uncomment to enable bidirectional attention ──
# for block in model.transformer.h:
#     block.attn.bias.fill_(True)
# print("Causal mask DISABLED — GPT-2 is now bidirectional.")

## 6. Focal Loss & Channel Statistics

In [75]:
def estimate_channel_frequencies(dataset_split):
    """Estimate positive-label frequency per channel."""
    positives = np.zeros(num_channels, dtype=np.int64)
    total = 0
    for ex in dataset_split:
        labels = np.array(ex["labels"])
        valid = labels[labels[:, 0] != -100.0]
        if valid.size == 0:
            continue
        positives += valid.sum(axis=0).astype(np.int64)
        total += valid.shape[0]
    return positives, total


channel_pos, total_tokens = estimate_channel_frequencies(train_split)
channel_pos = np.maximum(channel_pos, 1)
freq = channel_pos / float(total_tokens)
inv_freq = 1.0 / freq
alpha_channel = inv_freq / inv_freq.sum()
alpha_channel = torch.tensor(alpha_channel, dtype=torch.float32)

print("Channel statistics:")
for i, ch in enumerate(channels):
    print(f"  {ch:10s}: {channel_pos[i]:>6} positives, α={alpha_channel[i]:.4f}")
print(f"\nTotal valid tokens: {total_tokens}")

Channel statistics:
  B-GEN     :   3532 positives, α=0.0830
  I-GEN     :   3332 positives, α=0.0880
  B-UNFAIR  :    753 positives, α=0.3893
  I-UNFAIR  :   1859 positives, α=0.1577
  B-STEREO  :   1177 positives, α=0.2491
  I-STEREO  :   8908 positives, α=0.0329

Total valid tokens: 47182


In [76]:
class FocalLossMultiLabel(nn.Module):
    """Channel-wise focal loss for multi-label token classification.

    L = α_c · (1 − p_t)^γ · BCE(logit, target)
    """

    def __init__(self, alpha, gamma=2.0, reduction="mean", label_smoothing=0.0):
        super().__init__()
        self.register_buffer("alpha", alpha)
        self.gamma = gamma
        self.reduction = reduction
        self.label_smoothing = label_smoothing

    def forward(self, inputs, targets):
        if self.label_smoothing > 0:
            targets = targets * (1.0 - self.label_smoothing) + 0.5 * self.label_smoothing
        bce = F.binary_cross_entropy_with_logits(inputs, targets.float(), reduction="none")
        pt = torch.exp(-bce)
        focal = self.alpha.to(inputs.device) * (1 - pt) ** self.gamma * bce
        return focal.mean() if self.reduction == "mean" else focal.sum()

## 7. Trainer with GPT-2 LLRD

In [77]:
class FocalLossTrainerGPT2(Trainer):
    """Trainer with Focal Loss and Layer-wise LR Decay for GPT-2.

    GPT-2 parameter groups:
      - classifier.*                         → classifier_lr
      - transformer.h.{11..0}.*              → base_lr × decay^(11−i)
      - transformer.wte, wpe, ln_f           → base_lr × decay^12  (embeddings)
    """

    def __init__(self, *args, alpha_channel, gamma=2.0, label_smoothing=0.0,
                 llrd_decay_factor=0.85, classifier_lr=2e-4, **kwargs):
        super().__init__(*args, **kwargs)
        self.focal_loss = FocalLossMultiLabel(
            alpha=alpha_channel, gamma=gamma, label_smoothing=label_smoothing,
        )
        self.llrd_decay_factor = llrd_decay_factor
        self.classifier_lr = classifier_lr

    # ── LLRD optimizer ──────────────────────────────────────────────
    def create_optimizer(self):
        base_lr = self.args.learning_rate
        decay = self.llrd_decay_factor
        # GPT-2 layer norms use ln_1 / ln_2 / ln_f (not LayerNorm)
        no_decay_keys = ["bias", "ln_1.weight", "ln_2.weight", "ln_f.weight"]
        n_layers = self.model.config.n_layer  # 12 for gpt2-base

        opt_params = []

        # 1) Classifier head — highest LR
        opt_params.append({
            "params": [p for n, p in self.model.named_parameters()
                        if "classifier" in n or "score" in n],
            "lr": self.classifier_lr,
            "weight_decay": 0.0,
        })

        # 2) Transformer blocks: h.{n_layers-1} → h.0
        for layer_idx in range(n_layers - 1, -1, -1):
            layer_lr = base_lr * (decay ** (n_layers - 1 - layer_idx))
            layer_prefix = f"transformer.h.{layer_idx}."
            d, nd = [], []
            for n, p in self.model.named_parameters():
                if layer_prefix in n:
                    (nd if any(k in n for k in no_decay_keys) else d).append(p)
            if d:
                opt_params.append({"params": d, "lr": layer_lr,
                                   "weight_decay": self.args.weight_decay})
            if nd:
                opt_params.append({"params": nd, "lr": layer_lr,
                                   "weight_decay": 0.0})

        # 3) Embeddings + final LN — lowest LR
        emb_lr = base_lr * (decay ** n_layers)
        emb_names = ["transformer.wte", "transformer.wpe", "transformer.ln_f"]
        d, nd = [], []
        for n, p in self.model.named_parameters():
            if any(n.startswith(prefix) for prefix in emb_names):
                (nd if any(k in n for k in no_decay_keys) else d).append(p)
        if d:
            opt_params.append({"params": d, "lr": emb_lr,
                               "weight_decay": self.args.weight_decay})
        if nd:
            opt_params.append({"params": nd, "lr": emb_lr, "weight_decay": 0.0})

        self.optimizer = torch.optim.AdamW(opt_params, lr=base_lr, eps=1e-8)
        print(f"LLRD optimizer (GPT-2):")
        print(f"  Classifier LR:     {self.classifier_lr}")
        print(f"  Top layer LR:      {base_lr}")
        print(f"  Bottom layer LR:   {base_lr * decay**(n_layers-1):.2e}")
        print(f"  Embeddings LR:     {emb_lr:.2e}")
        return self.optimizer

    # ── Cosine scheduler ────────────────────────────────────────────
    def create_scheduler(self, num_training_steps, optimizer=None):
        if optimizer is None:
            optimizer = self.optimizer
        warmup = int(num_training_steps * self.args.warmup_ratio)
        self.lr_scheduler = get_cosine_schedule_with_warmup(
            optimizer, num_warmup_steps=warmup,
            num_training_steps=num_training_steps,
        )
        print(f"Cosine scheduler: {warmup} warmup / {num_training_steps} total")
        return self.lr_scheduler

    # ── Loss ────────────────────────────────────────────────────────
    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        labels = inputs.pop("labels")
        outputs = model(**inputs)
        logits = outputs.logits
        logits_flat = logits.view(-1, num_channels)
        labels_flat = labels.view(-1, num_channels)
        valid = labels_flat[:, 0] != -100.0
        loss = self.focal_loss(logits_flat[valid], labels_flat[valid])
        return (loss, outputs) if return_outputs else loss

In [78]:
# ── Metrics ──
thresholds = np.array([0.5] * num_channels, dtype=np.float32)


def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    probs = 1 / (1 + np.exp(-predictions))
    valid = labels[:, :, 0] != -100.0
    probs_flat = probs[valid]
    labels_flat = labels[valid]
    thr = thresholds.reshape(1, num_channels)
    preds_bin = (probs_flat >= thr).astype(int)
    labels_bin = labels_flat.astype(int)

    if labels_bin.sum() == 0:
        return {"f1_macro": 0.0, "precision_macro": 0.0,
                "recall_macro": 0.0, "hamming_loss": 0.0}

    channel_f1s = [f1_score(labels_bin[:, c], preds_bin[:, c],
                            average="binary", zero_division=0)
                   for c in range(num_channels)]
    return {
        "f1_macro": np.mean(channel_f1s),
        "precision_macro": precision_score(labels_bin, preds_bin,
                                           average="macro", zero_division=0),
        "recall_macro": recall_score(labels_bin, preds_bin,
                                      average="macro", zero_division=0),
        "hamming_loss": np.mean(preds_bin != labels_bin),
    }

## 8. Training

In [79]:
training_args = TrainingArguments(
    output_dir="./gus-net-gpt2",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    gradient_accumulation_steps=2,    # effective batch = 32
    num_train_epochs=20,
    weight_decay=0.01,
    warmup_ratio=0.1,
    logging_steps=50,
    fp16=True,
    load_best_model_at_end=True,
    metric_for_best_model="f1_macro",
    greater_is_better=True,
    report_to="none",
)

data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)

trainer = FocalLossTrainerGPT2(
    model=model,
    args=training_args,
    train_dataset=train_split,
    eval_dataset=dev_split,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    alpha_channel=alpha_channel,
    gamma=2.0,
    label_smoothing=0.05,
    llrd_decay_factor=0.85,
    classifier_lr=2e-4,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=5)],
)

print("Trainer ready:")
print(f"  Focal Loss (γ=2.0, smoothing=0.05)")
print(f"  LLRD (decay=0.85, classifier_lr=2e-4)")
print(f"  Cosine scheduler + early stopping (patience=5)")

Trainer ready:
  Focal Loss (γ=2.0, smoothing=0.05)
  LLRD (decay=0.85, classifier_lr=2e-4)
  Cosine scheduler + early stopping (patience=5)


  self.scaler = torch.cuda.amp.GradScaler()


In [None]:
print("Starting training...")
train_result = trainer.train()

print(f"\nTraining loss: {train_result.training_loss:.4f}")
print(f"Runtime:       {train_result.metrics['train_runtime']:.1f}s")

Starting training...
LLRD optimizer (GPT-2):
  Classifier LR:     0.0002
  Top layer LR:      5e-05
  Bottom layer LR:   8.37e-06
  Embeddings LR:     7.11e-06
Cosine scheduler: 164 warmup / 1640 total


  0%|          | 0/1640 [00:00<?, ?it/s]

You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
  else torch.cuda.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype)


## 9. Stochastic Weight Averaging

In [None]:
def apply_swa(trainer, checkpoint_dir, last_n=5):
    """Average weights from the last N checkpoints (memory-safe)."""
    checkpoints = sorted(
        glob.glob(f"{checkpoint_dir}/checkpoint-*"),
        key=lambda x: int(x.split("-")[-1]),
    )
    if len(checkpoints) < 2:
        print(f"Only {len(checkpoints)} checkpoint(s), skipping SWA.")
        return

    last = checkpoints[-last_n:]
    print(f"SWA: averaging {len(last)} checkpoints")

    # Free memory before loading checkpoints
    trainer.model.cpu()
    torch.cuda.empty_cache()
    gc.collect()

    # Running average to avoid holding all states at once
    avg = None
    n = 0
    for cp in last:
        sf = os.path.join(cp, "model.safetensors")
        bf = os.path.join(cp, "pytorch_model.bin")
        if os.path.exists(sf):
            from safetensors.torch import load_file
            state = load_file(sf, device="cpu")
        elif os.path.exists(bf):
            state = torch.load(bf, map_location="cpu", weights_only=True)
        else:
            continue
        n += 1
        if avg is None:
            avg = {k: v.float() for k, v in state.items()}
        else:
            # Running average: avg = avg + (new - avg) / n
            for k in avg:
                avg[k] += (state[k].float() - avg[k]) / n
        del state
        gc.collect()

    if avg and n > 0:
        trainer.model.load_state_dict(avg)
        del avg
        gc.collect()
        trainer.model.to(trainer.args.device)
        print(f"SWA applied ({n} checkpoints).")
    else:
        trainer.model.to(trainer.args.device)
        print("SWA: no valid checkpoints found.")


apply_swa(trainer, "./gus-net-gpt2", last_n=5)

## 10. Threshold Optimization

In [None]:
def optimize_thresholds(trainer, dev_dataset):
    """Two-pass threshold optimization: coarse grid + scipy refinement."""
    model = trainer.model
    model.eval()
    grid = np.arange(0.05, 0.96, 0.025).tolist()

    all_probs, all_labels = [], []
    for batch in trainer.get_eval_dataloader(dev_dataset):
        with torch.no_grad():
            labels = batch["labels"].cpu().numpy()
            inputs = {k: v.to(model.device) for k, v in batch.items() if k != "labels"}
            logits = model(**inputs).logits.cpu().numpy()
        all_probs.append(1 / (1 + np.exp(-logits)))
        all_labels.append(labels)

    probs = np.concatenate(all_probs)
    labels = np.concatenate(all_labels)
    valid = labels[:, :, 0] != -100.0
    pf = probs[valid]
    lf = labels[valid].astype(int)

    best_thr = np.zeros(num_channels, dtype=np.float32)

    # Pass 1: coarse
    print("Pass 1 — grid search:")
    for c in range(num_channels):
        best_f1, best_t = 0, 0.5
        for t in grid:
            f1 = f1_score(lf[:, c], (pf[:, c] >= t).astype(int),
                          average="binary", zero_division=0)
            if f1 > best_f1:
                best_f1, best_t = f1, t
        best_thr[c] = best_t
        print(f"  {channels[c]:10s}: thr={best_t:.3f}, F1={best_f1:.4f}")

    # Pass 2: scipy
    print("\nPass 2 — refinement:")
    for c in range(num_channels):
        lo = max(0.01, best_thr[c] - 0.05)
        hi = min(0.99, best_thr[c] + 0.05)
        res = minimize_scalar(
            lambda t: -f1_score(lf[:, c], (pf[:, c] >= t).astype(int),
                                average="binary", zero_division=0),
            bounds=(lo, hi), method="bounded",
        )
        if -res.fun >= f1_score(lf[:, c], (pf[:, c] >= best_thr[c]).astype(int),
                                 average="binary", zero_division=0):
            best_thr[c] = res.x
        print(f"  {channels[c]:10s}: thr={best_thr[c]:.4f}, F1={-res.fun:.4f}")

    # Global macro F1
    preds = (pf >= best_thr.reshape(1, -1)).astype(int)
    macro = np.mean([f1_score(lf[:, c], preds[:, c], average="binary", zero_division=0)
                     for c in range(num_channels)])
    return best_thr, macro


best_thr, best_f1_dev = optimize_thresholds(trainer, dev_split)
print(f"\nOptimized thresholds: {best_thr}")
print(f"Dev macro-F1: {best_f1_dev:.4f}")
thresholds = best_thr

## 11. Evaluation

In [None]:
# ── Token-level ──
print("Evaluating on test set (token-level)...")
test_metrics = trainer.evaluate(test_split)

print(f"\nTest results (optimized thresholds):")
print(f"  Macro F1:      {test_metrics['eval_f1_macro']:.4f}")
print(f"  Precision:     {test_metrics['eval_precision_macro']:.4f}")
print(f"  Recall:        {test_metrics['eval_recall_macro']:.4f}")
print(f"  Hamming Loss:  {test_metrics['eval_hamming_loss']:.4f}")

In [None]:
# ── Entity-level (span) evaluation ──

def extract_spans(bio_preds):
    """Extract (entity_type, start, end) spans from BIO predictions."""
    pairs = {0: 1, 2: 3, 4: 5}  # B→I channel pairs
    spans = []
    for b_idx, i_idx in pairs.items():
        etype = channels[b_idx].replace("B-", "")
        in_span, start = False, -1
        for t in range(len(bio_preds)):
            if bio_preds[t, b_idx] == 1:
                if in_span:
                    spans.append((etype, start, t))
                start, in_span = t, True
            elif bio_preds[t, i_idx] == 1 and in_span:
                continue
            else:
                if in_span:
                    spans.append((etype, start, t))
                    in_span = False
        if in_span:
            spans.append((etype, start, len(bio_preds)))
    return spans


def compute_entity_metrics(trainer, dataset, thresholds):
    model = trainer.model
    model.eval()
    all_pred, all_gold = [], []
    ex_idx = 0
    for batch in trainer.get_eval_dataloader(dataset):
        with torch.no_grad():
            labels = batch["labels"].cpu().numpy()
            inputs = {k: v.to(model.device) for k, v in batch.items() if k != "labels"}
            probs = torch.sigmoid(model(**inputs).logits).cpu().numpy()
        for i in range(labels.shape[0]):
            valid = labels[i, :, 0] != -100.0
            p = (probs[i][valid] >= thresholds).astype(int)
            g = labels[i][valid].astype(int)
            all_pred.extend([(ex_idx, *s) for s in extract_spans(p)])
            all_gold.extend([(ex_idx, *s) for s in extract_spans(g)])
            ex_idx += 1

    print("\nEntity-level evaluation:")
    print("-" * 60)
    for etype in ["GEN", "UNFAIR", "STEREO"]:
        ps = set(s for s in all_pred if s[1] == etype)
        gs = set(s for s in all_gold if s[1] == etype)
        tp = len(ps & gs)
        p = tp / max(len(ps), 1)
        r = tp / max(len(gs), 1)
        f1 = 2 * p * r / max(p + r, 1e-8)
        print(f"  {etype:8s}: F1={f1:.4f}  P={p:.4f}  R={r:.4f}  (support={len(gs)})")

    ps, gs = set(all_pred), set(all_gold)
    tp = len(ps & gs)
    micro_p = tp / max(len(ps), 1)
    micro_r = tp / max(len(gs), 1)
    micro_f1 = 2 * micro_p * micro_r / max(micro_p + micro_r, 1e-8)
    print("-" * 60)
    print(f"  {'MICRO':8s}: F1={micro_f1:.4f}  P={micro_p:.4f}  R={micro_r:.4f}")
    return {"micro_f1": micro_f1, "micro_p": micro_p, "micro_r": micro_r}


print("=" * 60)
print("ENTITY-LEVEL EVALUATION (Test Set)")
print("=" * 60)
entity_metrics = compute_entity_metrics(trainer, test_split, thresholds)

## 12. Inference Demo

In [None]:
def predict_bias(text, model, tokenizer, thresholds, device="cuda"):
    model.eval()
    model.to(device)
    inputs = tokenizer(
        text, return_tensors="pt", truncation=True,
        max_length=128, padding=True,
    )
    inputs = {k: v.to(device) for k, v in inputs.items()}
    with torch.no_grad():
        probs = torch.sigmoid(model(**inputs).logits[0]).cpu().numpy()
    preds = (probs >= thresholds).astype(int)
    tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])

    results = []
    for i, (tok, pred) in enumerate(zip(tokens, preds)):
        active = [channels[j] for j in range(num_channels) if pred[j] == 1]
        if active:
            clean = tok.replace("Ġ", "").replace("Ċ", "")
            results.append({"token": clean, "labels": active})
    return results


examples = [
    "Women are naturally better at nursing.",
    "All politicians are corrupt liars.",
    "Young people these days are so lazy and entitled.",
    "The engineer fixed the problem quickly.",
]

device = "cuda" if torch.cuda.is_available() else "cpu"
print("=" * 60)
print("INFERENCE DEMO (GPT-2)")
print("=" * 60)
for text in examples:
    print(f"\nInput: '{text}'")
    biases = predict_bias(text, model, tokenizer, thresholds, device)
    if biases:
        for b in biases:
            print(f"  → '{b['token']}': {', '.join(b['labels'])}")
    else:
        print("  → No bias detected (Neutral)")

## 13. Save Model

In [None]:
output_dir = "./gus-net-gpt2-final"
trainer.save_model(output_dir)
tokenizer.save_pretrained(output_dir)
np.save(f"{output_dir}/optimized_thresholds.npy", thresholds)

print(f"Model, tokenizer and thresholds saved to {output_dir}")

## 14. Summary

In [None]:
# Store causal results for final comparison
causal_results = {
    "token_f1": test_metrics["eval_f1_macro"],
    "token_precision": test_metrics["eval_precision_macro"],
    "token_recall": test_metrics["eval_recall_macro"],
    "token_hamming": test_metrics["eval_hamming_loss"],
    "entity_micro_f1": entity_metrics["micro_f1"],
    "entity_micro_p": entity_metrics["micro_p"],
    "entity_micro_r": entity_metrics["micro_r"],
    "thresholds": thresholds.copy(),
}

print("Causal GPT-2 results saved for comparison.")
print(f"  Token Macro F1:   {causal_results['token_f1']:.4f}")
print(f"  Entity Micro F1:  {causal_results['entity_micro_f1']:.4f}")

---

# Part B — GPT-2 Bidirectional Mode

GPT-2 uses a **causal (triangular) attention mask** — each token only attends to
tokens to its **left**. This is ideal for language generation but a disadvantage for
NER, where right context helps disambiguate tokens.

By filling the causal mask bias with `True`, we remove the triangular constraint and
let every token attend to every other token — effectively making GPT-2 **bidirectional**
(similar to BERT).

This **breaks the pre-training assumption** (the model was pre-trained with causal
attention), but fine-tuning can adapt the attention heads to use the full context.

Below we retrain from scratch with the same hyperparameters to measure the impact.

## B.1 Free Causal Model & Create Bidirectional Model

In [41]:
# ── Free causal model from GPU ──
del model, trainer
gc.collect()
torch.cuda.empty_cache()
print("Causal model freed from memory.")

# ── Fresh GPT-2 with bidirectional attention ──
config_bi = AutoConfig.from_pretrained("gpt2")
config_bi.num_labels = num_channels
config_bi.problem_type = "multi_label_classification"
config_bi.pad_token_id = tokenizer.pad_token_id
config_bi.classifier_dropout = 0.3
config_bi.resid_pdrop = 0.15
config_bi.embd_pdrop = 0.15
config_bi.attn_pdrop = 0.15

model_bi = GPT2ForTokenClassification.from_pretrained("gpt2", config=config_bi)
model_bi.resize_token_embeddings(len(tokenizer))

# ── Disable causal mask → bidirectional ──
for block in model_bi.transformer.h:
    # block.attn.bias is the causal mask (lower-triangular).
    # Filling with True lets every position attend to every other position.
    block.attn.bias.fill_(True)

print("Causal mask DISABLED — GPT-2 is now bidirectional.")
print(f"Parameters: {model_bi.num_parameters():,}")

Causal model freed from memory.


Some weights of GPT2ForTokenClassification were not initialized from the model checkpoint at gpt2 and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Causal mask DISABLED — GPT-2 is now bidirectional.
Parameters: 124,444,422


## B.2 Training (Bidirectional)

In [43]:
training_args_bi = TrainingArguments(
    output_dir="./gus-net-gpt2-bidirectional",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    gradient_accumulation_steps=2,
    num_train_epochs=20,
    weight_decay=0.01,
    warmup_ratio=0.1,
    logging_steps=50,
    fp16=True,
    load_best_model_at_end=True,
    metric_for_best_model="f1_macro",
    greater_is_better=True,
    report_to="none",
)

# Reset thresholds for bidirectional model
thresholds = np.array([0.5] * num_channels, dtype=np.float32)

trainer_bi = FocalLossTrainerGPT2(
    model=model_bi,
    args=training_args_bi,
    train_dataset=train_split,
    eval_dataset=dev_split,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    alpha_channel=alpha_channel,
    gamma=2.0,
    label_smoothing=0.05,
    llrd_decay_factor=0.85,
    classifier_lr=2e-4,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=5)],
)

print("Bidirectional trainer ready (same hyperparameters as causal).")

Bidirectional trainer ready (same hyperparameters as causal).


In [44]:
print("Starting bidirectional training...")
train_result_bi = trainer_bi.train()

print(f"\nTraining loss: {train_result_bi.training_loss:.4f}")
print(f"Runtime:       {train_result_bi.metrics['train_runtime']:.1f}s")

Starting bidirectional training...
LLRD optimizer (GPT-2):
  Classifier LR:     0.0002
  Top layer LR:      5e-05
  Bottom layer LR:   8.37e-06
  Embeddings LR:     7.11e-06
Cosine scheduler: 164 warmup / 1640 total


  0%|          | 0/1640 [00:00<?, ?it/s]

You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
  else torch.cuda.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype)


{'loss': 0.1212, 'learning_rate': 5.853658536585366e-05, 'epoch': 0.61}


  0%|          | 0/36 [00:00<?, ?it/s]

  probs = 1 / (1 + np.exp(-predictions))


{'eval_loss': 0.009381568990647793, 'eval_f1_macro': 0.03204712485758974, 'eval_precision_macro': 0.20303992388832323, 'eval_recall_macro': 0.027388654351307756, 'eval_hamming_loss': 0.07959542656112577, 'eval_runtime': 56.4962, 'eval_samples_per_second': 9.93, 'eval_steps_per_second': 0.637, 'epoch': 1.0}


  else torch.cuda.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype)


{'loss': 0.0106, 'learning_rate': 0.00011951219512195122, 'epoch': 1.22}
{'loss': 0.0085, 'learning_rate': 0.0001804878048780488, 'epoch': 1.83}


  0%|          | 0/36 [00:00<?, ?it/s]

{'eval_loss': 0.007518548984080553, 'eval_f1_macro': 0.04111549309123349, 'eval_precision_macro': 0.3333709273182957, 'eval_recall_macro': 0.02383874663775426, 'eval_hamming_loss': 0.06868301899084661, 'eval_runtime': 56.5279, 'eval_samples_per_second': 9.924, 'eval_steps_per_second': 0.637, 'epoch': 2.0}


  else torch.cuda.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype)


{'loss': 0.008, 'learning_rate': 0.00019973826287588464, 'epoch': 2.44}


  0%|          | 0/36 [00:00<?, ?it/s]

{'eval_loss': 0.006969396024942398, 'eval_f1_macro': 0.06088337713648617, 'eval_precision_macro': 0.3887381636344292, 'eval_recall_macro': 0.04070803775096044, 'eval_hamming_loss': 0.06643538877487866, 'eval_runtime': 56.0871, 'eval_samples_per_second': 10.002, 'eval_steps_per_second': 0.642, 'epoch': 3.0}


  else torch.cuda.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype)


{'loss': 0.0075, 'learning_rate': 0.0001984059629273457, 'epoch': 3.05}
{'loss': 0.0071, 'learning_rate': 0.00019596019297180145, 'epoch': 3.66}


  0%|          | 0/36 [00:00<?, ?it/s]

{'eval_loss': 0.006604738533496857, 'eval_f1_macro': 0.07464292206353153, 'eval_precision_macro': 0.327238944630249, 'eval_recall_macro': 0.05207316759344267, 'eval_hamming_loss': 0.06610964526531808, 'eval_runtime': 55.892, 'eval_samples_per_second': 10.037, 'eval_steps_per_second': 0.644, 'epoch': 4.0}


  else torch.cuda.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype)


{'loss': 0.007, 'learning_rate': 0.00019242862705875577, 'epoch': 4.27}
{'loss': 0.0069, 'learning_rate': 0.00018785122509109426, 'epoch': 4.88}


  0%|          | 0/36 [00:00<?, ?it/s]

{'eval_loss': 0.006280513945966959, 'eval_f1_macro': 0.07486101872242872, 'eval_precision_macro': 0.3498623035007253, 'eval_recall_macro': 0.04969812177446492, 'eval_hamming_loss': 0.06555588129906512, 'eval_runtime': 55.909, 'eval_samples_per_second': 10.034, 'eval_steps_per_second': 0.644, 'epoch': 5.0}


  else torch.cuda.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype)


{'loss': 0.0062, 'learning_rate': 0.00018227978067612868, 'epoch': 5.49}


  0%|          | 0/36 [00:00<?, ?it/s]

{'eval_loss': 0.005899177398532629, 'eval_f1_macro': 0.11405044681108815, 'eval_precision_macro': 0.44109679293721227, 'eval_recall_macro': 0.07718448689506617, 'eval_hamming_loss': 0.0648066712270758, 'eval_runtime': 56.0324, 'eval_samples_per_second': 10.012, 'eval_steps_per_second': 0.642, 'epoch': 6.0}


  else torch.cuda.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype)


{'loss': 0.0065, 'learning_rate': 0.00017577733507749007, 'epoch': 6.1}
{'loss': 0.0059, 'learning_rate': 0.00016841746389904304, 'epoch': 6.71}


  0%|          | 0/36 [00:00<?, ?it/s]

{'eval_loss': 0.005661883857101202, 'eval_f1_macro': 0.2046087324393708, 'eval_precision_macro': 0.511397742687098, 'eval_recall_macro': 0.1476145484252211, 'eval_hamming_loss': 0.06553959412358709, 'eval_runtime': 56.0103, 'eval_samples_per_second': 10.016, 'eval_steps_per_second': 0.643, 'epoch': 7.0}


  else torch.cuda.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype)


{'loss': 0.006, 'learning_rate': 0.0001602834445720413, 'epoch': 7.32}
{'loss': 0.0057, 'learning_rate': 0.0001514673140654609, 'epoch': 7.93}


  0%|          | 0/36 [00:00<?, ?it/s]

{'eval_loss': 0.005395388696342707, 'eval_f1_macro': 0.23433040953754017, 'eval_precision_macro': 0.5701048411881858, 'eval_recall_macro': 0.1658618966822866, 'eval_hamming_loss': 0.06303136909997069, 'eval_runtime': 58.4746, 'eval_samples_per_second': 9.594, 'eval_steps_per_second': 0.616, 'epoch': 8.0}


  else torch.cuda.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype)


{'loss': 0.0056, 'learning_rate': 0.0001420688274815834, 'epoch': 8.54}


  0%|          | 0/36 [00:00<?, ?it/s]

{'eval_loss': 0.005251815542578697, 'eval_f1_macro': 0.24415355806369102, 'eval_precision_macro': 0.5680875859001535, 'eval_recall_macro': 0.17046940581626582, 'eval_hamming_loss': 0.062412456431805596, 'eval_runtime': 55.8304, 'eval_samples_per_second': 10.048, 'eval_steps_per_second': 0.645, 'epoch': 9.0}


  else torch.cuda.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype)


{'loss': 0.0054, 'learning_rate': 0.00013219432932038712, 'epoch': 9.15}
{'loss': 0.0054, 'learning_rate': 0.00012195555018446599, 'epoch': 9.76}


  0%|          | 0/36 [00:00<?, ?it/s]

{'eval_loss': 0.005169693846255541, 'eval_f1_macro': 0.2735190447103653, 'eval_precision_macro': 0.5675943884078787, 'eval_recall_macro': 0.20044055577789585, 'eval_hamming_loss': 0.06205413857128897, 'eval_runtime': 55.9794, 'eval_samples_per_second': 10.022, 'eval_steps_per_second': 0.643, 'epoch': 10.0}


  else torch.cuda.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype)


{'loss': 0.0054, 'learning_rate': 0.00011146834253984006, 'epoch': 10.37}
{'loss': 0.0053, 'learning_rate': 0.00010085136983760677, 'epoch': 10.98}


  0%|          | 0/36 [00:00<?, ?it/s]

{'eval_loss': 0.005161278415471315, 'eval_f1_macro': 0.31485292028047573, 'eval_precision_macro': 0.5376139170394071, 'eval_recall_macro': 0.23913452103191565, 'eval_hamming_loss': 0.0625264666601518, 'eval_runtime': 56.188, 'eval_samples_per_second': 9.984, 'eval_steps_per_second': 0.641, 'epoch': 11.0}


  else torch.cuda.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype)


{'loss': 0.0052, 'learning_rate': 9.022476382910982e-05, 'epoch': 11.59}


  0%|          | 0/36 [00:00<?, ?it/s]

{'eval_loss': 0.005049669183790684, 'eval_f1_macro': 0.315449525693348, 'eval_precision_macro': 0.5471163451847674, 'eval_recall_macro': 0.23936118398999576, 'eval_hamming_loss': 0.06257532818658589, 'eval_runtime': 56.3902, 'eval_samples_per_second': 9.949, 'eval_steps_per_second': 0.638, 'epoch': 12.0}


  else torch.cuda.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype)


{'loss': 0.0051, 'learning_rate': 7.970876526719333e-05, 'epoch': 12.2}
{'loss': 0.0052, 'learning_rate': 6.942236337409622e-05, 'epoch': 12.8}


  0%|          | 0/36 [00:00<?, ?it/s]

{'eval_loss': 0.005014353897422552, 'eval_f1_macro': 0.32038838781788087, 'eval_precision_macro': 0.5609624335487331, 'eval_recall_macro': 0.24468338975072393, 'eval_hamming_loss': 0.06189126681650868, 'eval_runtime': 56.0816, 'eval_samples_per_second': 10.003, 'eval_steps_per_second': 0.642, 'epoch': 13.0}


  else torch.cuda.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype)


{'loss': 0.0051, 'learning_rate': 5.9481949470499255e-05, 'epoch': 13.41}


  0%|          | 0/36 [00:00<?, ?it/s]

{'eval_loss': 0.0049743917770683765, 'eval_f1_macro': 0.32993247288580646, 'eval_precision_macro': 0.5705957049024466, 'eval_recall_macro': 0.2514932465280891, 'eval_hamming_loss': 0.061646959184338254, 'eval_runtime': 56.7907, 'eval_samples_per_second': 9.878, 'eval_steps_per_second': 0.634, 'epoch': 14.0}


  else torch.cuda.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype)


{'loss': 0.0051, 'learning_rate': 5.000000000000002e-05, 'epoch': 14.02}
{'loss': 0.005, 'learning_rate': 4.108380385068289e-05, 'epoch': 14.63}


  0%|          | 0/36 [00:00<?, ?it/s]

{'eval_loss': 0.004941117484122515, 'eval_f1_macro': 0.31355721686043164, 'eval_precision_macro': 0.5959559622760103, 'eval_recall_macro': 0.2321657399926543, 'eval_hamming_loss': 0.06062086712922245, 'eval_runtime': 56.4177, 'eval_samples_per_second': 9.944, 'eval_steps_per_second': 0.638, 'epoch': 15.0}


  else torch.cuda.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype)


{'loss': 0.0051, 'learning_rate': 3.283424837422355e-05, 'epoch': 15.24}
{'loss': 0.0052, 'learning_rate': 2.5344677838803733e-05, 'epoch': 15.85}


  0%|          | 0/36 [00:00<?, ?it/s]

{'eval_loss': 0.004926861729472876, 'eval_f1_macro': 0.3280018381527943, 'eval_precision_macro': 0.5824266325365969, 'eval_recall_macro': 0.24773459449121882, 'eval_hamming_loss': 0.06097918498973908, 'eval_runtime': 57.0752, 'eval_samples_per_second': 9.829, 'eval_steps_per_second': 0.631, 'epoch': 16.0}


  else torch.cuda.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype)


{'loss': 0.0048, 'learning_rate': 1.8699837232516227e-05, 'epoch': 16.46}


  0%|          | 0/36 [00:00<?, ?it/s]

{'eval_loss': 0.004918945021927357, 'eval_f1_macro': 0.33512640271048916, 'eval_precision_macro': 0.5785517461840494, 'eval_recall_macro': 0.2551600196537773, 'eval_hamming_loss': 0.06122349262190951, 'eval_runtime': 56.0161, 'eval_samples_per_second': 10.015, 'eval_steps_per_second': 0.643, 'epoch': 17.0}


  else torch.cuda.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype)


{'loss': 0.0051, 'learning_rate': 1.2974913368195695e-05, 'epoch': 17.07}
{'loss': 0.005, 'learning_rate': 8.234684139637205e-06, 'epoch': 17.68}


  0%|          | 0/36 [00:00<?, ?it/s]

{'eval_loss': 0.004913683515042067, 'eval_f1_macro': 0.3291431309093964, 'eval_precision_macro': 0.5911900390495181, 'eval_recall_macro': 0.24738983419598573, 'eval_hamming_loss': 0.06071859018209062, 'eval_runtime': 55.929, 'eval_samples_per_second': 10.031, 'eval_steps_per_second': 0.644, 'epoch': 18.0}


  else torch.cuda.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype)


{'loss': 0.005, 'learning_rate': 4.53278555542519e-06, 'epoch': 18.29}
{'loss': 0.005, 'learning_rate': 1.9111048439335977e-06, 'epoch': 18.9}


  0%|          | 0/36 [00:00<?, ?it/s]

{'eval_loss': 0.004912855103611946, 'eval_f1_macro': 0.3322603871657897, 'eval_precision_macro': 0.5871754223589785, 'eval_recall_macro': 0.2508063300533998, 'eval_hamming_loss': 0.060832600410436824, 'eval_runtime': 55.8602, 'eval_samples_per_second': 10.043, 'eval_steps_per_second': 0.644, 'epoch': 19.0}


  else torch.cuda.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype)


{'loss': 0.0051, 'learning_rate': 3.99306496554519e-07, 'epoch': 19.51}


  0%|          | 0/36 [00:00<?, ?it/s]

{'eval_loss': 0.004913104698061943, 'eval_f1_macro': 0.33211686178296024, 'eval_precision_macro': 0.5850317810467742, 'eval_recall_macro': 0.25097895218355687, 'eval_hamming_loss': 0.060946610638783025, 'eval_runtime': 56.4172, 'eval_samples_per_second': 9.944, 'eval_steps_per_second': 0.638, 'epoch': 20.0}
{'train_runtime': 14268.1446, 'train_samples_per_second': 3.668, 'train_steps_per_second': 0.115, 'train_loss': 0.009423801802643916, 'epoch': 20.0}

Training loss: 0.0094
Runtime:       14268.1s


## B.3 SWA (Bidirectional)

In [45]:
apply_swa(trainer_bi, "./gus-net-gpt2-bidirectional", last_n=5)

SWA: averaging 5 checkpoints
SWA applied (5 checkpoints).


## B.4 Threshold Optimization (Bidirectional)

In [46]:
best_thr_bi, best_f1_dev_bi = optimize_thresholds(trainer_bi, dev_split)
print(f"\nOptimized thresholds (bidirectional): {best_thr_bi}")
print(f"Dev macro-F1: {best_f1_dev_bi:.4f}")
thresholds = best_thr_bi

Pass 1 — grid search:
  B-GEN     : thr=0.375, F1=0.5344
  I-GEN     : thr=0.375, F1=0.4035
  B-UNFAIR  : thr=0.375, F1=0.4438
  I-UNFAIR  : thr=0.350, F1=0.3206
  B-STEREO  : thr=0.450, F1=0.4515
  I-STEREO  : thr=0.425, F1=0.5535

Pass 2 — refinement:
  B-GEN     : thr=0.3866, F1=0.5401
  I-GEN     : thr=0.3750, F1=0.4004
  B-UNFAIR  : thr=0.3873, F1=0.4505
  I-UNFAIR  : thr=0.3466, F1=0.3287
  B-STEREO  : thr=0.4500, F1=0.4497
  I-STEREO  : thr=0.4315, F1=0.5561

Optimized thresholds (bidirectional): [0.3866261  0.375      0.38731122 0.34655127 0.45       0.43150443]
Dev macro-F1: 0.4551


## B.5 Evaluation (Bidirectional)

In [47]:
# ── Token-level ──
print("Evaluating bidirectional model on test set (token-level)...")
test_metrics_bi = trainer_bi.evaluate(test_split)

print(f"\nBidirectional test results (optimized thresholds):")
print(f"  Macro F1:      {test_metrics_bi['eval_f1_macro']:.4f}")
print(f"  Precision:     {test_metrics_bi['eval_precision_macro']:.4f}")
print(f"  Recall:        {test_metrics_bi['eval_recall_macro']:.4f}")
print(f"  Hamming Loss:  {test_metrics_bi['eval_hamming_loss']:.4f}")

Evaluating bidirectional model on test set (token-level)...


  else torch.cuda.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype)


  0%|          | 0/36 [00:00<?, ?it/s]


Bidirectional test results (optimized thresholds):
  Macro F1:      0.4349
  Precision:     0.4062
  Recall:        0.4973
  Hamming Loss:  0.0827


In [48]:
# ── Entity-level ──
print("=" * 60)
print("ENTITY-LEVEL EVALUATION — Bidirectional (Test Set)")
print("=" * 60)
entity_metrics_bi = compute_entity_metrics(trainer_bi, test_split, thresholds)

ENTITY-LEVEL EVALUATION — Bidirectional (Test Set)

Entity-level evaluation:
------------------------------------------------------------
  GEN     : F1=0.3344  P=0.2913  R=0.3927  (support=764)
  UNFAIR  : F1=0.2457  P=0.2535  R=0.2384  (support=151)
  STEREO  : F1=0.1265  P=0.1576  R=0.1057  (support=246)
------------------------------------------------------------
  MICRO   : F1=0.2898  P=0.2708  R=0.3118


## B.6 Inference Demo (Bidirectional)

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print("=" * 60)
print("INFERENCE DEMO (GPT-2 Bidirectional)")
print("=" * 60)
for text in examples:
    print(f"\nInput: '{text}'")
    biases = predict_bias(text, model_bi, tokenizer, thresholds, device)
    if biases:
        for b in biases:
            print(f"  → '{b['token']}': {', '.join(b['labels'])}")
    else:
        print("  → No bias detected (Neutral)")

## B.7 Save Bidirectional Model

In [49]:
output_dir_bi = "./gus-net-gpt2-bidirectional-final"
trainer_bi.save_model(output_dir_bi)
tokenizer.save_pretrained(output_dir_bi)
np.save(f"{output_dir_bi}/optimized_thresholds.npy", thresholds)

print(f"Bidirectional model saved to {output_dir_bi}")

Bidirectional model saved to ./gus-net-gpt2-bidirectional-final


---

## Comparison: Causal vs Bidirectional GPT-2

In [None]:
bidirectional_results = {
    "token_f1": test_metrics_bi["eval_f1_macro"],
    "token_precision": test_metrics_bi["eval_precision_macro"],
    "token_recall": test_metrics_bi["eval_recall_macro"],
    "token_hamming": test_metrics_bi["eval_hamming_loss"],
    "entity_micro_f1": entity_metrics_bi["micro_f1"],
    "entity_micro_p": entity_metrics_bi["micro_p"],
    "entity_micro_r": entity_metrics_bi["micro_r"],
    "thresholds": thresholds.copy(),
}

# ── Comparison table ──
def delta(bi, ca):
    d = bi - ca
    return f"{d:+.4f}" if d != 0 else "  0.0000"

print("=" * 70)
print("COMPARISON: Causal vs Bidirectional GPT-2")
print("=" * 70)
print(f"{'Metric':<25s} {'Causal':>10s} {'Bidirect.':>10s} {'Δ':>10s}")
print("-" * 70)

rows = [
    ("Token Macro F1",    "token_f1"),
    ("Token Precision",   "token_precision"),
    ("Token Recall",      "token_recall"),
    ("Token Hamming Loss","token_hamming"),
    ("Entity Micro F1",   "entity_micro_f1"),
    ("Entity Precision",  "entity_micro_p"),
    ("Entity Recall",     "entity_micro_r"),
]

for label, key in rows:
    c = causal_results[key]
    b = bidirectional_results[key]
    print(f"  {label:<23s} {c:>10.4f} {b:>10.4f} {delta(b, c):>10s}")

print("-" * 70)

# Winner
c_f1 = causal_results["token_f1"]
b_f1 = bidirectional_results["token_f1"]
winner = "Bidirectional" if b_f1 > c_f1 else "Causal" if c_f1 > b_f1 else "Tie"
margin = abs(b_f1 - c_f1)
print(f"\n  Winner (Token F1): {winner} ({margin:+.4f})")

c_ef1 = causal_results["entity_micro_f1"]
b_ef1 = bidirectional_results["entity_micro_f1"]
winner_e = "Bidirectional" if b_ef1 > c_ef1 else "Causal" if c_ef1 > b_ef1 else "Tie"
margin_e = abs(b_ef1 - c_ef1)
print(f"  Winner (Entity F1): {winner_e} ({margin_e:+.4f})")

print("\nOptimized thresholds:")
print(f"  {'Channel':<12s} {'Causal':>10s} {'Bidirect.':>10s}")
for i, ch in enumerate(channels):
    print(f"  {ch:<12s} {causal_results['thresholds'][i]:>10.4f} {bidirectional_results['thresholds'][i]:>10.4f}")
print("=" * 70)