# Imports

In [None]:
!pip install -q bert_score transformers pandas numpy torch pydantic tqdm bitsandbytes

In [None]:
# from langchain_huggingface import HuggingFacePipeline
# from langchain_core.prompts import PromptTemplate
import torch
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer, pipeline, BitsAndBytesConfig, BertModel, BertTokenizer
from bert_score import BERTScorer

import numpy as np
import pandas as pd
from typing import List, Any, Tuple
import json
import time
from tqdm import tqdm
import logging
import gc

# Functions

In [None]:
# evaluation functions
def calculate_bert(original: str, summary: str, scorer: BERTScorer) -> float:
  '''This function serves as a relevance calculation between the two texts'''
  _, __, f1 = scorer.score([original], [summary])
  return f1

def nli(m: AutoModelForSequenceClassification, t: AutoTokenizer, original: str, summary_sentences: List[str], hyperparameters: dict) -> bool:
  '''This function serves as a measure of information faithfulness'''
  claims = []
  for summary_sentence in summary_sentences:
    input_tokens = t(original, summary_sentence, return_tensors="pt", truncation=True).to(m.device)
    with torch.no_grad():
      o = m(**input_tokens)
      probs = torch.softmax(o.logits, dim=1)
      index = np.dot(probs.to("cpu"), np.array([0, 1, 2]))
      claims.append(index)
  return np.mean(claims)

def simplicity(m: AutoModelForCausalLM, t: AutoTokenizer, sp: str, summary: str, hyperparameters: dict) -> bool:
  '''This function serves as a measure of how simple the text is to understand'''
  query = f"""
  Evaluate the provided summary against the source text based on your assigned role.

  Return your response ONLY as a JSON object that adheres strictly to the following schema.

  ### JSON SCHEMA ###
  ```json{{
    "simplicity_score": "INTEGER (1 to 5)",
    "readability_critique": "STRING (A brief, two-sentence explanation of why the score was assigned, focusing on vocabulary, sentence length, and flow.)",
    "most_confusing_term": "STRING (The single word or phrase that would be most challenging for your persona, or 'N/A' if none.)"
  }}```
  ### END JSON SCHEMA ###

  Summary:
  ---
  {summary}
  ---
  """
  o, i, __ = generate(m, t, query, hyperparameters, sp=sp)
  return o, i

In [None]:
# functions for generation
def generate(model: AutoModelForCausalLM, tokenizer: AutoTokenizer, query: str, hyperparameters: dict, sp: str=None) -> Tuple[str, Any, str]:
  messages = []
  if sp != None:
      messages.append({
          "role": "system",
          "content": sp
          }
      )
  messages.append({
      "role": "user",
      "content": query
  })


  input_text = tokenizer.apply_chat_template(
      messages,
      tokenize=False,
      add_generation_prompt=True
  )
  input_tokens = tokenizer.apply_chat_template(
      messages,
      tokenize=True,
      add_generation_prompt=True,
      return_tensors="pt"
  ).to(model.device)

  with torch.no_grad():
    output = model.generate(input_ids=input_tokens,
                            return_dict_in_generate=True,
                            output_scores=False,
                            **hyperparameters)
  return output, input_tokens, input_text


def process_output(o, i, model_final_tag, model_final_tag_end, model_output_start, model_output_end):
  output_no_input = tokenizer.decode(o.sequences[0][len(i[0]):])
  output_no_thinking = output_no_input[output_no_input.index(model_final_tag)+len(model_final_tag):output_no_input.index(model_final_tag_end)]
  output_no_spaces = output_no_thinking.replace("\n", "")
  output_only_json = output_no_spaces[output_no_spaces.index(model_output_start)+len(model_output_start):output_no_spaces.rindex(model_output_end)]
  example_output = json.loads(output_only_json)
  return example_output

# Configs

In [None]:
# configs
system_prompt = """
You are a highly reliable and expert Clinical Data Abstraction Agent specializing in natural language processing of Electronic Health Records (EHRs). Your primary goal is to transform unstructured clinical text into reliable, structured, and actionable data or patient-centric summaries, acting with the expertise of a medical writer and public health educator.

RULES:
1. Output Format (CRITICAL): The final and ONLY output must be a valid JSON object enclosed in ```json ... ``` (triple backticks). Do not include any text, reasoning, or verification output outside these fences.
2. Model Role: Strictly adhere to the functional role of a Data Abstraction Agent and a Medical Writer/Educator.
3. Traceability & Grounding: All extracted facts and summaries must be directly traceable to the provided clinical text. Do not invent or assume information. This rule is essential for liability mitigation.
4. Jargon Control: Use the vocabulary and complexity level strictly defined in the <readability> field.
5. Liability Mitigation: Before generating the final JSON, you must first output an internal verification step to ensure all CRITICAL actions are captured and correctly translated.
6. Handling Ambiguity/Absence: If a required data point is completely absent or ambiguous, populate that field with "N/A" or "Not Documented". Do not use placeholder text from the schema (e.g., do not output "[Medication Name]").
"""
query_prompt = """
---TASK INSTRUCTION---
Using the text provided in the <clinical_note> section to summarize the diagnoses and medications listed under <medications> and <diagnoses> sections, generate the required structured JSON output.

**Target Directives (Safety First):**
1. **Source Grounding & Verification (CRITICAL):** Before generating the final JSON, you **MUST** internally verify all facts and extracted actions against the source note. List all extracted actions and warning signs here, confirming their priority (CRITICAL, URGENT, ROUTINE).
2. **Reading Level (CRITICAL):** Generate all fields using the <readbility> reading level.
3. **Diagnosis Breakdown (CRITICAL):** Provide the diagnoses as defined in the schema. The <readbility> reading level MUST be strictly applied.
4. **Medication Fidelity (CRITICAL):** All extracted medications must include the new `status` field. All medication fields (`med_name`, `regimen`, `status`, `purpose`) must be populated. The <readbility> reading level MUST be strictly applied to regimen and purpose medication fields.
5. **Jargon Guardrails (CRITICAL):** The basic reading level MUST be strictly applied to the `warning_signs` array.

---INTERNAL VERIFICATION---
[LLM must insert the verification list here, per Directive 1]
---END VERIFICATION---
FINAL COMMAND: The JSON object MUST contain every CRITICAL action item listed in the verification step above. Failure to reconcile is grounds for system termination.

---JSON SCHEMA DEFINITION---
{{
  "diagnosis_list": [
    {{
      "original": "from input <diagnoses>",
      "new": "rewritten based on <readability> level"
    }}
  ],

  "medication_list": [
    {{
      "med_name": "[Medication Name]",
      "regimen": "[Dose and Frequency]",
      "status": "[New | Continued | Discontinued]",
      "purpose": "[medication purpose]",
    }}
  ],
  "follow_up_instructions": [
    {{"action": "Schedule an appointment with [Primary Care Provider/Specialist]", "timeframe": "[e.g., within 7 days]"}},
    {{"action": "Take all new medications as prescribed", "timeframe": "Ongoing"}},
    {{"action": "Limit [Activity] and avoid [Food/Drug]", "timeframe": "[e.g., for 6 weeks]"}}
  ],
  "warning_signs": [
    "[List of 3-5 signs that require a return to the ED or immediate call to the doctor, written in simple language]"
  ]
}}
---END SCHEMA---

<readability>
{}
</readability>

<clinical_note>
{}
</clinical_note>

<diagnoses>
{}
</diagnoses>

<medications>
{}
</medications>
"""

model_final_tag = "<|end|><|start|>assistant<|channel|>final"
model_final_tag_end = "<|return|>"
model_output_start = "```json"
model_output_end = "```"
basic = "Simplest term, approx. 4.5-grade, e.g., 'Extra water (fluid) has built up inside your belly area.'"
intermediate = "7th-grade term, e.g., 'A buildup of fluid in the space inside your abdomen or peritoneal cavity.'"
advanced = "Specialist detail/MOA, e.g., 'Hepatic sinusoidal hypertension leads to splanchnic vasodilation and subsequent overflow due to RAAS activation.'"
output_eval_rules = "RULES\n1. Output Format (CRITICAL): The final and ONLY output must be a valid JSON object enclosed in ```json ... ``` (triple backticks). Do not include any text, reasoning, or verification output outside these fences."
generation_model_name = "openai/gpt-oss-20b"
evaluation_model_name = "facebook/bart-large-mnli"
generation_configs = dict(max_new_tokens=5000, temperature=0.5, do_sample=True, top_k=10)
evaluation_configs = dict(do_sample=False, max_new_tokens=5000)
system_prompt_evaluation_basic = "You are a 4.5 grade student who can understand simple terms. Please evaluate the medical summary (listed under Summary:) to see if you would understand this. Please return in the requested format."
system_prompt_evaluation_intermediate = "You are a 7th-grade student who can understand the following level of clincal language: A buildup of fluid in the space inside your abdomen or peritoneal cavity.'. Please evaluate the medical summary (listed under Summary:) to see if you would understand this."
system_prompt_evaluation_advanced = "You are a specialist (physician or researcher) with advanced clinical languange. Please evaluate the medical summary (listed under Summary:) to see if you would understand this.  Please return in the requested format."

# Load Data

In [None]:
num = 5

In [None]:
full_data = pd.read_csv("MIMIC-IV Notes Datathon v4 20251107(Sheet1).csv", header=1)
sampled_data = full_data.sample(29, random_state=0)
sampled_data.loc[0,:] = full_data.iloc[0,:]
data = sampled_data.iloc[-num:,:]
data.reset_index(inplace=True)
data.to_csv("sampled_data.csv")

# Generate

In [None]:
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
file_handler = logging.FileHandler('app.log')
file_handler.setLevel(logging.DEBUG) # Log all messages to the file
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')

# 5. Add the formatter to the handlers
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)

In [None]:
model = AutoModelForCausalLM.from_pretrained(generation_model_name, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(generation_model_name)

In [None]:
processed_outputs = []

In [None]:
for level in [basic, intermediate, advanced]:
  for ind in tqdm(data.index, total=num):
    example_query = query_prompt.format(level, data.loc[ind, "Brief Hospital Course"], data.loc[ind, "Diagnosis List"], data.loc[ind, "Medication List"])
    start = time.perf_counter()
    o, i, _ = generate(model, tokenizer, example_query, generation_configs, sp=system_prompt)
    end = time.perf_counter()
    try:
      example_output = process_output(o, i, model_final_tag, model_final_tag_end, model_output_start, model_output_end)
      processed_outputs.append([example_output, end-start])
    except:
      try:
        example_output = process_output(o, i, model_final_tag, model_final_tag_end, "", "")
        processed_outputs.append([example_output, end-start])
      except Exception as e:
        logger.error(f"Error with processing output: {e}")
        processed_outputs.append(None)
        pass
    json.dump(processed_outputs, open("original_output.json", "w"), indent=4)

# Evaluate

In [None]:
with open("example_output.json") as outfile:
  processed_outputs = json.load(outfile)

data = pd.read_csv("sampled_data.csv")

In [None]:
model = AutoModelForCausalLM.from_pretrained(generation_model_name, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(generation_model_name)

In [None]:
nli_model = AutoModelForSequenceClassification.from_pretrained(evaluation_model_name, device_map="auto")
nli_tokenizer = AutoTokenizer.from_pretrained(evaluation_model_name)
label_dictionary = nli_model.config.id2label

In [None]:
scorer = BERTScorer(model_type="bert-base-uncased")

In [None]:
all_outputs = []
for example_output in processed_outputs:
  try:
    unrolled = [json.dumps(value, indent=4) for value in example_output[0].values()]
    all_outputs.append(unrolled)
  except:
    all_outputs.append(None)

In [None]:
relevance_scores = []

In [None]:
for ind, example_output in enumerate(processed_outputs):
  try:
    relevance = float(calculate_bert(data.loc[ind % num, "Brief Hospital Course"], json.dumps(example_output[0], indent=4), scorer))
    relevance_scores.append(relevance)
  except Exception as e:
    print(e)
    relevance_scores.append(None)

In [None]:
correctness_scores = []

In [None]:
for index1, example_output in enumerate(all_outputs):
  try:
    correctness = nli(nli_model, nli_tokenizer, data.loc[index1 % num, "Brief Hospital Course"], example_output, evaluation_configs)
    correctness_scores.append(float(correctness))
  except:
    correctness_scores.append(None)

In [None]:
simplicity_scores = []

In [None]:
for index2, example_output in tqdm(enumerate(processed_outputs), total=len(processed_outputs)):
  eval_template = ""
  if index2 < num:
    eval_template = system_prompt_evaluation_basic
  elif index2 >= num and index2 < num*2:
    eval_template = system_prompt_evaluation_intermediate
  else:
    eval_template = system_prompt_evaluation_advanced
  try:
    if example_output == None:
      simplicity_scores.append(None)
      continue
    simple, inp = simplicity(model, tokenizer, eval_template+"\n"+output_eval_rules, json.dumps(example_output), evaluation_configs)
    v = process_output(simple, inp, model_final_tag, model_final_tag_end, model_output_start, model_output_end)
    simplicity_scores.append(v["simplicity_score"])
  except Exception as e:
    print(e)
    simplicity_scores.append(None)

# Re-process for UI

In [None]:
num = 1

In [None]:
example = 0

In [None]:
basic = processed_outputs[0:num][example][0]
intermediate = processed_outputs[num:num*2][example][0]
advanced = processed_outputs[num*2:][example][0]

In [None]:
output = {}
keep = ["regimen", "purpose"]
for b_dictionary, i_dictionary, a_dictionary in zip(basic["medication_list"], intermediate["medication_list"], advanced["medication_list"]):
  output[b_dictionary["med_name"]] = {"basic": "\n".join([value for key, value in b_dictionary.items() if key in keep]), "intermediate": "\n".join([value for key, value in i_dictionary.items() if key in keep]), "advanced": "\n".join([value for key, value in a_dictionary.items() if key in keep])}

In [None]:
for b_dictionary, i_dictionary, a_dictionary in zip(basic["diagnosis_list"], intermediate["diagnosis_list"], advanced["diagnosis_list"]):
  output[b_dictionary["original"]] = {"basic": b_dictionary["new"], "intermediate": i_dictionary["new"], "advanced": a_dictionary["new"]}

In [None]:
json.dump(output, open("finalized_output.json", "w"), indent=4)

In [None]:
list_of_actions = []
for item in basic["follow_up_instructions"]:
  list_of_actions.append(f"{item['action']} - {item['timeframe']}")

In [None]:
json.dump(list_of_actions, open("final_actions.json", "w"))