# Ablation Study 2: Gemini Extraction + Kimi-K2 Stack Depth

## Objective
Fix the extractor to Gemini-2.5 (best from Ablation 1), save the first-pass CSV, then study how 2-, 3-, and 5-round stacks of the Kimi-K2 verifier change quality/latency.


In [20]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import csv
import io
import time
from typing import List, Dict, Any, Optional
import warnings
warnings.filterwarnings('ignore')

sns.set_style("whitegrid")
plt.rcParams['figure.figsize'] = (12, 6)
plt.rcParams['font.size'] = 10

BASE_DIR = "/Users/guoshuyan/Desktop/OpenAD"
GOLD_DATA_CSV = os.path.join(BASE_DIR, "gold_annotation.csv")

EXTRACTOR_MODEL = 'Gemini-2.5'
VERIFIER_MODEL = 'Kimi-K2-Thinking'
VERIFIER_ROUNDS = [2, 3, 5]
FIRST_PASS_OUTPUT = os.path.join(BASE_DIR, "ablation2_gemini_round1.csv")
SCHEMA_FIELDS = ['trial_id', 'source_sentence', 'criterion_type', 'ad_domain', 'clinical_concept', 
                 'operator', 'value_lower', 'value_upper', 'units', 'diagnostic_framework', 
                 'severity_stage', 'temporal_scope', 'evidence_type', 'certainty']


## Import Extractor and Verifier Functions from Ablation Study 1


In [21]:
# Import extractor and verifier prompt functions
# (In practice, import from ablation1 module or define here)

def create_extractor_prompt(trial_id: str, criterion_type: str, source_sentence: str) -> str:
    """Create CSV-based extractor prompt."""
    prompt = f"""You are an expert in Alzheimer's disease (AD) clinical trial eligibility extraction.

Your task is to extract ONE structured criterion from a single sentence and output it as CSV.

You MUST output a CSV table with:
- EXACTLY one header row
- EXACTLY one data row

The columns MUST be in this exact order:

trial_id,source_sentence,criterion_type,ad_domain,clinical_concept,operator,value_lower,value_upper,units,diagnostic_framework,severity_stage,temporal_scope,evidence_type,certainty

TRIAL_ID: {trial_id}
CRITERION_TYPE: {criterion_type}
SENTENCE: {source_sentence}

Output ONLY the CSV header row and the single data row.
"""
    return prompt

def create_verifier_prompt(source_sentence: str, extracted_csv_row: str) -> str:
    """Create CSV-based verifier prompt."""
    prompt = f"""You are verifying the correctness of a structured extraction of an AD clinical trial eligibility criterion.

ORIGINAL SENTENCE: {source_sentence}
EXTRACTED CSV ROW: {extracted_csv_row}

For EACH field, output: field,present_in_text,is_correct,correct_value

Output ONLY the CSV table.
"""
    return prompt

def call_llm_extractor(
    model_name: str,
    trial_id: str,
    criterion_type: str,
    source_sentence: str,
) -> Dict[str, Any]:
    """Call LLM for extraction (mocked for offline evaluation)."""
    record = {field: '' for field in SCHEMA_FIELDS}
    record['trial_id'] = trial_id
    record['criterion_type'] = criterion_type
    record['source_sentence'] = source_sentence
    record['ad_domain'] = 'unspecified'
    record['clinical_concept'] = 'unknown'
    record['operator'] = '='
    record['value_lower'] = ''
    record['value_upper'] = ''
    record['units'] = ''
    record['diagnostic_framework'] = ''
    record['severity_stage'] = ''
    record['temporal_scope'] = ''
    record['evidence_type'] = 'text'
    record['certainty'] = 'unknown'
    record['model'] = model_name
    return record

def call_llm_verifier(model_name: str, prompt: str) -> pd.DataFrame:
    """Call LLM for verification."""
    # Mock implementation
    mock_csv = """field,present_in_text,is_correct,correct_value
trial_id,true,true,NCT123
source_sentence,true,true,"MMSE score ≥ 20"
criterion_type,true,true,inclusion
ad_domain,true,true,cognitive
clinical_concept,true,true,MMSE
operator,true,true,≥
value_lower,true,true,20"""
    return pd.read_csv(io.StringIO(mock_csv))

def parse_csv_response(csv_text: str) -> Dict[str, Any]:
    """Parse CSV extraction response."""
    reader = csv.DictReader(io.StringIO(csv_text))
    return next(reader, {})

def apply_verifier_corrections(extracted: Dict[str, Any], verification_df: pd.DataFrame) -> Dict[str, Any]:
    """Apply verifier corrections."""
    corrected = extracted.copy()
    for _, row in verification_df.iterrows():
        field = row['field']
        if row['is_correct'] == False and pd.notna(row.get('correct_value', '')):
            corrected[field] = row['correct_value']
    return corrected


## Stacked Extraction Pipeline


In [22]:
def run_stacked_extraction(
    trial_id: str,
    criterion_type: str,
    source_sentence: str,
    verifier_rounds: int,
    initial_result: Optional[Dict[str, Any]] = None,
) -> tuple:
    """Run stacked extraction with Gemini first-pass + stacked Kimi verifiers."""
    start_time = time.time()
    tokens_used = 0

    # Initial extraction (Gemini) if not provided
    if initial_result is None:
        extractor_prompt = create_extractor_prompt(trial_id, criterion_type, source_sentence)
        extracted = call_llm_extractor(EXTRACTOR_MODEL, trial_id, criterion_type, source_sentence)
        tokens_used += len(extractor_prompt.split()) + 100  # Estimate
    else:
        extracted = initial_result.copy()

    current_result = extracted.copy()
    all_verifications = []

    # Stack Kimi verifier rounds
    for round_idx in range(verifier_rounds):
        csv_row = ','.join([str(current_result.get(field, '')) for field in SCHEMA_FIELDS])
        verifier_prompt = create_verifier_prompt(source_sentence, csv_row)
        verification_df = call_llm_verifier(VERIFIER_MODEL, verifier_prompt)
        verification_df['verifier_round'] = round_idx + 1
        tokens_used += len(verifier_prompt.split()) + 50  # Estimate

        all_verifications.append(verification_df)
        current_result = apply_verifier_corrections(current_result, verification_df)

    elapsed_time = time.time() - start_time

    return current_result, all_verifications, elapsed_time, tokens_used

def calculate_field_accuracy(predicted: Dict[str, Any], gold: pd.Series) -> float:
    """Calculate field-level accuracy."""
    gold_non_empty = gold[gold.notna()]
    if len(gold_non_empty) == 0:
        return 1.0
    
    correct = 0
    for field in gold_non_empty.index:
        if field in predicted:
            pred_val = str(predicted[field]).strip().lower()
            gold_val = str(gold[field]).strip().lower()
            if pred_val == gold_val:
                correct += 1
    
    return correct / len(gold_non_empty) if len(gold_non_empty) > 0 else 0.0

def calculate_over_correction_rate(initial: Dict[str, Any], final: Dict[str, Any], gold: pd.Series) -> float:
    """Calculate over-correction rate (correct fields changed to incorrect)."""
    initial_correct = 0
    final_correct = 0
    total_fields = 0
    
    for field in SCHEMA_FIELDS:
        if pd.notna(gold.get(field, np.nan)):
            total_fields += 1
            gold_val = str(gold[field]).strip().lower()
            
            initial_val = str(initial.get(field, '')).strip().lower()
            final_val = str(final.get(field, '')).strip().lower()
            
            if initial_val == gold_val:
                initial_correct += 1
            if final_val == gold_val:
                final_correct += 1
    
    if total_fields == 0:
        return 0.0
    
    over_corrected = initial_correct - final_correct
    return max(0, over_corrected) / total_fields if total_fields > 0 else 0.0

def calculate_disagreement_reduction(verifications: List[pd.DataFrame]) -> float:
    """Calculate disagreement reduction across verifier layers."""
    if len(verifications) < 2:
        return 0.0
    
    # Count fields where verifiers disagree
    disagreements = 0
    total_fields = len(SCHEMA_FIELDS)
    
    for field in SCHEMA_FIELDS:
        field_values = []
        for vdf in verifications:
            field_row = vdf[vdf['field'] == field]
            if len(field_row) > 0:
                field_values.append(field_row.iloc[0].get('is_correct', True))
        
        # Check if all agree
        if len(set(field_values)) > 1:
            disagreements += 1
    
    return 1.0 - (disagreements / total_fields) if total_fields > 0 else 0.0


## Gemini First-Pass Extraction
Generate (and cache) the single-pass Gemini CSV that seeds all verifier stacks.



In [23]:
def run_gemini_first_pass(dataset: pd.DataFrame, output_path: str = FIRST_PASS_OUTPUT) -> pd.DataFrame:
    """Run Gemini once per row, persist the CSV, and return the DataFrame."""
    if os.path.exists(output_path):
        cached = pd.read_csv(output_path)
        if len(cached) == len(dataset):
            print(f"Found cached Gemini outputs at {output_path}; reusing them.")
            return cached
        print("Cached Gemini outputs found but row count mismatches dataset; re-generating.")

    outputs: List[Dict[str, Any]] = []

    for _, row in dataset.iterrows():
        prompt = create_extractor_prompt(row['trial_id'], row['criterion_type'], row['source_sentence'])
        extraction = call_llm_extractor(
            EXTRACTOR_MODEL,
            row['trial_id'],
            row['criterion_type'],
            row['source_sentence'],
        )

        normalized = {field: extraction.get(field, '') for field in SCHEMA_FIELDS}
        normalized['trial_id'] = row['trial_id']
        normalized['criterion_type'] = row['criterion_type']
        normalized['source_sentence'] = row['source_sentence']
        outputs.append(normalized)

    first_pass_df = pd.DataFrame(outputs)
    first_pass_df.to_csv(output_path, index=False)
    print(f"Saved Gemini first-pass outputs to {output_path} ({len(first_pass_df)} rows)")
    return first_pass_df



In [24]:
def evaluate_stack_depth(
    trial_id: str,
    criterion_type: str,
    source_sentence: str,
    gold_row: pd.Series,
    verifier_rounds: int,
    initial_extraction: Dict[str, Any],
) -> Dict[str, Any]:
    """Evaluate Gemini first-pass + stacked Kimi verifier rounds."""
    single_pass_result = initial_extraction.copy()
    single_pass_accuracy = calculate_field_accuracy(single_pass_result, gold_row)

    stacked_result, verifications, elapsed_time, tokens_used = run_stacked_extraction(
        trial_id,
        criterion_type,
        source_sentence,
        verifier_rounds,
        initial_result=single_pass_result,
    )
    stacked_accuracy = calculate_field_accuracy(stacked_result, gold_row)

    stack_improvement = stacked_accuracy - single_pass_accuracy
    over_correction = calculate_over_correction_rate(single_pass_result, stacked_result, gold_row)
    disagreement_reduction = calculate_disagreement_reduction(verifications)

    return {
        'verifier_rounds': verifier_rounds,
        'trial_id': trial_id,
        'single_pass_accuracy': single_pass_accuracy,
        'stacked_accuracy': stacked_accuracy,
        'stack_improvement': stack_improvement,
        'over_correction_rate': over_correction,
        'disagreement_reduction': disagreement_reduction,
        'elapsed_time': elapsed_time,
        'tokens_used': tokens_used,
    }


def run_ablation2_evaluation(sample_size: Optional[int] = None) -> pd.DataFrame:
    """Run Ablation 2: Gemini extraction cached, then stacked Kimi verifiers."""
    if not os.path.exists(GOLD_DATA_CSV):
        raise FileNotFoundError(f"Gold dataset not found at {GOLD_DATA_CSV}")

    gold_df = pd.read_csv(GOLD_DATA_CSV)
    if sample_size is not None:
        gold_df = gold_df.head(sample_size)

    if gold_df.empty:
        raise ValueError(f"Gold dataset is empty – check {GOLD_DATA_CSV}")

    first_pass_df = run_gemini_first_pass(gold_df, output_path=FIRST_PASS_OUTPUT)
    first_pass_lookup = {
        (row['trial_id'], row['criterion_type'], row['source_sentence']): row.to_dict()
        for _, row in first_pass_df.iterrows()
    }

    all_results = []
    for verifier_rounds in VERIFIER_ROUNDS:
        print(f"Evaluating {verifier_rounds} Kimi verifier rounds…")
        for _, gold_row in gold_df.iterrows():
            key = (gold_row['trial_id'], gold_row['criterion_type'], gold_row['source_sentence'])
            initial_row = first_pass_lookup.get(key)
            if initial_row is None:
                continue

            initial_extraction = {field: initial_row.get(field, '') for field in SCHEMA_FIELDS}
            result = evaluate_stack_depth(
                gold_row['trial_id'],
                gold_row['criterion_type'],
                gold_row['source_sentence'],
                gold_row,
                verifier_rounds,
                initial_extraction,
            )
            all_results.append(result)

    return pd.DataFrame(all_results)


## Run Evaluation


In [25]:
results_df = run_ablation2_evaluation()
results_df.to_csv('ablation2_results.csv', index=False)
print("Evaluation complete. Results saved to ablation2_results.csv")
print("Gemini first-pass CSV saved to", FIRST_PASS_OUTPUT)
print("\nResults summary (by Kimi verifier rounds):")
print(results_df.groupby('verifier_rounds').agg({
    'stacked_accuracy': 'mean',
    'stack_improvement': 'mean',
    'over_correction_rate': 'mean',
    'disagreement_reduction': 'mean',
    'elapsed_time': 'mean',
    'tokens_used': 'mean'
}).round(3))


Found cached Gemini outputs at /Users/guoshuyan/Desktop/OpenAD/ablation2_gemini_round1.csv; reusing them.
Evaluating 2 Kimi verifier rounds…
Evaluating 3 Kimi verifier rounds…
Evaluating 5 Kimi verifier rounds…
Evaluation complete. Results saved to ablation2_results.csv
Gemini first-pass CSV saved to /Users/guoshuyan/Desktop/OpenAD/ablation2_gemini_round1.csv

Results summary (by Kimi verifier rounds):
                 stacked_accuracy  stack_improvement  over_correction_rate  \
verifier_rounds                                                              
2                           0.284                0.0                   0.0   
3                           0.284                0.0                   0.0   
5                           0.284                0.0                   0.0   

                 disagreement_reduction  elapsed_time  tokens_used  
verifier_rounds                                                     
2                                   1.0         0.001      240.20