In [1]:
pip install evaluate



In [2]:
import torch
import torch.nn as nn
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModel, TrainingArguments, Trainer
import evaluate
import json
import matplotlib.pyplot as plt
import os
from sklearn.metrics import precision_score, recall_score, f1_score
import numpy as np

In [3]:
# ============================================================
# 1. Load & clean dataset
# ============================================================

print("Loading dataset...")
ds = load_dataset("ailsntua/QEvasion")

# remove empty rows
ds = ds.filter(lambda x: bool(x["interview_question"]) and bool(x["interview_answer"]))

# map labels
clarity_labels = sorted(set(ds["train"]["clarity_label"]))
evasion_labels = sorted(set(ds["train"]["evasion_label"]))

clarity2id = {c: i for i, c in enumerate(clarity_labels)}
id2clarity = {i: c for c, i in clarity2id.items()}
evasion2id = {c: i for i, c in enumerate(evasion_labels)}
id2evasion = {i: c for c, i in evasion2id.items()}

print("Clarity labels:", clarity2id)
print("Evasion labels:", evasion2id)

Loading dataset...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Clarity labels: {'Ambivalent': 0, 'Clear Non-Reply': 1, 'Clear Reply': 2}
Evasion labels: {'Claims ignorance': 0, 'Clarification': 1, 'Declining to answer': 2, 'Deflection': 3, 'Dodging': 4, 'Explicit': 5, 'General': 6, 'Implicit': 7, 'Partial/half-answer': 8}


In [4]:
# ============================================================
# 2. Tokenizer
# ============================================================

model_name = "microsoft/deberta-v3-base"
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)


def preprocess(batch):
    """Use HF's built-in text pair encoding (question, answer)."""

    enc = tokenizer(
        batch["interview_question"],
        batch["interview_answer"],
        padding="max_length",
        truncation=True,
        max_length=512
    )

    enc["labels"] = [
        (
            clarity2id[c] if c in clarity2id else 0,
            evasion2id[e] if e in evasion2id else 0
        )
        for c, e in zip(batch["clarity_label"], batch["evasion_label"])
    ]

    return enc


print("Tokenizing dataset...")
encoded = ds.map(preprocess, batched=True)

# keep only model-required columns
encoded = encoded.remove_columns(
    [c for c in ds["train"].column_names if c not in ["labels"]]
)

encoded.set_format("torch")



Tokenizing dataset...


In [5]:
# ============================================================
# 3. Multi-task model
# ============================================================

class MultiTaskDeberta(nn.Module):
    def __init__(self, base_model, num_clarity, num_evasion):
        super().__init__()

        self.encoder = AutoModel.from_pretrained(base_model)
        hidden = self.encoder.config.hidden_size

        self.clarity_head = nn.Linear(hidden, num_clarity)
        self.evasion_head = nn.Linear(hidden, num_evasion)

        self.loss_fn = nn.CrossEntropyLoss()

    def forward(self, input_ids, attention_mask, token_type_ids=None, labels=None):
        enc_out = self.encoder(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        cls = enc_out.last_hidden_state[:, 0, :]

        logits_clarity = self.clarity_head(cls)
        logits_evasion = self.evasion_head(cls)

        loss = None
        if labels is not None:
            labels_clarity = labels[:, 0]
            labels_evasion = labels[:, 1]

            loss = (self.loss_fn(logits_clarity, labels_clarity) +
                    self.loss_fn(logits_evasion, labels_evasion)) / 2.0

        # Return a tuple (loss, logits_clarity, logits_evasion) if labels exist
        return (loss, logits_clarity, logits_evasion) if loss is not None else (logits_clarity, logits_evasion)



print("Initializing model...")
model = MultiTaskDeberta(
    base_model=model_name,
    num_clarity=len(clarity2id),
    num_evasion=len(evasion2id)
)

Initializing model...


In [6]:
# ============================================================
# 4. Metrics
# ============================================================

acc = evaluate.load("accuracy")
precision = evaluate.load("precision")
recall = evaluate.load("recall")
f1 = evaluate.load("f1")


def compute_metrics(pred):
    # Unpack the tuple returned by the model
    # pred.predictions: (logits_clarity, logits_evasion)
    logits_clarity, logits_evasion = pred.predictions

    # Convert logits to predicted labels
    preds_clarity = np.argmax(logits_clarity, axis=-1)
    preds_evasion = np.argmax(logits_evasion, axis=-1)

    # True labels
    labels_clarity = pred.label_ids[:, 0]
    labels_evasion = pred.label_ids[:, 1]

    # Compute metrics
    metrics = {
        "clarity_accuracy": acc.compute(predictions=preds_clarity, references=labels_clarity)["accuracy"],
        "clarity_precision": precision_score(labels_clarity, preds_clarity, average="weighted", zero_division=0),
        "clarity_recall": recall_score(labels_clarity, preds_clarity, average="weighted", zero_division=0),
        "clarity_f1": f1_score(labels_clarity, preds_clarity, average="weighted", zero_division=0),

        "evasion_accuracy": acc.compute(predictions=preds_evasion, references=labels_evasion)["accuracy"],
        "evasion_precision": precision_score(labels_evasion, preds_evasion, average="weighted", zero_division=0),
        "evasion_recall": recall_score(labels_evasion, preds_evasion, average="weighted", zero_division=0),
        "evasion_f1": f1_score(labels_evasion, preds_evasion, average="weighted", zero_division=0),
    }

    return metrics


In [None]:
# ============================================================
# 5. Trainer setup
# ============================================================

training_args = TrainingArguments(
    output_dir="./clarity_model",
    learning_rate=2e-5,
    num_train_epochs=3,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    weight_decay=0.01,
    logging_steps=50,
    do_eval = True,
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    report_to="none"
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=encoded["train"],
    eval_dataset=encoded.get("validation", encoded["test"]),
    compute_metrics=compute_metrics
)

print("\nStarting training...\n")
trainer.train()

print("Saving model and trainer state...")
trainer.save_model("./clarity_model")
trainer.save_state()


Starting training...



Epoch,Training Loss,Validation Loss,Clarity Accuracy,Clarity Precision,Clarity Recall,Clarity F1,Evasion Accuracy,Evasion Precision,Evasion Recall,Evasion F1
1,1.3418,2.216249,0.678571,0.648376,0.678571,0.611413,0.006494,1.0,0.006494,0.012903
2,1.1175,2.296629,0.685065,0.698186,0.685065,0.675198,0.097403,1.0,0.097403,0.177515
3,1.1495,2.625414,0.681818,0.677487,0.681818,0.670791,0.058442,1.0,0.058442,0.110429


In [None]:
# Force save trainer state (for plotting)
os.makedirs("./clarity_model", exist_ok=True)
with open("./clarity_model/trainer_state.json", "w") as f:
    json.dump(trainer.state.log_history, f)

In [None]:
# ============================================================
# 6. Plot Training Loss and Evaluation Metrics
# ============================================================

print("\nPlotting training curves...\n")

state_file = "./clarity_model/trainer_state.json"

if os.path.exists(state_file):
    with open(state_file, "r") as f:
        state = json.load(f)

    if isinstance(state, list):
        logs = state
    else:
        logs = state.get("log_history", [])

    steps, train_loss = [], []

    # Metric lists for clarity
    clarity_acc_vals, clarity_prec_vals, clarity_recall_vals, clarity_f1_vals = [], [], [], []

    # Metric lists for evasion
    evasion_acc_vals, evasion_prec_vals, evasion_recall_vals, evasion_f1_vals = [], [], [], []

    for entry in logs:
        if "loss" in entry and "step" in entry:
            steps.append(entry["step"])
            train_loss.append(entry["loss"])

        # Clarity metrics
        if "eval_clarity_accuracy" in entry:
            clarity_acc_vals.append(entry["eval_clarity_accuracy"])
        if "eval_clarity_precision" in entry:
            clarity_prec_vals.append(entry["eval_clarity_precision"])
        if "eval_clarity_recall" in entry:
            clarity_recall_vals.append(entry["eval_clarity_recall"])
        if "eval_clarity_f1" in entry:
            clarity_f1_vals.append(entry["eval_clarity_f1"])

        # Evasion metrics
        if "eval_evasion_accuracy" in entry:
            evasion_acc_vals.append(entry["eval_evasion_accuracy"])
        if "eval_evasion_precision" in entry:
            evasion_prec_vals.append(entry["eval_evasion_precision"])
        if "eval_evasion_recall" in entry:
            evasion_recall_vals.append(entry["eval_evasion_recall"])
        if "eval_evasion_f1" in entry:
            evasion_f1_vals.append(entry["eval_evasion_f1"])

    # Plot training loss
    plt.figure(figsize=(10,5))
    plt.plot(steps, train_loss, label="Training Loss")
    plt.xlabel("Step")
    plt.ylabel("Loss")
    plt.title("Training Loss Over Time")
    plt.grid(True)
    plt.legend()
    plt.show()

    # Plot Clarity metrics
    plt.figure(figsize=(10,5))
    if clarity_acc_vals: plt.plot(range(1, len(clarity_acc_vals)+1), clarity_acc_vals, label="Accuracy")
    if clarity_prec_vals: plt.plot(range(1, len(clarity_prec_vals)+1), clarity_prec_vals, label="Precision")
    if clarity_recall_vals: plt.plot(range(1, len(clarity_recall_vals)+1), clarity_recall_vals, label="Recall")
    if clarity_f1_vals: plt.plot(range(1, len(clarity_f1_vals)+1), clarity_f1_vals, label="F1 Score")
    plt.xlabel("Evaluation Epoch")
    plt.ylabel("Metric Value")
    plt.title("Clarity Head Metrics Over Epochs")
    plt.grid(True)
    plt.legend()
    plt.show()

    # Plot Evasion metrics
    plt.figure(figsize=(10,5))
    if evasion_acc_vals: plt.plot(range(1, len(evasion_acc_vals)+1), evasion_acc_vals, label="Accuracy")
    if evasion_prec_vals: plt.plot(range(1, len(evasion_prec_vals)+1), evasion_prec_vals, label="Precision")
    if evasion_recall_vals: plt.plot(range(1, len(evasion_recall_vals)+1), evasion_recall_vals, label="Recall")
    if evasion_f1_vals: plt.plot(range(1, len(evasion_f1_vals)+1), evasion_f1_vals, label="F1 Score")
    plt.xlabel("Evaluation Epoch")
    plt.ylabel("Metric Value")
    plt.title("Evasion Head Metrics Over Epochs")
    plt.grid(True)
    plt.legend()
    plt.show()

else:
    print("trainer_state.json not found. Run training first.")


In [None]:
# ============================================================
# 7. Inference wrapper
# ============================================================

def predict(question, answer):
    device = torch.device("cpu")
    model.to(device)
    model.eval()

    encoded = tokenizer(
        question,
        answer,
        return_tensors="pt",
        truncation=True,
        max_length=512
    )

    inputs = {k: v.to(device) for k, v in encoded.items()}

    with torch.no_grad():
        out = model(**inputs)

    clarity_id = out["logits_clarity"].argmax(dim=-1).item()
    evasion_id = out["logits_evasion"].argmax(dim=-1).item()

    return id2clarity[clarity_id], id2evasion[evasion_id]

In [None]:
# ============================================================
# 8. Examples
# ============================================================

examples = [
    ("Why did you veto the bill?",
     "Because the timing was not appropriate."),

    ("Can you explain why the health budget was reduced?",
     "We are looking into multiple ways of improving efficiency."),

    ("Do you take responsibility for the outcome?",
     "Our team has been working very hard and we are committed to improvement."),

    ("Why was the water infrastructure project delayed?",
     "We are currently reviewing contractor proposals to ensure compliance."),

    ("Can you explain why the education grant was cut this quarter?",
     "Several departments are undergoing restructuring to optimize future funding."),

    ("Did your office approve the new zoning changes?",
     "Urban development is a collaborative process, and many stakeholders are involved."),

    ("Why haven't the environmental reports been released yet?",
     "We want to ensure that all findings are thoroughly verified before publication."),

    ("Is the transportation department responsible for the recent safety lapses?",
     "Safety is a shared responsibility and we're evaluating all contributing factors."),

    ("Why did you choose not to attend the committee hearing?",
     "Scheduling conflicts required my attention elsewhere, but my team was present."),

    ("Are you planning to revise the energy policy this year?",
     "We are exploring many promising avenues and will have more to share soon."),

    ("Why were the housing subsidies paused?",
     "We're analyzing long-term impacts to ensure sustainability."),

    ("Did the agency conduct the internal audit as required?",
     "Audits are ongoing, and we'll release findings once the review is complete."),

    ("What led to the recent increase in administrative fees?",
     "The agency is evaluating cost structures to improve service delivery.")
]

print("\nExample Predictions:\n")
for q, a in examples:
    c, e = predict(q, a)
    print(f"Q: {q}\nA: {a}\nâ†’ Clarity: {c} | Evasion: {e}\n")