## load model

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
with open("hf.token", "r") as f:
    hftoken = f.read().strip()  

import os
cache_dir = "/mnt/c/Users/yc/.cache/huggingface"
os.environ['HF_HOME'] = cache_dir

from huggingface_hub import login
login(token=hftoken)  # Move token to environment variable

from ult import *

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
model_name = "meta-llama/Llama-3.1-8B-Instruct"
device_map = {"": 0}
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map=device_map,
    torch_dtype=torch.bfloat16
)

Loading checkpoint shards: 100%|██████████| 4/4 [01:37<00:00, 24.32s/it]


## data

In [4]:
txt_file_path = 'data/formaldef.txt'
formaldic = txt_to_dict(txt_file_path)
len(formaldic)

meddict={}
for k,v in formaldic.items():
    meddict[k.split('Listen to pronunciation')[0].split('(')[0]]=v
filename= 'data/filtered_medical_dictionary.csv'
eighth_grade_words=set()
with open(filename, 'r', encoding='utf-8') as csvfile:
    reader = csv.reader(csvfile)
    next(reader)  # Skip the header row
    for row in reader:
        if row:  # Make sure row is not empty
            eighth_grade_words.add(row[0])  # Add the word (first column)
filtered_meddict = {word: explanation for word, explanation in meddict.items() 
                   if word in eighth_grade_words}
meddict=filtered_meddict
# load data
df = pd.read_csv('data/CORAL/coral-expert-curated-medical-oncology-reports-to-advance-language-model-inference-1.0/coral/unannotated/data/breastca_unannotated.csv')
# df = df.sample(1, random_state=42)
# df=df.iloc[[44,45,53,70,83, 0,1,2,3,4]] 
df=df.iloc[[44,45,53,70,83]] 
# df=df.iloc[[70,]] 
# test_note=df.iloc[2]['note_text']
df

Unnamed: 0,coral_idx,Sex,UCSFDerivedRaceEthnicity_X,BirthDate,note_text
44,184,Female,Latinx,1983-10-04,Medical Oncology Consult Note Video Consult ...
45,185,Female,Multi-Race/Ethnicity,1973-03-08,This is a shared service. Physician Statement...
53,193,Female,Multi-Race/Ethnicity,1979-06-20,ID: ***** ***** is a 39 y.o. premenopausal pat...
70,210,Female,Southwest Asian and North African,1974-04-05,Patient Name: ***** ***** ***** *****: 08/22/...
83,223,Female,Unknown/Declined,1960-10-12,We performed this consultation using real-time...


# keypoint

In [23]:

extraction_prompts = {

"Reason_for_Visit": """
TASK: Extract 'Reason for Visit'.
Respond *only* with a JSON object using this exact schema:
{
"Patient type": "either New patient or follow up"
"second opinion": "whether the visit is consultation/second opinion or not",
"in-person":" either Televisit or in-person. (note, video consult, televisit, telehealth are the same thing",
"summary": "A brief summary of the reason for visit."
}
""",


"What_We_Found": """
TASK: Extract 'What We Found'. 
Respond *only* with a JSON object using this exact schema:
{
"Type_of_Cancer": "list the type of cancer"
"Stage_of_Cancer": "list the stage if it is written in the note",
"Distant Metastasis": "if there is distant metastasis (met)." "Yes, to where; No, local, not sure, need more evidence such as imaging.",
"Metastasis": "if there is met, Yes (to where), No, or Not sure",
"lab_summary": "Summary of key lab results.",
"findings": "Summary of new findings or disease status."
}
""",


"Treatment_Summary": """
TASK: Extract 'Treatment Summary'. 
Respond *only* with a JSON object using this exact schema:
{
"current_meds": "List of current oncologic medications or regimens.",
"recent_changes": "Any holds, dose reductions, or switches.",
"supportive_meds": "List of supportive medications."
}
""",


"Goals of care": """
TASK: Extract 'What We Discussed / Decided'.
Respond *only* with a JSON object using this exact schema:
{
"goals_of_treatment": "eg, cancer is not curable, but it's treatable and the goal is to extend the duration and maintain the quality of life",
"response_assessment": "How the cancer is responding to the treatment, not responding or responding.",
}
""",
}

plan_extraction_prompts = {

"Medication_Plan_chatgpt": """
TASK: Extract the 'Medication Plan' from the given 'Assessment/Plan' section.
Include all current and future medication plans for both cancer therapy and supportive treatment. Cancer treatment could be one or many in chemotherapy, hormonal therapy, bone therapy, radiotherapy (eg. rad onc, xrt). Supportive treatment could be one or many in bowel regimen, pain medication, psychiatry medication, neuropathy or any blood transfusion plan. 

Include whether a medication is being started now (e.g,“will start”, “Rx sent”, “starting today” ), plan or discuss in the future after certain condition (e.g., “plan to start after radiation”, “discussed addition of…”), continue or maintained(“continue”), stop or change.

Include an alterative, second-line or clinical trials options if dicussed. 
Include an 'other treatment' section for any other medications that only briefly mentioned in the plan section.
Respond *only* with a JSON object using this exact schema:
{
    "The medication/treatment, one of chemotherapy, hormonal therapy, bone therapy, radiotherapy": 
{"summary":"the summary of this type of medication, including start/stop/cotinue if applicable",
"Short term side_effects_discussed": "short term Side effects of this particular medications.",
"Long term side_effects_discussed": "long term Side effects of this particular medications.."}
}
""",


"Medication_Plan": """
TASK: Extract the 'Medication Plan' from the given 'Assessment/Plan' section.
Include future medication plans, changes to current meds (start, stop, continue), supportive meds, bowel regimen, and any blood transfusion plan. Do not include past medications.
Respond *only* with a JSON object using this exact schema:
{
    "medication_plan": "A summary of the complete medication plan."
}
""",


"Therapy plan": """
TASK: Extract the 'therapy Plan' from the given 'Assessment/Plan' section.
Include chemotherapy, radiotherapy, hormonal therapy, and bone-therapy plans. Include future plans, changes to current therapy (start, stop, continue). Do not include past therapies.
Respond *only* with a JSON object using this exact schema:
{
    "therapy_plan": "A summary of the complete therapy plan."
}
""",

"radiotherapy plan": """
TASK: Extract the 'radiotherapy Plan' from the given 'Assessment/Plan' section.
For radiotherapy, include ANY statement that indicates 
it is being considered, recommended, or may be used, even if no explicit
start/continue/plan keywords are present.
""",


"Procedure_Plan": """
TASK: Extract the 'Procedure Plan' from the given 'Assessment/Plan' section.
Include future procedures including surgery, radiation therapy, or interventional procedures such as biopsy, lumbar puncture, or Chemo Port Insertion. Do not include past procedures.
Do not include imaging plan, lab plan, or medication plan.
Respond *only* with a JSON object using this exact schema:
{
    "procedure_plan": "A summary of any planned procedures."
}
""",


"Imaging Plan": """
TASK: Extract the 'Imaging Plan' from the given 'Assessment/Plan' section.
Include all future imaging like CT, MRI, PET/CT, ultrasound,  DEXA scans, including timing and rationale if mentioned.
Do not include any procedure plan, lab plan, or medication plan. Do not include past Imaging.
Respond *only* with a JSON object using this exact schema:
{
    "imaging_plan": "A summary of any planned procedures."
}
""",


"Lab Plan": """
TASK: Extract the 'lab Plan' from the given 'Assessment/Plan' section.
Include future labs like CBC, CMP, tumor markers, coagulation profile. Specify frequency and rationale if mentioned.
Do not include any procedure plan, medication plan, or imagining plan. Do not include past labs.
Respond *only* with a JSON object using this exact schema:
{
    "lab_plan": "A summary of the future lab plan."
}
""",


"Genetic_Testing_Plan": """
TASK: Extract the 'Genetic Testing Plan' from the given 'Assessment/Plan' section.

1. Scope of section:
- First, find the 'Assessment/Plan' section of the note (usually the final section).
- Only look for plans that are clearly about FUTURE genetic or molecular TESTING.

2. What COUNTS as 'genetic or molecular testing':
- Diagnostic, prognostic, or predictive laboratory assays, such as:
  - Tumor genomic sequencing / NGS / panel testing
  - Germline genetic panels (e.g., hereditary cancer panel, BRCA testing)
  - Liquid biopsy / circulating tumor DNA (ctDNA) tests
  - Specific biomarker tests (e.g., PD-L1 testing, MSI testing, EGFR mutation testing)
  - Any plan to "send", "order", "check", or "obtain" a genetic, genomic, or molecular TEST

3. What MUST be EXCLUDED:
- DO NOT include any medications, systemic therapies, or treatment plans:
  - Chemotherapy, immunotherapy, targeted therapies (e.g., FGFR inhibitor, PARP inhibitor, CDK4/6 inhibitor, TKIs)
  - Hormonal therapy, radiation therapy, surgery
  - Clinical trial options, even if the trial involves targeted drugs or inhibitors
- DO NOT include tests that are already completed, historic, or only mentioned in past oncology history.
- DO NOT include imaging (CT, PET, MRI, X-ray, ultrasound) or routine labs (CBC, CMP).

4. If there is NO new genetic or molecular test planned in the Assessment/Plan:
- Set the value to a clear negative statement, for example:
  "No new genetic or molecular tests were planned during this visit."

5. Output format:
Respond *only* with a JSON object using this exact schema:
{
    "genetic_testing_plan": "A summary of any future planned genetic or molecular tests, or a clear statement that none are planned."
}
""",


"Referral": """
TASK: Extract 'Referral' from the given 'Assessment/Plan' section.

Respond *only* with a JSON object using this exact schema:
{
"Nutrition": "any nutration referrals such as diet optimization, appetite, weight maintenance), "Genetics": "eg, germline testing, family counseling",
"Specialty": "eg, Palliative care (symptom or pain management, goals of care), Radiation oncology, surgical oncology, Psychology, psychiatry for coping and mood support"
"Others": "Physical or occupational therapy, Social work, financial counseling}
""",


"follow up/next visit": """
TASK: Extract the 'follow up/next visit' from the given 'Assessment/Plan' section.

Respond *only* with a JSON object using this exact schema:
{
    "Next clinic visit": "(in-person or telehealth): timing and purpose"
}
""",


"Advance care planning": """
TASK: Extract the 'Advance care planning' from the given 'Assessment/Plan' section.
Include Advance directives, health-care proxy, code status (if appropriate)、
Respond *only* with a JSON object using this exact schema:
{
    "Advance care": "A summary of any planned Advance care."
}
""",

}

In [24]:
explain_prompts = {

"Why you came today": """
TASK: Briefly state the purpose of today's visit using only information from the key points. 
""",


"What you told us": """
TASK: List the patient concerns, symptoms, or questions mentioned in the NOTE section.
""",


"What did we find": """
TASK: Explain the result of today's visit (exam, blood work, scans) in plain language. You need to include all 'key points' from KEYPOINTS section.

""",


"What is the plan": """
TASK: List the next steps (treatment, monitoring, follow-up) in plain language. Include all related information from KEYPOINTS section if the following are mentioned, including:
    1, Medication plan
    2, Procedure plan
    3, Imaging plan
    4, Lab plan
    5, Genetic Testing Plan
    6, Referral
""",


"Closing with Support": """
TASK: write an ending sentence to show your support. Limit to one sentence maximum.
""",

}

In [25]:
# create save file and write prompts
with open('results.txt', 'w') as f:
    f.write("extraction_prompts\n")
    original_text = (str(extraction_prompts)
                     .replace('\\n', '\n')
                     .replace("\\'", "'")
                     .replace('\\"', '"'))
    f.write(original_text + "\n")
    f.write("=" * 60 + "\n")

    f.write("plan_extraction_prompts\n")
    original_text = (str(plan_extraction_prompts)
                     .replace('\\n', '\n')
                     .replace("\\'", "'")
                     .replace('\\"', '"'))
    f.write(original_text + "\n")
    f.write("=" * 60 + "\n")

    f.write("explain_prompts\n")
    original_text = (str(explain_prompts)
                     .replace('\\n', '\n')
                     .replace("\\'", "'")
                     .replace('\\"', '"'))
    f.write(original_text + "\n")
    f.write("=" * 60 + "\n")

with open('results.txt', 'a') as f:
    f.write("\n"*5)

In [26]:
terminators = [
    tokenizer.eos_token_id,
    tokenizer.convert_tokens_to_ids("<|eot_id|>")
]

EXPLANATION_GENERATION_CONFIG = {
    "max_new_tokens": 666,
    "temperature": 0.3,
    "top_p": 0.85,
    "repetition_penalty": 1.2,     
    "no_repeat_ngram_size": 3,     
    "do_sample": True,
    "pad_token_id": tokenizer.eos_token_id,
    "eos_token_id": terminators,
    "early_stopping": True,
}

CLEANING_CONFIG = {
    "max_new_tokens": 666,
    "temperature": 0.1,
    "top_p": 0.9,
    "repetition_penalty": 1.2,
    "do_sample": True,
    "pad_token_id": tokenizer.eos_token_id,
    "eos_token_id": terminators, 
    "early_stopping": True,
}

KEYPOINT_CONFIG = {
    "max_new_tokens": 512, 
    "eos_token_id": tokenizer.eos_token_id,
    "do_sample": False
    }

assesment_and_plan_CONFIG= {
    "max_new_tokens": 2048, 
    "eos_token_id": tokenizer.eos_token_id,
    "do_sample": False
    }

In [None]:
all_results = {}

for index, row in df.iterrows():
    print(f"\nProcessing row {index}/{len(df)}...")

    note_text = row['note_text']
    assesment_and_plan,_=run_model(
    'here is a medical note\n\n'+note_text+'\n\n now, return me all the orignal text after the words like "Assessment and Plan" or "Assessment/Plan". ignore anything before that. ingore the line breaking characters.',
    model,
    tokenizer,
    assesment_and_plan_CONFIG)
    print('got assesment_and_plan')


    # extract keypoints from extraction_prompts
    keypoints_base_prompt = (
        f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n"
        f"You are a medical data extraction expert. You will be given a long medical note. "
        f"Your task is to answer a series of questions about it, one by one. "
        f"Respond *only* with the valid JSON object requested. Do not add markdown backticks or any other text."
        f"<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n"
        f"Here is the medical note:\n\n"
        f"--- BEGIN NOTE ---\n{note_text}\n--- END NOTE ---"
        f"\n\nI will now ask you to extract specific sections. "
        f"Please wait for my first extraction task."
        f"<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n"
        f"{{\"status\": \"Understood. I have read the note and am ready.\"}}"
    )

    with torch.no_grad():
        inputs = tokenizer(keypoints_base_prompt, return_tensors="pt").to(model.device)
        outputs = model(
            input_ids=inputs["input_ids"], 
            attention_mask=inputs["attention_mask"],
            use_cache=True
        )
        base_cache = outputs.past_key_values
        del inputs, outputs
        torch.cuda.empty_cache()
        gc.collect()

    keypoints = {}
    run_model_function = run_model_with_cache_manual 
    for key, task in extraction_prompts.items():
        task_prompt = (
            f"<|start_header_id|>user<|end_header_id|>\n\n" 
            f"{task}"
            f"<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n"
        )
        answer, returned_cache = run_model_function(
            task_prompt, 
            model, 
            tokenizer, 
            KEYPOINT_CONFIG, 
            kv_cache=base_cache
        )
        del returned_cache
        torch.cuda.empty_cache()
        gc.collect()
        
        try:
            clean_answer = answer.strip().strip("```json").strip("```").strip()
            keypoints[key] = json.loads(clean_answer)
        except json.JSONDecodeError:
            keypoints[key] = answer
    print('keypoints from extraction_prompts')


    # extract keypoints from plan_extraction_prompts
    keypoints_base_prompt = (
    f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n"
    f"You are a medical data extraction expert. You will be given a long medical note. "
    f"Your task is to answer a series of questions about it, one by one. "
    f"Respond *only* with the valid JSON object requested. Do not add markdown backticks or any other text."
    f"<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n"
    f"Here is the medical note:\n\n"
    f"--- BEGIN NOTE ---\n{assesment_and_plan}\n--- END NOTE ---"
    f"\n\nI will now ask you to extract specific sections. "
    f"Please wait for my first extraction task."
    f"<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n"
    f"{{\"status\": \"Understood. I have read the note and am ready.\"}}"
    )

    with torch.no_grad():
        inputs = tokenizer(keypoints_base_prompt, return_tensors="pt").to(model.device)
        outputs = model(
            input_ids=inputs["input_ids"], 
            attention_mask=inputs["attention_mask"],
            use_cache=True
        )
        base_cache = outputs.past_key_values
        del inputs, outputs
        torch.cuda.empty_cache()
        gc.collect()

    run_model_function = run_model_with_cache_manual 
    for key, task in plan_extraction_prompts.items():
        task_prompt = (
            f"<|start_header_id|>user<|end_header_id|>\n\n" 
            f"{task}"
            f"<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n"
        )
        answer, returned_cache = run_model_function(
            task_prompt, 
            model, 
            tokenizer, 
            KEYPOINT_CONFIG, 
            kv_cache=base_cache
        )
        del returned_cache
        torch.cuda.empty_cache()
        gc.collect()
        
        try:
            clean_answer = answer.strip().strip("```json").strip("```").strip()
            keypoints[key] = json.loads(clean_answer)
        except json.JSONDecodeError:
            keypoints[key] = answer

    print('keypoints from plan_extraction_prompts')









    # # explain
    # explain_base_prompt = (
    #     f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n"
    #     f"You are an experienced and compassionate Oncologist and medical educator. Your primary role is to translate complex medical information into clear, 8th-grade level English. "
    #     f"Your task is to answer a series of questions about it, one by one. "
    #     f'''### STRICT NEGATIVE CONSTRAINTS:
    #     * Do not say anything not present in the medical note.
    #     * If cancer has spread (metastasis), DO NOT list the specific organs affected. Say "the cancer has spread to other parts of the body."
    #     * Do not use fatalistic language. The focus MUST be on quality of life.
    #     *NO ADDING SIDE EFFECTS, Unless they appear in the keypoints exactly.
    #     '''
    #     f"<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n"
    #     f"Here is the medical note:\n\n"
    #     f"--- BEGIN NOTE ---\n{note_text}\n--- END NOTE ---"
    #     f"Here is the keypoints extracted from the medical note:\n\n"
    #     f"--- BEGIN KEYPOINTS ---\n{keypoints}\n--- END KEYPOINTS ---"
    #     f"\n\nI will now ask you to extract specific sections. "
    #     f"Please wait for my first extraction task."
    #     f"<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n"
    #     f"{{\"status\": \"Understood. I have read the note and am ready.\"}}"
    # )

    # with torch.no_grad():
    #     inputs = tokenizer(explain_base_prompt, return_tensors="pt").to(model.device)
    #     outputs = model(
    #         input_ids=inputs["input_ids"], 
    #         attention_mask=inputs["attention_mask"],
    #         use_cache=True
    #     )
    #     base_cache = outputs.past_key_values
    #     del inputs, outputs
    #     torch.cuda.empty_cache()
    #     gc.collect()

    # explain = {}
    # run_model_function = run_model_with_cache_manual 
    # for key, task in explain_prompts.items():
    #     task_prompt = (
    #         f"<|start_header_id|>user<|end_header_id|>\n\n" 
    #         f"{task}"
    #         f"<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n"
    #     )
    #     answer, returned_cache = run_model_function(
    #         task_prompt, 
    #         model, 
    #         tokenizer, 
    #         KEYPOINT_CONFIG, 
    #         kv_cache=base_cache
    #     )
    #     del returned_cache
    #     torch.cuda.empty_cache()
    #     gc.collect()
        
    #     try:
    #         clean_answer = answer.strip().strip("```json").strip("```").strip()
    #         explain[key] = json.loads(clean_answer)
    #     except json.JSONDecodeError:
    #         explain[key] = answer

    # print('explain')

    

    # print('explain')
    # explanation_prompt= create_explanation_prompt(note_text, keypoints)
    # explanation,_  = run_model_function(explanation_prompt, model, tokenizer, EXPLANATION_GENERATION_CONFIG)

    # # clean up
    # print('clean')
    # cleaning_prompt = create_cleaning_prompt(explanation)
    # final_result,_  = run_model_function(cleaning_prompt, model, tokenizer, CLEANING_CONFIG)
    
    row_result = {
        'coral_idx': row['coral_idx'],
        'note_text': note_text,
        'assesment_and_plan': assesment_and_plan,
        'keypoints': keypoints,
        # 'explain': explain,
        # 'raw_result': explanation,
        # 'clean_result': final_result
    }
    
    all_results[index]=(row_result)

    # # print here
    # print('\n'*5)
    # print(f"\n{'='*60}")
    # print(f"RESULTS FOR ROW {index + 1}")
    # print(f"{'='*60}")
    
    # for col in ['assesment_and_plan','keypoints', ]:
    #     print(f"\n--- Column: {col} ---")
    #     original_text = row_result[col]
    #     try:
    #         (print_json((original_text)))
    #     except:
    #         print(original_text)


    # write to txt file
    with open('results.txt', 'a') as f:  # 'a' to append, 'w' to overwrite
        f.write('\n' * 5)
        f.write(f"\n{'='*60}\n")
        f.write(f"RESULTS FOR ROW {index + 1}\n")
        f.write(f"{'='*60}\n")
        
        for col in row_result.keys():
            f.write(f"\n--- Column: {col} ---\n")
            original_text = row_result[col]
            try:
                import json
                f.write(json.dumps(original_text, indent=2) + '\n')
            except:
                f.write(str(original_text) + '\n')
            


Processing row 44/5...
got assesment_and_plan
keypoints from extraction_prompts
keypoints from plan_extraction_prompts

Processing row 45/5...
got assesment_and_plan
keypoints from extraction_prompts
keypoints from plan_extraction_prompts

Processing row 53/5...
got assesment_and_plan
keypoints from extraction_prompts
keypoints from plan_extraction_prompts

Processing row 70/5...
got assesment_and_plan
keypoints from extraction_prompts
keypoints from plan_extraction_prompts

Processing row 83/5...
got assesment_and_plan
keypoints from extraction_prompts
keypoints from plan_extraction_prompts


In [30]:
row_result

{'coral_idx': 223,
 'note_text': 'We performed this consultation using real-time Telehealth tools, including a live video connection between my location and the patient\'s location. Prior to initiating the consultation, we obtained informed verbal consent to perform this consultation using Telehealth tools and answered all the questions about the Telehealth interaction.    Patient Name: ***** *****  ***** *****: 11/10/20    HPI: ***** ***** is a 60 y.o. female with a CHEK2 mutation, multiple sclerosis, and metastatic breast cancer with disease in bone, soft tissue, liver, and possibly meninges who is seeking recommendations for ongoing management and treatment. She is a patient of Dr. ***** at ***** in ***** *****.    Please see detailed oncologic history below in the problem list, summarized from ***** records and patient report.     ONCOLOGIC HISTORY:  Patient Active Problem List    Diagnosis Date Noted   \x07 Metastatic breast cancer (CMS code) 11/09/2020     Right breast cancer:  1

In [29]:
    with open('resultstest.txt', 'w') as f:  # 'a' to append, 'w' to overwrite
        f.write('\n' * 5)
        f.write(f"\n{'='*60}\n")
        f.write(f"RESULTS FOR ROW {index + 1}\n")
        f.write(f"{'='*60}\n")
        
        for col in row_result.keys():
            original_text = row_result[col]
            f.write(f"\n--- Column: {col} ---\n")
            original_text = row_result[col]
            try:
                original_text = (str(plan_extraction_prompts)
                     .replace('\\n', '\n')
                     .replace("\\'", "'")
                     .replace('\\"', '"'))
                f.write(original_text + "\n")
            except:
                f.write(str(original_text) + '\n')