# Stage 2 Evaluation Harness (Preview + Metrics)

This notebook wires the Stage 2 coverage preview directly into the full Stage 2 evaluation harness. It is designed for sanity checking, early warning, and final field-level precision/recall evaluation.

In [None]:

# Standard library imports
import json
from collections import defaultdict
from typing import Dict


## Configuration

In [None]:

GOLD_PATH = "training_data/stage2_gold_val.json"
PRED_PATH = "training_data/stage2_pred_val.json"


## Load and Index Datasets

In [None]:

def load_and_index(gold_path: str, pred_path: str):
    with open(gold_path, "r", encoding="utf-8") as f:
        gold = json.load(f)

    with open(pred_path, "r", encoding="utf-8") as f:
        pred = json.load(f)

    gold_by_id = {r["reference_id"]: r for r in gold}
    pred_by_id = {r["reference_id"]: r for r in pred}

    if gold_by_id.keys() != pred_by_id.keys():
        raise ValueError("Gold / prediction reference_id mismatch")

    return gold_by_id, pred_by_id


gold_by_id, pred_by_id = load_and_index(GOLD_PATH, PRED_PATH)

print(f"Loaded {len(gold_by_id)} records")


## Stage 2 Coverage Preview (Pre-flight Check)

In [None]:

def coverage_preview(gold_by_id: Dict, pred_by_id: Dict):
    field_stats = defaultdict(lambda: {"gold_filled": 0, "pred_filled": 0})
    warnings = []

    for ref_id in gold_by_id:
        gold_fields = gold_by_id[ref_id]["stage2_gold"]["fields"]
        pred_fields = pred_by_id[ref_id]["stage2_output"]["fields"]

        for field, gold_obj in gold_fields.items():
            pred_obj = pred_fields.get(field, {})

            if gold_obj.get("value") not in (None, "", []):
                field_stats[field]["gold_filled"] += 1

            if pred_obj.get("value") not in (None, "", []):
                field_stats[field]["pred_filled"] += 1

    for field, stats in field_stats.items():
        if stats["gold_filled"] > 0:
            coverage = stats["pred_filled"] / stats["gold_filled"]
            if coverage < 0.5:
                warnings.append(
                    f"Low coverage for field '{field}': {coverage:.2f}"
                )

    return field_stats, warnings


field_stats, warnings = coverage_preview(gold_by_id, pred_by_id)

print("=== Coverage Preview ===")
for field, stats in field_stats.items():
    print(
        f"{field:<15} | Gold filled: {stats['gold_filled']:<3} "
        f"| Pred filled: {stats['pred_filled']:<3}"
    )

if warnings:
    print("\nWarnings:")
    for w in warnings:
        print("⚠️", w)
else:
    print("\nNo coverage warnings.")


## Full Stage 2 Field-Level Evaluation

In [None]:

def evaluate_stage2(gold_by_id: Dict, pred_by_id: Dict):
    results = defaultdict(lambda: {"tp": 0, "fp": 0, "fn": 0})

    for ref_id in gold_by_id:
        gold_fields = gold_by_id[ref_id]["stage2_gold"]["fields"]
        pred_fields = pred_by_id[ref_id]["stage2_output"]["fields"]

        for field, gold_obj in gold_fields.items():
            gold_val = gold_obj.get("value")
            pred_val = pred_fields.get(field, {}).get("value")

            if gold_val not in (None, "", []):
                if pred_val == gold_val:
                    results[field]["tp"] += 1
                else:
                    results[field]["fn"] += 1

            if pred_val not in (None, "", []) and pred_val != gold_val:
                results[field]["fp"] += 1

    metrics = {}
    for field, r in results.items():
        tp, fp, fn = r["tp"], r["fp"], r["fn"]

        precision = tp / (tp + fp) if (tp + fp) else 0.0
        recall = tp / (tp + fn) if (tp + fn) else 0.0
        f1 = (
            2 * precision * recall / (precision + recall)
            if (precision + recall)
            else 0.0
        )

        metrics[field] = {
            "precision": round(precision, 3),
            "recall": round(recall, 3),
            "f1": round(f1, 3),
            "tp": tp,
            "fp": fp,
            "fn": fn
        }

    return metrics


metrics = evaluate_stage2(gold_by_id, pred_by_id)

print("=== Stage 2 Evaluation Metrics ===")
for field, m in metrics.items():
    print(field, m)


## Combined Evaluation Report

In [None]:

report = {
    "coverage_preview": {
        "field_stats": field_stats,
        "warnings": warnings
    },
    "metrics": metrics
}

print(json.dumps(report, indent=2))
