In [None]:
import json
from collections import defaultdict
from pathlib import Path
from typing import Callable, Dict, List, Tuple

import numpy as np

# from rouge_score import rouge_scorer
# from summac.model_summac import SummaCZS
from bert_score import BERTScorer
from medcat.cat import CAT
from scipy.optimize import linear_sum_assignment

In [None]:
output_dir = Path().cwd() / "output"
output_gt_path = output_dir / "prsb_example_gt_manual.json"
output_pred_path = output_dir / "prsb_example_completion.json"
json_schema_path = (
    Path().cwd().parent
    / "guidelines"
    / "eDischarge-Summary-v2.1-1st-Feb-21_pydantic.json"
)

In [None]:
output_gt_json = json.loads(output_gt_path.read_text())  # ["Completion"]
output_pred_json = json.loads(output_pred_path.read_text())  # ["Completion"]
json_schema = json.loads(json_schema_path.read_text())

In [None]:
pred_jsons = [output_pred_json for _ in range(10)]
gt_jsons = [output_gt_json for _ in range(10)]

In [None]:
def add_object_metrics_to_section_dict(
    object_metrics: Dict[str, float], metrics: Dict[str, float], section_name: str
):
    for metric_name, metric_value in object_metrics.items():
        metrics[f"{section_name}__{metric_name}"] = metric_value
    return metrics


def evaluate_string_array(
    pred_array: List[str], gt_array: List[str], calc_metric: Callable[[str, str], float]
) -> float:
    if len(gt_array) == 0 and len(pred_array) == 0:
        return 1
    elif len(gt_array) == 0 or len(pred_array) == 0:
        return 0

    cost = np.array([[calc_metric(pred, gt) for pred in pred_array] for gt in gt_array])
    pred_idxs, gt_idxs = linear_sum_assignment(cost, maximize=True)
    return cost[pred_idxs, gt_idxs].mean()


def evaluate_object_array(
    pred_array: List[Dict],
    gt_array: List[Dict],
    object_schema: Dict,
    calc_metric: Callable[[str, str], float],
) -> Dict[str, float]:
    object_properties = list(object_schema["properties"].keys())
    if len(gt_array) == 0 and len(pred_array) == 0:
        return {field: 1 for field in object_properties}
    elif len(gt_array) == 0 or len(pred_array) == 0:
        return {field: 0 for field in object_properties}

    array_metrics: Dict[str, List[float]] = defaultdict(list)
    main_key = object_properties[0]
    cost = np.array(
        [
            [calc_metric(pred[main_key], gt[main_key]) for pred in pred_array]
            for gt in gt_array
        ]
    )
    pred_idxs, gt_idxs = linear_sum_assignment(cost, maximize=True)

    for pred_idx, gt_idx in zip(pred_idxs, gt_idxs):
        for element_name, metric_value in evaluate_json(
            pred_array[pred_idx], gt_array[gt_idx], object_schema, calc_metric
        ).items():
            if type(metric_value) == dict:
                for field, value in metric_value.items():
                    array_metrics[f"{element_name}__{field}"].append(value)
            else:
                array_metrics[element_name].append(metric_value)
    return {
        field: np.mean(metric_values) if len(metric_values) > 1 else metric_values[0]
        for field, metric_values in array_metrics.items()
    }


def evaluate_json(
    pred_json: Dict,
    gt_json: Dict,
    json_schema: Dict,
    calc_metric: Callable[[str, str], float],
) -> Dict[str, float]:
    metrics: Dict[str, float] = {}
    # print(pred_discharge_summary, gt_discharge_summary, json_schema)
    for section_name, section_schema in json_schema["properties"].items():
        if section_schema["type"] == "object":
            object_metrics = evaluate_json(
                pred_json[section_name],
                gt_json[section_name],
                section_schema,
                calc_metric,
            )
            metrics = add_object_metrics_to_section_dict(
                object_metrics, metrics, section_name
            )
        elif section_schema["type"] == "array":
            if section_schema["items"]["type"] == "object":
                object_metrics = evaluate_object_array(
                    pred_json[section_name],
                    gt_json[section_name],
                    section_schema["items"],
                    calc_metric,
                )
                metrics = add_object_metrics_to_section_dict(
                    object_metrics, metrics, section_name
                )
            elif section_schema["items"]["type"] == "string":
                metrics[section_name] = evaluate_string_array(
                    pred_json[section_name],
                    gt_json[section_name],
                    calc_metric,
                )
        elif section_schema["type"] == "string":
            metrics[section_name] = calc_metric(
                pred_json[section_name], gt_json[section_name]
            )
    return metrics


def average_per_field_metrics(field_metrics: Dict[str, float]) -> float:
    return sum(field_metrics.values()) / len(field_metrics)


def run_per_field_dataset_metrics(
    pred_jsons: List[Dict],
    gt_jsons: List[Dict],
    metric_function: Callable[[str, str], float],
) -> Tuple[Dict[str, float], float]:
    metric_list_per_field = defaultdict(list)
    for pred_json, output_json in zip(pred_jsons, gt_jsons):
        metrics = evaluate_json(pred_json, output_json, json_schema, metric_function)
        for field, value in metrics.items():
            metric_list_per_field[field].append(value)
    metrics_per_field = {
        field: np.array(values).mean()
        for field, values in metric_list_per_field.items()
    }
    return metrics_per_field, average_per_field_metrics(metrics_per_field)

In [None]:
def has_content_metric(pred_value: str, gt_value: str) -> float:
    if pred_value and not gt_value:
        return 0
    elif not pred_value and gt_value:
        return 0
    return 1


has_content_per_field_metrics, average_has_content = run_per_field_dataset_metrics(
    pred_jsons, gt_jsons, has_content_metric
)
print(average_has_content)
has_content_per_field_metrics

In [None]:
def calc_bert_score(pred_value: str, gt_value: str, bert_scorer) -> float:
    if not pred_value and not gt_value:
        return 1
    elif not pred_value or not gt_value:
        return 0
    _, _, F1 = bert_scorer.score([pred_value], [gt_value])
    return F1.mean()


scorer = BERTScorer(lang="en", rescale_with_baseline=True)

In [None]:
bert_score_per_field_metrics, average_bert_score = run_per_field_dataset_metrics(
    pred_jsons, gt_jsons, lambda pred, gt: calc_bert_score(pred, gt, scorer)
)
print(average_bert_score)
bert_score_per_field_metrics

In [None]:
MODEL_PATH = (
    Path.cwd().parent
    / "models"
    / "mc_modelpack_snomed_int_16_mar_2022_25be3857ba34bdd5.zip"
)

cat = CAT.load_model_pack(MODEL_PATH)
cat.pipe.force_remove("Status")
cat.pipe.spacy_nlp.pipeline

In [None]:
def calc_snomed_match(pred_value: str, gt_value: str, cat: CAT) -> float:
    pred_cuis = {ent._.cui for ent in cat(pred_value).ents} if pred_value else set()
    gt_cuis = {ent._.cui for ent in cat(gt_value).ents} if gt_value else set()

    precision = (
        len(pred_cuis.intersection(gt_cuis)) / len(pred_cuis) if pred_cuis else 1
    )
    recall = len(gt_cuis.intersection(pred_cuis)) / len(gt_cuis) if gt_cuis else 1
    if precision + recall == 0:
        return 0
    f1 = 2 * precision * recall / (precision + recall)
    return f1

In [None]:
snomed_per_field_metrics, average_snomed = run_per_field_dataset_metrics(
    pred_jsons, gt_jsons, lambda pred, gt: calc_snomed_match(pred, gt, cat)
)
print(average_snomed)
snomed_per_field_metrics