<sub>Developed by SeongKu Kang, August 2025 — Do not distribute</sub>

# Task 2: Fine-tuning with Cross-Encoders

In this notebook, we shift our focus from representation-based models (e.g., dual encoders that independently encode queries and documents) to **cross-encoders**, where a query and a document are processed *together* through the same BERT model.  

The key idea of a cross-encoder is that it allows the model to **jointly attend** to both the query and the document tokens. Instead of computing embeddings separately and then measuring similarity, the cross-encoder takes the concatenated sequence:

`[CLS] query tokens [SEP] document tokens [SEP]`

and produces a single contextualized representation. The `[CLS]` token is then used for classification, such as predicting the relevance of the document to the query.  

This approach is particularly suited for **ESCI (Exact, Substitute, Complement, Irrelevant)** classification, where fine-grained semantic distinctions matter. Cross-encoders capture subtle interactions between query and document words (e.g., synonyms, negations, product attributes) that embedding-only approaches often miss.  

Although cross-encoders are more computationally expensive at inference time, they typically deliver **higher accuracy** for tasks requiring nuanced semantic matching.  

In this notebook, we will:
1. Format query–document pairs into BERT inputs.  
2. Fine-tune a pre-trained BERT model with supervised ESCI labels.  
3. Evaluate classification performance.  

⚠️ **Note**   
This notebook may take a long time to run, since it repeatedly encodes texts with BERT.  
It is provided mainly for **reference**, and you are encouraged to review the workflow rather than execute every cell.

In [1]:
import json
import random
from tqdm import tqdm
from pathlib import Path
from utils import * 
import copy

import pandas as pd
from collections import Counter

import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split, ConcatDataset

import warnings
warnings.filterwarnings("ignore", category=FutureWarning)

device = "cuda" if torch.cuda.is_available() else "cpu"

In [2]:
# Path config (Cross-encoder setting: on-the-fly encoding, no precomputed embeddings used)
ROOT = Path("dataset")

pid2text = load_corpus(ROOT / "corpus.jsonl")
qid2text = load_queries(ROOT / "queries_1k.jsonl")
test_qid2text = load_queries(ROOT / "queries_test.jsonl")

In [3]:
# === ESCI label mapping ===
esci_label2id = {"E": 0, "S": 1, "C": 2, "I": 3}
esci_id2label = {0: "E", 1: "S", 2: "C", 3: "I"}

def load_qrels(path, esci_label2id):
    """
    Load qrels file (query-document relevance labels).
    - Keeps only rows with ESCI labels (E/S/C/I).
    - Maps string labels to numeric IDs.
    """
    df = pd.read_csv(path, sep="\t", header=0)        # read TSV with header
    df = df[df["label"].isin(esci_label2id)]          # filter invalid labels
    df["label_idx"] = df["label"].map(esci_label2id)  # map to numeric IDs
    return df

# === Load qrels (train & test) ===
QRELS_TRAIN_PATH = ROOT / "qrels_1k.tsv"
QRELS_TEST_PATH = ROOT / "qrels_test.tsv"

train_qrels_df = load_qrels(QRELS_TRAIN_PATH, esci_label2id)
test_qrels_df = load_qrels(QRELS_TEST_PATH, esci_label2id)

train_qrels_df.head(3), test_qrels_df.head(3)

(   query-id   corpus-id label  label_idx
 0     17397  B07R9HPJPW     E          0
 1     17397  B07Q34YJ4T     E          0
 2     17397  B07NH217CV     E          0,
    query-id   corpus-id label  label_idx
 0     12712  B081FY5ZYQ     S          1
 1     12712  B07X45P8C3     S          1
 2     12712  B07V9PPK61     S          1)

In [4]:
def build_triplets(qrels_df):
    """
    Build (query_id, product_id, label) triplets from qrels dataframe.
    Args:
        qrels_df: DataFrame where each row = (qid, pid, label)
    Returns:
        list of tuples: [(qid, pid, label), ...]
    """
    triplets = []
    for row in qrels_df.itertuples(index=False):
        qid = row[0]      # query ID
        pid = row[1]      # product/document ID
        label = row[-1]   # relevance label
        triplets.append((qid, pid, label))
    return triplets

# Build train/test triplets
train_triplets = build_triplets(train_qrels_df)
test_triplets = build_triplets(test_qrels_df)

print(f"# Train triplets: {len(train_triplets)}")
print(f"# Test triplets:  {len(test_triplets)}")

# Train triplets: 20303
# Test triplets:  10149


In [5]:
class CrossEncoderDataset(Dataset):
    def __init__(self, triplets, qid2text, pid2text, tokenizer, max_length=512):
        """
        Cross-encoder dataset: each instance is (query, document, label)
        Args:
            triplets: list of (qid, pid, label)
            qid2text: dict mapping query IDs to query text
            pid2text: dict mapping product IDs to document text
            tokenizer: HuggingFace tokenizer for BERT/Transformer
            max_length: maximum sequence length for tokenization
        """
        self.pairs = []
        for qid, pid, label in triplets:
            if qid in qid2text and pid in pid2text:  # safeguard
                self.pairs.append((qid2text[qid], pid2text[pid], label))
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        query, doc, label = self.pairs[idx]

        # Tokenize query–document pair for cross-encoder input
        encoded = self.tokenizer(
            query,
            doc,
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt"
        )

        return {
            "input_ids": encoded["input_ids"].squeeze(0),       # token IDs (query+doc packed together)
            "attention_mask": encoded["attention_mask"].squeeze(0), # 1 = real token, 0 = padding
            "y": torch.tensor(label, dtype=torch.long)          # classification/regression label
        }

In [6]:
from collections import defaultdict
from sklearn.metrics import accuracy_score, f1_score

def evaluate(model, dataloader, device="cpu", num_classes=None):
    model.eval()
    all_preds, all_labels = [], []

    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)

            y = batch["y"].to(device)
            logits = model(input_ids=input_ids, attention_mask=attention_mask)
            preds = torch.argmax(logits, dim=1)
            all_preds.extend(preds.cpu().tolist())
            all_labels.extend(y.cpu().tolist())

    acc = accuracy_score(all_labels, all_preds)
    f1_macro = f1_score(all_labels, all_preds, average="macro", zero_division=0)

    # === Per-class accuracy 계산 ===
    class_correct = defaultdict(int)
    class_total = defaultdict(int)
    
    for y_true, y_pred in zip(all_labels, all_preds):
        class_total[y_true] += 1
        if y_true == y_pred:
            class_correct[y_true] += 1

    per_class_acc = {}
    class_range = range(num_classes) if num_classes is not None else sorted(class_total.keys())
    for cls in class_range:
        total = class_total[cls]
        correct = class_correct[cls]
        per_class_acc[cls] = correct / total if total > 0 else 0.0

    return {
        "accuracy": acc,
        "f1_macro": f1_macro,
        "per_class_accuracy": per_class_acc
    }

In [7]:
from transformers import AutoTokenizer, AutoModel

class CrossEncoderESCI(nn.Module):
    def __init__(self, model_name="bert-base-uncased", num_labels=4):
        super().__init__()
        # Pretrained BERT (or any Transformer encoder)
        self.bert = AutoModel.from_pretrained(model_name)
        # Linear classifier: [CLS] hidden vector → num_labels
        self.classifier = nn.Linear(self.bert.config.hidden_size, num_labels)

    def forward(self, input_ids, attention_mask, token_type_ids=None):
        # BERT encoding
        outputs = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids
        )
        # CLS token embedding (batch_size × hidden_size)
        cls_output = outputs.last_hidden_state[:, 0, :]
        # Classification logits (batch_size × num_labels)
        logits = self.classifier(cls_output)
        return logits

In [8]:
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
test_dataset = CrossEncoderDataset(test_triplets, test_qid2text, pid2text, tokenizer)
test_loader = DataLoader(test_dataset, batch_size=64)

results_dict = {'valid':{}, 'test': {}}

In [9]:
# === Split into train/validation sets ===
random.shuffle(train_triplets)  
val_size = int(len(train_triplets) * 0.2)  
val_triplets = train_triplets[:val_size]   # first 20% for validation
train_triplets = train_triplets[val_size:] # remaining 80% for training

# === Build datasets/loaders ===
train_dataset = CrossEncoderDataset(train_triplets, qid2text, pid2text, tokenizer)
val_dataset = CrossEncoderDataset(val_triplets, qid2text, pid2text, tokenizer)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)   # shuffle for training
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)      # no shuffle for validation

In [13]:
# === Initialize model and optimizer ===
model = CrossEncoderESCI(model_name="bert-base-uncased", num_labels=4).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)  # AdamW is standard for transformer fine-tuning

In [None]:
best_val_acc = -1
best_model_state = None
patience = 5
patience_counter = 0

val_acc_list = []
test_acc_list = []

EPOCHS = 500

for epoch in range(1, EPOCHS + 1):
    model.train()
    total_loss = 0

    for batch in tqdm(train_loader, desc=f"Epoch {epoch}"):
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        y = batch["y"].to(device)
        logits = model(input_ids=input_ids, attention_mask=attention_mask)
        loss = F.cross_entropy(logits, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    avg_loss = total_loss / len(train_loader)
    print(f"[Epoch {epoch}] Train Loss: {avg_loss:.4f}")

    # === Validation ===
    val_result = evaluate(model, val_loader, device=device)
    val_acc = val_result["f1_macro"]
    val_acc_list.append(val_acc)

    is_improved = val_acc > best_val_acc
    print_eval_result_esci(val_result, stage="val", is_improved=is_improved)

    # === Test ===
    test_result = evaluate(model, test_loader, device=device)
    test_acc = test_result["f1_macro"]
    test_acc_list.append(test_acc)
    print_eval_result_esci(test_result, stage="test")

    # === Update best model ===
    if is_improved:
        best_val_acc = val_acc
        best_model_state = copy.deepcopy(model.state_dict())
        patience_counter = 0
    else:
        patience_counter += 1

    # === Early stopping ===
    if patience_counter >= patience:
        print(f"[Early Stopping] No improvement for {patience} consecutive epochs.")
        break

Epoch 1: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 508/508 [05:35<00:00,  1.52it/s]


[Epoch 1] Train Loss: 1.0564
[VAL ] Acc: 0.5953 | F1-macro: 0.4423 *
        E: 0.9110 | S: 0.3331 | C: 0.0510 | I: 0.4214
[TEST] Acc: 0.5061 | F1-macro: 0.2926
        E: 0.9074 | S: 0.1606 | C: 0.0000 | I: 0.1916


Epoch 2: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 508/508 [05:35<00:00,  1.52it/s]


[Epoch 2] Train Loss: 0.7908
[VAL ] Acc: 0.6771 | F1-macro: 0.6270 *
        E: 0.7871 | S: 0.6445 | C: 0.4847 | I: 0.5036
[TEST] Acc: 0.5269 | F1-macro: 0.4099
        E: 0.6924 | S: 0.4829 | C: 0.1499 | I: 0.2535


Epoch 3:  22%|█████████████████████████████████████▏                                                                                                                                    | 111/508 [01:12<04:22,  1.51it/s]

In [None]:
model.load_state_dict(best_model_state)
final_test_result = evaluate(model, test_loader, device=device)
print_eval_result_esci(final_test_result, stage="final_test")