# GUS-Net: Social Bias Classification with Generalizations, Unfairness, and Stereotypes

## Abstract
This notebook documents the implementation and utilization of **GUS-Net**, a specialized Named Entity Recognition (NER) model designed to detect social biases in text. The model focuses on three specific categories of bias:
1.  **Generalizations (GEN)**: Categorical statements that attribute properties to a group without exception.
2.  **Unfairness (UNFAIR)**: Language that is inherently unjust, pejorative, or discriminatory.
3.  **Stereotypes (STEREO)**: Attribution of fixed characteristics to particular social groups.

The methodology encompasses three primary approaches:
1.  **Inference with Pre-trained Models**: Utilizing the HuggingFace `ethical-spectacle/social-bias-ner` model.
2.  **Training Methodology**: A comprehensive guide to training the model from scratch using the focal loss function to address class imbalance.
3.  **Integration Architecture**: A proposed framework for integrating GUS-Net into the **Attention Atlas** visualization tool.

---

## 1. Methodology: Pre-trained Model Inference

This section demonstrates how to utilize the pre-trained GUS-Net model for immediate inference. This approach is optimal for rapid testing and integration where custom fine-tuning is not required.

In [None]:
from transformers import AutoTokenizer, AutoModelForTokenClassification
import torch

# 1. Load Model and Tokenizer
# The model is hosted on HuggingFace Hub under the Ethical Spectacle organization
model_name = "ethical-spectacle/social-bias-ner"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForTokenClassification.from_pretrained(model_name)

# 2. Define Label Mappings
# The model uses BIO tagging scheme for 3 bias types + Outside (O)
id2label = {
    0: "O",           # Outside (neutral)
    1: "B-STEREO",    # Begin Stereotype
    2: "I-STEREO",    # Inside Stereotype
    3: "B-GEN",       # Begin Generalization
    4: "I-GEN",       # Inside Generalization
    5: "B-UNFAIR",    # Begin Unfairness
    6: "I-UNFAIR"     # Inside Unfairness
}

def detect_bias(text: str, threshold: float = 0.5):
    """
    Detects social bias spans in a given text using the GUS-Net model.
    
    Args:
        text (str): The input text to analyze.
        threshold (float): Confidence threshold for label acceptance (default: 0.5).
        
    Returns:
        list: A list of dictionaries containing biased tokens and their classification metadata.
    """
    # Tokenization with offset mapping to retrieve character positions
    inputs = tokenizer(
        text, 
        return_tensors="pt", 
        truncation=True, 
        max_length=128,
        return_offsets_mapping=True
    )
    
    offset_mapping = inputs.pop("offset_mapping")[0]
    
    # Inference step
    with torch.no_grad():
        outputs = model(**inputs)
        
    # Apply sigmoid activation for multi-label classification probability
    probs = torch.sigmoid(outputs.logits[0])
    
    # Process results
    results = []
    tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
    
    for i, (token, prob, offset) in enumerate(zip(tokens, probs, offset_mapping)):
        # Skip special tokens
        if token in ["[CLS]", "[SEP]", "[PAD]"]:
            continue
            
        # Identify labels exceeding the confidence threshold
        labels = []
        for label_id, p in enumerate(prob):
            if p > threshold and label_id != 0:  # Ignore 'O' label
                labels.append({
                    "label": id2label[label_id],
                    "confidence": p.item()
                })
        
        if labels:
            results.append({
                "token": token,
                "char_start": offset[0].item(),
                "char_end": offset[1].item(),
                "labels": labels
            })
    
    return results

# --- Demonstration ---
examples = [
    "Women are naturally better at nursing.",
    "All politicians are corrupt liars.",
    "Young people these days are so lazy and entitled.",
    "The engineer fixed the problem quickly."  # Neutral control sentence
]

print("=" * 60)
print("INFERENCE RESULTS")
print("=" * 60)

for text in examples:
    print(f"\nInput: '{text}'")
    biases = detect_bias(text)
    if biases:
        for b in biases:
            labels_str = ", ".join([f"{l['label']} ({l['confidence']:.2f})" for l in b['labels']])
            print(f"  → '{b['token']}': {labels_str}")
    else:
        print("  → No bias detected (Neutral)")

## 2. Experimental Setup: Fine Tuning

This section outlines the procedure for training the GUS-Net model from the ground up utilizing the official dataset. This methodology is applicable for domain adaptation, reproducibility studies, or further academic research.

In [25]:
# Install required packages (uncomment if needed)
# !pip install datasets transformers torch scikit-learn numpy

from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    AutoModelForTokenClassification,
    TrainingArguments,
    Trainer,
    DataCollatorForTokenClassification,
)
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import json
from sklearn.metrics import f1_score, precision_score, recall_score

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")


PyTorch version: 2.8.0+cu126
CUDA available: True


In [27]:
# Load local standardized dataset
import json
from datasets import Dataset, DatasetDict
import os

# Use the standardized local dataset
dataset_path = '../../dataset/new_dataset.json'
if not os.path.exists(dataset_path):
    # Fallback for different working directories
    dataset_path = 'dataset/new_dataset.json'

with open(dataset_path, 'r', encoding='utf-8') as f:
    local_data = json.load(f)

dataset = DatasetDict({
    'train': Dataset.from_list(local_data['bias_dataset'])
})

print(f"Loaded {len(dataset['train'])} examples from {dataset_path}")
print(f"Dataset features: {dataset['train'].column_names}")
print('\nSample entry:')
print(json.dumps(dataset['train'][0], indent=2, ensure_ascii=False)[:300] + '...')

Dropped 78 duplicate samples
Dataset structure: DatasetDict({
    train: Dataset({
        features: ['text_str', 'ner_tags', 'rationale', 'individual_ner_annotations_json', 'id', '__index_level_0__'],
        num_rows: 3661
    })
})
Number of examples: 3661

Sample example:
{'text_str': 'Should we assume that every retired individual spends hours playing bingo each week?', 'ner_tags': "[['O'], ['O'], ['O'], ['O'], ['B-GEN'], ['I-GEN', 'B-STEREO'], ['I-GEN', 'I-STEREO'], ['I-STEREO'], ['I-STEREO'], ['I-STEREO'], ['I-STEREO'], ['I-STEREO'], ['I-STEREO']]", 'rationale': '[\'Reasoning: Let\\\'s think step by step in order to produce the annotations. We will analyze each word to determine if it fits the description of a generalization (GEN).\\n\\n1. "Should" - This is an auxiliary verb and does not fit the definition of a generalization. Tag: O\\n2. "we" - This is a pronoun but not used as a generalization here. Tag: O\\n3. "assume" - This is a verb and does not fit the definition of a ge

In [28]:
# Initialize tokenizer
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

# Define multi-label channels (no explicit O class; O = all zeros)
channels = ["B-GEN", "I-GEN", "B-UNFAIR", "I-UNFAIR", "B-STEREO", "I-STEREO"]
channel2idx = {c: i for i, c in enumerate(channels)}
idx2channel = {i: c for i, c in enumerate(channels)}
num_channels = len(channels)

print(f"Number of channels: {num_channels}")
print(f"Channel mapping: {channel2idx}")




Number of channels: 6
Channel mapping: {'B-GEN': 0, 'I-GEN': 1, 'B-UNFAIR': 2, 'I-UNFAIR': 3, 'B-STEREO': 4, 'I-STEREO': 5}


In [29]:
def parse_annotations(example):
    """
    Extracts annotations from the standardized dictionary format.
    Constructs a list of tags for each word.
    """
    anno_dict = example['annotations']
    # Combine tags for each token index
    combined = []
    keys = ['GEN', 'STEREO', 'UNFAIR']
    
    # All lists have the same length
    length = len(anno_dict['GEN'])
    for i in range(length):
        tags = [anno_dict[k][i] for k in keys if anno_dict[k][i] != 'O']
        if not tags:
            tags = ['O']
        combined.append(tags)
    return combined

In [30]:
def prepare_example(example):
    """
    Preprocess an example:
    - Tokenize with is_split_into_words=True for word-to-subword alignment
    - Use word_ids() to map each subword token back to its original word
    - Build multi-hot label matrix [seq_len, num_channels]
    - Mask special/padding tokens with -100
    """
    text = example['text_str']
    word_tags = parse_annotations(example)  # list of lists, one per word

    # Split text into words
    words = text.split()

    # Tokenize with word-level alignment
    tokenized = tokenizer(
        words,
        is_split_into_words=True,
        truncation=True,
        max_length=128,
        padding='max_length',
    )

    word_ids = tokenized.word_ids() 
    seq_len = len(word_ids)

    labels_multi = np.zeros((seq_len, num_channels), dtype=np.float32)

    prev_word_id = None
    for idx, word_id in enumerate(word_ids):
        if word_id is None:
            prev_word_id = None
            continue
        if word_id >= len(word_tags):
            prev_word_id = word_id
            continue
        tags = word_tags[word_id]
        for tag in tags:
            if tag == 'O': continue
            if word_id == prev_word_id:
                if tag.startswith('B-'):
                    i_tag = 'I-' + tag[2:]
                    if i_tag in channel2idx:
                        labels_multi[idx, channel2idx[i_tag]] = 1.0
                elif tag in channel2idx:
                    labels_multi[idx, channel2idx[tag]] = 1.0
            else:
                if tag in channel2idx:
                    labels_multi[idx, channel2idx[tag]] = 1.0
        prev_word_id = word_id

    final_labels = []
    for idx, word_id in enumerate(word_ids):
        if word_id is None:
            final_labels.append([-100.0] * num_channels)
        else:
            final_labels.append(labels_multi[idx].tolist())

    tokenized['labels'] = final_labels
    return tokenized

In [31]:
print("Tokenizing dataset...")
tokenized_dataset = dataset.map(
    prepare_example,
    batched=False,
    remove_columns=dataset["train"].column_names,
)
print("Tokenization complete!")

# === Sanity check: verify annotations survived preprocessing ===
total_positive_tokens = 0
total_valid_tokens = 0
for ex in tokenized_dataset["train"]:
    labels = np.array(ex["labels"])
    valid = labels[labels[:, 0] != -100.0]
    total_valid_tokens += len(valid)
    total_positive_tokens += (valid > 0).any(axis=1).sum()

print(f"\nSanity check:")
print(f"  Total valid tokens: {total_valid_tokens}")
print(f"  Tokens with at least one bias label: {total_positive_tokens}")
print(f"  Positive rate: {total_positive_tokens / max(total_valid_tokens, 1):.2%}")
assert total_positive_tokens > 0, (
    "FATAL: No positive labels found. Check that prepare_example "
    "is reading from the correct dataset field."
)
print("  OK — annotations loaded successfully.")


Tokenizing dataset...


Map:   0%|          | 0/3661 [00:00<?, ? examples/s]

Tokenization complete!

Sanity check:
  Total valid tokens: 68314
  Tokens with at least one bias label: 21677
  Positive rate: 31.73%
  OK — annotations loaded successfully.


In [32]:
# Split: 70% train, 15% dev, 15% test (matching paper methodology)

# First split: train (70%) vs dev+test (30%)
train_devtest = tokenized_dataset["train"].train_test_split(
    test_size=0.30, seed=42
)
train_split = train_devtest["train"]
devtest_split = train_devtest["test"]

# Second split: dev (50% of 30% = 15%) vs test (50% of 30% = 15%)
dev_test = devtest_split.train_test_split(test_size=0.5, seed=42)
dev_split = dev_test["train"]
test_split = dev_test["test"]

print(f"Train size: {len(train_split)}")
print(f"Dev size:   {len(dev_split)}")
print(f"Test size:  {len(test_split)}")


Train size: 2562
Dev size:   549
Test size:  550


In [33]:
# Load BERT model with increased regularization for small dataset
from transformers import AutoConfig

config = AutoConfig.from_pretrained("bert-base-uncased")
config.num_labels = num_channels
config.problem_type = "multi_label_classification"
config.classifier_dropout = 0.3        # Dropout before classifier head
config.hidden_dropout_prob = 0.15      # Slightly increase hidden dropout

model = AutoModelForTokenClassification.from_pretrained(
    "bert-base-uncased",
    config=config,
)

print(f"Model loaded: {model.config.model_type}")
print(f"Number of parameters: {model.num_parameters():,}")
print(f"Classifier dropout: {model.config.classifier_dropout}")
print(f"Hidden dropout: {model.config.hidden_dropout_prob}")


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForTokenClassification: ['cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias']
- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForTokenClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: 

Model loaded: bert
Number of parameters: 108,896,262
Classifier dropout: 0.3
Hidden dropout: 0.15


In [34]:
def estimate_channel_frequencies(dataset_split):
    """
    Estimate frequency of positives per channel in the training split.
    """
    positives = np.zeros(num_channels, dtype=np.int64)
    total = 0

    for example in dataset_split:
        labels = np.array(example["labels"])  # [seq_len, num_channels]
        # Mask for valid tokens
        valid_mask = labels[:, 0] != -100.0
        valid_labels = labels[valid_mask]
        if valid_labels.size == 0:
            continue
        positives += valid_labels.sum(axis=0).astype(np.int64)
        total += valid_labels.shape[0]

    return positives, total


# Calculate alpha values for focal loss
channel_pos, total_tokens = estimate_channel_frequencies(train_split)
# Avoid division by zero: if a channel doesn't appear, assign minimum frequency
channel_pos = np.maximum(channel_pos, 1)
freq = channel_pos / float(total_tokens)

# α_c ∝ 1 / freq_c, normalized
inv_freq = 1.0 / freq
alpha_channel = inv_freq / inv_freq.sum()
alpha_channel = torch.tensor(alpha_channel, dtype=torch.float32)

print("Channel statistics:")
for i, ch in enumerate(channels):
    print(f"  {ch}: {channel_pos[i]} positives, α={alpha_channel[i]:.4f}")
print(f"\nTotal valid tokens: {total_tokens}")


Channel statistics:
  B-GEN: 3446 positives, α=0.0835
  I-GEN: 3436 positives, α=0.0837
  B-UNFAIR: 716 positives, α=0.4018
  I-UNFAIR: 1978 positives, α=0.1454
  B-STEREO: 1135 positives, α=0.2534
  I-STEREO: 8939 positives, α=0.0322

Total valid tokens: 47550


In [35]:
class FocalLossMultiLabel(nn.Module):
    """
    Focal loss applied channel-wise for multi-label classification.
    inputs: logits [N, num_channels]
    targets: multi-hot [N, num_channels]
    alpha: tensor [num_channels]
    label_smoothing: smooth targets to prevent overconfident predictions
    """

    def __init__(self, alpha, gamma=2.0, reduction="mean", label_smoothing=0.0):
        super().__init__()
        self.register_buffer("alpha", alpha)
        self.gamma = gamma
        self.reduction = reduction
        self.label_smoothing = label_smoothing

    def forward(self, inputs, targets):
        """
        inputs: logits
        targets: multi-hot (0/1)
        """
        # Label smoothing: y_smooth = y * (1 - eps) + 0.5 * eps
        if self.label_smoothing > 0:
            targets = targets * (1.0 - self.label_smoothing) + 0.5 * self.label_smoothing

        bce = F.binary_cross_entropy_with_logits(
            inputs, targets.float(), reduction="none"
        )  # [N, C]
        pt = torch.exp(-bce)
        # Broadcasting alpha: [C] -> [N, C]
        focal = self.alpha.to(inputs.device) * (1 - pt) ** self.gamma * bce

        if self.reduction == "mean":
            return focal.mean()
        elif self.reduction == "sum":
            return focal.sum()
        else:
            return focal


In [36]:
from transformers import get_cosine_schedule_with_warmup


class FocalLossTrainer(Trainer):
    """
    Trainer with:
    - Focal Loss for multi-label classification (with label smoothing)
    - Layer-wise Learning Rate Decay (LLRD) for BERT fine-tuning
    - Cosine annealing scheduler with warmup
    """

    def __init__(self, *args, alpha_channel, gamma=2.0, label_smoothing=0.0,
                 llrd_decay_factor=0.85, classifier_lr=2e-4, **kwargs):
        super().__init__(*args, **kwargs)
        self.focal_loss = FocalLossMultiLabel(
            alpha=alpha_channel, gamma=gamma, label_smoothing=label_smoothing
        )
        self.llrd_decay_factor = llrd_decay_factor
        self.classifier_lr = classifier_lr

    def create_optimizer(self):
        """Layer-wise Learning Rate Decay: higher LR for top layers, lower for bottom."""
        base_lr = self.args.learning_rate
        decay = self.llrd_decay_factor
        no_decay_keys = ["bias", "LayerNorm.weight", "LayerNorm.bias"]

        opt_params = []

        # 1. Classifier head: highest LR
        opt_params.append({
            "params": [p for n, p in self.model.named_parameters() if "classifier" in n],
            "lr": self.classifier_lr,
            "weight_decay": 0.0,
        })

        # 2. BERT encoder layers 11 -> 0: progressively lower LR
        for layer_idx in range(11, -1, -1):
            layer_lr = base_lr * (decay ** (11 - layer_idx))
            layer_decay = []
            layer_no_decay = []

            for n, p in self.model.named_parameters():
                if f"bert.encoder.layer.{layer_idx}." in n:
                    if any(nd in n for nd in no_decay_keys):
                        layer_no_decay.append(p)
                    else:
                        layer_decay.append(p)

            if layer_decay:
                opt_params.append({
                    "params": layer_decay,
                    "lr": layer_lr,
                    "weight_decay": self.args.weight_decay,
                })
            if layer_no_decay:
                opt_params.append({
                    "params": layer_no_decay,
                    "lr": layer_lr,
                    "weight_decay": 0.0,
                })

        # 3. Embeddings: lowest LR
        emb_lr = base_lr * (decay ** 12)
        emb_decay = []
        emb_no_decay = []
        for n, p in self.model.named_parameters():
            if "bert.embeddings" in n:
                if any(nd in n for nd in no_decay_keys):
                    emb_no_decay.append(p)
                else:
                    emb_decay.append(p)

        if emb_decay:
            opt_params.append({
                "params": emb_decay,
                "lr": emb_lr,
                "weight_decay": self.args.weight_decay,
            })
        if emb_no_decay:
            opt_params.append({
                "params": emb_no_decay,
                "lr": emb_lr,
                "weight_decay": 0.0,
            })

        self.optimizer = torch.optim.AdamW(opt_params, lr=base_lr, eps=1e-8)

        print(f"LLRD optimizer created:")
        print(f"  Classifier LR: {self.classifier_lr}")
        print(f"  Top BERT layer LR: {base_lr}")
        print(f"  Bottom BERT layer LR: {base_lr * decay**11:.2e}")
        print(f"  Embeddings LR: {emb_lr:.2e}")

        return self.optimizer

    def create_scheduler(self, num_training_steps, optimizer=None):
        """Cosine annealing with warmup for smoother convergence."""
        if optimizer is None:
            optimizer = self.optimizer

        warmup_steps = int(num_training_steps * self.args.warmup_ratio)

        self.lr_scheduler = get_cosine_schedule_with_warmup(
            optimizer,
            num_warmup_steps=warmup_steps,
            num_training_steps=num_training_steps,
        )

        print(f"Cosine scheduler: {warmup_steps} warmup steps, {num_training_steps} total steps")
        return self.lr_scheduler

    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        labels = inputs.pop("labels")  # [batch, seq_len, num_channels]
        outputs = model(**inputs)
        logits = outputs.logits  # [batch, seq_len, num_channels]

        # Flatten batch and seq_len
        logits_flat = logits.view(-1, num_channels)
        labels_flat = labels.view(-1, num_channels)

        # Mask for valid tokens (not padding)
        valid_mask = labels_flat[:, 0] != -100.0
        logits_valid = logits_flat[valid_mask]
        labels_valid = labels_flat[valid_mask]

        loss = self.focal_loss(logits_valid, labels_valid)

        return (loss, outputs) if return_outputs else loss


In [37]:
# Initial thresholds per channel (can be optimized later)
thresholds = np.array([0.5] * num_channels, dtype=np.float32)


def compute_metrics(eval_pred):
    """
    Compute per-channel and macro-averaged metrics for multi-label token classification.
    Each channel is evaluated independently as a binary classification task.
    Includes a sanity check against degenerate all-zero labels.
    """
    predictions, labels = eval_pred  # preds: [batch, seq, C], labels: [batch, seq, C]
    # Convert logits to sigmoid probabilities
    probs = 1 / (1 + np.exp(-predictions))

    # Mask for valid tokens
    valid_mask = labels[:, :, 0] != -100.0  # [batch, seq]
    probs_flat = probs[valid_mask]          # [N_valid, C]
    labels_flat = labels[valid_mask]        # [N_valid, C]

    # Apply channel-wise thresholds
    thr = thresholds.reshape(1, num_channels)
    preds_bin = (probs_flat >= thr).astype(int)
    labels_bin = labels_flat.astype(int)

    # Sanity check: warn if labels contain no positives at all
    total_positives = labels_bin.sum()
    if total_positives == 0:
        print("WARNING: All labels are zero — annotations may not have been loaded correctly.")
        return {
            "f1_macro": 0.0,
            "precision_macro": 0.0,
            "recall_macro": 0.0,
            "hamming_loss": 0.0,
            "total_positives": 0,
        }

    # Hamming loss
    hamming = np.mean(preds_bin != labels_bin)

    # Per-channel metrics (proper multi-label evaluation)
    channel_f1s = []
    for c in range(num_channels):
        f1_c = f1_score(labels_bin[:, c], preds_bin[:, c], average="binary", zero_division=0)
        channel_f1s.append(f1_c)

    # Macro average across channels (not flattened, which hides class imbalance)
    f1_macro = np.mean(channel_f1s)
    precision = precision_score(
        labels_bin, preds_bin, average="macro", zero_division=0
    )
    recall = recall_score(
        labels_bin, preds_bin, average="macro", zero_division=0
    )

    return {
        "f1_macro": f1_macro,
        "precision_macro": precision,
        "recall_macro": recall,
        "hamming_loss": hamming,
        "total_positives": int(total_positives),
    }


In [38]:
training_args = TrainingArguments(
    output_dir="./gus-net-bert-multilabel",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    gradient_accumulation_steps=2,   # Effective batch size = 32
    num_train_epochs=20,             # Upper bound; early stopping will halt sooner
    weight_decay=0.01,
    warmup_ratio=0.1,
    logging_steps=50,                # More frequent logging
    fp16=True,
    load_best_model_at_end=True,
    metric_for_best_model="f1_macro",
    greater_is_better=True,
    report_to="none",
)

print("Training configuration:")
print(f"  Batch size: {training_args.per_device_train_batch_size}")
print(f"  Gradient accumulation: {training_args.gradient_accumulation_steps}")
print(f"  Effective batch size: {training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps}")
print(f"  Learning rate: {training_args.learning_rate}")
print(f"  Epochs (max): {training_args.num_train_epochs}")
print(f"  FP16: {training_args.fp16}")


Training configuration:
  Batch size: 16
  Gradient accumulation: 2
  Effective batch size: 32
  Learning rate: 5e-05
  Epochs (max): 20
  FP16: True


In [39]:
from transformers import EarlyStoppingCallback

data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)

trainer = FocalLossTrainer(
    model=model,
    args=training_args,
    train_dataset=train_split,
    eval_dataset=dev_split,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    alpha_channel=alpha_channel,
    gamma=2.0,
    label_smoothing=0.05,
    llrd_decay_factor=0.85,
    classifier_lr=2e-4,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=5)],
)

print("Trainer initialized with:")
print("  - Focal loss (gamma=2.0, label_smoothing=0.05)")
print("  - LLRD (decay=0.85, classifier_lr=2e-4)")
print("  - Cosine annealing scheduler")
print("  - Early stopping (patience=5)")


Trainer initialized with:
  - Focal loss (gamma=2.0, label_smoothing=0.05)
  - LLRD (decay=0.85, classifier_lr=2e-4)
  - Cosine annealing scheduler
  - Early stopping (patience=5)


  self.scaler = torch.cuda.amp.GradScaler()


In [40]:
# Train the model
print("Starting training...")
train_result = trainer.train()

print("\nTraining completed!")
print(f"Training loss: {train_result.training_loss:.4f}")
print(f"Training time: {train_result.metrics['train_runtime']:.2f}s")


Starting training...
LLRD optimizer created:
  Classifier LR: 0.0002
  Top BERT layer LR: 5e-05
  Bottom BERT layer LR: 8.37e-06
  Embeddings LR: 7.11e-06
Cosine scheduler: 160 warmup steps, 1600 total steps


  0%|          | 0/1600 [00:00<?, ?it/s]

You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
  else torch.cuda.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype)


{'loss': 0.0182, 'learning_rate': 6.25e-05, 'epoch': 0.62}


  0%|          | 0/35 [00:00<?, ?it/s]

{'eval_loss': 0.006513354834169149, 'eval_f1_macro': 0.015373010403805476, 'eval_precision_macro': 0.16055045871559634, 'eval_recall_macro': 0.008485261758655903, 'eval_hamming_loss': 0.07043536632144227, 'eval_total_positives': 4389, 'eval_runtime': 2.2118, 'eval_samples_per_second': 248.213, 'eval_steps_per_second': 15.824, 'epoch': 0.99}


  else torch.cuda.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype)


{'loss': 0.0071, 'learning_rate': 0.000125, 'epoch': 1.24}
{'loss': 0.0054, 'learning_rate': 0.0001875, 'epoch': 1.86}


  0%|          | 0/35 [00:00<?, ?it/s]

{'eval_loss': 0.004559027496725321, 'eval_f1_macro': 0.41159663921546114, 'eval_precision_macro': 0.6821901878119488, 'eval_recall_macro': 0.3219465611743348, 'eval_hamming_loss': 0.05908771256872523, 'eval_total_positives': 4389, 'eval_runtime': 2.6631, 'eval_samples_per_second': 206.147, 'eval_steps_per_second': 13.142, 'epoch': 2.0}


  else torch.cuda.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype)


{'loss': 0.0045, 'learning_rate': 0.00019961946980917456, 'epoch': 2.48}


  0%|          | 0/35 [00:00<?, ?it/s]

{'eval_loss': 0.004048365633934736, 'eval_f1_macro': 0.5335371242305434, 'eval_precision_macro': 0.6484360897670595, 'eval_recall_macro': 0.47651554917157707, 'eval_hamming_loss': 0.051895537655031326, 'eval_total_positives': 4389, 'eval_runtime': 2.5079, 'eval_samples_per_second': 218.908, 'eval_steps_per_second': 13.956, 'epoch': 2.99}


  else torch.cuda.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype)


{'loss': 0.004, 'learning_rate': 0.00019807852804032305, 'epoch': 3.11}
{'loss': 0.0036, 'learning_rate': 0.0001953716950748227, 'epoch': 3.73}


  0%|          | 0/35 [00:00<?, ?it/s]

{'eval_loss': 0.003925737924873829, 'eval_f1_macro': 0.5488768340017405, 'eval_precision_macro': 0.7033957205199021, 'eval_recall_macro': 0.45836233212676875, 'eval_hamming_loss': 0.048027745812555936, 'eval_total_positives': 4389, 'eval_runtime': 3.2771, 'eval_samples_per_second': 167.527, 'eval_steps_per_second': 10.68, 'epoch': 4.0}


  else torch.cuda.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype)


{'loss': 0.0034, 'learning_rate': 0.00019153114791194473, 'epoch': 4.35}
{'loss': 0.0032, 'learning_rate': 0.00018660254037844388, 'epoch': 4.97}


  0%|          | 0/35 [00:00<?, ?it/s]

{'eval_loss': 0.003771999152377248, 'eval_f1_macro': 0.5893112925843704, 'eval_precision_macro': 0.6963144518660972, 'eval_recall_macro': 0.5143525914197825, 'eval_hamming_loss': 0.046253676000511446, 'eval_total_positives': 4389, 'eval_runtime': 11.6641, 'eval_samples_per_second': 47.068, 'eval_steps_per_second': 3.001, 'epoch': 4.99}


  else torch.cuda.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype)


{'loss': 0.0028, 'learning_rate': 0.00018064446042674828, 'epoch': 5.59}


  0%|          | 0/35 [00:00<?, ?it/s]

{'eval_loss': 0.003983289934694767, 'eval_f1_macro': 0.5877875858340303, 'eval_precision_macro': 0.6759326119705141, 'eval_recall_macro': 0.5343575540009923, 'eval_hamming_loss': 0.047068789157396755, 'eval_total_positives': 4389, 'eval_runtime': 8.0998, 'eval_samples_per_second': 67.779, 'eval_steps_per_second': 4.321, 'epoch': 6.0}


  else torch.cuda.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype)


{'loss': 0.0027, 'learning_rate': 0.0001737277336810124, 'epoch': 6.21}
{'loss': 0.0025, 'learning_rate': 0.00016593458151000688, 'epoch': 6.83}


  0%|          | 0/35 [00:00<?, ?it/s]

{'eval_loss': 0.004262497182935476, 'eval_f1_macro': 0.5951710842327343, 'eval_precision_macro': 0.7093563484464361, 'eval_recall_macro': 0.5172565094978419, 'eval_hamming_loss': 0.044655414908579466, 'eval_total_positives': 4389, 'eval_runtime': 4.1879, 'eval_samples_per_second': 131.091, 'eval_steps_per_second': 8.357, 'epoch': 6.99}


  else torch.cuda.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype)


{'loss': 0.0024, 'learning_rate': 0.0001573576436351046, 'epoch': 7.45}


  0%|          | 0/35 [00:00<?, ?it/s]

{'eval_loss': 0.004098713397979736, 'eval_f1_macro': 0.6248726502142938, 'eval_precision_macro': 0.6684035933963743, 'eval_recall_macro': 0.5879966062367098, 'eval_hamming_loss': 0.0453266845671909, 'eval_total_positives': 4389, 'eval_runtime': 5.3297, 'eval_samples_per_second': 103.007, 'eval_steps_per_second': 6.567, 'epoch': 8.0}


  else torch.cuda.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype)


{'loss': 0.0022, 'learning_rate': 0.00014809887689193877, 'epoch': 8.07}
{'loss': 0.0021, 'learning_rate': 0.000138268343236509, 'epoch': 8.7}


  0%|          | 0/35 [00:00<?, ?it/s]

{'eval_loss': 0.004405978601425886, 'eval_f1_macro': 0.6288894137198777, 'eval_precision_macro': 0.6907769572367387, 'eval_recall_macro': 0.5786202111524106, 'eval_hamming_loss': 0.04427183224651579, 'eval_total_positives': 4389, 'eval_runtime': 2.2018, 'eval_samples_per_second': 249.339, 'eval_steps_per_second': 15.896, 'epoch': 8.99}


  else torch.cuda.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype)


{'loss': 0.0019, 'learning_rate': 0.00012798290140309923, 'epoch': 9.32}
{'loss': 0.002, 'learning_rate': 0.00011736481776669306, 'epoch': 9.94}


  0%|          | 0/35 [00:00<?, ?it/s]

{'eval_loss': 0.004365340806543827, 'eval_f1_macro': 0.6339641690904293, 'eval_precision_macro': 0.6694708508588366, 'eval_recall_macro': 0.6029420866813918, 'eval_hamming_loss': 0.044783275795934024, 'eval_total_positives': 4389, 'eval_runtime': 2.1626, 'eval_samples_per_second': 253.86, 'eval_steps_per_second': 16.184, 'epoch': 10.0}


  else torch.cuda.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype)


{'loss': 0.0018, 'learning_rate': 0.00010654031292301432, 'epoch': 10.56}


  0%|          | 0/35 [00:00<?, ?it/s]

{'eval_loss': 0.004688904620707035, 'eval_f1_macro': 0.6300966102699059, 'eval_precision_macro': 0.6870985549790265, 'eval_recall_macro': 0.5839420399337362, 'eval_hamming_loss': 0.04417593658099987, 'eval_total_positives': 4389, 'eval_runtime': 2.1419, 'eval_samples_per_second': 256.312, 'eval_steps_per_second': 16.34, 'epoch': 10.99}


  else torch.cuda.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype)


{'loss': 0.0017, 'learning_rate': 9.563806126346642e-05, 'epoch': 11.18}
{'loss': 0.0016, 'learning_rate': 8.478766138100834e-05, 'epoch': 11.8}


  0%|          | 0/35 [00:00<?, ?it/s]

{'eval_loss': 0.004693201743066311, 'eval_f1_macro': 0.6359996457178297, 'eval_precision_macro': 0.6880767581190232, 'eval_recall_macro': 0.5941204040855681, 'eval_hamming_loss': 0.044048075693645314, 'eval_total_positives': 4389, 'eval_runtime': 3.0455, 'eval_samples_per_second': 180.266, 'eval_steps_per_second': 11.492, 'epoch': 12.0}


  else torch.cuda.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype)


{'loss': 0.0016, 'learning_rate': 7.411809548974792e-05, 'epoch': 12.42}


  0%|          | 0/35 [00:00<?, ?it/s]

{'eval_loss': 0.0049332003109157085, 'eval_f1_macro': 0.6320977621460756, 'eval_precision_macro': 0.6904401685867859, 'eval_recall_macro': 0.5887122582556291, 'eval_hamming_loss': 0.044048075693645314, 'eval_total_positives': 4389, 'eval_runtime': 3.8268, 'eval_samples_per_second': 143.461, 'eval_steps_per_second': 9.146, 'epoch': 12.99}


  else torch.cuda.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype)


{'loss': 0.0016, 'learning_rate': 6.375619617162985e-05, 'epoch': 13.04}
{'loss': 0.0015, 'learning_rate': 5.3825138676496624e-05, 'epoch': 13.66}


  0%|          | 0/35 [00:00<?, ?it/s]

{'eval_loss': 0.0050855595618486404, 'eval_f1_macro': 0.6288753352678627, 'eval_precision_macro': 0.7020078597135292, 'eval_recall_macro': 0.5759885844428859, 'eval_hamming_loss': 0.043888249584452116, 'eval_total_positives': 4389, 'eval_runtime': 3.7349, 'eval_samples_per_second': 146.992, 'eval_steps_per_second': 9.371, 'epoch': 14.0}


  else torch.cuda.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype)


{'loss': 0.0015, 'learning_rate': 4.444297669803981e-05, 'epoch': 14.29}
{'loss': 0.0014, 'learning_rate': 3.5721239031346066e-05, 'epoch': 14.91}


  0%|          | 0/35 [00:00<?, ?it/s]

{'eval_loss': 0.005176238249987364, 'eval_f1_macro': 0.6261544474659215, 'eval_precision_macro': 0.6952381552350962, 'eval_recall_macro': 0.574736277677645, 'eval_hamming_loss': 0.043632527809743, 'eval_total_positives': 4389, 'eval_runtime': 3.4961, 'eval_samples_per_second': 157.031, 'eval_steps_per_second': 10.011, 'epoch': 14.99}


  else torch.cuda.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype)


{'loss': 0.0014, 'learning_rate': 2.776360379402445e-05, 'epoch': 15.53}


  0%|          | 0/35 [00:00<?, ?it/s]

{'eval_loss': 0.005228512454777956, 'eval_f1_macro': 0.6271913747656575, 'eval_precision_macro': 0.6888739012801444, 'eval_recall_macro': 0.5796567753815595, 'eval_hamming_loss': 0.0437923539189362, 'eval_total_positives': 4389, 'eval_runtime': 2.5122, 'eval_samples_per_second': 218.536, 'eval_steps_per_second': 13.932, 'epoch': 16.0}


  else torch.cuda.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype)


{'loss': 0.0014, 'learning_rate': 2.0664665970876496e-05, 'epoch': 16.15}
{'loss': 0.0014, 'learning_rate': 1.4508812932705363e-05, 'epoch': 16.77}


  0%|          | 0/35 [00:00<?, ?it/s]

{'eval_loss': 0.005314307752996683, 'eval_f1_macro': 0.6274086303246702, 'eval_precision_macro': 0.696680668420151, 'eval_recall_macro': 0.5737421192776829, 'eval_hamming_loss': 0.0433927886459532, 'eval_total_positives': 4389, 'eval_runtime': 2.4039, 'eval_samples_per_second': 228.382, 'eval_steps_per_second': 14.56, 'epoch': 16.99}
{'train_runtime': 1089.3009, 'train_samples_per_second': 47.039, 'train_steps_per_second': 1.469, 'train_loss': 0.0031209433797136908, 'epoch': 16.99}

Training completed!
Training loss: 0.0031
Training time: 1089.30s


In [41]:
import glob
import gc


def apply_swa(trainer, checkpoint_dir, last_n=5):
    """
    Stochastic Weight Averaging: average weights from last N checkpoints.
    Memory-efficient: loads raw state_dicts from safetensors/bin files
    instead of instantiating full models.
    """
    checkpoints = sorted(
        glob.glob(f"{checkpoint_dir}/checkpoint-*"),
        key=lambda x: int(x.split("-")[-1])
    )

    if len(checkpoints) < 2:
        print(f"Only {len(checkpoints)} checkpoint(s) found, skipping SWA.")
        return

    last_checkpoints = checkpoints[-last_n:]
    print(f"SWA: Averaging {len(last_checkpoints)} checkpoints:")
    for cp in last_checkpoints:
        print(f"  {cp}")

    avg_state_dict = None
    n = len(last_checkpoints)

    for cp_path in last_checkpoints:
        # Load raw state_dict without building a full model
        import os
        safetensors_path = os.path.join(cp_path, "model.safetensors")
        bin_path = os.path.join(cp_path, "pytorch_model.bin")

        if os.path.exists(safetensors_path):
            from safetensors.torch import load_file
            state = load_file(safetensors_path, device="cpu")
        elif os.path.exists(bin_path):
            state = torch.load(bin_path, map_location="cpu", weights_only=True)
        else:
            print(f"  WARNING: No model file found in {cp_path}, skipping.")
            n -= 1
            continue

        if avg_state_dict is None:
            avg_state_dict = {k: v.float() for k, v in state.items()}
        else:
            for k in avg_state_dict:
                avg_state_dict[k] += state[k].float()

        del state
        gc.collect()

    if avg_state_dict is None or n < 1:
        print("SWA: No valid checkpoints found, skipping.")
        return

    for k in avg_state_dict:
        avg_state_dict[k] /= n

    trainer.model.load_state_dict(avg_state_dict)
    trainer.model.to(trainer.args.device)
    del avg_state_dict
    gc.collect()
    print(f"SWA weights loaded successfully (averaged {n} checkpoints).")


# Apply SWA over the last 5 checkpoints
apply_swa(trainer, "./gus-net-bert-multilabel", last_n=5)


SWA: Averaging 5 checkpoints:
  ./gus-net-bert-multilabel\checkpoint-1046
  ./gus-net-bert-multilabel\checkpoint-1127
  ./gus-net-bert-multilabel\checkpoint-1207
  ./gus-net-bert-multilabel\checkpoint-1288
  ./gus-net-bert-multilabel\checkpoint-1368
SWA weights loaded successfully (averaged 5 checkpoints).


In [42]:
# Evaluate on dev set with default thresholds
print("Evaluating on development set...")
dev_metrics = trainer.evaluate(dev_split)

print("\nDevelopment set results (default thresholds=0.5):")
print(f"  Macro F1: {dev_metrics['eval_f1_macro']:.4f}")
print(f"  Precision: {dev_metrics['eval_precision_macro']:.4f}")
print(f"  Recall: {dev_metrics['eval_recall_macro']:.4f}")
print(f"  Hamming Loss: {dev_metrics['eval_hamming_loss']:.4f}")


Evaluating on development set...


  else torch.cuda.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype)


  0%|          | 0/35 [00:00<?, ?it/s]


Development set results (default thresholds=0.5):
  Macro F1: 0.6283
  Precision: 0.6939
  Recall: 0.5787
  Hamming Loss: 0.0437


In [43]:
from scipy.optimize import minimize_scalar


def optimize_thresholds(trainer, dev_dataset, grid=None):
    """
    Find optimal thresholds per channel on the validation split.
    Two-pass approach:
      1. Coarse grid search over 37 threshold values
      2. Fine-grained refinement with scipy bounded optimization
    """
    model = trainer.model
    model.eval()

    if grid is None:
        grid = np.arange(0.05, 0.96, 0.025).tolist()  # 37 points

    all_probs = []
    all_labels = []

    # Collect logits and labels on dev set
    dataloader = trainer.get_eval_dataloader(dev_dataset)
    for batch in dataloader:
        with torch.no_grad():
            labels = batch["labels"].detach().cpu().numpy()
            inputs = {k: v.to(model.device) for k, v in batch.items() if k != "labels"}
            outputs = model(**inputs)
            logits = outputs.logits.detach().cpu().numpy()

        all_probs.append(1 / (1 + np.exp(-logits)))
        all_labels.append(labels)

    all_probs = np.concatenate(all_probs, axis=0)
    all_labels = np.concatenate(all_labels, axis=0)

    # Mask for valid tokens
    valid_mask = all_labels[:, :, 0] != -100.0
    probs_flat = all_probs[valid_mask]
    labels_flat = all_labels[valid_mask]
    labels_bin = labels_flat.astype(int)

    best_thresholds = np.zeros(num_channels, dtype=np.float32)

    # Pass 1: Coarse grid search
    print("\nPass 1 — Coarse grid search:")
    for c in range(num_channels):
        y_true_c = labels_bin[:, c]
        probs_c = probs_flat[:, c]

        best_thr_c = 0.5
        best_f1_c = 0.0

        for thr in grid:
            y_pred_c = (probs_c >= thr).astype(int)
            f1_c = f1_score(y_true_c, y_pred_c, average="binary", zero_division=0)
            if f1_c > best_f1_c:
                best_f1_c = f1_c
                best_thr_c = thr

        best_thresholds[c] = best_thr_c
        print(f"  {channels[c]}: threshold={best_thr_c:.3f}, F1={best_f1_c:.4f}")

    # Pass 2: Fine-grained refinement with scipy
    print("\nPass 2 — Scipy refinement:")
    for c in range(num_channels):
        y_true_c = labels_bin[:, c]
        probs_c = probs_flat[:, c]

        # Search around the coarse optimum
        lo = max(0.01, best_thresholds[c] - 0.05)
        hi = min(0.99, best_thresholds[c] + 0.05)

        def neg_f1(thr):
            y_pred_c = (probs_c >= thr).astype(int)
            return -f1_score(y_true_c, y_pred_c, average="binary", zero_division=0)

        result = minimize_scalar(neg_f1, bounds=(lo, hi), method="bounded")
        refined_thr = result.x
        refined_f1 = -result.fun

        if refined_f1 >= f1_score(y_true_c, (probs_c >= best_thresholds[c]).astype(int),
                                   average="binary", zero_division=0):
            best_thresholds[c] = refined_thr

        print(f"  {channels[c]}: threshold={best_thresholds[c]:.4f}, F1={refined_f1:.4f}")

    # Evaluate global metrics with optimized thresholds
    thr_mat = best_thresholds.reshape(1, num_channels)
    preds_bin = (probs_flat >= thr_mat).astype(int)

    # Per-channel F1
    channel_f1s = []
    for c in range(num_channels):
        f1_c = f1_score(labels_bin[:, c], preds_bin[:, c], average="binary", zero_division=0)
        channel_f1s.append(f1_c)
    best_f1_global = np.mean(channel_f1s)

    return best_thresholds, best_f1_global


In [44]:
print("Starting threshold optimization...")
best_thr, best_f1_dev = optimize_thresholds(trainer, dev_split)

print(f"\nOptimized thresholds: {best_thr}")
print(f"Macro-F1 on dev with optimized thresholds: {best_f1_dev:.4f}")

# Update global thresholds
thresholds = best_thr


Starting threshold optimization...

Pass 1 — Coarse grid search:
  B-GEN: threshold=0.450, F1=0.6846
  I-GEN: threshold=0.450, F1=0.6437
  B-UNFAIR: threshold=0.450, F1=0.5855
  I-UNFAIR: threshold=0.375, F1=0.5420
  B-STEREO: threshold=0.525, F1=0.6866
  I-STEREO: threshold=0.450, F1=0.7212

Pass 2 — Scipy refinement:
  B-GEN: threshold=0.4500, F1=0.6839
  I-GEN: threshold=0.4488, F1=0.6438
  B-UNFAIR: threshold=0.4500, F1=0.5836
  I-UNFAIR: threshold=0.3340, F1=0.5435
  B-STEREO: threshold=0.5185, F1=0.6866
  I-STEREO: threshold=0.4614, F1=0.7217

Optimized thresholds: [0.45       0.44881573 0.45       0.33402056 0.51845783 0.46144348]
Macro-F1 on dev with optimized thresholds: 0.6443


In [45]:
print("Evaluating on test set with optimized thresholds...")
test_metrics = trainer.evaluate(test_split)

print("\nTest set results (optimized thresholds):")
print(f"  Macro F1: {test_metrics['eval_f1_macro']:.4f}")
print(f"  Precision: {test_metrics['eval_precision_macro']:.4f}")
print(f"  Recall: {test_metrics['eval_recall_macro']:.4f}")
print(f"  Hamming Loss: {test_metrics['eval_hamming_loss']:.4f}")


Evaluating on test set with optimized thresholds...


  else torch.cuda.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype)


  0%|          | 0/35 [00:00<?, ?it/s]


Test set results (optimized thresholds):
  Macro F1: 0.6302
  Precision: 0.6368
  Recall: 0.6280
  Hamming Loss: 0.0438


In [46]:
def extract_spans_from_sequence(bio_preds, channel_pairs=None):
    """
    Extract entity spans from a single sequence of BIO predictions.
    Returns a set of (entity_type, start, end) tuples.
    """
    if channel_pairs is None:
        channel_pairs = {
            0: 1,   # B-GEN -> I-GEN
            2: 3,   # B-UNFAIR -> I-UNFAIR
            4: 5,   # B-STEREO -> I-STEREO
        }

    spans = []
    for b_idx, i_idx in channel_pairs.items():
        entity_type = channels[b_idx].replace("B-", "")
        in_span = False
        span_start = -1

        for t in range(len(bio_preds)):
            if bio_preds[t, b_idx] == 1:
                if in_span:
                    spans.append((entity_type, span_start, t))
                span_start = t
                in_span = True
            elif bio_preds[t, i_idx] == 1 and in_span:
                continue
            else:
                if in_span:
                    spans.append((entity_type, span_start, t))
                    in_span = False

        if in_span:
            spans.append((entity_type, span_start, len(bio_preds)))

    return spans


def compute_entity_metrics(trainer, dataset, thresholds):
    """
    Compute entity-level (span-level) F1/Precision/Recall.
    Matches the paper's evaluation methodology.
    """
    model = trainer.model
    model.eval()

    all_pred_spans = []
    all_gold_spans = []
    example_idx = 0

    dataloader = trainer.get_eval_dataloader(dataset)
    for batch in dataloader:
        with torch.no_grad():
            labels = batch["labels"].detach().cpu().numpy()
            inputs = {k: v.to(model.device) for k, v in batch.items() if k != "labels"}
            outputs = model(**inputs)
            probs = torch.sigmoid(outputs.logits).detach().cpu().numpy()

        for i in range(labels.shape[0]):
            valid_mask = labels[i, :, 0] != -100.0
            preds_i = (probs[i][valid_mask] >= thresholds).astype(int)
            labels_i = labels[i][valid_mask].astype(int)

            pred_spans = extract_spans_from_sequence(preds_i)
            gold_spans = extract_spans_from_sequence(labels_i)

            # Add example index to make spans unique across examples
            all_pred_spans.extend([(example_idx, *s) for s in pred_spans])
            all_gold_spans.extend([(example_idx, *s) for s in gold_spans])
            example_idx += 1

    # Per-entity-type metrics
    entity_types = ["GEN", "UNFAIR", "STEREO"]
    print("\nEntity-level evaluation:")
    print("-" * 60)

    type_metrics = {}
    for etype in entity_types:
        pred_set = set(s for s in all_pred_spans if s[1] == etype)
        gold_set = set(s for s in all_gold_spans if s[1] == etype)

        tp = len(pred_set & gold_set)
        fp = len(pred_set - gold_set)
        fn = len(gold_set - pred_set)

        p = tp / max(tp + fp, 1)
        r = tp / max(tp + fn, 1)
        f1 = 2 * p * r / max(p + r, 1e-8)

        type_metrics[etype] = {"precision": p, "recall": r, "f1": f1,
                               "support": len(gold_set)}
        print(f"  {etype:8s}: F1={f1:.4f}  P={p:.4f}  R={r:.4f}  (support={len(gold_set)})")

    # Overall entity-level metrics
    pred_set = set(all_pred_spans)
    gold_set = set(all_gold_spans)
    tp = len(pred_set & gold_set)
    fp = len(pred_set - gold_set)
    fn = len(gold_set - pred_set)

    overall_p = tp / max(tp + fp, 1)
    overall_r = tp / max(tp + fn, 1)
    overall_f1 = 2 * overall_p * overall_r / max(overall_p + overall_r, 1e-8)

    # Macro entity F1
    macro_f1 = np.mean([m["f1"] for m in type_metrics.values()])

    print("-" * 60)
    print(f"  {'MICRO':8s}: F1={overall_f1:.4f}  P={overall_p:.4f}  R={overall_r:.4f}")
    print(f"  {'MACRO':8s}: F1={macro_f1:.4f}")

    return {
        "entity_f1_micro": overall_f1,
        "entity_f1_macro": macro_f1,
        "entity_precision": overall_p,
        "entity_recall": overall_r,
        "per_type": type_metrics,
    }


# Run entity-level evaluation on test set
print("=" * 60)
print("ENTITY-LEVEL EVALUATION (Test Set)")
print("=" * 60)
entity_metrics = compute_entity_metrics(trainer, test_split, thresholds)


ENTITY-LEVEL EVALUATION (Test Set)

Entity-level evaluation:
------------------------------------------------------------
  GEN     : F1=0.5689  P=0.5319  R=0.6115  (support=749)
  UNFAIR  : F1=0.3841  P=0.4265  R=0.3494  (support=166)
  STEREO  : F1=0.4550  P=0.4810  R=0.4316  (support=234)
------------------------------------------------------------
  MICRO   : F1=0.5238  P=0.5112  R=0.5370
  MACRO   : F1=0.4693


In [47]:
# Save the fine-tuned model
output_dir = "./gus-net-bert-final"
trainer.save_model(output_dir)
tokenizer.save_pretrained(output_dir)

# Save optimized thresholds
np.save(f"{output_dir}/optimized_thresholds.npy", thresholds)

print(f"\nModel, tokenizer, and thresholds saved to {output_dir}")



Model, tokenizer, and thresholds saved to ./gus-net-bert-final


In [48]:
def predict_bias(text, model, tokenizer, thresholds, device="cuda"):
    """
    Predict bias labels for a given text using the trained model.
    """
    model.eval()
    model.to(device)
    
    # Tokenize
    inputs = tokenizer(
        text,
        return_tensors="pt",
        truncation=True,
        max_length=128,
        padding=True,
        return_offsets_mapping=True,
    )
    
    offset_mapping = inputs.pop("offset_mapping")[0].tolist()
    inputs = {k: v.to(device) for k, v in inputs.items()}
    
    # Predict
    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits[0]  # [seq_len, num_channels]
        probs = torch.sigmoid(logits).cpu().numpy()
    
    # Apply thresholds
    predictions = (probs >= thresholds).astype(int)
    
    # Extract spans
    tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
    results = []
    
    for i, (token, pred, offset) in enumerate(zip(tokens, predictions, offset_mapping)):
        if token in ["[CLS]", "[SEP]", "[PAD]"]:
            continue
        active_channels = [channels[j] for j in range(num_channels) if pred[j] == 1]
        if active_channels:
            start, end = offset
            results.append({
                "token": token,
                "position": (start, end),
                "labels": active_channels,
                "text": text[start:end] if start > 0 and end > 0 else token,
            })
    
    return results


# Test example
test_text = "Women are naturally better at multitasking than men."
predictions = predict_bias(test_text, model, tokenizer, thresholds)

print(f"\nInput text: {test_text}\n")
print("Detected biases:")
for pred in predictions:
    print(f"  '{pred['text']}' [{pred['token']}] -> {pred['labels']}")



Input text: Women are naturally better at multitasking than men.

Detected biases:
  'women' [women] -> ['B-GEN', 'B-STEREO']
  'are' [are] -> ['I-STEREO']
  'naturally' [naturally] -> ['I-STEREO']
  'better' [better] -> ['I-STEREO']
  'at' [at] -> ['I-STEREO']
  'multi' [multi] -> ['I-STEREO']
  'tas' [##tas] -> ['I-STEREO']
  'king' [##king] -> ['I-STEREO']
  'than' [than] -> ['I-STEREO']
  'men' [men] -> ['B-GEN', 'I-STEREO']
  '.' [.] -> ['I-GEN', 'I-STEREO']


In [49]:
print("=" * 70)
print("TRAINING SUMMARY")
print("=" * 70)
print(f"Model: BERT-base-uncased")
print(f"Dataset: GUS-Net (ethical-spectacle/gus-dataset-v1)")
print(f"Task: Multi-label token classification for bias detection")
print(f"\nBias types: GEN (Generalizations), UNFAIR (Unfairness), STEREO (Stereotypes)")
print(f"Number of channels: {num_channels}")
print(f"\nDataset splits:")
print(f"  Train: {len(train_split)} examples")
print(f"  Dev: {len(dev_split)} examples")
print(f"  Test: {len(test_split)} examples")
print(f"\nTraining techniques:")
print(f"  - Layer-wise Learning Rate Decay (decay=0.85)")
print(f"  - Cosine annealing scheduler with warmup")
print(f"  - Focal loss (gamma=2.0, label_smoothing=0.05)")
print(f"  - Classifier dropout: 0.3, Hidden dropout: 0.15")
print(f"  - Gradient accumulation (effective batch=32)")
print(f"  - Early stopping (patience=5)")
print(f"  - SWA (last 5 checkpoints)")
print(f"  - Two-pass threshold optimization (grid + scipy)")
print(f"\nToken-level test performance:")
print(f"  Macro F1: {test_metrics['eval_f1_macro']:.4f}")
print(f"  Precision: {test_metrics['eval_precision_macro']:.4f}")
print(f"  Recall: {test_metrics['eval_recall_macro']:.4f}")
print(f"  Hamming Loss: {test_metrics['eval_hamming_loss']:.4f}")
if 'entity_metrics' in dir():
    print(f"\nEntity-level test performance:")
    print(f"  Macro F1: {entity_metrics['entity_f1_macro']:.4f}")
    print(f"  Micro F1: {entity_metrics['entity_f1_micro']:.4f}")
    print(f"  Precision: {entity_metrics['entity_precision']:.4f}")
    print(f"  Recall: {entity_metrics['entity_recall']:.4f}")
print("=" * 70)


TRAINING SUMMARY
Model: BERT-base-uncased
Dataset: GUS-Net (ethical-spectacle/gus-dataset-v1)
Task: Multi-label token classification for bias detection

Bias types: GEN (Generalizations), UNFAIR (Unfairness), STEREO (Stereotypes)
Number of channels: 6

Dataset splits:
  Train: 2562 examples
  Dev: 549 examples
  Test: 550 examples

Training techniques:
  - Layer-wise Learning Rate Decay (decay=0.85)
  - Cosine annealing scheduler with warmup
  - Focal loss (gamma=2.0, label_smoothing=0.05)
  - Classifier dropout: 0.3, Hidden dropout: 0.15
  - Gradient accumulation (effective batch=32)
  - Early stopping (patience=5)
  - SWA (last 5 checkpoints)
  - Two-pass threshold optimization (grid + scipy)

Token-level test performance:
  Macro F1: 0.6302
  Precision: 0.6368
  Recall: 0.6280
  Hamming Loss: 0.0438

Entity-level test performance:
  Macro F1: 0.4693
  Micro F1: 0.5238
  Precision: 0.5112
  Recall: 0.5370
