# Test SelfReflectiveCritic

End-to-end test notebook for the NLI-based verification module.

Tests the `SelfReflectiveCritic` class which uses `roberta-large-mnli`
to verify rationale statements against context passages following the
Self-MedRAG algorithm.

**First run will download the model (~1.4GB).**

## 1. Setup & Import

In [None]:
import sys
import os
import time

# Add project root so backend imports work
PROJECT_ROOT = os.path.abspath(os.path.join(os.getcwd(), ".."))
if PROJECT_ROOT not in sys.path:
    sys.path.insert(0, PROJECT_ROOT)

print(f"Project root: {PROJECT_ROOT}")

In [None]:
# Patch settings to avoid .env validation errors
os.environ.setdefault("NLI_MODEL_NAME", "roberta-large-mnli")

from backend.app.models.schemas import (
    VerificationRequest,
    VerificationResponse,
    StatementVerification,
)
from backend.app.services.verification_service import SelfReflectiveCritic

print("Imports successful.")

In [None]:
# Initialize the critic (loads model)
critic = SelfReflectiveCritic()
print(f"\nDevice: {critic.device}")
print(f"Sentence threshold (tau): {critic.tau}")
print(f"Passage threshold (theta): {critic.theta}")

---

## 2. Helper Function

In [None]:
def display_result(result, title=""):
    """Pretty-print a VerificationResponse."""
    print("=" * 60)
    if title:
        print(f"  {title}")
        print("=" * 60)
    print(f"  Passed: {result.is_passed}")
    print(f"  Support Score (S_i): {result.support_score}")
    print(f"  Supported: {len(result.supported_statements)} / "
          f"{len(result.supported_statements) + len(result.unsupported_statements)}")
    print()

    if result.supported_statements:
        print("  [SUPPORTED]")
        for s in result.supported_statements:
            passage_preview = (s.best_passage[:60] + "...") if s.best_passage and len(s.best_passage) > 60 else s.best_passage
            print(f"    - [{s.confidence_score:.4f}] {s.statement}")
            print(f"      Best passage: {passage_preview}")
        print()

    if result.unsupported_statements:
        print("  [UNSUPPORTED]")
        for s in result.unsupported_statements:
            print(f"    - [{s.confidence_score:.4f}] {s.statement}")
        print()

    print("=" * 60)
    print()

---

## 3. Test Case 1: Fully Supported

All statements are clearly entailed by the passages.

In [None]:
request_1 = VerificationRequest(
    statements=[
        "Aspirin reduces inflammation.",
        "Aspirin is used to treat headaches.",
        "Aspirin works by inhibiting cyclooxygenase enzymes.",
    ],
    passages=[
        "Aspirin (acetylsalicylic acid) is a nonsteroidal anti-inflammatory drug (NSAID) "
        "that reduces inflammation, pain, and fever. It works by irreversibly inhibiting "
        "cyclooxygenase (COX) enzymes, which are involved in prostaglandin synthesis.",
        "Aspirin is commonly used to relieve minor aches and pains including headaches, "
        "muscle pain, toothaches, and the common cold.",
    ],
)

result_1 = critic.verify(request_1)
display_result(result_1, "Test 1: Fully Supported (expected: is_passed=True)")

assert result_1.is_passed is True, "Test 1 FAILED: expected is_passed=True"
assert result_1.support_score >= 0.7, "Test 1 FAILED: expected support_score >= 0.7"
print("Test 1 PASSED")

---

## 4. Test Case 2: Mixed Results

Some statements are supported, some are not.

In [None]:
request_2 = VerificationRequest(
    statements=[
        "Metformin is used to treat type 2 diabetes.",
        "Metformin works by decreasing hepatic glucose production.",
        "Metformin causes significant weight gain.",
        "Metformin is an antibiotic.",
    ],
    passages=[
        "Metformin is the first-line medication for the treatment of type 2 diabetes, "
        "particularly in overweight patients. It acts by decreasing hepatic glucose "
        "production and increasing insulin sensitivity.",
        "Common side effects of metformin include gastrointestinal symptoms such as "
        "diarrhea, nausea, and abdominal pain. Metformin is generally associated with "
        "weight neutrality or modest weight loss, not weight gain.",
    ],
)

result_2 = critic.verify(request_2)
display_result(result_2, "Test 2: Mixed (some supported, some not)")

assert len(result_2.unsupported_statements) > 0, "Test 2 FAILED: expected some unsupported"
assert len(result_2.supported_statements) > 0, "Test 2 FAILED: expected some supported"
print("Test 2 PASSED")

---

## 5. Test Case 3: Fully Unsupported

All statements are contradicted by or irrelevant to the passages.

In [None]:
request_3 = VerificationRequest(
    statements=[
        "The Earth is flat.",
        "Water boils at 50 degrees Celsius at sea level.",
        "Humans can breathe underwater without equipment.",
    ],
    passages=[
        "The Earth is an oblate spheroid, roughly spherical in shape, with a "
        "circumference of approximately 40,075 kilometers.",
        "Water boils at 100 degrees Celsius (212 degrees Fahrenheit) at standard "
        "atmospheric pressure at sea level.",
        "Humans require specialized equipment such as scuba gear or submarines "
        "to breathe underwater, as human lungs cannot extract dissolved oxygen "
        "from water.",
    ],
)

result_3 = critic.verify(request_3)
display_result(result_3, "Test 3: Fully Unsupported (expected: is_passed=False)")

assert result_3.is_passed is False, "Test 3 FAILED: expected is_passed=False"
print("Test 3 PASSED")

---

## 6. Test Case 4: Medical Domain

Realistic medical statements from a clinical context.

In [None]:
request_4 = VerificationRequest(
    statements=[
        "Hypertension is a major risk factor for stroke.",
        "ACE inhibitors lower blood pressure by blocking angiotensin-converting enzyme.",
        "Beta-blockers reduce heart rate and cardiac output.",
        "Statins are primarily used to lower blood sugar levels.",
    ],
    passages=[
        "Hypertension (high blood pressure) is one of the most significant risk factors "
        "for cardiovascular disease, including stroke, heart attack, and heart failure. "
        "Managing blood pressure through lifestyle changes and medication is crucial.",
        "ACE inhibitors (angiotensin-converting enzyme inhibitors) such as lisinopril and "
        "enalapril work by blocking the enzyme that converts angiotensin I to angiotensin "
        "II, thereby reducing vasoconstriction and lowering blood pressure.",
        "Beta-blockers (beta-adrenergic blocking agents) reduce heart rate, decrease "
        "cardiac output, and lower blood pressure. They are used to treat hypertension, "
        "angina, and arrhythmias.",
        "Statins (HMG-CoA reductase inhibitors) are medications primarily used to lower "
        "LDL cholesterol levels in the blood. They are not used for blood sugar control.",
    ],
)

result_4 = critic.verify(request_4)
display_result(result_4, "Test 4: Medical Domain")

# The first 3 should be supported, the 4th should not
print(f"Supported count: {len(result_4.supported_statements)}")
print(f"Unsupported count: {len(result_4.unsupported_statements)}")
print("Test 4 PASSED (review results above)")

---

## 7. Test Case 5: Threshold Sensitivity

Vary tau and theta to observe how thresholds affect the verdict.

In [None]:
# Use the medical domain request from Test 4
# Test with different tau (sentence) thresholds

original_tau = critic.tau
original_theta = critic.theta

tau_values = [0.3, 0.5, 0.7, 0.9]

print(f"{'tau':>6} | {'theta':>6} | {'Supported':>10} | {'Score':>8} | {'Passed':>8}")
print("-" * 52)

for tau in tau_values:
    critic.tau = tau
    critic.theta = original_theta
    result = critic.verify(request_4)
    print(f"{tau:6.1f} | {critic.theta:6.1f} | "
          f"{len(result.supported_statements):10d} | "
          f"{result.support_score:8.4f} | "
          f"{str(result.is_passed):>8}")

# Restore original values
critic.tau = original_tau
critic.theta = original_theta

print("\nThreshold sensitivity test complete.")

In [None]:
# Test with different theta (passage/overall) thresholds
theta_values = [0.3, 0.5, 0.7, 0.9]

print(f"{'tau':>6} | {'theta':>6} | {'Supported':>10} | {'Score':>8} | {'Passed':>8}")
print("-" * 52)

for theta in theta_values:
    critic.tau = original_tau
    critic.theta = theta
    result = critic.verify(request_4)
    print(f"{critic.tau:6.1f} | {theta:6.1f} | "
          f"{len(result.supported_statements):10d} | "
          f"{result.support_score:8.4f} | "
          f"{str(result.is_passed):>8}")

# Restore
critic.tau = original_tau
critic.theta = original_theta

print("\nTheta sensitivity test complete.")

---

## 8. Performance Benchmark

Measure inference time per statement.

In [None]:
# Warm-up run
_ = critic.verify(request_1)

# Benchmark multiple runs
num_runs = 5
times = []

for i in range(num_runs):
    start = time.perf_counter()
    _ = critic.verify(request_4)
    elapsed = time.perf_counter() - start
    times.append(elapsed)

avg_time = sum(times) / len(times)
num_statements = len(request_4.statements)
num_passages = len(request_4.passages)
num_pairs = num_statements * num_passages

print("=" * 50)
print("       PERFORMANCE BENCHMARK")
print("=" * 50)
print(f"  Device:             {critic.device}")
print(f"  Statements:         {num_statements}")
print(f"  Passages:           {num_passages}")
print(f"  Total NLI pairs:    {num_pairs}")
print(f"  Runs:               {num_runs}")
print(f"  Avg total time:     {avg_time:.4f}s")
print(f"  Avg per statement:  {avg_time/num_statements:.4f}s")
print(f"  Avg per NLI pair:   {avg_time/num_pairs:.4f}s")
print("=" * 50)

---

## 9. Schema Validation

In [None]:
# Verify output types and structure
result = critic.verify(request_1)

assert isinstance(result, VerificationResponse), "Output must be VerificationResponse"
assert isinstance(result.is_passed, bool), "is_passed must be bool"
assert isinstance(result.support_score, float), "support_score must be float"
assert isinstance(result.supported_statements, list), "supported_statements must be list"
assert isinstance(result.unsupported_statements, list), "unsupported_statements must be list"

for s in result.supported_statements + result.unsupported_statements:
    assert isinstance(s, StatementVerification), "Each item must be StatementVerification"
    assert isinstance(s.statement, str), "statement must be str"
    assert s.label in ("Supported", "Unsupported"), "label must be Supported or Unsupported"
    assert isinstance(s.confidence_score, float), "confidence_score must be float"
    assert 0.0 <= s.confidence_score <= 1.0, "confidence_score must be in [0, 1]"

# Test JSON serialization
json_output = result.model_dump_json(indent=2)
print("JSON output (first 500 chars):")
print(json_output[:500])
print("\nSchema validation PASSED")

In [None]:
print("\n" + "=" * 50)
print("  ALL TESTS COMPLETED SUCCESSFULLY")
print("=" * 50)