In [None]:
from pathlib import Path
import pandas as pd
import json

from discharge_summaries.schemas.mimic import BHC
import pickle
from discharge_summaries.openai_llm.chat_models import AzureOpenAIChatModel
from discharge_summaries.openai_llm.message import Message, Role
import os
from discharge_summaries.snomed.lookup import SnomedLookup
from discharge_summaries.snomed.phrase_matcher import SnomedPhraseMatcher
from dotenv import load_dotenv
import spacy


In [None]:
load_dotenv()

In [None]:
MIMIC_III_DIR = (
    Path.cwd().parent / "data" / "physionet.org" / "files" / "mimiciii" / "1.4"
)
BHC_FPATH = MIMIC_III_DIR / "BHCS.json"
PHYSICIAN_NOTE_FPATH = MIMIC_III_DIR / "physician_notes_mimic.csv"

SNOMED_DIR = Path.cwd().parent / "data" / "snomed"
TUNED_PHRASE_MATCHER_FPATH = SNOMED_DIR / "tuned_snomed_phrase_matcher.pkl"
PROMPT_MESSAGE_FPATH = Path.cwd() / "prompt_message.txt"
AZURE_ENGINE = "gpt-35-turbo"
AZURE_API_VERSION = "2023-07-01-preview"
EXAMPLE_DIR = Path.cwd() / "example"

In [None]:
snomed_phrase_matcher = pickle.load(TUNED_PHRASE_MATCHER_FPATH.open("rb"))

In [None]:
snomed_lookup = SnomedLookup.load(SNOMED_DIR)

In [None]:
nlp = spacy.load("en_core_sci_lg", disable=["ner"])


In [None]:
llm = AzureOpenAIChatModel(
    api_base=os.getenv("AZURE_OPENAI_ENDPOINT"),
    api_key=os.getenv("AZURE_OPENAI_KEY"),
    api_version=AZURE_API_VERSION,
    engine=AZURE_ENGINE,
    temperature=0,
    timeout=20,
)

In [None]:
bhcs = [BHC(**bhc_dict) for bhc_dict in json.loads(BHC_FPATH.read_text())]
physician_notes = pd.read_csv(PHYSICIAN_NOTE_FPATH)


In [None]:
for idx, bhc in enumerate(bhcs[:100]):
    print(idx)
    print(bhc.assessment_and_plan)
    print("*" * 80)

In [None]:
sample_bhc = bhcs[10]
sample_physician_notes = physician_notes[physician_notes["HADM_ID"] == sample_bhc.hadm_id].sort_values("CHARTTIME")

In [None]:
import json
(EXAMPLE_DIR / "bhc.txt").write_text(sample_bhc.full_text)
with (EXAMPLE_DIR / "bhc.txt").open("a") as f:
    f.write(f"\n{'*'*80}\n{json.dumps(sample_bhc.dict(), indent=4)}")

## First Paragraph

Bit long I think... Check average

In [None]:
examples = "\n\n".join(bhc.assessment_and_plan for bhc in bhcs[:3])
first_para_system_message = Message(
    role=Role.SYSTEM,
    content=f"""You are a consultant doctor completing a medical discharge summary.
Your task is to write the first paragraph of the summary.
The paragraph should be 30 words long.
The paragraph must include the patient's:
- Age
- Gender
- Past medical history
- Reason for hospital admission
This information can be found in the admission note provided by the user.

The following are examples of first paragraphs:
{examples}"""
)

In [None]:
first_para_user_message = Message(
    role=Role.USER,
    content=f"""Admission Note

{sample_physician_notes.iloc[0]["TEXT"]}
Please write the first paragraph of the discharge summary using the admission note and the requirements given in the system message.
"""
)

In [None]:
first_para_response = llm.query([first_para_system_message, first_para_user_message])

In [None]:
message_delimiter = "\n" + ("*" * 80) + "\n"
message_strings = message_delimiter.join(f"{message.role}\n{message.content}" for message in [first_para_system_message, first_para_user_message, first_para_response])
message_strings += f"\n{message_delimiter}\nMIMIC BHC\n{sample_bhc.assessment_and_plan}"
(EXAMPLE_DIR / "first_para.txt").write_text(message_strings)

## Findings

In [None]:
examples = "\n\n".join("\n".join(para.heading for para in bhc.problem_sections if snomed_phrase_matcher(para.heading)) for bhc in bhcs[:3])
findings_system_message = Message(
    role=Role.SYSTEM,
    content=f"""You are a consultant doctor completing a medical discharge summary.
Your task is to list the main clinical findings made during the patient's stay.
Each finding should be on a new line.
Use Snomed CT preferred terms.
This information can be found in the physician note provided by the user.

The following are examples of previous patient's clinical findings:
{examples}"""
)

In [None]:
findings_user_message =  Message(
    role=Role.USER,
    content=f"""Physician Note

{sample_physician_notes.iloc[-1]["TEXT"]}
Please write the list findings using the physician note and the requirements given in the system message.
"""
)

In [None]:
findings_response = llm.query([findings_system_message, findings_user_message])
gpt_findings = findings_response.content.split("\n")

In [None]:
message_delimiter = "\n" + ("*" * 80) + "\n"
output_strings = [f"{message.role}\n{message.content}" for message in [findings_system_message, findings_user_message]]
gpt_heading_strings = "\n".join(sorted(gpt_findings))
output_strings.append(f"{Role.ASSISTANT}\n{gpt_heading_strings}")
heading_strings = "\n".join(sorted(para.heading for para in sample_bhc.problem_sections))
output_strings.append(f"Mimic Headings\n{heading_strings}")
(EXAMPLE_DIR / "findings.txt").write_text(message_delimiter.join(output_strings))

## Retrieval

In [None]:
findings_cuis = [[span.label_ for span in finding_spans] for finding_spans in snomed_phrase_matcher.pipe(gpt_findings)]
finding_to_cuis = {finding: cuis for finding, cuis in zip(gpt_findings, findings_cuis)}

In [None]:
custom_phrase_matcher = SnomedPhraseMatcher(nlp, False)
for cui in {cui for cuis in findings_cuis for cui in cuis}:
    custom_phrase_matcher.add_parent_cui(int(cui), snomed_lookup)

In [None]:
unmatched_finding_cui = -1
for finding, cuis in finding_to_cuis.items():
    if not cuis:
        custom_phrase_matcher._phrase_matcher.add(str(unmatched_finding_cui), [custom_phrase_matcher._nlp(finding)])
        finding_to_cuis[finding] = [unmatched_finding_cui]
        unmatched_finding_cui -= 1
finding_to_cuis

In [None]:
physician_note_spans = custom_phrase_matcher.pipe(sample_physician_notes["TEXT"])

In [None]:
cui_to_spans = defaultdict(list)
for note_idx, note_spans in enumerate(physician_note_spans):
    for span in note_spans:
        cui_to_spans[span.label_].append((note_idx, span))

finding_to_spans = {
    finding: [span for cui in cuis for span in cui_to_spans[cui]]
    for finding, cuis in finding_to_cuis.items()
}

finding_to_spans = {cui: sorted(spans, key=lambda x: (x[0], x[1].start)) for cui, spans in finding_to_spans.items()}


In [None]:
finding_to_spans_simplified = {
    finding: [{"text":span.text, "doc_idx": doc_idx, "start": span.start, "end": span.end} for doc_idx, span in spans]
    for finding, spans in finding_to_spans.items()
}
(EXAMPLE_DIR / "retrieval.json").write_text(json.dumps(finding_to_spans_simplified, indent=4))

In [None]:
tokenized_bhc = list(spacy_tokenizer.pipe([bhc.assessment_and_plan for bhc in bhcs if bhc.assessment_and_plan]))


In [None]:
import numpy as np
np.median([len(doc) for doc in tokenized_bhc])

In [None]:
import spacy

nlp = spacy.load("en_core_sci_lg")
spacy_tokenizer = nlp.tokenizer

In [None]:
len(nlp("""Mr Known lastname 52368 is a 59M w HCV cirrhosis w grade II esophageal varices
admitted w coffee-ground emesis and melena concerning for UGIB,
s/p MICU stay for hypotension."""))