## Metrics

Notebook that produces all metrics presented in results section.

Unfortunately, annotated data cannot be publicly released due to MIMIC license restrictions

In [None]:
import json
from collections import Counter, defaultdict
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import tiktoken

from llm_discharge_summaries.metrics.distance_functions import (
    delta_distance,
    is_error,
    l1_distance,
)
from llm_discharge_summaries.metrics.krippendorffs_alpha import calc_krippendorffs_alpha

In [None]:
pd.options.display.float_format = "{:,.2f}".format

In [None]:
OUTPUT_DIR = Path.cwd() / "outputs"
ANNOTATED_DIR = OUTPUT_DIR / "annotators_eval"
RAW_GENERATED_DIR = OUTPUT_DIR / "llm_responses"

GPT_4_TURBO_INPUT_COST_PER_1K = 0.01
GPT_4_TURBO_OUTPUT_COST_PER_1K = 0.03
TOKENIZER_NAME = "cl100k_base"

ERROR_COLUMNS = [
    "Missed- Severe",
    "Missed- Minor",
    "Added- Hallucination",
    "Added- Not relevant",
]

## Load data

In [None]:
hadm_id_to_discharge_summary = {
    directory.stem: json.loads((directory / "discharge_summary.json").read_text())
    for directory in RAW_GENERATED_DIR.iterdir()
    if directory.is_dir()
}

Could have been done nicer but raw messages delimited by *** and in the order
- system message
- one shot example
- user physician notes
- llm response
- time taken

In [None]:
delimiter = "\n" + "*" * 80 + "\n"
hadm_id_to_messages = {
    directory.stem: list((directory / "raw_messages.txt").read_text().split(delimiter))
    for directory in RAW_GENERATED_DIR.iterdir()
    if directory.is_dir()
}

## Extractive

Find % of entries that were extractive i.e. directly from source text.

In [None]:
def find_json_strings(json_object):
    strings_list = []

    def process_object(obj):
        if isinstance(obj, str):
            strings_list.append(obj)
        elif isinstance(obj, list):
            for item in obj:
                process_object(item)
        elif isinstance(obj, dict):
            for value in obj.values():
                process_object(value)

    process_object(json_object)
    return strings_list

In [None]:
num_extractive_sentences = 0
num_generated_sentences = 0

for hadm_id in hadm_id_to_messages.keys():
    discharge_summary_json = hadm_id_to_discharge_summary[hadm_id]
    physician_notes_text_lowercase = hadm_id_to_messages[hadm_id][3].lower()

    json_strings = find_json_strings(discharge_summary_json)
    json_sentences_lowercase = [
        sentence.lower()
        for item in json_strings
        for sentence in item.split(". ")
        if sentence != ""
    ]

    num_extractive_sentences += sum(
        1
        for sentence_lowercase in json_sentences_lowercase
        if sentence_lowercase in physician_notes_text_lowercase
    )
    num_generated_sentences += len(json_sentences_lowercase)

num_extractive_sentences / num_generated_sentences

## Token Lengths

In [None]:
tokenizer = tiktoken.get_encoding(TOKENIZER_NAME)

In [None]:
messages = next(iter(hadm_id_to_messages.values()))
# Same for all inputs
prompt_token_length = sum(len(tokenizer.encode(message)) for message in messages[:3])
prompt_token_length

In [None]:
note_message_length_token = []
note_message_length_char = []
for messages in hadm_id_to_messages.values():
    # Message 3 is the physician note message
    note_message = messages[3]
    note_message_length_char.append(len(note_message))
    note_message_length_token.append(len(tokenizer.encode(note_message)))

print(
    np.percentile(note_message_length_char, [25, 50, 75]),
    np.max(note_message_length_char),
)
print(
    np.percentile(note_message_length_token, [25, 50, 75]),
    np.max(note_message_length_token),
)

## Completion metrics

Calc average time and costs

In [None]:
completion_times = [
    float(messages[-2].split(": ")[1]) for messages in hadm_id_to_messages.values()
]

print(np.percentile(completion_times, [25, 50, 75]), np.max(completion_times))
print(np.percentile(completion_times, [25, 50, 75]), np.max(completion_times))

In [None]:
costs = []
for messages in hadm_id_to_messages.values():
    num_input_tokens = sum(len(tokenizer.encode(message)) for message in messages[:4])
    num_output_tokens = len(tokenizer.encode(messages[4]))
    costs.append(
        num_input_tokens / 1000 * GPT_4_TURBO_INPUT_COST_PER_1K
        + num_output_tokens / 1000 * GPT_4_TURBO_OUTPUT_COST_PER_1K
    )
print(np.percentile(costs, [25, 50, 75]), np.max(costs))

## Precision Recall

Load all the eval dfs, needs some pre-processing to get in 'nice' pandas format

In [None]:
eval_dfs = []
for annotator_dir in ANNOTATED_DIR.iterdir():
    if not annotator_dir.is_dir():
        continue
    for hadm_id_dir in annotator_dir.iterdir():
        if not hadm_id_dir.is_dir():
            continue
        hadm_id = hadm_id_dir.stem
        df = pd.read_excel(
            (hadm_id_dir / f"discharge_summary_{hadm_id}.xlsx"),
            engine="openpyxl",
            header=1,
        )
        # Drop empty rows
        df = df.dropna(axis=0, how="all")

        # TODO: Remove when fixed in annotator
        mask = (df["Section"] == "Allergies And Adverse Reaction") & (
            df["Field"].isnull()
        )
        if mask.any():
            df.loc[mask, "Field"] = "Causative Agent"
            df.loc[mask, "Value"] = "No known drug allergies or adverse reactions"

        # Fill empty sections and fields with whatever is above
        df["Section"] = df["Section"].ffill()
        df["Field"] = df["Field"].ffill()
        # Autopopulated fields are not generated by LLM so we can drop them
        # from evaluation
        df = df.loc[df["Value"] != "Autopopulated"]

        # Empty cells are not errors so set to 0
        df[
            [
                "Missed- Severe",
                "Missed- Minor",
                "Added- Hallucination",
                "Added- Not relevant",
            ]
        ] = df[
            [
                "Missed- Severe",
                "Missed- Minor",
                "Added- Hallucination",
                "Added- Not relevant",
            ]
        ].fillna(
            0
        )

        # Help with grouping
        df["Field"] = df["Field"].str.replace(
            r"Causative Agent [0-9]+", "Causative Agent", regex=True
        )
        df["Field"] = df["Field"].str.replace(
            r"Description Of Reaction [0-9]+", "Description Of Reaction", regex=True
        )

        # Help with identification downstream
        df.hadm_id = hadm_id
        df.annotator = annotator_dir.stem
        eval_dfs.append(df)

In [None]:
eval_dfs[0]

Choose between the 2 to group on a per field or section basis. Variable names assume done on a field basis

In [None]:
grouping_key = ["Section", "Field"]
# grouping_key = ["Section"]

Per dataframe error analysis

In [None]:
total_errors_list = []
for idx, eval_df in enumerate(eval_dfs):
    # Sum num errors per group
    df_field_errors = eval_df.groupby(grouping_key)[ERROR_COLUMNS].sum()
    # Add column for number of values per group
    df_field_errors["Num Values"] = eval_df.groupby(grouping_key)["Value"].count()

    # Add column for number of values not found per group
    not_found_count = (
        eval_df[eval_df["Value"] == "Information not found in notes"]
        .groupby(grouping_key)["Value"]
        .count()
    )
    df_field_errors["Not Found"] = not_found_count
    # If no values were found, set to 0
    df_field_errors["Not Found"].fillna(0, inplace=True)

    # Clinical summary is a free text paragraph, so we estimate each sentence as a value
    clinical_summary_text = eval_df[
        (eval_df["Section"] == "Clinical Summary")
        & (eval_df["Field"] == "Clinical Summary")
    ]["Value"].iloc[0]
    estimated_num_sentences = len(clinical_summary_text.split(". "))
    df_field_errors.at[("Clinical Summary", "Clinical Summary"), "Num Values"] = (
        estimated_num_sentences
    )

    # Separate tracking of total errors for micro calculation later
    total_errors_list.append(df_field_errors[ERROR_COLUMNS].sum().sum())

    if idx == 0:
        field_errors = df_field_errors
    else:
        field_errors += df_field_errors

In [None]:
field_errors

GP Practice is never filled (47 values 47 not found). This is because there are no GPs in America! So drop from analysis

In [None]:
field_errors.drop("Gp Practice", inplace=True)

Median number of errors (mean is affected by extreme values)

In [None]:
np.median(total_errors_list)

Calc per field metrics

In [None]:
field_false_positives = (
    field_errors["Added- Hallucination"] + field_errors["Added- Not relevant"]
)
field_false_negatives = field_errors["Missed- Severe"] + field_errors["Missed- Minor"]
field_true_positives = field_errors["Num Values"] - field_false_positives

In [None]:
field_recall = field_true_positives / (field_true_positives + field_false_negatives)
field_precision = field_true_positives / (field_true_positives + field_false_positives)
field_f1 = 2 * (field_precision * field_recall) / (field_precision + field_recall)

Calc mean num elements and not found elements

In [None]:
field_mean_num_elements = field_errors["Num Values"] / len(eval_dfs)
field_mean_not_found = field_errors["Not Found"] / len(eval_dfs)

In [None]:
df_metrics = pd.concat(
    [
        field_mean_num_elements,
        field_mean_not_found,
        field_recall,
        field_precision,
        field_f1,
    ],
    keys=[
        "Average Number of Elements",
        "Proportion Not Found in Notes",
        "Recall",
        "Precision",
        "F1",
    ],
    axis=1,
)

In [None]:
df_metrics

Repeat same calculations but summed across all evaluations to give micro averages

In [None]:
total_errors = field_errors.sum()
total_num_elements = total_errors["Num Values"] / len(eval_dfs)
total_not_found = total_errors["Not Found"] / len(eval_dfs)
total_false_positives = (
    total_errors["Added- Hallucination"] + total_errors["Added- Not relevant"]
)
total_false_negatives = total_errors["Missed- Severe"] + total_errors["Missed- Minor"]
total_true_positives = total_errors["Num Values"] - total_false_positives

In [None]:
micro_recall = total_true_positives / (total_true_positives + total_false_negatives)
micro_precision = total_true_positives / (total_true_positives + total_false_positives)
micro_f1 = 2 * (micro_precision * micro_recall) / (micro_precision + micro_recall)

In [None]:
print(df_metrics.mean(axis=0))
print([micro_recall, micro_precision, micro_f1])

## Inter annotator agreement

Get paired annotations

In [None]:
hadm_id_to_annotator_dfs = defaultdict(list)
for eval_df in eval_dfs:
    hadm_id_to_annotator_dfs[eval_df.hadm_id].append(eval_df)

Each annotator is a 4-d vector of the number of errors of each type

In [None]:
multi_annotated_fields: list[list[tuple[int]]] = []
for hadm_id, annotator_dfs in hadm_id_to_annotator_dfs.items():
    if len(annotator_dfs) == 2:
        annotator_1_fields = (
            eval_dfs[0]
            .groupby(grouping_key)
            .sum()[ERROR_COLUMNS]
            .values.astype(int)
            .tolist()
        )
        annotator_2_fields = (
            eval_dfs[1]
            .groupby(grouping_key)
            .sum()[ERROR_COLUMNS]
            .values.astype(int)
            .tolist()
        )
        multi_annotated_fields.extend(
            [
                # Mypy (fairly) can't infer that the elements are ints
                [tuple(annotator_1_field), tuple(annotator_2_field)]  # type: ignore
                for annotator_1_field, annotator_2_field in zip(
                    annotator_1_fields, annotator_2_fields
                )
            ]
        )
len(multi_annotated_fields)

Calc proportion of annotations that agree if field has error Y/N

In [None]:
1 - sum(
    is_error(multi_annotated_field[0], multi_annotated_field[1])
    for multi_annotated_field in multi_annotated_fields
) / len(multi_annotated_fields)

Calc proportion of annotations that exactly match

In [None]:
1 - sum(
    delta_distance(multi_annotated_field[0], multi_annotated_field[1])
    for multi_annotated_field in multi_annotated_fields
) / len(multi_annotated_fields)

Calc krippendorffs_alpha

In [None]:
calc_krippendorffs_alpha(multi_annotated_fields, delta_distance)

Visualization of distances

In [None]:
observed_distances = [
    l1_distance(multi_annotated_field[0], multi_annotated_field[1])
    for multi_annotated_field in multi_annotated_fields
]

all_annotations = []

all_annotations = [
    annotation
    for multi_annotated_field in multi_annotated_fields
    for annotation in multi_annotated_field
]
expected_distances = [
    l1_distance(annotation_1, annotation_2)
    for annotation_1 in all_annotations
    for annotation_2 in all_annotations
]

In [None]:
# Plot paired distances and inter-items distances in a histogram normalized to 1
plt.hist(observed_distances, alpha=0.5, label="observed", density=True)
plt.hist(expected_distances, alpha=0.5, label="expected", density=True)
plt.legend(loc="upper right")
plt.show()

In [None]:
Counter(all_annotations)