## load model

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
with open("hf.token", "r") as f:
    hftoken = f.read().strip()  

import os
cache_dir = "/mnt/c/Users/yc/.cache/huggingface"
os.environ['HF_HOME'] = cache_dir

from huggingface_hub import login
login(token=hftoken)  # Move token to environment variable

from ult import *

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
model_name = "meta-llama/Llama-3.1-8B-Instruct"
device_map = {"": 0}
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map=device_map,
    torch_dtype=torch.bfloat16
)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

## data

In [None]:
txt_file_path = 'data/formaldef.txt'
formaldic = txt_to_dict(txt_file_path)
len(formaldic)

meddict={}
for k,v in formaldic.items():
    meddict[k.split('Listen to pronunciation')[0].split('(')[0]]=v
filename= 'data/filtered_medical_dictionary.csv'
eighth_grade_words=set()
with open(filename, 'r', encoding='utf-8') as csvfile:
    reader = csv.reader(csvfile)
    next(reader)  # Skip the header row
    for row in reader:
        if row:  # Make sure row is not empty
            eighth_grade_words.add(row[0])  # Add the word (first column)
filtered_meddict = {word: explanation for word, explanation in meddict.items() 
                   if word in eighth_grade_words}
meddict=filtered_meddict
# load data
df = pd.read_csv('data/CORAL/coral-expert-curated-medical-oncology-reports-to-advance-language-model-inference-1.0/coral/unannotated/data/breastca_unannotated.csv')
# df = df.sample(1, random_state=42)
# df=df.iloc[[44,45,53,70,83, 0,1,2,3,4]] 
# df=df.iloc[[44,45,53,70,83]] 
test_note=df.iloc[0]['note_text']
df

Unnamed: 0,coral_idx,Sex,UCSFDerivedRaceEthnicity_X,BirthDate,note_text
44,184,Female,Latinx,1983-10-04,Medical Oncology Consult Note Video Consult ...
45,185,Female,Multi-Race/Ethnicity,1973-03-08,This is a shared service. Physician Statement...
53,193,Female,Multi-Race/Ethnicity,1979-06-20,ID: ***** ***** is a 39 y.o. premenopausal pat...
70,210,Female,Southwest Asian and North African,1974-04-05,Patient Name: ***** ***** ***** *****: 08/22/...
83,223,Female,Unknown/Declined,1960-10-12,We performed this consultation using real-time...


# keypoint

In [None]:
extraction_prompts = {

"Medication_Plan": """
TASK: Extract the 'Medication Plan'. Find the 'Assessment/Plan' section of the note, usually the final section.
Include all current and future medication plans for both cancer therapy and supportive treatment. Cancer treatment could be one or many in chemotherapy, hormone therapy, bone therapy, radiotherapy (eg. rad onc, xrt). Supportive treatment could be one or many in bowel regimen, pain medication, psychiatry medication, neuropathy or any blood transfusion plan. 

Include whether a medication is being started now (e.g,“will start”, “Rx sent”, “starting today” ), plan or discuss in the future after certain condition (e.g., “plan to start after radiation”, “discussed addition of…”), continue or maintained(“continue”), stop or change.

Do NOT include procedures, labs, imaging, or genetic testing.
For your information, a completed, finished, status post medication, s/p, means past treatments.
Include an alterative, second-line or clinical trials options if dicussed.

Respond *only* with a JSON object using this exact schema:
{
    "The medication/treatment, one of chemotherapy, hormone therapy, bone therapy, radiotherapy": 
        {
        "summary":"the summary of this type of medication, including start/stop/cotinue if applicable",
        "Short term side_effects_discussed": "short term Side effects of this particular medications.",
        "Long term side_effects_discussed": "long term Side effects of this particular medications."
        },
}
""",

"Medication_Plan old": """
TASK: Extract the 'Medication Plan'. Find the 'Assessment/Plan' section of the note, usually the final section.
Include future medication plans, changes to current meds (start, stop, continue), supportive meds, bowel regimen, and any blood transfusion plan.
Respond *only* with a JSON object using this exact schema:
{
    "medication_plan": "A summary of the complete medication plan."
}
""",



}


In [None]:
res=[]

for row in df.itertuples():
    test_note=row.note_text
    gen_config = {
    "max_new_tokens": 512,  # <-- Increased from 150 to 512
    "eos_token_id": tokenizer.eos_token_id,
    "do_sample": False
    }

    # myprint("\n--- 1. Creating Base KV Cache from long text... ---")

    base_prompt = (
        f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n"
        f"You are a medical data extraction expert. You will be given a long medical note. "
        f"Your task is to answer a series of questions about it, one by one. "
        # FIX 2: Stricter instruction in system prompt
        f"Respond *only* with the valid JSON object requested. Do not add markdown backticks or any other text."
        f"<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n"
        f"Here is the medical note:\n\n"
        f"--- BEGIN NOTE ---\n{test_note}\n--- END NOTE ---"
        f"\n\nI will now ask you to extract specific sections. "
        f"Please wait for my first extraction task."
        f"<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n"
        f"{{\"status\": \"Understood. I have read the note and am ready.\"}}"
    )

    # myprint("  Tokenizing and running forward pass to get base cache...")
    with torch.no_grad():
        inputs = tokenizer(base_prompt, return_tensors="pt").to(model.device)
        
        # This logic to create the base_cache is correct
        outputs = model(
            input_ids=inputs["input_ids"], 
            attention_mask=inputs["attention_mask"],
            use_cache=True
        )
        
        base_cache = outputs.past_key_values
        
        # Clean up
        del inputs, outputs
        torch.cuda.empty_cache()
        gc.collect()

    # myprint("Base KV Cache created. Ready for 'branching' extractions.")

    # --- FIX 3: Define STRICT SCHEMAS for each extraction task ---
    # This is the most important "prompt engineering" fix.
    # We are telling the model *exactly* what keys to use.


    extracted_data = {}

    # myprint("\n--- 2. Running EFFICIENT 'Branching' Extractions ---")

    # (Assuming `run_model_with_cache_manual` is your function `run_model_with_cache`)
    # If not, please rename this call to `run_model_with_cache`
    run_model_function = run_model_with_cache_manual 

    for key, task in extraction_prompts.items():
        # myprint(f"\nExtracting: {key}...")
        
        # --- FIX 4: Remove the BOS token from the loop ---
        # The base_cache *already* contains the BOS token.
        # Adding it again can cause errors.
        task_prompt = (
            f"<|start_header_id|>user<|end_header_id|>\n\n" # <-- Removed <|begin_of_text|>
            f"{task}" # <-- Use the new, detailed prompt
            f"<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n"
        )

        # Use the SAME base_cache for each extraction
        answer, returned_cache = run_model_function(
            task_prompt, 
            model, 
            tokenizer, 
            gen_config, 
            kv_cache=base_cache
        )
        
        # Your memory management here is good.
        del returned_cache
        torch.cuda.empty_cache()
        gc.collect()
        
        # myprint(f"  Raw Output: {answer}")
        
        try:
            # We are now *only* expecting a JSON object
            clean_answer = answer.strip().strip("```json").strip("```").strip()
            extracted_data[key] = json.loads(clean_answer)
        except json.JSONDecodeError:
            extracted_data[key] = {"error": "Failed to parse JSON", "raw": answer}

    # myprint("\n--- 3. All extractions complete. ---")
    # myprint("\n--- FINAL EXTRACTED DATA ---")
    res.append(json.dumps(extracted_data, indent=2))

In [None]:
resdf=[json.loads(r) for r in res]
for i in range(1):
    test_note=df.iloc[i]['note_text']
    resdf[i]['note']=test_note
    resdf[i]['coral_idx']=df.iloc[i].coral_idx
resdf=pd.DataFrame(resdf)
resdf