In [None]:
import json
from collections import defaultdict
from pathlib import Path
from typing import Callable, Dict, List

import numpy as np
from scipy.optimize import linear_sum_assignment

# from rouge_score import rouge_scorer
# from summac.model_summac import SummaCZS

In [None]:
output_dir = Path().cwd() / "output"
output_gt_path = output_dir / "prsb_example_gt_manual.json"
output_pred_path = output_dir / "prsb_example_completion.json"
json_schema_path = (
    Path().cwd().parent
    / "guidelines"
    / "eDischarge-Summary-v2.1-1st-Feb-21_pydantic.json"
)

In [None]:
output_gt_json = json.loads(output_gt_path.read_text())  # ["Completion"]
output_pred_json = json.loads(output_pred_path.read_text())  # ["Completion"]
json_schema = json.loads(json_schema_path.read_text())

In [None]:
def has_content_metric(pred_value: str, gt_value: str) -> float:
    if pred_value and not gt_value:
        return 0
    elif not pred_value and gt_value:
        return 0
    return 1


def add_object_metrics_to_section_dict(
    object_metrics: Dict[str, float], metrics: Dict[str, float], section_name: str
):
    for metric_name, metric_value in object_metrics.items():
        metrics[f"{section_name}__{metric_name}"] = metric_value
    return metrics


def evaluate_string_array(
    pred_array: List[str], gt_array: List[str], calc_metric: Callable[[str, str], float]
) -> float:
    if len(gt_array) == 0 and len(pred_array) == 0:
        return 1
    elif len(gt_array) == 0 or len(pred_array) == 0:
        return 0

    cost = np.array([[calc_metric(pred, gt) for pred in pred_array] for gt in gt_array])
    pred_idxs, gt_idxs = linear_sum_assignment(cost, maximize=True)
    return cost[pred_idxs, gt_idxs].mean()


def evaluate_object_array(
    pred_array: List[Dict],
    gt_array: List[Dict],
    object_schema: Dict,
    calc_metric: Callable[[str, str], float],
) -> Dict[str, float]:
    object_properties = list(object_schema["properties"].keys())
    if len(gt_array) == 0 and len(pred_array) == 0:
        return {field: 1 for field in object_properties}
    elif len(gt_array) == 0 or len(pred_array) == 0:
        return {field: 0 for field in object_properties}

    array_metrics: Dict[str, List[float]] = defaultdict(list)
    main_key = object_properties[0]
    cost = np.array(
        [
            [calc_metric(pred[main_key], gt[main_key]) for pred in pred_array]
            for gt in gt_array
        ]
    )
    pred_idxs, gt_idxs = linear_sum_assignment(cost, maximize=True)

    for pred_idx, gt_idx in zip(pred_idxs, gt_idxs):
        for element_name, metric_value in evaluate_json(
            pred_array[pred_idx], gt_array[gt_idx], object_schema, calc_metric
        ).items():
            if type(metric_value) == dict:
                for field, value in metric_value.items():
                    array_metrics[f"{element_name}__{field}"].append(value)
            else:
                array_metrics[element_name].append(metric_value)
    return {
        field: np.mean(metric_values) if len(metric_values) > 1 else metric_values[0]
        for field, metric_values in array_metrics.items()
    }


def evaluate_json(
    pred_json: Dict,
    gt_json: Dict,
    json_schema: Dict,
    calc_metric: Callable[[str, str], float],
):
    metrics: Dict[str, float] = {}
    # print(pred_discharge_summary, gt_discharge_summary, json_schema)
    for section_name, section_schema in json_schema["properties"].items():
        if section_schema["type"] == "object":
            object_metrics = evaluate_json(
                pred_json[section_name],
                gt_json[section_name],
                section_schema,
                calc_metric,
            )
            metrics = add_object_metrics_to_section_dict(
                object_metrics, metrics, section_name
            )
        elif section_schema["type"] == "array":
            if section_schema["items"]["type"] == "object":
                object_metrics = evaluate_object_array(
                    pred_json[section_name],
                    gt_json[section_name],
                    section_schema["items"],
                    calc_metric,
                )
                metrics = add_object_metrics_to_section_dict(
                    object_metrics, metrics, section_name
                )
            elif section_schema["items"]["type"] == "string":
                metrics[section_name] = evaluate_string_array(
                    pred_json[section_name],
                    gt_json[section_name],
                    calc_metric,
                )
        elif section_schema["type"] == "string":
            metrics[section_name] = calc_metric(
                pred_json[section_name], gt_json[section_name]
            )
    return metrics

In [None]:
metrics_per_field = evaluate_json(
    output_pred_json, output_gt_json, json_schema, has_content_metric
)
print(np.mean(list(metrics_per_field.values())))
metrics_per_field