<a href="https://colab.research.google.com/github/suleiman-odeh/NLP_Project_Team16/blob/main/Gemma_2/zero_shot_indirect_Gemma_2_9B.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install -q -U transformers bitsandbytes accelerate

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.0/44.0 kB[0m [31m1.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m12.0/12.0 MB[0m [31m61.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m59.1/59.1 MB[0m [31m9.1 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
import pandas as pd
import json

# Load the Cleaned Data
try:
    df = pd.read_json('QEvasion_cleaned.jsonl', lines=True)
    print(f"Data loaded successfully! Total rows: {len(df)}")
except FileNotFoundError:
    print("ERROR: Could not find 'QEvasion_cleaned.jsonl'. Please upload it to Colab files.")


# This maps the 9 evasion labels to the 3 clarity IDs (0, 1, 2)
# Source: Project Proposal & Paper Taxonomy
EVASION_TO_CLARITY_ID = {
    # 0: Clear Reply
    'Explicit': 0,

    # 1: Ambivalent Reply
    'Implicit': 1,
    'Dodging': 1,
    'Deflection': 1,
    'General': 1,
    'Partial/half-answer': 1,

    # 2: Clear Non-Reply
    'Declining to answer': 2,
    'Claims ignorance': 2,
    'Clarification': 2
}

# --- Step 3: Define the Gemma-Specific Prompt Function ---
def create_gemma_indirect_prompt(question, answer):
    """
    Creates a Zero-Shot prompt for Gemma-2-9b-It.
    Includes definitions of the 9 labels to help the model distinguish them.
    """
    system_instruction = """You are an expert political discourse analyst.
Your task is to classify the relationship between a Question and an Answer into exactly one of the following 9 categories.

Definitions:
1. Explicit: The information requested is explicitly stated.
2. Implicit: The answer is implied but not explicitly stated.
3. General: The answer is too general or lacks specificity.
4. Partial/half-answer: Only addresses part of the question.
5. Dodging: Ignores the question entirely.
6. Deflection: Shifts focus to a related but different topic.
7. Declining to answer: Refuses to answer (directly or indirectly).
8. Claims ignorance: Claims not to know the answer.
9. Clarification: Asks for clarification instead of answering.

Instructions:
- Analyze the Question and Answer carefully.
- Return ONLY the category name from the list above.
- Do not add explanations or punctuation."""

    # Gemma uses a specific chat template: <start_of_turn>user ... <end_of_turn>model
    prompt = f"""<start_of_turn>user
{system_instruction}

Question: "{question}"
Answer: "{answer}"

Classify the answer:<end_of_turn>
<start_of_turn>model
"""
    return prompt

print("Setup complete. Prompt function and Mapping dictionary are ready.")

In [None]:
import torch
import gc
from tqdm import tqdm

# --- Step 1: Prepare Test Data ---
# We filter only the rows meant for testing
test_df = df[df['split_type'] == 'test'].copy()
print(f"Processing Test Set: {len(test_df)} samples")

# Storage for results
predictions_evasion = []
predictions_clarity = []

# --- Step 2: Inference Loop ---
print("Starting Zero-Shot Inference (Gemma-2-9B-It)...")

# Ensure model is in eval mode to disable dropout etc.
model.eval()

for index, row in tqdm(test_df.iterrows(), total=len(test_df)):
    # A. Create Prompt
    prompt_text = create_gemma_indirect_prompt(row['question'], row['cleaned_answer'])

    # B. Tokenize
    inputs = tokenizer(prompt_text, return_tensors="pt").to("cuda")

    # C. Generate
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=20,   # We only need a short label
            do_sample=False,     # Deterministic (Greedy) for reproducibility
            pad_token_id=tokenizer.eos_token_id
        )

    # D. Decode & Clean Output
    # We slice [inputs.input_ids.shape[1]:] to remove the input prompt from the output
    generated_text = tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)

    # Clean up: remove whitespace, punctuation (like periods at the end)
    pred_label = generated_text.strip().rstrip('.')

    # E. Map to Clarity ID (Indirect Logic)
    # We look up the predicted evasion string in your dictionary.
    # If the model hallucinates a label not in the dict, we default to -1 (Error)
    mapped_id = EVASION_TO_CLARITY_ID.get(pred_label, -1)

    predictions_evasion.append(pred_label)
    predictions_clarity.append(mapped_id)

# --- Step 3: Save Results ---
test_df['pred_evasion'] = predictions_evasion
test_df['pred_clarity_id'] = predictions_clarity

output_filename = "gemma_indirect_predictions.csv"
test_df.to_csv(output_filename, index=False)
print(f"\nInference complete. Results saved to '{output_filename}'.")