In [None]:
import json
from pathlib import Path
from typing import Callable, Dict, List

import numpy as np
from scipy.optimize import linear_sum_assignment

# from rouge_score import rouge_scorer
# from summac.model_summac import SummaCZS

In [None]:
output_dir = Path().cwd() / "output"
output_gt_path = output_dir / "prsb_example_gt.json"
output_pred_path = output_dir / "prsb_example.json"
json_schema_path = (
    Path().cwd().parent
    / "guidelines"
    / "eDischarge-Summary-v2.1-1st-Feb-21_schema.json"
)

In [None]:
output_gt_json = json.loads(output_gt_path.read_text())["Completion"]
output_pred_json = json.loads(output_pred_path.read_text())["Completion"]
json_schema = json.loads(json_schema_path.read_text())

In [None]:
def dummy_metric(pred_value: str, gt_value: str) -> float:
    return 1 if pred_value == gt_value else 0


def evaluate_array(
    pred_array: List[Dict],
    gt_array: List[Dict],
    array_schema: Dict,
    calc_metric: Callable[[str, str], float],
) -> Dict[str, float]:
    array_fields = list(array_schema["items"]["properties"].keys())
    if len(gt_array) == 0 and len(pred_array) == 0:
        return {field: 1 for field in array_fields}
    elif len(gt_array) == 0 or len(pred_array) == 0:
        return {field: 0 for field in array_fields}
    else:
        array_metrics: Dict[str, List[float]] = {field: [] for field in array_fields}

    main_key = array_fields[0]
    cost = np.array(
        [
            [calc_metric(pred[main_key], gt[main_key]) for pred in pred_array]
            for gt in gt_array
        ]
    )
    pred_idxs, gt_idxs = linear_sum_assignment(cost, maximize=True)

    for pred_idx, gt_idx in zip(pred_idxs, gt_idxs):
        for element_name, metric_value in evaluate_json(
            pred_array[pred_idx], gt_array[gt_idx], array_schema["items"], calc_metric
        ).items():
            array_metrics[element_name].append(metric_value)
    return {
        field: np.mean(metric_values) if len(metric_values) > 1 else metric_values[0]
        for field, metric_values in array_metrics.items()
    }


def evaluate_json(
    pred_json: Dict,
    gt_json: Dict,
    json_schema: Dict,
    calc_metric: Callable[[str, str], float],
):
    metrics = {}
    # print(pred_discharge_summary, gt_discharge_summary, json_schema)
    for section_name, section_schema in json_schema["properties"].items():
        if section_schema["type"] == "object":
            metrics[section_name] = evaluate_json(
                pred_json[section_name],
                gt_json[section_name],
                section_schema,
                calc_metric,
            )
        elif section_schema["type"] == "array":
            metrics[section_name] = evaluate_array(
                pred_json[section_name],
                gt_json[section_name],
                section_schema,
                calc_metric,
            )
        elif section_schema["type"] == "string":
            metrics[section_name] = calc_metric(
                pred_json[section_name], gt_json[section_name]
            )
    return metrics


evaluate_json(output_pred_json, output_gt_json, json_schema, dummy_metric)