# Fine‑tuning DeBERTa‑v3‑base Cross‑Encoder for Hallucination Detection
This notebook shows how to fine‑tune a **cross‑encoder** built on `microsoft/deberta‑v3‑base` to classify hallucination labels (`PASS` = grounded, `FAIL` = hallucinated) on the **HaluEval** dataset.  

**Main steps**
1. Install / import libraries  
2. Load & preprocess the dataset (`query`, `response`, `label`)  
3. Tokenise text pairs with DeBERTa tokenizer  
4. Fine‑tune using `Trainer` API (PyTorch backend)  
5. Evaluate & save the model for later inference

*Created automatically on 2025‑07‑01 02:22 UTC*

In [None]:
!pip install -q transformers datasets scikit-learn accelerate evaluate

In [1]:
import pandas as pd
import torch, random, numpy as np
from datasets import Dataset
from sklearn.model_selection import train_test_split
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    TrainingArguments,
    Trainer,
    DataCollatorWithPadding,
    set_seed
)
from sklearn.metrics import accuracy_score, precision_recall_fscore_support

SEED = 42
set_seed(SEED)
torch.manual_seed(SEED)
random.seed(SEED)
np.random.seed(SEED)

## 1  Load HaluEval and domain-specific dataset

In [2]:
# === Configuration ===
#input_path = "../data/halueval_groundedness.csv"
#output_path = "../data/halueval_biencoder_results.csv"

#input_path_challenger = "../data/synthetic_groundedness_challenger_set_long.csv"
#output_path_challenger = "../data/synthetic_groundedness_challenger_set_biencoder_results.csv"


# === Configuration for Google Colab ===
input_path = "/content/halueval_groundedness.csv"
output_path = "/content/halueval_biencoder_results.csv"

input_path_challenger = "/content/synthetic_groundedness_challenger_set_long.csv"
output_path_challenger = "/content/synthetic_groundedness_challenger_set_biencoder_results.csv"

from google.colab import files
# Upload files
uploaded = files.upload()

Saving halueval_groundedness.csv to halueval_groundedness.csv
Saving synthetic_groundedness_challenger_set.csv to synthetic_groundedness_challenger_set.csv
Saving synthetic_groundedness_challenger_set_long.csv to synthetic_groundedness_challenger_set_long.csv


In [7]:
# Load datasets
gen_df = pd.read_csv(input_path)
dom_df = pd.read_csv(input_path_challenger)

print("HaluEval:", gen_df.shape)
print("Challenger:", dom_df.shape)
display(gen_df.head())
display(dom_df.head())

HaluEval: (1000, 4)
Challenger: (770, 4)


Unnamed: 0,query,context,response,label
0,The manager in which Mark Lazarus clashed with...,He transferred to Wolverhampton Wanderers for...,1948 and 1964,PASS
1,The manager in which Mark Lazarus clashed with...,He transferred to Wolverhampton Wanderers for...,Mark Lazarus clashed with the manager of Wolve...,FAIL
2,"No. 11 Squadron RAAF was based at what base, 2...",No. 11 Squadron is a Royal Australian Air Forc...,RAAF Base Edinburgh,PASS
3,"No. 11 Squadron RAAF was based at what base, 2...",No. 11 Squadron is a Royal Australian Air Forc...,"RAAF Base Edinburgh, located in Australia.",FAIL
4,Which movie starring Kim Roi-ha is based on Ko...,"Kim Roi-ha (born November 15, 1965) is a South...",Memories of Murder,PASS


Unnamed: 0,query,context,response,label
0,I made a big purchase that I regretted and can...,### Refund Not Showing Up? Here's What You Nee...,"To verify your refund, please follow these ste...",PASS
1,I am still waiting to find out the status of m...,**Refund Status Inquiry**\r\n\r\nWe understand...,We understand that waiting for a refund can be...,PASS
2,"As advised by you, i requested seller to refun...",### Refund Status Inquiry\r\n\r\nThank you for...,Thank you for reaching out to us regarding you...,PASS
3,"Hurry and refund me, I am waiting",### Refund Not Showing Up\r\n\r\nWe understand...,We understand that waiting for a refund can be...,PASS
4,I was supposed to get a refund but I do not se...,**FAQ: My Refund is Not Showing Up**\r\n\r\nIf...,If you are expecting a refund but do not see i...,PASS


In [8]:
# Map labels from strings into numeric
label_map = {'PASS': 1, 'FAIL': 0}
gen_df['label'] = gen_df['label'].map(label_map)
dom_df['label'] = dom_df['label'].map(label_map)

# Combine query and context with markers
gen_df['query_context'] = "Query: " + gen_df['query'].astype(str) + " Context: " + gen_df['context'].astype(str)
dom_df['query_context'] = "Query: " + dom_df['query'].astype(str) + " Context: " + dom_df['context'].astype(str)

# Check label mapping
assert gen_df['label'].isna().sum() == 0, 'Unknown labels in gen_df!'
assert dom_df['label'].isna().sum() == 0, 'Unknown labels in dom_df!'

print("General label counts:\n", gen_df['label'].value_counts())
print("Challenger label counts:\n", dom_df['label'].value_counts())


General label counts:
 label
1    500
0    500
Name: count, dtype: int64
Challenger label counts:
 label
1    385
0    385
Name: count, dtype: int64


In [9]:
VAL_FRAC = 0.15
TEST_FRAC = 0.15
SEED = 42

def group_pairs(df):
    pair_key = ['query', 'context']
    groups = df.groupby(pair_key, as_index=False)
    return [g[1] for g in groups]

def three_way_pair_split(pairs, val_frac=0.15, test_frac=0.15, seed=42):
    pairs = pairs.copy()
    random.seed(seed)
    random.shuffle(pairs)
    n = len(pairs)
    n_test = int(n * test_frac)
    n_val = int(n * val_frac)
    test_pairs = pairs[:n_test]
    val_pairs = pairs[n_test:n_test+n_val]
    train_pairs = pairs[n_test+n_val:]
    train = pd.concat(train_pairs, ignore_index=True)
    val = pd.concat(val_pairs, ignore_index=True)
    test = pd.concat(test_pairs, ignore_index=True)
    return train, val, test

# --- Group into pairs ---
gen_pairs = group_pairs(gen_df)
domain_pairs = group_pairs(dom_df)

# --- Stratified 3-way split ---
gen_train, gen_val, gen_test = three_way_pair_split(gen_pairs, val_frac=VAL_FRAC, test_frac=TEST_FRAC, seed=SEED)
domain_train, domain_val, domain_test = three_way_pair_split(domain_pairs, val_frac=VAL_FRAC, test_frac=TEST_FRAC, seed=SEED)

# --- Combine splits ---
train = pd.concat([gen_train, domain_train], ignore_index=True)
val = pd.concat([gen_val, domain_val], ignore_index=True)
test = pd.concat([gen_test, domain_test], ignore_index=True)

# --- Save to disk ---
train.to_csv("combined_train.csv", index=False)
val.to_csv("combined_val.csv", index=False)
test.to_csv("combined_test.csv", index=False)

print(f"Train: {len(train)} rows")
print(f"Val: {len(val)} rows")
print(f"Test: {len(test)} rows")


Train: 1242 rows
Val: 264 rows
Test: 264 rows


## 2  Tokenise pairs with DeBERTa tokenizer

In [11]:
model_name = 'microsoft/deberta-v3-base'
tokenizer = AutoTokenizer.from_pretrained(model_name)

def tokenize_fn(batch):
    return tokenizer(batch['query_context'], batch['response'], truncation=True)

train_ds = Dataset.from_pandas(train)
train_ds = train_ds.map(tokenize_fn, batched=True)
train_ds = train_ds.rename_column('label', 'labels')

test_ds = Dataset.from_pandas(test)
test_ds = test_ds.map(tokenize_fn, batched=True)
test_ds = test_ds.rename_column('label', 'labels')

val_ds = Dataset.from_pandas(val)
val_ds = val_ds.map(tokenize_fn, batched=True)
val_ds = val_ds.rename_column('label', 'labels')

Map:   0%|          | 0/1242 [00:00<?, ? examples/s]

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


Map:   0%|          | 0/264 [00:00<?, ? examples/s]

Map:   0%|          | 0/264 [00:00<?, ? examples/s]

## 3  Define evaluation metrics

In [12]:
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = logits.argmax(axis=1)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='binary')
    acc = accuracy_score(labels, preds)
    return {'accuracy': acc, 'precision': precision, 'recall': recall, 'f1': f1}

## 4  Load pre‑trained DeBERTa‑v3‑base with classification head

In [13]:
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)

pytorch_model.bin:   0%|          | 0.00/371M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/371M [00:00<?, ?B/s]

Some weights of DebertaV2ForSequenceClassification were not initialized from the model checkpoint at microsoft/deberta-v3-base and are newly initialized: ['classifier.bias', 'classifier.weight', 'pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


## 5  Set training hyper‑parameters

In [None]:
training_args = TrainingArguments(
    output_dir='/content/deberta_halueval_ckpt',
    evaluation_strategy='epoch',
    save_strategy='epoch',
    learning_rate=2e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=3,
    weight_decay=0.01,
    load_best_model_at_end=True,
    metric_for_best_model='f1',
    logging_steps=20,
    report_to='none'
)

## 6  Fine‑tune

In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    tokenizer=tokenizer,
    data_collator=DataCollatorWithPadding(tokenizer),
    compute_metrics=compute_metrics,
)

trainer.train()

In [None]:
test_results = trainer.evaluate(test_ds)
print("Test set results:", test_results)

## 7  Save fine‑tuned model

In [None]:
trainer.save_model('deberta_crossencoder_halueval')
tokenizer.save_pretrained('deberta_crossencoder_halueval')

## 8  Run quick inference

In [None]:
from transformers import pipeline
clf = pipeline('text-classification', model='deberta_crossencoder_halueval', tokenizer='deberta_crossencoder_halueval', device=0 if torch.cuda.is_available() else -1)

test_query = 'How do I open a new savings account?'
test_response = 'You can open a savings account online or visit any of our branches with valid ID.'

pred = clf({'text_pair': (test_query, test_response)})
pred