# Judge Evaluation Pipeline Tutorial

## Learning Objectives

- ✅ Split datasets deterministically with hash-based partitioning
- ✅ Select balanced few-shot examples from training set
- ✅ Build judge prompts with few-shot calibration
- ✅ Calculate TPR (True Positive Rate) and TNR (True Negative Rate)
- ✅ Interpret judge bias and correction strategies
- ✅ Apply judgy library for statistical bias correction

## Estimated Time

**Execution:** 15-20 minutes (depends on API)

**Prerequisites:** Completed [Parallel Labeling Tutorial](parallel_labeling_tutorial.ipynb)

## ⚠️ API Cost Warning

**Estimated cost:** $0.10-0.30 for demo (20 traces with `gpt-4o-mini`)  
**Full dataset cost:** $1.00-2.00 for 150 traces  
**Estimated time:** 30-60 seconds for demo | 5-10 minutes for full dataset

For this tutorial, we **limit evaluation to 20 traces** to keep costs under $0.30. Remove `limit=20` in cell-11 to run on full test set.

In [None]:
from pathlib import Path
import hashlib
import json
from typing import List, Dict, Any
from sklearn.metrics import confusion_matrix
import litellm
from pydantic import BaseModel
from dotenv import load_dotenv

load_dotenv()
print("✓ Setup complete")

## 1. Load Labeled Data

In [None]:
DATA_FILE = Path("nurtureboss_traces_labeled.json")

with open(DATA_FILE) as f:
    labeled_traces = json.load(f)

print(f"✓ Loaded {len(labeled_traces)} labeled conversations")

# Count label distribution
pass_count = sum(1 for t in labeled_traces if t.get('all_responses_substantiated') == True)
fail_count = len(labeled_traces) - pass_count

print(f"  PASS: {pass_count} ({pass_count/len(labeled_traces)*100:.1f}%)")
print(f"  FAIL: {fail_count} ({fail_count/len(labeled_traces)*100:.1f}%)")

## 2. Deterministic Dataset Splitting (Hash-Based)

In [None]:
def split_dataset(
    records: List[Dict],
    train_size: int = 20,
    dev_size: int = 30
) -> Dict[str, List[Dict]]:
    """Split dataset deterministically using ID hash.
    
    Same IDs always go to same split (reproducible).
    """
    # Sort by hash of ID for deterministic order
    sorted_records = sorted(
        records,
        key=lambda r: hashlib.sha256(r['id'].encode()).hexdigest()
    )
    
    train = sorted_records[:train_size]
    dev = sorted_records[train_size:train_size + dev_size]
    test = sorted_records[train_size + dev_size:]
    
    return {'train': train, 'dev': dev, 'test': test}

# Split data
splits = split_dataset(labeled_traces)

print(f"Dataset Splits:")
for name, data in splits.items():
    pass_pct = sum(1 for r in data if r.get('all_responses_substantiated'))/len(data)*100
    print(f"  {name.upper()}: {len(data)} traces ({pass_pct:.1f}% PASS)")

## 3. Select Few-Shot Examples

In [None]:
def select_few_shot(
    train_data: List[Dict],
    num_pass: int = 1,
    num_fail: int = 1
) -> Dict[str, Dict]:
    """Select balanced few-shot examples from training set."""
    
    pass_examples = [r for r in train_data if r.get('all_responses_substantiated')]
    fail_examples = [r for r in train_data if not r.get('all_responses_substantiated')]
    
    return {
        'pass': pass_examples[0] if pass_examples else None,
        'fail': fail_examples[0] if fail_examples else None
    }

# Get few-shot examples
few_shot = select_few_shot(splits['train'])

print("Few-Shot Examples Selected:")
print(f"  PASS example: {few_shot['pass']['id'] if few_shot['pass'] else 'None'}")
print(f"  FAIL example: {few_shot['fail']['id'] if few_shot['fail'] else 'None'}")

## 4. Build Judge Prompt

In [None]:
class JudgeResult(BaseModel):
    all_responses_substantiated: bool
    reason: str

def build_judge_prompt(
    messages: List[Dict],
    metadata: Dict,
    example_pass_conv: str,
    example_fail_conv: str
) -> str:
    """Construct judge prompt with few-shot examples."""
    
    conv = "\n".join(f"{m['role'].upper()}: {m['content']}" for m in messages)
    meta_str = json.dumps(metadata, indent=2) if metadata else "<none>"
    
    return f"""
You are evaluating substantiation in AI conversations.

PASS = All claims verified by user input, tool outputs, or metadata
FAIL = At least one unsubstantiated claim

Rules:
1. Courtesy statements don't need evidence ("How can I help?")
2. Paraphrases of tool outputs count as substantiated
3. Specific claims need specific evidence (e.g., "balcony" must appear in tools)
4. When uncertain, default to PASS

Few-Shot Examples:
--- PASS Example ---
{example_pass_conv}

--- FAIL Example ---
{example_fail_conv}
--- End Examples ---

Evaluate this conversation:

=== CONVERSATION ===
{conv}

=== METADATA ===
{meta_str}

Return JSON: {{"all_responses_substantiated": bool, "reason": str}}
""".strip()

# Test prompt construction
sample = splits['dev'][0]
test_prompt = build_judge_prompt(
    sample['messages'],
    {k:v for k,v in sample.items() if k not in ['id', 'messages', 'all_responses_substantiated', 'substantiation_rationale']},
    "USER: Find recipes\nAGENT: Searching database...",
    "USER: Tell me about A11\nAGENT: It has a balcony"
)
print(f"Prompt constructed ({len(test_prompt)} chars)")

## 5. Evaluate Judge on Test Set

In [None]:
def evaluate_judge(
    test_data: List[Dict],
    few_shot_examples: Dict,
    model: str = "gpt-4o-mini",
    limit: int = 20  # Limit for demo
) -> Dict:
    """Run judge on test set and calculate metrics."""
    
    # Format few-shot examples
    pass_conv = "\n".join(
        f"{m['role'].upper()}: {m['content']}"
        for m in few_shot_examples['pass']['messages']
    ) if few_shot_examples['pass'] else "<no example>"
    
    fail_conv = "\n".join(
        f"{m['role'].upper()}: {m['content']}"
        for m in few_shot_examples['fail']['messages']
    ) if few_shot_examples['fail'] else "<no example>"
    
    y_true = []
    y_pred = []
    
    # Limit for demo to save API costs
    test_subset = test_data[:limit]
    
    for trace in test_subset:
        # Ground truth
        truth = trace.get('all_responses_substantiated')
        if truth is None:
            continue
        
        # Build prompt
        prompt = build_judge_prompt(
            trace['messages'],
            {k:v for k,v in trace.items() if k not in ['id', 'messages', 'all_responses_substantiated', 'substantiation_rationale']},
            pass_conv,
            fail_conv
        )
        
        # Call judge
        response = litellm.completion(
            model=model,
            messages=[{"role": "user", "content": prompt}],
            response_format=JudgeResult,
            temperature=0
        )
        
        result = JudgeResult(**json.loads(response.choices[0].message.content))
        
        y_true.append(truth)
        y_pred.append(result.all_responses_substantiated)
        
        print(f"." if truth == result.all_responses_substantiated else "X", end="")
    
    print(f"\n\n✓ Evaluated {len(y_true)} traces")
    
    # Calculate confusion matrix
    cm = confusion_matrix(y_true, y_pred, labels=[True, False])
    tn, fp, fn, tp = cm[1,1], cm[1,0], cm[0,1], cm[0,0]  # Adjusted for [True, False] labels
    
    tpr = tp / (tp + fn) if (tp + fn) > 0 else 0
    tnr = tn / (tn + fp) if (tn + fp) > 0 else 0
    accuracy = (tp + tn) / len(y_true) if y_true else 0
    
    return {
        'tpr': tpr,
        'tnr': tnr,
        'accuracy': accuracy,
        'tp': tp,
        'tn': tn,
        'fp': fp,
        'fn': fn
    }

# Run evaluation (limited to 20 for demo)
print("Evaluating judge (20 traces, ~30-60 seconds)...\n")
metrics = evaluate_judge(splits['test'], few_shot, limit=20)

print(f"\nJudge Performance:")
print(f"  Accuracy: {metrics['accuracy']*100:.1f}%")
print(f"  TPR (Sensitivity): {metrics['tpr']*100:.1f}%")
print(f"  TNR (Specificity): {metrics['tnr']*100:.1f}%")
print(f"\nConfusion Matrix:")
print(f"  TP: {metrics['tp']}  FN: {metrics['fn']}")
print(f"  FP: {metrics['fp']}  TN: {metrics['tn']}")

In [None]:
# ========================================
# VALIDATION: Verify judge metrics
# ========================================

# Assert metrics were calculated
assert 'tpr' in metrics, "TPR not calculated"
assert 'tnr' in metrics, "TNR not calculated"
assert 'accuracy' in metrics, "Accuracy not calculated"

# Assert metric ranges (probabilities: 0-1)
assert 0 <= metrics['tpr'] <= 1, f"Invalid TPR: {metrics['tpr']}"
assert 0 <= metrics['tnr'] <= 1, f"Invalid TNR: {metrics['tnr']}"
assert 0 <= metrics['accuracy'] <= 1, f"Invalid accuracy: {metrics['accuracy']}"

# Assert confusion matrix values are non-negative
for key in ['tp', 'tn', 'fp', 'fn']:
    assert metrics[key] >= 0, f"Invalid {key}: {metrics[key]}"

# Assert confusion matrix sums correctly
total_predictions = metrics['tp'] + metrics['tn'] + metrics['fp'] + metrics['fn']
assert total_predictions > 0, "No predictions made"
assert total_predictions <= 20, f"Too many predictions: {total_predictions} (expected ≤20 for demo)"

# Warn if judge performance is poor
if metrics['tpr'] < 0.75 or metrics['tnr'] < 0.75:
    print(f"⚠️  WARNING: Judge performance below 75% threshold")
    print(f"   TPR: {metrics['tpr']:.1%} | TNR: {metrics['tnr']:.1%}")
    print(f"   Consider refining prompt or using better model")

print(f"✅ VALIDATION PASSED:")
print(f"   - All metrics in valid range [0, 1]")
print(f"   - Confusion matrix: {total_predictions} predictions")
print(f"   - TPR: {metrics['tpr']:.1%} | TNR: {metrics['tnr']:.1%}")

## 6. Interpret Results

In [None]:
print("Interpretation:\n")

if metrics['tpr'] >= 0.85 and metrics['tnr'] >= 0.85:
    print("✅ GOOD: Judge is reliable (TPR & TNR ≥ 85%)")
    print("   Can use for automated evaluation with statistical correction")
elif metrics['tpr'] >= 0.75 and metrics['tnr'] >= 0.75:
    print("⚠️  ACCEPTABLE: Judge is usable but needs improvement")
    print("   Consider refining prompt or using better model")
else:
    print("❌ POOR: Judge is unreliable (TPR or TNR < 75%)")
    print("   Options: Better model, clearer criteria, or manual evaluation")

print(f"\nBias Analysis:")
if metrics['tpr'] < metrics['tnr'] - 0.1:
    print("  Judge is TOO STRICT (misses valid PASSes)")
    print("  → Add lenient examples to few-shot")
elif metrics['tnr'] < metrics['tpr'] - 0.1:
    print("  Judge is TOO LENIENT (misses FAILs)")
    print("  → Add strict examples to few-shot")
else:
    print("  Judge is BALANCED")

## 7. Bias Correction with judgy Library (Optional)

In [None]:
# Requires: pip install judgy
# from judgy import correct_rate
#
# # Observed pass rate on large dataset
# p_obs = 0.75  # 75% of production traces judged as PASS
#
# # Correct using measured TPR/TNR
# corrected = correct_rate(
#     p_obs,
#     tpr=metrics['tpr'],
#     tnr=metrics['tnr'],
#     n=1000  # sample size
# )
#
# print(f"Observed rate: {p_obs*100:.1f}%")
# print(f"Corrected rate: {corrected['theta_hat']*100:.1f}%")
# print(f"95% CI: [{corrected['ci_lower']*100:.1f}%, {corrected['ci_upper']*100:.1f}%]")

print("⏭️  Install judgy to run bias correction: pip install judgy")

## Summary

### What You Learned

1. **Hash-based splitting** - Deterministic, reproducible train/dev/test
2. **Few-shot calibration** - Use training examples to guide judge
3. **TPR/TNR metrics** - Measure judge bias quantitatively
4. **Bias correction** - Adjust observed rates with judgy library

### Judge Quality Thresholds

| TPR/TNR | Quality | Action |
|---------|---------|--------|
| ≥ 0.90 | Excellent | Deploy confidently |
| 0.85-0.90 | Good | Use with correction |
| 0.75-0.85 | Acceptable | Refine prompt |
| < 0.75 | Poor | Better model or manual eval |

### Next Steps

- Run on full test set (remove `limit=20`)
- Iterate on prompt if TPR/TNR < 0.85
- Apply to production logs with judgy correction
- See [Bias Correction Tutorial](../homeworks/hw3/bias_correction_tutorial.md)

---

**Tutorial Status:** ✅ Complete  
**Last Updated:** 2025-10-30