In [1]:
with open("hf.token", "r") as f:
    hftoken = f.read().strip()  

import os
cache_dir = "/mnt/c/Users/yc/.cache/huggingface"
os.environ['HF_HOME'] = cache_dir

from huggingface_hub import login
login(token=hftoken)  # Move token to environment variable

from ult import *

  from .autonotebook import tqdm as notebook_tqdm


## load model

In [2]:

model_name = "mistralai/Mixtral-8x7B-Instruct-v0.1"
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",          # Use NF4 (a highly-performant 4-bit format)
    bnb_4bit_compute_dtype=torch.bfloat16, # Use bfloat16 for faster computation
    bnb_4bit_use_double_quant=False,
)
device_map = {"": 0}
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map=device_map,
    quantization_config=quantization_config,  # Pass the config here
    # We are NOT using load_in_8bit=True
)

Loading checkpoint shards: 100%|██████████| 19/19 [03:28<00:00, 10.96s/it]


In [3]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from typing import Dict, List, Optional # Added List, Optional

def run_model(
    prompt: str,
    model,
    tokenizer,
    generation_config: Optional[Dict] = None, # Make config optional
    chat_history: Optional[List[Dict]] = None # Optional chat history
) -> str:
    """
    Runs the loaded Hugging Face model to generate text based on a prompt,
    handling attention masks and pad tokens correctly.

    Args:
        prompt: The latest user input text prompt for the model.
        model: The loaded Hugging Face CausalLM model.
        tokenizer: The loaded Hugging Face tokenizer.
        generation_config: A dictionary containing parameters for model.generate()
                           (e.g., max_new_tokens, temperature, do_sample).
        chat_history: An optional list of previous chat messages in the format
                      [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]

    Returns:
        The generated text string from the assistant.
    """
    # --- Configuration Handling ---
    if generation_config is None:
        generation_config = {}

    # Set default max_new_tokens if not provided, sensible default for testing
    if "max_new_tokens" not in generation_config:
        generation_config["max_new_tokens"] = 250 # Increased default slightly

    # --- Pad Token Handling ---
    # Ensure pad_token_id is set. Using eos_token_id is common for autoregressive models.
    if tokenizer.pad_token_id is None:
        tokenizer.pad_token_id = tokenizer.eos_token_id
        # Also update the model config if possible (might not be needed for inference only)
        # model.config.pad_token_id = tokenizer.pad_token_id

    # Add pad_token_id to generation_config if not already present
    if "pad_token_id" not in generation_config:
         generation_config["pad_token_id"] = tokenizer.pad_token_id

    # --- Chat History and Input Formatting ---
    if chat_history is None:
        chat_history = []

    # Add the current user prompt to the history
    messages = chat_history + [{"role": "user", "content": prompt}]

    # Apply the chat template to format the input correctly for Mixtral Instruct
    # return_tensors="pt" ensures PyTorch tensors are returned
    # add_generation_prompt=True adds the necessary tokens to signal assistant response start
    # Explicitly handle padding and truncation if needed, though usually handled by generate
    inputs = tokenizer.apply_chat_template(
        messages,
        return_tensors="pt",
        add_generation_prompt=True
        # padding=True, # Consider adding if batching > 1, may not be needed here
        # truncation=True, # Consider if prompts can exceed model max length
        # max_length=model.config.max_position_embeddings # Or a specific length
    ).to(model.device)

    # Get the attention mask from the tokenizer output
    # `apply_chat_template` usually creates this when return_tensors='pt'
    attention_mask = inputs.attention_mask if hasattr(inputs, 'attention_mask') else torch.ones_like(inputs)


    # Get the length of the input tokens *before* generation
    input_length = inputs.shape[1]

    # --- Generation ---
    with torch.no_grad(): # Disable gradient calculations for inference
        outputs = model.generate(
            input_ids=inputs,
            attention_mask=attention_mask, # Explicitly pass the attention mask
            # pad_token_id=tokenizer.pad_token_id, # REMOVED: Pass via **generation_config
            **generation_config # Pass generation parameters including pad_token_id
        )

    # --- Decoding ---
    # Extract only the newly generated tokens (after the input)
    response_tokens = outputs[0][input_length:]

    # Decode the response tokens into a string, skipping special tokens like <eos>
    raw_output = tokenizer.decode(response_tokens, skip_special_tokens=True)

    return raw_output.strip() # Added strip() to remove leading/trailing whitespace


# --- Example Usage (requires model and tokenizer to be loaded) ---
# Assuming 'model' and 'tokenizer' are loaded from your 4-bit Mixtral script:

if __name__ == "__main__":

    # ---------------------------------------------------------------------

    test_prompt = "What is the difference between chemotherapy and immunotherapy for cancer treatment?"

    # Define generation parameters
    gen_config = {
        "max_new_tokens": 250,
        "temperature": 0.6, # Slightly lower temp for more focused answer
        "do_sample": True,
        "top_p": 0.9,
        "eos_token_id": tokenizer.eos_token_id,
        "pad_token_id": tokenizer.eos_token_id # Ensure consistency
    }

    # --- Optional Chat History Example ---
    history = [
        # {"role": "user", "content": "What is cancer?"},
        # {"role": "assistant", "content": "Cancer is a disease where cells grow uncontrollably..."}
    ]
    # ------------------------------------

    print(f"\nRunning model with prompt: '{test_prompt}'")
    import time
    start_time = time.time()

    try:
        response = run_model(test_prompt, model, tokenizer, gen_config, chat_history=history)
        end_time = time.time()
        duration = end_time - start_time
        num_tokens = len(tokenizer.encode(response)) # Rough token count

        print("\n--- Model Response ---")
        print(response)
        print("--- End Response ---")
        print(f"\nGeneration took {duration:.2f} seconds.")
        if duration > 0:
            print(f"Approximate speed: {num_tokens / duration:.2f} tokens/second.")

    except Exception as e:
        print(f"An error occurred during generation: {e}")

        import traceback
        traceback.print_exc()



Running model with prompt: 'What is the difference between chemotherapy and immunotherapy for cancer treatment?'

--- Model Response ---
Chemotherapy and immunotherapy are both common types of cancer treatments, but they work in fundamentally different ways:

1. Chemotherapy: This is a type of cancer treatment that uses drugs to kill cancer cells or stop them from dividing. Chemotherapy can be administered intravenously, orally, or topically. It is often used to treat advanced-stage cancers, reduce the size of tumors before surgery, or kill any remaining cancer cells after other treatments. However, chemotherapy can also damage healthy cells, leading to side effects such as hair loss, nausea, and fatigue.

2. Immunotherapy: This is a relatively newer form of cancer treatment that uses the body's own immune system to fight cancer cells. Immunotherapy works by boosting, directing, or restoring the natural defenses of the immune system. There are several types of immunotherapy, including

## data

In [4]:

def txt_to_dict(file_path):
    data_dict = {}
    with open(file_path, 'r') as file:
        lines = file.readlines()
        for i in range(0, len(lines) - 1, 2):
            key = lines[i].strip()    # Odd line are key
            value = lines[i + 1].strip()  # Even line are value
            data_dict[key] = value

    return data_dict

txt_file_path = 'data/formaldef.txt'
formaldic = txt_to_dict(txt_file_path)
len(formaldic)

meddict={}
for k,v in formaldic.items():
    meddict[k.split('Listen to pronunciation')[0].split('(')[0]]=v
filename= 'data/filtered_medical_dictionary.csv'
eighth_grade_words=set()
with open(filename, 'r', encoding='utf-8') as csvfile:
    reader = csv.reader(csvfile)
    next(reader)  # Skip the header row
    for row in reader:
        if row:  # Make sure row is not empty
            eighth_grade_words.add(row[0])  # Add the word (first column)
filtered_meddict = {word: explanation for word, explanation in meddict.items() 
                   if word in eighth_grade_words}
meddict=filtered_meddict
# load data
df = pd.read_csv('data/CORAL/coral-expert-curated-medical-oncology-reports-to-advance-language-model-inference-1.0/coral/unannotated/data/breastca_unannotated.csv')
df = df.sample(1, random_state=42) 
# text=df.iloc[0]['note_text']
len(df)

1

# prompt

In [5]:

        
def clean_model_output(text: str, fix_incomplete=True) -> str:
    """
    Clean up model-generated text with common fixes, not currently using
    """
    if not text:
        return ""
    
    # Basic cleanup
    text = text.strip()
    text = re.sub(r'\s+', ' ', text)  # Normalize all whitespace
    
    # Fix paragraph spacing
    text = re.sub(r'\n\s*\n\s*\n+', '\n\n', text)
    
    # Fix punctuation spacing
    text = re.sub(r'\s+([,.!?;:])', r'\1', text)
    text = re.sub(r'([.!?])\s*([A-Z])', r'\1 \2', text)
    
    # Remove repetitive patterns (simple version)
    lines = text.split('\n')
    cleaned_lines = []
    for line in lines:
        if not cleaned_lines or line.strip() != cleaned_lines[-1].strip():
            cleaned_lines.append(line)
    text = '\n'.join(cleaned_lines)
    
    # Handle incomplete sentences
    if fix_incomplete and text and not text.endswith(('.', '!', '?')):
        sentences = re.split(r'[.!?]+', text)
        if len(sentences) > 1 and len(sentences[-1].strip()) < 10:
            # Remove likely incomplete last sentence
            text = '.'.join(sentences[:-1]) + '.'
    
    return text.strip()

def extract_medical_terms(text: str, meddict: Dict[str, str]) -> Dict[str, str]:

    found_terms = {}
    
    # Strategy 1: Single words
    words = re.findall(r'\b[A-Za-z]+(?:[-\'][A-Za-z]+)*\b', text)
    for word in words:
        definition = find_term_in_dict(word, meddict)
        if definition:
            found_terms[word] = definition
    
    # Strategy 2: Multi-word terms
    for n in range(2, 6):
        n_grams = get_n_grams(text, n)
        for phrase in n_grams:
            definition = find_term_in_dict(phrase, meddict)
            if definition:
                found_terms[phrase] = definition
    
    # Strategy 3: Medical abbreviations
    abbreviations = re.findall(r'\b[A-Z]{2,8}\b', text)
    for abbrev in abbreviations:
        definition = find_term_in_dict(abbrev, meddict)
        if definition:
            found_terms[abbrev] = definition
    
    # Strategy 4: Medical procedures and conditions with specific patterns
    medical_patterns = [
        r'\b\w+oscopy\b',          # bronchoscopy, endoscopy, etc.
        r'\b\w+ectomy\b',          # appendectomy, etc.
        r'\b\w+itis\b',            # bronchitis, arthritis, etc.
        r'\b\w+osis\b',            # fibrosis, stenosis, etc.
        r'\b\w+emia\b',            # anemia, septicemia, etc.
        r'\b\w+pathy\b',           # myopathy, neuropathy, etc.
        r'\b\w+malacia\b',         # tracheomalacia, etc.
    ]
    
    for pattern in medical_patterns:
        matches = re.findall(pattern, text, re.IGNORECASE)
        for match in matches:
            definition = find_term_in_dict(match, meddict)
            if definition:
                found_terms[match] = definition
    
    # Strategy 5: Medication names
    medication_patterns = [
        r'\b\w+cillin\b',          # penicillin, amoxicillin, etc.
        r'\b\w+mycin\b',           # streptomycin, etc.
        r'\b\w+floxacin\b',        # levofloxacin, ciprofloxacin, etc.
        r'\b\w+sone\b',            # prednisone, cortisone, etc.
        r'\b\w+pam\b',             # lorazepam, etc.
    ]
    
    for pattern in medication_patterns:
        matches = re.findall(pattern, text, re.IGNORECASE)
        for match in matches:
            definition = find_term_in_dict(match, meddict)
            if definition:
                found_terms[match] = definition
    
    return found_terms


def get_n_grams(text: str, n: int) -> List[str]:
    """Generate n-grams from text."""
    words = re.findall(r'\b[A-Za-z]+\b', text.lower())
    n_grams = []
    for i in range(len(words) - n + 1):
        phrase = ' '.join(words[i:i+n])
        n_grams.append(phrase)
    return n_grams


def find_term_in_dict(term: str, meddict: Dict[str, str]) -> Optional[str]:
    """Find term in medical dictionary."""
    search_formats = [
        term, term.lower(), term.upper(), term.title(), term.capitalize()
    ]
    
    for search_term in search_formats:
        if search_term in meddict:
            return meddict[search_term]
    
    # Partial matching
    for key in meddict.keys():
        if key.lower() == term.lower():
            return meddict[key]
    
    return None

def create_annotated_text(text: str, meddict: Dict[str, str]) -> str:

    # 1. Use the existing function to find all unique terms and their definitions.
    found_terms = extract_medical_terms(text, meddict)
    
    # 2. Sort terms by length in descending order to handle overlaps.
    # This is critical for terms like "cerebral palsy" and "palsy".
    sorted_terms = sorted(found_terms.keys(), key=len, reverse=True)
    
    annotated_text = text
    
    # 3. Iterate and replace.
    for term in sorted_terms:
        definition = found_terms[term]
        annotation = f"{term} [DEFINITION: {definition}]"
        pattern = r'\b' + re.escape(term) + r'\b'
        annotated_text = re.sub(pattern, annotation, annotated_text, count=1, flags=re.IGNORECASE)
        
    return annotated_text


def create_key_summary_prompt(original_text: str) -> str:
    messages = [
        {"role": "system", "content": """You are an expert medical information extractor with exceptional attention to detail. Your task is to carefully read the medical text and extract ONLY the key factual details that are explicitly mentioned. You must be extremely precise and never infer, assume, or add any information not directly stated in the text.
CRITICAL RULES:
1. Extract ONLY facts explicitly stated in the text
2. Use the EXACT wording from the original text when possible
3. Do not interpret, infer, or elaborate beyond what is written
4. If multiple items exist in a category, separate them with semicolons
5. Keep each section concise but complete
6. If a category is not mentioned or unclear, write "Not mentioned" for that section."""},
        {"role": "user", "content": f"""**Medical Text:**\n"{original_text}"\n\n**Format your response EXACTLY as follows:**\n\n**TREATMENTS RECEIVED:**\n[List treatments]\n\n**MEDICAL CONDITIONS:**\n[List conditions]\n\n**CANCER STAGE:**\n[List stage]\n\n**REFERRALS:**\n[List referrals]\n\n**CURRENT STATUS:**\n[List status]\n\n**NEXT STEPS/PLAN:**\n[List plan]"""}
    ]
    return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)


def create_explanation_prompt(annotated_text: str, keysummary: str) -> str:

    audience_instruction = "The determined audience for this letter is the **patient**. You must address them directly as 'you' throughout the entire letter."
    
    system_prompt = f"""### Persona
You are an experienced and compassionate Oncologist and medical educator. Your primary role is to translate complex medical information into clear, 8th-grade level English. Your tone should be professional, empathetic, and honest.

### Golden Rule: Radical Simplicity - Translate, Don't Transfer
Your single most important task is to convert medical terminology into simple explanations. Do not just define a term; replace it with an easy-to-understand explanation.

### Your Task:
Write a complete letter explaining the information from the medical text. Imagine you are sitting with the recipient and explaining this to them in person.

{audience_instruction}

### Letter Structure and Flow:
Organize your letter logically using the following question-based headers:
* **Empathetic Opening:** Start with a warm and supportive salutation.
* **Why you came today:** Briefly explain the purpose of the visit in plain words.
* **What you told us:** Summarize what the patient and family shared (symptoms, worries, questions).
* **What did we find** Explain what the team learned from today's visit (exam, blood work, scans) in simple terms.
* **What is the plan** List the next steps (treatment, monitoring, follow-up) in plain language.
* **Closing with Support:** End by reinforcing your support. Limit to one sentence maximum.

### STRICT NEGATIVE CONSTRAINTS:
* Do not say anything not present in the medical note.
* If cancer has spread (metastasis), DO NOT list the specific organs affected. Say "the cancer has spread to other parts of the body."
* Do not use fatalistic language. The focus MUST be on quality of life.

### Strict Output Formatting:
Provide ONLY the letter. Your output must start directly with "Dear [Patient Name]," and end with "Sincerely,\nYour Care Team at [Institution]". Use the question-based headers exactly as specified."""
    
    user_prompt = f"""### Original Medical Text:
"{annotated_text}"

### Internal Fact-Checking Reference (Translate these facts, do not copy them):
"{keysummary}"

Please provide the patient-focused explanation now. Your output must start directly with the salutation."""

    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": user_prompt}
    ]
    return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)



def create_cleaning_prompt(raw_response: str) -> str:
    """
    Creates the correctly formatted prompt for Llama 3 to clean/revise a response.
    """
    system_prompt = f"""### Persona
You are an expert medical writer and editor. Your unique skill is communicating complex clinical information with absolute precision and clarity. Your primary directive is to preserve the original meaning without fail.

### Target Audience: patient

### Your Task
Revise the provided medical text to improve its quality in the following areas:
1.  **Clean Up Language:** Improve sentence structure and use professional language.
2.  **Reduce Repetition:** Eliminate redundant words and phrases.
3.  **Improve Flow:** Enhance the logical flow and transitions.
4.  **Define Medical Terms** For any complex medical term, replace it with a simple explanation. Example: "you had tachycardia."--> "you had a fast heart rate".
5.  **Remove formatting subtitles:** remove any section headers or subtitles from the text to make it read as a continuous letter.
### The Golden Rule: Preserve Clinical Meaning at All Costs
The revised text MUST be factually identical to the original.
* **DO NOT** alter any clinical facts, diagnoses, measurements, or timelines.
* **DO NOT** change the certainty of a statement.

### Instructions for Output
Provide your Revised Medical Text: The complete, revised version of the text."""
    user_prompt = f"""### Medical Text to Revise:
{raw_response}"""

    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": user_prompt}
    ]
    
    return tokenizer.apply_chat_template(
        messages, 
        tokenize=False, 
        add_generation_prompt=True
    )


terminators = [
    tokenizer.eos_token_id,
    tokenizer.convert_tokens_to_ids("<|eot_id|>")
]


KEY_SUMMARY_CONFIG = {
    "max_new_tokens": 500,
    "do_sample": False, 
    "repetition_penalty": 1.1,
    "pad_token_id": tokenizer.eos_token_id,
    "eos_token_id": terminators, 
}

EXPLANATION_GENERATION_CONFIG = {
    "max_new_tokens": 1024,
    "temperature": 0.3,
    "top_p": 0.85,
    "repetition_penalty": 1.15,
    "do_sample": True,
    "pad_token_id": tokenizer.eos_token_id,
    "eos_token_id": terminators, 
    "early_stopping": True,
}

CLEANING_CONFIG = {
    "max_new_tokens": 1500,
    "temperature": 0.1,
    "top_p": 0.9,
    "repetition_penalty": 1.1,
    "do_sample": True,
    "pad_token_id": tokenizer.eos_token_id,
    "eos_token_id": terminators, 
    "early_stopping": True,
}



def run_model(prompt: str, model, tokenizer, generation_config: Dict) -> str:

        inputs = tokenizer(prompt, return_tensors="pt", truncation=True).to(model.device)
        input_length = inputs["input_ids"].shape[1]
        
        with torch.no_grad():
            outputs = model.generate(
                inputs["input_ids"],
                attention_mask=inputs["attention_mask"],
                **generation_config
            )
        response_tokens = outputs[0][input_length:]
        raw_output = tokenizer.decode(response_tokens, skip_special_tokens=True)
        return raw_output
        


all_results = {}
LINE_WIDTH = 100 

for index, row in df.iterrows():
    text = row['note_text'] 
    
    print(f"\nProcessing row {index + 1}/{len(df)}...")
    
    # --- Step 1: Determine Audience ---
    audience = 'patient' 
    
    # --- Step 2: Annotate Text (Non-LLM step) ---
    annotated_text = create_annotated_text(text, meddict)
    
    # --- Step 3: Extract Key Summary ---
    summary_prompt = create_key_summary_prompt(text)
    keysummary = run_model(summary_prompt, model, tokenizer, KEY_SUMMARY_CONFIG)
    
    # --- Step 4: Generate Initial Explanation ---
    explanation_prompt = create_explanation_prompt(annotated_text, keysummary)
    explanation = run_model(explanation_prompt, model, tokenizer, EXPLANATION_GENERATION_CONFIG)

    # --- Step 5: Clean and Finalize the Explanation ---
    cleaning_prompt = create_cleaning_prompt(explanation)
    final_result = run_model(cleaning_prompt, model, tokenizer, CLEANING_CONFIG)
    
    # --- Store and Print Results ---
    row_result = {
        'original_text': text,
        # 'annotated_text': annotated_text,
        'keysummary': keysummary,
        'raw_explanation': explanation,
        'final_letter': final_result
    }
    
    all_results[index]=(row_result)
    print(f"\n{'='*60}")
    print(f"RESULTS FOR ROW {index + 1}")
    print(f"{'='*60}")
    
    for col in ['keysummary', 'raw_explanation', 'final_letter']:
        print(f"\n--- Column: {col} ---")
        original_text = row_result[col]
        wrapped_text = textwrap.fill(original_text, width=LINE_WIDTH)
        print(wrapped_text)


Processing row 84/1...


Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


KeyboardInterrupt: 

In [13]:
results_df = pd.DataFrame(all_results).T
# results_df.drop(columns=['annotated_text'], inplace=True)
output_format = 'csv'

# --- File Saving Logic ---
base_filename = 'output/output'
extension = ''

if output_format == 'csv':
    extension = '.csv'
elif output_format == 'json':
    extension = '.json'
else:
    print(f"Error: Unsupported format '{output_format}'. Please choose 'csv' or 'json'.")
    exit() # Exits the script if the format is invalid


output_filename = base_filename + extension
counter = 1

output_dir = os.path.dirname(base_filename)
if not os.path.exists(output_dir):
    os.makedirs(output_dir)

while os.path.exists(output_filename):
    output_filename = f"{base_filename}_{counter}{extension}"
    counter += 1

if output_format == 'csv':
    results_df.to_csv(output_filename, index=True)
elif output_format == 'json':
    results_df.to_json(output_filename, orient='index', indent=2)

print(f"DataFrame saved to '{output_filename}'")

DataFrame saved to 'output/output_8.csv'
