<a href="https://colab.research.google.com/github/suleiman-odeh/NLP_Project_Team16/blob/main/Gemma_2/zero_shot_direct_Gemma_2_9B.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install -q -U transformers bitsandbytes accelerate

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.0/44.0 kB[0m [31m1.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m12.0/12.0 MB[0m [31m41.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m59.1/59.1 MB[0m [31m10.0 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
"""
loading the model
cleared the output since it cant be upload to github
"""
import torch
import pandas as pd
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from huggingface_hub import login

# logging using user access token
login()

#  Define 4-Bit Configuration
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,              # Loading in 4-bit
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16, # Compute in 16-bit for speed, store in 4-bit
)

# Load model
model_id = "google/gemma-2-9b-it"

print(f"Loading {model_id} in 4-bit...")
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=bnb_config,
    device_map="auto",
    trust_remote_code=True
)

print("Model loaded successfully!")


In [5]:
"""
load the dataset and keeps only the TEST split.
We only need test for zero-shot evaluation.
"""

import pandas as pd

DATA_FILE = "QEvasion_cleaned.jsonl"

print(f"Loading data from {DATA_FILE} ...")
df = pd.read_json(DATA_FILE, lines=True)

test_df = df[df["split_type"] == "test"].copy()
print(f"Test rows: {len(test_df)}")

# test
print("\nColumns:", list(test_df.columns))
print(test_df[["question", "cleaned_answer", "clarity_id"]].head(2))

Loading data from QEvasion_cleaned.jsonl ...
Test rows: 308

Columns: ['title', 'date', 'president', 'url', 'question_order', 'interview_question', 'interview_answer', 'gpt3.5_summary', 'gpt3.5_prediction', 'question', 'annotator_id', 'annotator1', 'annotator2', 'annotator3', 'inaudible', 'multiple_questions', 'affirmative_questions', 'index', 'clarity_label', 'evasion_label', 'split_type', 'cleaned_answer', 'clarity_id', 'evasion_id', 'annotator1_id', 'annotator2_id', 'annotator3_id']
                                               question  \
3448   Inquiring about the status or information reg...   
3449  Will you invite them to the White House to neg...   

                                         cleaned_answer  clarity_id  
3448  Well, the world has made it clear that these t...           1  
3449  I think that anytime and anyplace that they ar...           1  


In [6]:
"""
define the direct (3-class) prompt and a small parser.
We keep the output format strict so evaluation is reliable.
"""

import re

def make_direct_prompt(question, answer):
    return f"""You classify the interviewee response to a question.

we use this:
1. clear reply (answers what was asked)
2. clear Non-Reply (does not answer/ declines)
3. Ambiguous (partial answer/ too general)

Return the taxonomy number: 1, 2, or 3.

Answer: {answer}

Question: {question}

Taxonomy code:"""

# Map prompt taxonomy -> dataset clarity_id
# 1 -> 0 (Clear Reply), 2 -> 2 (Non-Reply), 3 -> 1 (Ambivalent/Ambiguous)
CODE_TO_CLARITY = {1: 0, 2: 2, 3: 1}

def parse_direct_output(text):
    t = str(text).strip().lower()

    # First try to catch a clean digit (1/2/3)
    m = re.search(r"\b([1-3])\b", t)
    if m:
        return CODE_TO_CLARITY[int(m.group(1))]

    # Fallback if it prints words
    if "non-reply" in t or "non reply" in t:
        return 2
    if "clear reply" in t:
        return 0
    if "ambiguous" in t or "ambivalent" in t:
        return 1

    return -1


In [7]:
"""
run zero-shot inference on the test split.
We decode only the newly generated tokens.
"""

import torch
from tqdm import tqdm

@torch.no_grad()
def gemma_generate_label(prompt, max_new_tokens=3):
    inputs = tokenizer(
        prompt,
        return_tensors="pt",
        truncation=True,
        max_length=2048
    ).to(model.device)

    out = model.generate(
        **inputs,
        max_new_tokens=max_new_tokens,
        do_sample=False,
        pad_token_id=tokenizer.eos_token_id,
        use_cache=True
    )

    # Decode only the new tokens (after the prompt)
    new_tokens = out[0][inputs["input_ids"].shape[1]:]
    return tokenizer.decode(new_tokens, skip_special_tokens=True).strip()

preds = []
raw_outs = []
y_true = test_df["clarity_id"].astype(int).tolist()

print("Starting Gemma-2 direct zero-shot on TEST...")

debug_n = 5
for i, row in tqdm(test_df.reset_index(drop=True).iterrows(), total=len(test_df)):
    prompt = make_direct_prompt(row["question"], row["cleaned_answer"])
    raw = gemma_generate_label(prompt, max_new_tokens=3)
    pred = parse_direct_output(raw)

    if i < debug_n:
        print(f"\n[DEBUG {i+1}] raw={repr(raw)} | parsed={pred} | true={int(row['clarity_id'])}")

    raw_outs.append(raw)
    preds.append(pred)

test_df["raw_output"] = raw_outs
test_df["pred_clarity_id"] = preds

print("\nFinished.")
print("Parsing failures (-1):", sum(p == -1 for p in preds))


Starting Gemma-2 direct zero-shot on TEST...


  0%|          | 1/308 [00:01<09:04,  1.77s/it]


[DEBUG 1] raw='3' | parsed=1 | true=1


  1%|          | 2/308 [00:02<05:47,  1.14s/it]


[DEBUG 2] raw='3' | parsed=1 | true=1


  1%|          | 3/308 [00:03<05:21,  1.05s/it]


[DEBUG 3] raw='3' | parsed=1 | true=1


  1%|▏         | 4/308 [00:04<05:18,  1.05s/it]


[DEBUG 4] raw='3' | parsed=1 | true=1


  2%|▏         | 5/308 [00:05<04:43,  1.07it/s]


[DEBUG 5] raw='3' | parsed=1 | true=1


100%|██████████| 308/308 [05:07<00:00,  1.00it/s]


Finished.
Parsing failures (-1): 0





In [8]:
"""
print the evaluation metrics (macro + weighted).
We ignore any rows we couldn't parse (-1).
"""

from sklearn.metrics import classification_report, accuracy_score, precision_score, recall_score, f1_score

valid_idx = [i for i, p in enumerate(preds) if p != -1]

print("\n" + "="*60)
print("DIRECT RESULTS | Gemma-2-9B-IT | Zero-shot | TEST")
print("="*60)

if len(valid_idx) == 0:
    print("No valid predictions parsed. Check the debug outputs and prompt format.")
else:
    if len(valid_idx) < len(preds):
        print(f"Warning: {len(preds) - len(valid_idx)} predictions could not be parsed.")

    y_true_f = [y_true[i] for i in valid_idx]
    y_pred_f = [preds[i] for i in valid_idx]

    acc = accuracy_score(y_true_f, y_pred_f)

    prec_macro = precision_score(y_true_f, y_pred_f, average="macro", zero_division=0)
    rec_macro  = recall_score(y_true_f, y_pred_f, average="macro", zero_division=0)
    f1_macro   = f1_score(y_true_f, y_pred_f, average="macro")

    prec_weighted = precision_score(y_true_f, y_pred_f, average="weighted", zero_division=0)
    rec_weighted  = recall_score(y_true_f, y_pred_f, average="weighted", zero_division=0)
    f1_weighted   = f1_score(y_true_f, y_pred_f, average="weighted")

    print(f"Accuracy:           {acc:.4f}")
    print("-"*30)
    print(f"Macro Precision:    {prec_macro:.4f}")
    print(f"Macro Recall:       {rec_macro:.4f}")
    print(f"Macro F1:           {f1_macro:.4f}")
    print("-"*30)
    print(f"Weighted Precision: {prec_weighted:.4f}")
    print(f"Weighted Recall:    {rec_weighted:.4f}")
    print(f"Weighted F1:        {f1_weighted:.4f}")
    print("-"*60)

    print(classification_report(
        y_true_f,
        y_pred_f,
        target_names=["Clear (0)", "Ambivalent (1)", "Non-Reply (2)"],
        zero_division=0
    ))



DIRECT RESULTS | Gemma-2-9B-IT | Zero-shot | TEST
Accuracy:           0.6851
------------------------------
Macro Precision:    0.6211
Macro Recall:       0.4059
Macro F1:           0.3869
------------------------------
Weighted Precision: 0.7045
Weighted Recall:    0.6851
Weighted F1:        0.5920
------------------------------------------------------------
                precision    recall  f1-score   support

     Clear (0)       0.83      0.06      0.12        79
Ambivalent (1)       0.70      0.98      0.81       206
 Non-Reply (2)       0.33      0.17      0.23        23

      accuracy                           0.69       308
     macro avg       0.62      0.41      0.39       308
  weighted avg       0.70      0.69      0.59       308



In [9]:
"""
save files so we can compare models later without rerunning.
"""

test_df.to_csv("full_test_dataset_zs_direct_gemma.csv", index=True)
print("Saved: full_test_dataset_zs_direct_gemma.csv")

mini_df = test_df[["pred_clarity_id", "clarity_id", "raw_output"]].copy()
mini_df.to_csv("predictions_comparison_direct_gemma.csv", index=True)
print("Saved: predictions_comparison_direct_gemma.csv")


Saved: full_test_dataset_zs_direct_gemma.csv
Saved: predictions_comparison_direct_gemma.csv
