In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import json
import os
import pickle
from collections import defaultdict
from pathlib import Path

import nltk
import numpy as np
import pandas as pd
from bert_score import BERTScorer
from dotenv import load_dotenv
from medcat.cat import CAT
from rouge_score import rouge_scorer
from summac.model_summac import SummaCZS

from discharge_summaries.schemas.medcat import MedCATSpan
from discharge_summaries.schemas.mimic import Record
from discharge_summaries.schemas.output import Paragraph

In [None]:
load_dotenv()
nltk.download("punkt")

UMLS_API_KEY = os.environ.get("UMLS_API_KEY")

UMLS_BASE_URL = "https://uts-ws.nlm.nih.gov/rest"

DATA_DIR = Path.cwd().parent / "data"
GT_DATA_PATH = DATA_DIR / "train.pkl"
PRED_DATA_PATH = Path.cwd() / "output" / "2023_07_18_17_57.json"

MODEL_PATH = Path.cwd().parent / "models" / "umls_sm_pt2ch_533bab5115c6c2d6.zip"

OPEN_API_VERSION = "2023-05-15"
DEPLOYMENT_NAME = "gpt-35-turbo"

In [None]:
with open(GT_DATA_PATH, "rb") as in_file:
    gt_dataset = [Record(**record) for record in pickle.load(in_file)]
gt_dataset = gt_dataset[:10]
len(gt_dataset)

In [None]:
gt_bhcs_full = [sample.discharge_summary.bhc for sample in gt_dataset]
gt_bhcs_paras = [sample.discharge_summary.bhc_paragraphs for sample in gt_dataset]
len(gt_bhcs_full), len(gt_bhcs_paras)

In [None]:
print(gt_bhcs_full[0])

In [None]:
with PRED_DATA_PATH.open() as file_in:
    pred_bhcs_paras = [
        [Paragraph(**para) for para in bhc] for bhc in json.load(file_in)
    ]
pred_bhcs_full = [
    "\n\n".join(f"# {para.heading}: {para.text}" for para in bhc)
    for bhc in pred_bhcs_paras
]

In [None]:
len(gt_bhcs_paras), len(pred_bhcs_paras)

In [None]:
gt_sample = gt_bhcs_paras[0]
gt_headings = [para.heading for para in gt_sample[1:]]
sorted(gt_headings)

In [None]:
cat = CAT.load_model_pack(MODEL_PATH)
# type_ids_filter = ["T047"]
# cui_filters = {
#     cui
#     for type_ids in type_ids_filter
#     for cui in cat.cdb.addl_info["type_id2cuis"][type_ids]
# }
# cat.cdb.config.linking["filters"]["cuis"] = cui_filters

In [None]:
def extract_cuis(text: str, cat: CAT) -> set[str]:
    annotated_text = cat(text)
    return (
        {
            MedCATSpan.from_spacy_span(ent, cat, context="").cui
            for ent in annotated_text.ents
        }
        if annotated_text
        else set()
    )


def extract_cuis_from_bhc_headings(bhc: list[Paragraph], cat: CAT) -> set[str]:
    return {cui for bhc_para in bhc for cui in extract_cuis(bhc_para.heading, cat)}

In [None]:
def cui_to_name(cui: str, cat: CAT) -> str:
    return cat.cdb.get_name(cui)


def cuis_to_names(cuis: set[str], cat: CAT) -> list[str]:
    return sorted(cat.cdb.get_name(cui) for cui in cuis)

## CUI Matching

In [None]:
gt_heading_cuis_to_paras = [
    {cui: para for para in paras[1:] for cui in extract_cuis(para.heading, cat)}
    for paras in gt_bhcs_paras
]
pred_heading_cuis_to_paras = [
    {cui: para for para in paras for cui in extract_cuis(para.heading, cat)}
    for paras in pred_bhcs_paras
]

In [None]:
print(gt_bhcs_full[1])

In [None]:
cat.get_entities("hand")

In [None]:
for idx, (gt_heading_cuis_to_para, pred_heading_cuis_to_para) in enumerate(
    zip(gt_heading_cuis_to_paras, pred_heading_cuis_to_paras)
):
    missed = set(gt_heading_cuis_to_para.keys()) - set(pred_heading_cuis_to_para.keys())
    if missed:
        print(idx)
        print(sorted(cuis_to_names(missed, cat)))
        print(sorted(cuis_to_names(set(pred_heading_cuis_to_para.keys()), cat)))
    # hit = set(gt_heading_cuis_to_para.keys()).intersection(set(pred_heading_cuis_to_para.keys()))
    # print(sorted(hit))
    # print(sorted(gt_heading_cuis_to_para.keys()))
    # print("----")

In [None]:
for gt_sample, pred_sample in zip(gt_bhcs_paras[:1], pred_bhcs_paras[:1]):
    pred_heading_to_evidence = {
        para.heading.lower(): para.evidence for para in pred_sample
    }
    for gt_para in gt_sample:
        if gt_para.heading.lower()[2:] not in pred_heading_to_evidence.keys():
            print(gt_para.heading)
            # print("---")
            # print(gt_para.text)
            # print("---")
            # for evidence in pred_heading_to_evidence[gt_para.heading.lower()[2:]]:
            #     print(evidence)

## Prefix agreement

In [None]:
gt_cuis = [
    extract_cuis_from_bhc_headings(gt_bhc_paras, cat) for gt_bhc_paras in gt_bhcs_paras
]
pred_cuis = [
    extract_cuis_from_bhc_headings(pred_bhc_paras, cat)
    for pred_bhc_paras in pred_bhcs_paras
]

In [None]:
num_hits = 0
num_gts = 0

for gt, pred in zip(gt_cuis, pred_cuis):
    num_hits += len(gt.intersection(pred))
    num_gts += len(gt)

    print(cuis_to_names(gt - gt.intersection(pred), cat))

num_hits / num_gts

In [None]:
for text in gt_dataset[-1].physician_notes:
    print(text.text)

## Document Level Metrics

In [None]:
def average_text_length(texts: list[str]) -> float:
    return np.array([len(text.split()) for text in texts]).mean()


average_text_length(gt_bhcs_full), average_text_length(pred_bhcs_full)

In [None]:
def calc_rouge_score(gold: list[str], pred: list[str]) -> pd.DataFrame:
    scorer = rouge_scorer.RougeScorer(
        ["rouge1", "rouge2", "rougeL", "rougeLsum"], use_stemmer=True
    )
    scores: dict[str, dict[str, list[float]]] = defaultdict(
        lambda: {"precision": [], "recall": [], "f1": []}
    )
    for sample_gold, sample_pred in zip(gold, pred):
        for metric, score in scorer.score(sample_gold, sample_pred).items():
            scores[metric]["precision"].append(score.precision)
            scores[metric]["recall"].append(score.recall)
            scores[metric]["f1"].append(score.fmeasure)

    for metric, score in scores.items():
        for name, value in score.items():
            scores[metric][name] = np.array(value).mean()

    return pd.DataFrame(scores).T


calc_rouge_score(gt_bhcs_full, pred_bhcs_full)

In [None]:
bert_scorer_full = BERTScorer(
    model_type="microsoft/deberta-xlarge-mnli",
    lang="en",
    rescale_with_baseline=True,
    idf=True,
    idf_sents=gt_bhcs_full,
)
P, R, F1 = bert_scorer_full.score(gt_bhcs_full, pred_bhcs_full, verbose=True)
P.mean(), R.mean(), F1.mean()

## Paragraph level metrics

In [None]:
gt_paragraphs = []
pred_paragraphs = []
evidence = []

for gt_bhc_paras, pred_bhc_paras in zip(gt_bhcs_paras, pred_bhcs_paras):
    pred_cui_to_para = {
        next(iter(extract_cuis(pred_para.heading, cat))): pred_para
        for pred_para in pred_bhc_paras
    }
    for gt_para in gt_bhc_paras[1:]:  # TODO add 1st para
        gt_para_cuis = extract_cuis(gt_para.heading, cat)
        matched_paras = [
            pred_cui_to_para[cui].text
            for cui in gt_para_cuis
            if cui in pred_cui_to_para
        ]
        extracts = [
            extract
            for cui in gt_para_cuis
            if cui in pred_cui_to_para
            for extract in pred_cui_to_para[cui].evidence
        ]
        if matched_paras:
            gt_paragraphs.append(gt_para.text)
            pred_paragraphs.append("\n".join(matched_paras))
            evidence.append("\n\n".join(extracts))

len([para for gt_bhc_paras in gt_bhcs_paras for para in gt_bhc_paras]), len(
    gt_paragraphs
), len(pred_paragraphs)

In [None]:
calc_rouge_score(gt_paragraphs, pred_paragraphs)

In [None]:
average_text_length(gt_paragraphs), average_text_length(pred_paragraphs)

In [None]:
model_zs = SummaCZS(
    granularity="paragraph", model_name="vitc", device="cpu"
)  # If you have a GPU: switch to: device="cuda"
# TODO: add SummaCConv

In [None]:
score_zs1 = model_zs.score(evidence[:1], pred_paragraphs[:1])

In [None]:
score_zs1["scores"]

In [None]:
print(evidence[0])

In [None]:
with np.printoptions(precision=3, suppress=True):
    print(score_zs1["images"][0])