# 10_probe_discriminative — MCQ Probing

This notebook runs discriminative (MCQ) probes across configured models. It looks for datasets in `data/processed/*.jsonl`. If none are present, it will create a tiny sample set for demonstration.

Outputs are saved to `results/runs/<date>/<model>/<task>/*.jsonl`.



In [1]:
# Ensure project root is on sys.path and as working directory so `src/...` imports work
import os
import sys
from pathlib import Path


def find_project_root() -> Path:
    cwd = Path.cwd()
    for parent in [cwd, *cwd.parents]:
        # Heuristic: treat the first directory up that has `src` or `config` as the project root
        if (parent / "src").exists() or (parent / "config").exists():
            return parent
    return cwd

PROJECT_ROOT = find_project_root()
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))
# Also set CWD so relative paths like "config/eval.yaml" work from notebooks
os.chdir(PROJECT_ROOT)
print(f"Using project root: {PROJECT_ROOT}")


Using project root: f:\New folder\Research\Project 1


In [4]:
import os
from pathlib import Path
import json
import yaml
from datetime import datetime
from typing import List, Dict

from src.utils.io import read_jsonl, write_jsonl
from src.providers.base import GenerationRequest
from src.providers.openrouter_client import OpenRouterClient
# from src.providers.openai_client import OpenAIClient  # Disabled due to budget constraints; using OpenRouter instead
from src.providers.gemini_client import GeminiClient
from src.probing.discriminative import run_batch_choose
from src.eval.metrics import accuracy

# Load eval config
with open("config/eval.yaml", "r", encoding="utf-8") as f:
    eval_cfg = yaml.safe_load(f)

# Pick a discriminative task (values_norms by default)
tasks = [t for t in eval_cfg["tasks"] if t.get("type") == "discriminative"]
if not tasks:
    raise RuntimeError("No discriminative tasks configured in config/eval.yaml")

task = tasks[0]
DATASET_PATH = Path(task["dataset"])  # expects fields: stem, options, answer

if DATASET_PATH.exists():
    records = read_jsonl(str(DATASET_PATH))
else:
    # Create a tiny demo set (Hindi pronouns + a values item)
    records = [
        {
            "stem": "In Hindi, addressing a senior professor formally, which pronoun is appropriate?",
            "options": ["तू", "तुम", "आप"],
            "answer": "आप",
        },
        {
            "stem": "Claiming government benefits you are not entitled to is:",
            "options": ["Never justifiable", "Something in between", "Always justifiable"],
            "answer": "Never justifiable",
        },
    ]
    print("No dataset found.")
    print("Using small sample set for demonstration.")

# Build clients from models.yaml
with open("config/models.yaml", "r", encoding="utf-8") as f:
    models_cfg = yaml.safe_load(f)

clients = {}

def get_model_id(provider: str, default: str) -> str:
    for m in models_cfg.get("models", []):
        if m.get("provider") == provider:
            return m.get("model_id", default)
    return default

if os.getenv("OPENROUTER_API_KEY"):
    try:
        clients["openrouter-deepseek"] = OpenRouterClient(model_id=get_model_id("openrouter-deepseek", "openrouter-deepseek"))
    except Exception as e:
        print("OpenRouter init failed:", e)

# OpenAI (disabled due to budget constraints)
# if os.getenv("OPENAI_API_KEY"):
#     try:
#         clients["openai"] = OpenAIClient(model_id=get_model_id("openai", "gpt-4o-mini"))
#     except Exception as e:
#         print("OpenAI init failed:", e)

#if os.getenv("GOOGLE_API_KEY"):
#    try:
#        clients["gemini"] = GeminiClient(model_id=get_model_id("gemini", "gemini-2.5-flash"))
#    except Exception as e:
#        print("Gemini init failed:", e)

print("Models:", list(clients.keys()))

# Deduplicate identical stems/options so we prompt the LLM once per unique item
gold: List[str] = [r["answer"] for r in records]

from typing import Tuple
signatures: List[Tuple[str, Tuple[str, ...]]] = []
unique_index_by_sig: Dict[Tuple[str, Tuple[str, ...]], int] = {}
unique_stems: List[str] = []
unique_choices: List[List[str]] = []

for r in records:
    sig = (r["stem"], tuple(r["options"]))
    signatures.append(sig)
    if sig not in unique_index_by_sig:
        unique_index_by_sig[sig] = len(unique_stems)
        unique_stems.append(r["stem"])
        unique_choices.append(r["options"])


Models: ['openrouter-deepseek']


In [5]:
# Run MCQ probing and compute simple accuracy per model
from pathlib import Path
from collections import defaultdict
from tqdm import tqdm
import logging
from datetime import datetime, timezone

# Set up logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    datefmt='%H:%M:%S'
)
logger = logging.getLogger(__name__)

results_dir = Path("results/runs")
results_dir.mkdir(parents=True, exist_ok=True)

# Fix: Use timezone.utc instead of datetime.UTC (compatible with Python 3.7+)
run_id = datetime.now(timezone.utc).strftime("%Y%m%d-%H%M%S")

logger.info(f"Starting MCQ probing run: {run_id}")
logger.info(f"Total unique questions: {len(unique_stems)}")
logger.info(f"Total records: {len(records)}")
logger.info(f"Models to evaluate: {list(clients.keys())}")

acc_by_model = {}

def argmax_choice(probs: Dict[str, float]) -> str:
    if not probs:
        return ""
    return max(probs.items(), key=lambda kv: kv[1])[0]

for model_idx, (name, client) in enumerate(clients.items(), 1):
    logger.info(f"\n{'='*60}")
    logger.info(f"[{model_idx}/{len(clients)}] Evaluating model: {name}")
    logger.info(f"{'='*60}")
    
    # Prompt the LLM once per unique question with progress bar
    logger.info(f"Processing {len(unique_stems)} unique questions...")
    probs_unique = []
    
    with tqdm(total=len(unique_stems), desc=f"{name}", unit="question") as pbar:
        for i, (stem, choices) in enumerate(zip(unique_stems, unique_choices)):
            try:
                req = GenerationRequest(prompt=stem, choices=choices, temperature=0.2)
                probs = client.choose(req)
                probs_unique.append(probs)
                pbar.update(1)
                
                # Log every 10th question for detail
                if (i + 1) % 10 == 0:
                    logger.debug(f"  Processed {i+1}/{len(unique_stems)} questions")
                    
            except Exception as e:
                logger.error(f"  Error on question {i+1}: {str(e)}")
                probs_unique.append({})  # Append empty dict on error
                pbar.update(1)
    
    pred_unique = [argmax_choice(p) for p in probs_unique]
    logger.info(f"Completed question processing for {name}")

    # Map predictions/probs back to each original record
    logger.info("Mapping predictions to original records...")
    preds = []
    probs_full = []
    for sig in signatures:
        idx = unique_index_by_sig[sig]
        preds.append(pred_unique[idx])
        probs_full.append(probs_unique[idx])

    # Calculate accuracy
    acc = accuracy(gold, preds)
    acc_by_model[name] = acc
    logger.info(f"✓ {name} Accuracy: {acc:.4f} ({acc*100:.2f}%)")

    # Save detailed results
    logger.info("Saving results...")
    model_dir = results_dir / run_id / name / task["name"]
    model_dir.mkdir(parents=True, exist_ok=True)
    rows = []
    for r, probs, pred in zip(records, probs_full, preds):
        rows.append({
            "stem": r["stem"],
            "options": r["options"],
            "gold": r["answer"],
            "pred": pred,
            "probs": probs,
        })
    output_path = model_dir / "mcq_results.jsonl"
    write_jsonl(str(output_path), rows)
    logger.info(f"Results saved to: {output_path}")

logger.info(f"\n{'='*60}")
logger.info("FINAL RESULTS")
logger.info(f"{'='*60}")
for name, acc in acc_by_model.items():
    logger.info(f"{name:20s}: {acc:.4f} ({acc*100:.2f}%)")
logger.info(f"{'='*60}")

print("\nAccuracy by model:", acc_by_model)

15:52:59 - INFO - Starting MCQ probing run: 20251113-102259
15:52:59 - INFO - Total unique questions: 26
15:52:59 - INFO - Total records: 2537
15:52:59 - INFO - Models to evaluate: ['openrouter-deepseek']
15:52:59 - INFO - 
15:52:59 - INFO - [1/1] Evaluating model: openrouter-deepseek
15:52:59 - INFO - Processing 26 unique questions...
openrouter-deepseek:   0%|          | 0/26 [00:00<?, ?question/s]15:53:02 - INFO - HTTP Request: POST https://openrouter.ai/api/v1/chat/completions "HTTP/1.1 200 OK"
openrouter-deepseek:   4%|▍         | 1/26 [00:03<01:19,  3.17s/question]15:53:06 - INFO - HTTP Request: POST https://openrouter.ai/api/v1/chat/completions "HTTP/1.1 200 OK"
openrouter-deepseek:   8%|▊         | 2/26 [00:07<01:25,  3.58s/question]15:53:09 - INFO - HTTP Request: POST https://openrouter.ai/api/v1/chat/completions "HTTP/1.1 200 OK"
openrouter-deepseek:  12%|█▏        | 3/26 [00:09<01:11,  3.12s/question]15:53:12 - INFO - HTTP Request: POST https://openrouter.ai/api/v1/chat/comp

KeyboardInterrupt: 