Functions for extraction evaluation

In [None]:
import re
import Levenshtein
from collections import Counter
from spellchecker import SpellChecker
import pandas as pd
import numpy as np
import pdb
import math

# Initialize spell checker
spell = SpellChecker()

def normalize_whitespace(text):
    # Replace newlines with spaces and collapse all whitespace
    return re.sub(r'\s+', ' ', text).strip()


# FREE TEXT EVALUATION

def compute_text_metrics(extracted, gold):
    """
    Compute free-text evaluation metrics:
      - Precision, Recall, and F1 (token level)
      - Normalized Levenshtein distance (LD_norm)
      - Normalized misspelled words ratio (W_norm)
    """

    # Normalize whitespace in both texts
    extracted = normalize_whitespace(extracted)
    gold = normalize_whitespace(gold)
    
    # Tokenize (splitting by whitespace)
    tokens_ext = extracted.split()
    tokens_gold = gold.split()
    
    # Compute precision, recall, and F1 score using token counts
    count_ext = Counter(tokens_ext)
    count_gold = Counter(tokens_gold)
    intersection = sum((count_ext & count_gold).values())
    precision = intersection / len(tokens_ext) if tokens_ext else 0
    recall = intersection / len(tokens_gold) if tokens_gold else 0
    f1 = 2 * precision * recall / (precision + recall) if (precision + recall) else 0

    # Normalized Levenshtein distance
    ld = Levenshtein.distance(extracted, gold)
    ld_norm = min(ld / len(gold), 1.0) if gold else 1.0
    
    # Misspelled words: count words in extracted text that are unknown
    misspelled = spell.unknown(tokens_ext)
    W_norm = len(misspelled) / len(tokens_gold) if tokens_gold else 0
    W_norm = min(W_norm,1)

    return {
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'ld_norm': ld_norm,
        'W_norm': W_norm
    }


# IMAGE EVALUATION

def compute_image_penalty(expected_images, extracted_images=0):
    """
    Compute a simple image penalty based on the number of expected images minus
    the number of extracted images.
    """
    penalty = max(expected_images - extracted_images, 0)
    return penalty

# Table Markdown Parsing Function
def markdown_to_df(md):
    """
    Parse a markdown table string into a pandas DataFrame.
    """
    # Split by lines and ignore empty lines.
    lines = [line.strip() for line in md.splitlines() if line.strip()]
    if len(lines) < 2:
        return pd.DataFrame()
    header_line = lines[0]
    headers = [h.strip() for h in header_line.strip('|').split('|')]
    data_rows = []
    for line in lines[2:]:
        # Split by pipe and strip whitespace.
        row = [cell.strip() for cell in line.strip('|').split('|')]
        if row:
            data_rows.append(row)
    return pd.DataFrame(data_rows, columns=headers)

# TABLE EVALUATION FOR MARKDOWN TABLES
def compute_table_metrics_markdown(extracted_md, gold_md):
    """
    Compute metrics for a markdown table.
    If no table is detected (i.e. the extracted markdown is empty or yields an empty DataFrame),
    return a dictionary with default values and a detected flag of 0.
    Otherwise, compute the usual metrics and set detected to 1.
    """
    df_ext = markdown_to_df(extracted_md)
    df_gold = markdown_to_df(gold_md)
    
    if df_ext.empty and df_gold.empty:  # Both tables are missing
        return {
            'row_score': 1,  # Perfect score since both are absent
            'col_score': 1,  # Perfect score since both are absent
            'precision_table': 1,  # No false positives
            'recall_table': 1,     # No false negatives
            'f1_table': 1,         # Since both precision and recall are perfect
            'ld_table_norm': 0,    # No changes needed
            'W_table_norm': 0,     # No misspellings as there's nothing present
            'detected': 1          # Indicate detection of no tables successfully
        }

    if df_ext.empty:
        # Table wasn’t detected.
        return {
            'row_score': 0,
            'col_score': 0,
            'precision_table': 0,
            'recall_table': 0,
            'f1_table': 0,
            'ld_table_norm': 1,       # A high penalty value
            'W_table_norm': 1,        # Assume worst-case misspelling ratio
            'detected': 0
        }
    
    # If the table exists
    flat_ext = df_ext.values.flatten().tolist()
    flat_gold = df_gold.values.flatten().tolist()
    
    count_ext = Counter(flat_ext)
    count_gold = Counter(flat_gold)
    intersection = sum((count_ext & count_gold).values())
    precision = intersection / len(flat_ext) if flat_ext else 0
    recall = intersection / len(flat_gold) if flat_gold else 0
    f1_table = 2 * precision * recall / (precision + recall) if (precision + recall) else 0

    # Row and Column Scores
    rows_ext, rows_gold = (df_ext.shape[0], df_gold.shape[0])
    row_score = 1 - abs(rows_ext - rows_gold) / rows_gold if rows_gold > 0 else 0
    cols_ext, cols_gold = (df_ext.shape[1], df_gold.shape[1])
    col_score = 1 - abs(cols_ext - cols_gold) / cols_gold if cols_gold > 0 else 0

    # Serialize tables 
    serialized_ext = df_ext.to_csv(index=False)
    serialized_gold = df_gold.to_csv(index=False)
    ld_table = Levenshtein.distance(serialized_ext, serialized_gold)
    ld_table_norm = ld_table / len(serialized_gold) if serialized_gold else 1
    ld_table_norm = min(ld_table_norm, 1)

    # Misspelled words in table text.
    tokens_table_ext = serialized_ext.split()
    tokens_table_gold = serialized_gold.split()
    misspelled_table = spell.unknown(tokens_table_ext)
    W_table_norm = len(misspelled_table) / len(tokens_table_gold) if tokens_table_gold else 1
    W_table_norm = min(W_table_norm,1)
    result = {
        'row_score': max(row_score, 0),
        'col_score': max(col_score, 0),
        'precision_table': precision,
        'recall_table': recall,
        'f1_table': f1_table,
        'ld_table_norm': ld_table_norm,
        'W_table_norm': W_table_norm,
        'detected': 1
    }
    return result


# AGGREGATE TABLE SCORES

def aggregate_table_scores(table_metrics_list, lambda_weight):
    """
    Aggregate scores from multiple tables.
    If a particular table's 'detected' value is 0, then it is treated as missing.
    Here, if no table is detected at all, I return TPRES=0.
    Otherwise, I calculate the arithmetic mean of the per-table composite scores.
    The per-table composite score is defined as:
      Score_table = f1_table + row_score + col_score - lambda_weight*(ld_table_norm + W_table_norm)
    """

    total = 0
    for metrics in table_metrics_list:
        score = (metrics['f1_table'] + metrics['row_score'] + metrics['col_score'] -
                 lambda_weight * (metrics['ld_table_norm'] + metrics['W_table_norm']))
        print("Score for table:", score)
        total += score
    avg_score = total / len(table_metrics_list)
    TPRES = 1
    return avg_score, TPRES



# COMPOSITE SECTION SCORE

def compute_section_score(text_metrics, table_score, image_penalty, presence, weights, lambda_weight):
    """
    Compute the composite score for a given section by squashing each component (free text,
    table, and image penalty) into the range [-1, 1] using tanh, then combining them.
    
    Raw formula:
        raw_section_score = PRESENCE * [ 
            w_text * (F1_text - lambda*(ld_norm + W_norm)) +
            w_table * (TPRES * table_score) -
            w_img * (image_penalty)
        ]
    
    Then  squash (normalize) each component and the final score.
    """
    w_text = weights.get('w_text', 1.0)
    w_table = weights.get('w_table', 1.0)
    w_img = weights.get('w_img', 1.0)
    
    # Compute raw free text component.
    raw_text_component = text_metrics['f1'] - lambda_weight * (text_metrics['ld_norm'] + text_metrics['W_norm'])
    norm_text_component = math.tanh(raw_text_component)
    print("Normalized text component:", norm_text_component)
    # Table component 
    raw_table_component = table_score
    norm_table_component = math.tanh(raw_table_component)
    print("Normalized table component:", norm_table_component)
    # Squash the image penalty
    norm_image_penalty = math.tanh(image_penalty)
    print("Normalized image penalty:", norm_image_penalty)
    # Combine normalized components
    raw_section_score = w_text * norm_text_component + w_table * norm_table_component - w_img * norm_image_penalty
    
    # Multiply by presence indicator and then squash the overall score.
    print("Raw section score before presence:", raw_section_score)
    final_section_score = presence * math.tanh(raw_section_score)
    print("Final section score:", final_section_score)
    return final_section_score

# AGGREGATE PROTOCOL SCORE

def aggregate_protocol_score(section_scores, section_weights):
    """
    Aggregate the composite scores of the four sections into an overall protocol score.
    """
    total_score = sum(section_weights[sec] * section_scores.get(sec, 0) for sec in section_weights)
    return total_score



Template to use the functions

In [None]:

if __name__ == "__main__":
    gold_free_text = ""
    extracted_free_text = ""    
    text_metrics = compute_text_metrics(extracted_free_text, gold_free_text)
    
    print("Text Metrics:", text_metrics)
  
    #table 1
    gold_table_md = ("")
    extracted_table_md = ("")

    table_metrics = compute_table_metrics_markdown(extracted_table_md, gold_table_md)

    #table 2
    gold_table_md = ("")
    extracted_table_md = ("")


    table_metrics_2 = compute_table_metrics_markdown(extracted_table_md, gold_table_md)

    table_metrics_list = [table_metrics, table_metrics_2] 

    #print table metrics
    print("Table Metrics List:")
    for i, metrics in enumerate(table_metrics_list):
        print(f"Table {i+1} Metrics:", metrics)

    lambda_weight = 0.5  # Equal weighting.
    agg_table_score, TPRES = aggregate_table_scores(table_metrics_list, lambda_weight)

    print("Aggregated Table Score:", agg_table_score)
    
    image_penalty = compute_image_penalty(expected_images=0)
    
    # section is present.
    presence = 1
    
    # Define component weights.
    comp_weights = {'w_text': 1, 'w_table': 1, 'w_img': 1}
    
    # Compute composite section score.
    clinical_section_score = compute_section_score(text_metrics, agg_table_score, image_penalty, presence, comp_weights, lambda_weight)
    print("Section Score:", clinical_section_score)
    
